update others

This commit is contained in:
gongjy 2024-09-19 09:35:02 +08:00
parent 48ea6a4cbf
commit c81c17dab7
2 changed files with 9 additions and 9 deletions

@ -131,7 +131,7 @@ if __name__ == "__main__":
epochs = 20
batch_size = 64
learning_rate = 2e-4
device = 'cuda:0'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
save_dir = os.path.join(out_dir)
os.makedirs(save_dir, exist_ok=True)

@ -110,13 +110,13 @@ def init_model(lm_config):
if model_from == 1:
model = Transformer(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
# ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
# state_dict = torch.load(ckp, map_location=device)
# unwanted_prefix = '_orig_mod.'
# for k, v in list(state_dict.items()):
# if k.startswith(unwanted_prefix):
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
# model.load_state_dict(state_dict, strict=False)
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=device)
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
@ -148,7 +148,7 @@ if __name__ == "__main__":
gradient_accumulation_steps = 1
batch_size = 40
learning_rate = 1e-4
device = 'cuda:0'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
# dtype = 'float16'
save_dir = os.path.join(out_dir)