From 19b388cd8718e5a447933c2cfc7954cc135bd3e6 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sat, 15 Feb 2025 23:55:10 +0800 Subject: [PATCH] update generate args --- model/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/model/model.py b/model/model.py index 3f49236..f6005b3 100644 --- a/model/model.py +++ b/model/model.py @@ -329,13 +329,15 @@ class MiniMindLM(PreTrainedModel): stream=False, rp=1., use_cache=True, pad_token_id=0, **args): # 流式生成 if stream: - return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache) + return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, + **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) - out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache) + out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, + **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) @@ -353,10 +355,10 @@ class MiniMindLM(PreTrainedModel): start, first_seq, past_kvs = input_ids.shape[1], True, None while input_ids.shape[1] < max_new_tokens - 1: if first_seq or not use_cache: - out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False + out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False else: out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache, - start_pos=input_ids.shape[1] - 1) + start_pos=input_ids.shape[1] - 1, **args) logits, past_kvs = out.logits[:, -1, :], out.past_key_values logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9)