From 7c67ba0b927d34cb78e89fd594ee1efdcdd8a084 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Wed, 30 Oct 2024 15:26:28 +0800 Subject: [PATCH] update fast_inference --- fast_infenence.py => fast_inference.py | 70 ++++++++++++++------------ 1 file changed, 37 insertions(+), 33 deletions(-) rename fast_infenence.py => fast_inference.py (57%) diff --git a/fast_infenence.py b/fast_inference.py similarity index 57% rename from fast_infenence.py rename to fast_inference.py index 5ab4eea..ee3a119 100644 --- a/fast_infenence.py +++ b/fast_inference.py @@ -1,19 +1,16 @@ import json +import random +import numpy as np import streamlit as st import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation.utils import GenerationConfig -st.set_page_config(page_title="minimind-v1(108M)") -st.title("minimind-v1(108M)") +st.set_page_config(page_title="MiniMind-V1") +st.title("MiniMind-V1") -model_id = "minimind-v1" +model_id = "./minimind-v1" -# ----------------------------------------------------------------------------- -temperature = 0.7 -top_k = 8 -max_seq_len = 1 * 1024 -# ----------------------------------------------------------------------------- @st.cache_resource def load_model_tokenizer(): @@ -33,28 +30,41 @@ def load_model_tokenizer(): def clear_chat_messages(): del st.session_state.messages + del st.session_state.chat_messages def init_chat_messages(): with st.chat_message("assistant", avatar='🤖'): - st.markdown("您好,我是由JingyaoGong创造的MiniMind,很高兴为您服务😄") + st.markdown("我是由JingyaoGong创造的MiniMind,很高兴为您服务😄 \n" + "注:所有AI生成内容的准确性和立场无法保证,不代表我们的态度或观点。") if "messages" in st.session_state: for message in st.session_state.messages: - avatar = "🧑‍💻" if message["role"] == "user" else "🤖" + avatar = "🫡" if message["role"] == "user" else "🤖" with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) else: st.session_state.messages = [] + st.session_state.chat_messages = [] return st.session_state.messages -# max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 1024, 512, step=1) -# top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) -# top_k = st.sidebar.slider("top_k", 0, 100, 0, step=1) -# temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step=0.01) -# do_sample = st.sidebar.checkbox("do_sample", value=False) +st.sidebar.title("设定调整") +st.session_state.history_chat_num = st.sidebar.slider("携带历史对话条数", 0, 6, 0, step=2) +st.session_state.max_new_tokens = st.sidebar.slider("最大输入/生成长度", 256, 768, 512, step=1) +st.session_state.top_k = st.sidebar.slider("top_k", 0, 16, 14, step=1) +st.session_state.temperature = st.sidebar.slider("temperature", 0.3, 1.3, 0.5, step=0.01) + + +def setup_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def main(): @@ -65,32 +75,30 @@ def main(): with st.chat_message("user", avatar='🧑‍💻'): st.markdown(prompt) messages.append({"role": "user", "content": prompt}) + st.session_state.chat_messages.append({"role": "user", "content": '请问,' + prompt + '?'}) with st.chat_message("assistant", avatar='🤖'): placeholder = st.empty() + # Generate a random seed + random_seed = random.randint(0, 2 ** 32 - 1) + setup_seed(random_seed) - chat_messages = [] - chat_messages.append({"role": "user", "content": '请问,' + prompt}) - # print(messages) new_prompt = tokenizer.apply_chat_template( - chat_messages, + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):], tokenize=False, add_generation_prompt=True - )[-(max_seq_len - 1):] + )[-(st.session_state.max_new_tokens - 1):] x = tokenizer(new_prompt).data['input_ids'] x = (torch.tensor(x, dtype=torch.long)[None, ...]) - - response = '' - with torch.no_grad(): - res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, - top_k=top_k, stream=True) + res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens, + temperature=st.session_state.temperature, + top_k=st.session_state.top_k, stream=True) try: y = next(res_y) except StopIteration: return - history_idx = 0 while y != None: answer = tokenizer.decode(y[0].tolist()) if answer and answer[-1] == '�': @@ -99,7 +107,6 @@ def main(): except: break continue - # print(answer) if not len(answer): try: y = next(res_y) @@ -107,17 +114,14 @@ def main(): break continue placeholder.markdown(answer) - response = answer try: y = next(res_y) except: break - # if contain_history_chat: - # assistant_answer = answer.replace(new_prompt, "") - # messages.append({"role": "assistant", "content": assistant_answer}) - - messages.append({"role": "assistant", "content": response}) + assistant_answer = answer.replace(new_prompt, "") + messages.append({"role": "assistant", "content": assistant_answer}) + st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer}) st.button("清空对话", on_click=clear_chat_messages)