update eval-chat
This commit is contained in:
parent
6d7a988365
commit
0bce9e5a31
14
2-eval.py
14
2-eval.py
@ -39,7 +39,7 @@ def init_model(lm_config):
|
||||
# 加载到模型中
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
|
||||
model = model.to(device)
|
||||
|
||||
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
|
||||
@ -55,13 +55,13 @@ def setup_seed(seed):
|
||||
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
|
||||
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# -----------------------------------------------------------------------------
|
||||
out_dir = 'out'
|
||||
start = ""
|
||||
temperature = 0.5
|
||||
temperature = 0.7
|
||||
top_k = 16
|
||||
setup_seed(1337)
|
||||
# device = 'cpu'
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
dtype = 'bfloat16'
|
||||
@ -84,8 +84,8 @@ if __name__ == "__main__":
|
||||
stream = True
|
||||
|
||||
prompt_datas = [
|
||||
'你叫什么名字啊?',
|
||||
'你叫什么名字?',
|
||||
'你叫什么名字',
|
||||
'你是谁',
|
||||
'中国有哪些比较好的大学?',
|
||||
'全世界最好的大学是什么?',
|
||||
'你知道光速是多少吗?',
|
||||
@ -115,6 +115,9 @@ if __name__ == "__main__":
|
||||
|
||||
i = 0
|
||||
while i < len(prompt_datas):
|
||||
# Generate a random seed
|
||||
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||
setup_seed(random_seed)
|
||||
if not contain_history_chat:
|
||||
messages = messages_origin.copy()
|
||||
|
||||
@ -125,6 +128,7 @@ if __name__ == "__main__":
|
||||
print(f'[Q]: {prompt}')
|
||||
i += 1
|
||||
|
||||
prompt = '请问,' + prompt
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# print(messages)
|
||||
|
Loading…
x
Reference in New Issue
Block a user