update my_openai_api.py
This commit is contained in:
parent
13e791e516
commit
b043ec996b
@ -24,9 +24,9 @@ DEVICE_NAME = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|||||||
DEVICE = torch.device(DEVICE_NAME)
|
DEVICE = torch.device(DEVICE_NAME)
|
||||||
MODEL_PATH = "./minimind-small-T"
|
MODEL_PATH = "./minimind-small-T"
|
||||||
TOKENIZE_PATH = MODEL_PATH
|
TOKENIZE_PATH = MODEL_PATH
|
||||||
max_new_tokens = 2048
|
max_new_tokens = 1024
|
||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
top_k = 8
|
top_k = 16
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------------------------
|
||||||
@ -61,7 +61,11 @@ class Transformers():
|
|||||||
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
||||||
temperature=temperature, top_k=top_k, stream=True)
|
temperature=temperature, top_k=top_k, stream=True)
|
||||||
|
|
||||||
y = next(res_y)
|
try:
|
||||||
|
y = next(res_y)
|
||||||
|
except:
|
||||||
|
print("No answer")
|
||||||
|
return 'No answer'
|
||||||
|
|
||||||
history_idx = 0
|
history_idx = 0
|
||||||
while y != None:
|
while y != None:
|
||||||
@ -92,7 +96,7 @@ class Transformers():
|
|||||||
def chat_no_stream(self, tokenizer, messages: List[dict]):
|
def chat_no_stream(self, tokenizer, messages: List[dict]):
|
||||||
input_ids, eos_token_id, new_prompt = self.build_chat_input(tokenizer, messages)
|
input_ids, eos_token_id, new_prompt = self.build_chat_input(tokenizer, messages)
|
||||||
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
||||||
temperature=temperature, top_k=top_k, stream=False)
|
temperature=temperature, top_k=top_k, stream=False)
|
||||||
y = next(res_y)
|
y = next(res_y)
|
||||||
answer = tokenizer.decode(y[0].tolist())
|
answer = tokenizer.decode(y[0].tolist())
|
||||||
return answer
|
return answer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user