update others
This commit is contained in:
parent
48ea6a4cbf
commit
c81c17dab7
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
epochs = 20
|
||||
batch_size = 64
|
||||
learning_rate = 2e-4
|
||||
device = 'cuda:0'
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
dtype = 'bfloat16'
|
||||
save_dir = os.path.join(out_dir)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
@ -110,13 +110,13 @@ def init_model(lm_config):
|
||||
if model_from == 1:
|
||||
model = Transformer(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
# ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
# state_dict = torch.load(ckp, map_location=device)
|
||||
# unwanted_prefix = '_orig_mod.'
|
||||
# for k, v in list(state_dict.items()):
|
||||
# if k.startswith(unwanted_prefix):
|
||||
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
||||
|
||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
gradient_accumulation_steps = 1
|
||||
batch_size = 40
|
||||
learning_rate = 1e-4
|
||||
device = 'cuda:0'
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
dtype = 'bfloat16'
|
||||
# dtype = 'float16'
|
||||
save_dir = os.path.join(out_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user