From c81c17dab791eafac6edc8095025e3309d41483a Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Thu, 19 Sep 2024 09:35:02 +0800 Subject: [PATCH] update others --- 1-pretrain.py | 2 +- 3-full_sft.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/1-pretrain.py b/1-pretrain.py index ce427aa..fabb445 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -131,7 +131,7 @@ if __name__ == "__main__": epochs = 20 batch_size = 64 learning_rate = 2e-4 - device = 'cuda:0' + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' save_dir = os.path.join(out_dir) os.makedirs(save_dir, exist_ok=True) diff --git a/3-full_sft.py b/3-full_sft.py index ce39557..3c94597 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -110,13 +110,13 @@ def init_model(lm_config): if model_from == 1: model = Transformer(lm_config) moe_path = '_moe' if lm_config.use_moe else '' - # ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth' - # state_dict = torch.load(ckp, map_location=device) - # unwanted_prefix = '_orig_mod.' - # for k, v in list(state_dict.items()): - # if k.startswith(unwanted_prefix): - # state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) - # model.load_state_dict(state_dict, strict=False) + ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth' + state_dict = torch.load(ckp, map_location=device) + unwanted_prefix = '_orig_mod.' + for k, v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) else: model = AutoModel.from_pretrained('./minimind', trust_remote_code=True) @@ -148,7 +148,7 @@ if __name__ == "__main__": gradient_accumulation_steps = 1 batch_size = 40 learning_rate = 1e-4 - device = 'cuda:0' + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' dtype = 'bfloat16' # dtype = 'float16' save_dir = os.path.join(out_dir)