diff --git a/eval_model.py b/eval_model.py index 2732140..a3d0eaf 100644 --- a/eval_model.py +++ b/eval_model.py @@ -123,7 +123,7 @@ def main(): parser.add_argument('--history_cnt', default=0, type=int) parser.add_argument('--stream', default=True, type=bool) parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载") - parser.add_argument('--model_mode', default=0, type=int, + parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型") args = parser.parse_args() @@ -133,6 +133,8 @@ def main(): test_mode = int(input('[0] 自动测试\n[1] 手动输入\n')) messages = [] for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')): + setup_seed(random.randint(0, 2048)) + # setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子 if test_mode == 0: print(f'👶: {prompt}') messages = messages[-args.history_cnt:] if args.history_cnt else [] @@ -177,6 +179,4 @@ def main(): if __name__ == "__main__": - torch.backends.cudnn.deterministic = True - random.seed(random.randint(0, 2048)) main()