Minimind/0-eval_pretrain.py
2024-08-28 16:41:44 +08:00

181 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
import time
import numpy as np
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig
warnings.filterwarnings('ignore')
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model',
trust_remote_code=True, use_fast=False)
model_from = 1 # 1从权重2用transformers
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
# 处理不需要的前缀
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
for k, v in list(state_dict.items()):
if 'mask' in k:
del state_dict[k]
# 加载到模型中
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
model = model.to(device)
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
return model, tokenizer
def setup_seed(seed):
random.seed(seed) # 设置 Python 的随机种子
np.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
torch.cuda.manual_seed(seed) # 为当前 GPU 设置随机种子(如果有)
torch.cuda.manual_seed_all(seed) # 为所有 GPU 设置随机种子(如果有)
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
if __name__ == "__main__":
# -----------------------------------------------------------------------------
out_dir = 'out'
start = ""
temperature = 0.7
top_k = 8
setup_seed(1337)
# device = 'cpu'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
max_seq_len = 512
lm_config = LMConfig()
lm_config.max_seq_len = max_seq_len
# 对话是否携带历史对话(当前模型太弱,增大历史上下文,基本导致胡言乱语)
contain_history_chat = False
# -----------------------------------------------------------------------------
model, tokenizer = init_model(lm_config)
model = model.eval()
# 推送到huggingface
# model.push_to_hub("minimind")
# tokenizer.push_to_hub("minimind")
# answer_way = int(input('输入0自动测试输入1问题测试'))
answer_way = 0
stream = True
prompt_datas = [
'椭圆和圆的区别',
'中国关于马克思主义基本原理',
'人类大脑的主要功能是',
'万有引力是',
'世界上人口最多的国家是',
'DNA的全称是',
'数学中π的值大约是',
'世界上最高的山峰是',
'太阳系中最大的行星是',
'二氧化碳的化学分子式是',
'地球上最大的动物是',
'地球自转一圈大约需要',
'杭州市的美食有',
'江苏省的最好的大学',
]
messages_origin = []
messages = messages_origin
qa_index = 0
while True:
start = time.time()
if not contain_history_chat:
messages = messages_origin.copy()
if answer_way == 1:
# run generation
prompt = input('用户:')
else:
if qa_index >= len(prompt_datas):
break
prompt = prompt_datas[qa_index]
print('问题:', prompt)
qa_index += 1
messages.append({"role": "user", "content": prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-(max_seq_len - 1):]
x = tokenizer(prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
answer = new_prompt
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature,
top_k=top_k, stream=stream)
print('回答:', end='')
try:
y = next(res_y)
except StopIteration:
print("No answer")
continue
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '<EFBFBD>':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
print(answer[history_idx:], end='', flush=True)
try:
y = next(res_y)
except:
break
history_idx = len(answer)
if not stream:
break
print('\n')
if contain_history_chat:
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
end = time.time()
print(end - start,'s')