update eval-chat

This commit is contained in:
gongjy 2024-10-28 12:57:25 +08:00
parent 6d7a988365
commit 0bce9e5a31

View File

@ -39,7 +39,7 @@ def init_model(lm_config):
# 加载到模型中 # 加载到模型中
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
else: else:
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
model = model.to(device) model = model.to(device)
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)') print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
@ -55,13 +55,13 @@ def setup_seed(seed):
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的 torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性 torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
if __name__ == "__main__": if __name__ == "__main__":
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
out_dir = 'out' out_dir = 'out'
start = "" start = ""
temperature = 0.5 temperature = 0.7
top_k = 16 top_k = 16
setup_seed(1337)
# device = 'cpu' # device = 'cpu'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' dtype = 'bfloat16'
@ -84,8 +84,8 @@ if __name__ == "__main__":
stream = True stream = True
prompt_datas = [ prompt_datas = [
'你叫什么名字啊?', '你叫什么名字',
'叫什么名字?', '是谁',
'中国有哪些比较好的大学?', '中国有哪些比较好的大学?',
'全世界最好的大学是什么?', '全世界最好的大学是什么?',
'你知道光速是多少吗?', '你知道光速是多少吗?',
@ -115,6 +115,9 @@ if __name__ == "__main__":
i = 0 i = 0
while i < len(prompt_datas): while i < len(prompt_datas):
# Generate a random seed
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
if not contain_history_chat: if not contain_history_chat:
messages = messages_origin.copy() messages = messages_origin.copy()
@ -125,6 +128,7 @@ if __name__ == "__main__":
print(f'[Q]: {prompt}') print(f'[Q]: {prompt}')
i += 1 i += 1
prompt = '请问,' + prompt
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
# print(messages) # print(messages)