fix bug
This commit is contained in:
parent
b5d10d9a7d
commit
e7ed05834b
@ -123,7 +123,7 @@ def main():
|
||||
parser.add_argument('--history_cnt', default=0, type=int)
|
||||
parser.add_argument('--stream', default=True, type=bool)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||
parser.add_argument('--model_mode', default=0, type=int,
|
||||
parser.add_argument('--model_mode', default=1, type=int,
|
||||
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -133,6 +133,8 @@ def main():
|
||||
test_mode = int(input('[0] 自动测试\n[1] 手动输入\n'))
|
||||
messages = []
|
||||
for idx, prompt in enumerate(prompts if test_mode == 0 else iter(lambda: input('👶: '), '')):
|
||||
setup_seed(random.randint(0, 2048))
|
||||
# setup_seed(2025) # 如需固定每次输出则换成【固定】的随机种子
|
||||
if test_mode == 0: print(f'👶: {prompt}')
|
||||
|
||||
messages = messages[-args.history_cnt:] if args.history_cnt else []
|
||||
@ -177,6 +179,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.backends.cudnn.deterministic = True
|
||||
random.seed(random.randint(0, 2048))
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user