update my_openai_api.py
This commit is contained in:
parent
13e791e516
commit
b043ec996b
@ -24,9 +24,9 @@ DEVICE_NAME = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
DEVICE = torch.device(DEVICE_NAME)
|
||||
MODEL_PATH = "./minimind-small-T"
|
||||
TOKENIZE_PATH = MODEL_PATH
|
||||
max_new_tokens = 2048
|
||||
max_new_tokens = 1024
|
||||
temperature = 0.7
|
||||
top_k = 8
|
||||
top_k = 16
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------------------------------------------
|
||||
@ -61,7 +61,11 @@ class Transformers():
|
||||
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
||||
temperature=temperature, top_k=top_k, stream=True)
|
||||
|
||||
y = next(res_y)
|
||||
try:
|
||||
y = next(res_y)
|
||||
except:
|
||||
print("No answer")
|
||||
return 'No answer'
|
||||
|
||||
history_idx = 0
|
||||
while y != None:
|
||||
@ -92,7 +96,7 @@ class Transformers():
|
||||
def chat_no_stream(self, tokenizer, messages: List[dict]):
|
||||
input_ids, eos_token_id, new_prompt = self.build_chat_input(tokenizer, messages)
|
||||
res_y = self.model.generate(input_ids, tokenizer.eos_token_id, max_new_tokens=max_new_tokens,
|
||||
temperature=temperature, top_k=top_k, stream=False)
|
||||
temperature=temperature, top_k=top_k, stream=False)
|
||||
y = next(res_y)
|
||||
answer = tokenizer.decode(y[0].tolist())
|
||||
return answer
|
||||
|
Loading…
x
Reference in New Issue
Block a user