From 844e79148ca5cd80470b41408f6a391e6e6d6d9a Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sat, 15 Feb 2025 23:56:09 +0800 Subject: [PATCH] update generate args --- model/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/model/model.py b/model/model.py index f6005b3..1dc6478 100644 --- a/model/model.py +++ b/model/model.py @@ -329,15 +329,13 @@ class MiniMindLM(PreTrainedModel): stream=False, rp=1., use_cache=True, pad_token_id=0, **args): # 流式生成 if stream: - return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, - **args) + return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) - out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, - **args) + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) @@ -351,7 +349,7 @@ class MiniMindLM(PreTrainedModel): ] return torch.cat(generated, dim=0) - def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): + def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): start, first_seq, past_kvs = input_ids.shape[1], True, None while input_ids.shape[1] < max_new_tokens - 1: if first_seq or not use_cache: