This commit is contained in:
gongjy 2025-02-13 21:07:43 +08:00
parent b5d10d9a7d
commit e7ed05834b

View File

@ -123,7 +123,7 @@ def main():
parser.add_argument('--history_cnt', default=0, type=int)
parser.add_argument('--stream', default=True, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载")
parser.add_argument('--model_mode', default=0, type=int,
parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型")
args = parser.parse_args()
@ -133,6 +133,8 @@ def main():
test_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
messages = []
for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')):
setup_seed(random.randint(0, 2048))
# setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
if test_mode == 0: print(f'👶: {prompt}')
messages = messages[-args.history_cnt:] if args.history_cnt else []
@ -177,6 +179,4 @@ def main():
if __name__ == "__main__":
torch.backends.cudnn.deterministic = True
random.seed(random.randint(0, 2048))
main()