diff --git a/train_dpo.py b/train_dpo.py index 2c857ff..dc9054a 100644 --- a/train_dpo.py +++ b/train_dpo.py @@ -134,7 +134,7 @@ def init_model(lm_config): tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') model = MiniMindLM(lm_config) moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'./out/full_dist_{lm_config.dim}{moe_path}.pth' + ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth' state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) # 初始化参考模型