update eval-chat
This commit is contained in:
parent
6d7a988365
commit
0bce9e5a31
14
2-eval.py
14
2-eval.py
@ -39,7 +39,7 @@ def init_model(lm_config):
|
|||||||
# 加载到模型中
|
# 加载到模型中
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
|
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
|
||||||
@ -55,13 +55,13 @@ def setup_seed(seed):
|
|||||||
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
|
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
|
||||||
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
|
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
out_dir = 'out'
|
out_dir = 'out'
|
||||||
start = ""
|
start = ""
|
||||||
temperature = 0.5
|
temperature = 0.7
|
||||||
top_k = 16
|
top_k = 16
|
||||||
setup_seed(1337)
|
|
||||||
# device = 'cpu'
|
# device = 'cpu'
|
||||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
dtype = 'bfloat16'
|
dtype = 'bfloat16'
|
||||||
@ -84,8 +84,8 @@ if __name__ == "__main__":
|
|||||||
stream = True
|
stream = True
|
||||||
|
|
||||||
prompt_datas = [
|
prompt_datas = [
|
||||||
'你叫什么名字啊?',
|
'你叫什么名字',
|
||||||
'你叫什么名字?',
|
'你是谁',
|
||||||
'中国有哪些比较好的大学?',
|
'中国有哪些比较好的大学?',
|
||||||
'全世界最好的大学是什么?',
|
'全世界最好的大学是什么?',
|
||||||
'你知道光速是多少吗?',
|
'你知道光速是多少吗?',
|
||||||
@ -115,6 +115,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(prompt_datas):
|
while i < len(prompt_datas):
|
||||||
|
# Generate a random seed
|
||||||
|
random_seed = random.randint(0, 2 ** 32 - 1)
|
||||||
|
setup_seed(random_seed)
|
||||||
if not contain_history_chat:
|
if not contain_history_chat:
|
||||||
messages = messages_origin.copy()
|
messages = messages_origin.copy()
|
||||||
|
|
||||||
@ -125,6 +128,7 @@ if __name__ == "__main__":
|
|||||||
print(f'[Q]: {prompt}')
|
print(f'[Q]: {prompt}')
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
prompt = '请问,' + prompt
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
# print(messages)
|
# print(messages)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user