update fast_inference

This commit is contained in:
gongjy 2024-10-30 15:26:28 +08:00
parent db39571493
commit 7c67ba0b92

View File

@ -1,19 +1,16 @@
import json import json
import random
import numpy as np
import streamlit as st import streamlit as st
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig from transformers.generation.utils import GenerationConfig
st.set_page_config(page_title="minimind-v1(108M)") st.set_page_config(page_title="MiniMind-V1")
st.title("minimind-v1(108M)") st.title("MiniMind-V1")
model_id = "minimind-v1" model_id = "./minimind-v1"
# -----------------------------------------------------------------------------
temperature = 0.7
top_k = 8
max_seq_len = 1 * 1024
# -----------------------------------------------------------------------------
@st.cache_resource @st.cache_resource
def load_model_tokenizer(): def load_model_tokenizer():
@ -33,28 +30,41 @@ def load_model_tokenizer():
def clear_chat_messages(): def clear_chat_messages():
del st.session_state.messages del st.session_state.messages
del st.session_state.chat_messages
def init_chat_messages(): def init_chat_messages():
with st.chat_message("assistant", avatar='🤖'): with st.chat_message("assistant", avatar='🤖'):
st.markdown("您好我是由JingyaoGong创造的MiniMind很高兴为您服务😄") st.markdown("我是由JingyaoGong创造的MiniMind很高兴为您服务😄 \n"
"所有AI生成内容的准确性和立场无法保证不代表我们的态度或观点。")
if "messages" in st.session_state: if "messages" in st.session_state:
for message in st.session_state.messages: for message in st.session_state.messages:
avatar = "🧑‍💻" if message["role"] == "user" else "🤖" avatar = "🫡" if message["role"] == "user" else "🤖"
with st.chat_message(message["role"], avatar=avatar): with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"]) st.markdown(message["content"])
else: else:
st.session_state.messages = [] st.session_state.messages = []
st.session_state.chat_messages = []
return st.session_state.messages return st.session_state.messages
# max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 1024, 512, step=1) st.sidebar.title("设定调整")
# top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) st.session_state.history_chat_num = st.sidebar.slider("携带历史对话条数", 0, 6, 0, step=2)
# top_k = st.sidebar.slider("top_k", 0, 100, 0, step=1) st.session_state.max_new_tokens = st.sidebar.slider("最大输入/生成长度", 256, 768, 512, step=1)
# temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step=0.01) st.session_state.top_k = st.sidebar.slider("top_k", 0, 16, 14, step=1)
# do_sample = st.sidebar.checkbox("do_sample", value=False) st.session_state.temperature = st.sidebar.slider("temperature", 0.3, 1.3, 0.5, step=0.01)
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main(): def main():
@ -65,32 +75,30 @@ def main():
with st.chat_message("user", avatar='🧑‍💻'): with st.chat_message("user", avatar='🧑‍💻'):
st.markdown(prompt) st.markdown(prompt)
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
st.session_state.chat_messages.append({"role": "user", "content": '请问,' + prompt + ''})
with st.chat_message("assistant", avatar='🤖'): with st.chat_message("assistant", avatar='🤖'):
placeholder = st.empty() placeholder = st.empty()
# Generate a random seed
random_seed = random.randint(0, 2 ** 32 - 1)
setup_seed(random_seed)
chat_messages = []
chat_messages.append({"role": "user", "content": '请问,' + prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template( new_prompt = tokenizer.apply_chat_template(
chat_messages, st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):],
tokenize=False, tokenize=False,
add_generation_prompt=True add_generation_prompt=True
)[-(max_seq_len - 1):] )[-(st.session_state.max_new_tokens - 1):]
x = tokenizer(new_prompt).data['input_ids'] x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long)[None, ...]) x = (torch.tensor(x, dtype=torch.long)[None, ...])
response = ''
with torch.no_grad(): with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,
top_k=top_k, stream=True) temperature=st.session_state.temperature,
top_k=st.session_state.top_k, stream=True)
try: try:
y = next(res_y) y = next(res_y)
except StopIteration: except StopIteration:
return return
history_idx = 0
while y != None: while y != None:
answer = tokenizer.decode(y[0].tolist()) answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '<EFBFBD>': if answer and answer[-1] == '<EFBFBD>':
@ -99,7 +107,6 @@ def main():
except: except:
break break
continue continue
# print(answer)
if not len(answer): if not len(answer):
try: try:
y = next(res_y) y = next(res_y)
@ -107,17 +114,14 @@ def main():
break break
continue continue
placeholder.markdown(answer) placeholder.markdown(answer)
response = answer
try: try:
y = next(res_y) y = next(res_y)
except: except:
break break
# if contain_history_chat: assistant_answer = answer.replace(new_prompt, "")
# assistant_answer = answer.replace(new_prompt, "") messages.append({"role": "assistant", "content": assistant_answer})
# messages.append({"role": "assistant", "content": assistant_answer}) st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})
messages.append({"role": "assistant", "content": response})
st.button("清空对话", on_click=clear_chat_messages) st.button("清空对话", on_click=clear_chat_messages)