update eval

This commit is contained in:
jingyaogong 2025-04-09 16:56:57 +08:00
parent 4a758564e4
commit d503093ec4

View File

@ -16,7 +16,7 @@ def init_model(args):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
if args.load == 0: if args.load == 0:
moe_path = '_moe' if args.use_moe else '' moe_path = '_moe' if args.use_moe else ''
modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason'} modes = {0: 'pretrain', 1: 'full_sft', 2: 'rlhf', 3: 'reason', 4: 'grpo'}
ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth' ckp = f'./{args.out_dir}/{modes[args.model_mode]}_{args.dim}{moe_path}.pth'
model = MiniMindLM(LMConfig( model = MiniMindLM(LMConfig(
@ -123,7 +123,7 @@ def main():
parser.add_argument('--stream', default=True, type=bool) parser.add_argument('--stream', default=True, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载") parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载")
parser.add_argument('--model_mode', default=1, type=int, parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型") help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型4: RLAIF-Chat模型")
args = parser.parse_args() args = parser.parse_args()
model, tokenizer = init_model(args) model, tokenizer = init_model(args)
@ -143,7 +143,7 @@ def main():
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True add_generation_prompt=True
)[-args.max_seq_len + 1:] if args.model_mode != 0 else (tokenizer.bos_token + prompt) )[-args.max_seq_len - 1:] if args.model_mode != 0 else (tokenizer.bos_token + prompt)
answer = new_prompt answer = new_prompt
with torch.no_grad(): with torch.no_grad():