diff --git a/2-eval.py b/2-eval.py index 4ce7759..c1cc938 100644 --- a/2-eval.py +++ b/2-eval.py @@ -18,11 +18,11 @@ def count_parameters(model): def init_model(lm_config): tokenizer = AutoTokenizer.from_pretrained('./model', trust_remote_code=True, use_fast=False) - model_from = 2 # 1从权重,2用transformers + model_from = 1 # 1从权重,2用transformers if model_from == 1: moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'./out/multi_chat/full_sft_{lm_config.dim}{moe_path}.pth' + ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth' model = Transformer(lm_config) state_dict = torch.load(ckp, map_location=device) @@ -40,9 +40,9 @@ def init_model(lm_config): # 加载到模型中 model.load_state_dict(state_dict, strict=False) else: - model = AutoModelForCausalLM.from_pretrained("minimind-small", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained("minimind", trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained('minimind-small', + tokenizer = AutoTokenizer.from_pretrained('minimind', trust_remote_code=True, use_fast=False) model = model.to(device)