update readme's error

This commit is contained in:
gongjy 2024-08-28 20:23:37 +08:00
parent fa6dd96e93
commit c555741029

View File

@ -18,11 +18,11 @@ def count_parameters(model):
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model',
trust_remote_code=True, use_fast=False)
model_from = 2 # 1从权重2用transformers
model_from = 1 # 1从权重2用transformers
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/multi_chat/full_sft_{lm_config.dim}{moe_path}.pth'
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
@ -40,9 +40,9 @@ def init_model(lm_config):
# 加载到模型中
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModelForCausalLM.from_pretrained("minimind-small", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("minimind", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('minimind-small',
tokenizer = AutoTokenizer.from_pretrained('minimind',
trust_remote_code=True, use_fast=False)
model = model.to(device)