From ed01c5d84ad84c81157245920f4784accf8b540b Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sat, 5 Apr 2025 12:03:04 +0800 Subject: [PATCH] update inference --- eval_model.py | 4 ++-- model/model.py | 27 ++++++++++++++++++--------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/eval_model.py b/eval_model.py index fc71257..f67de60 100644 --- a/eval_model.py +++ b/eval_model.py @@ -120,7 +120,7 @@ def main(): # history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文 # 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置 parser.add_argument('--history_cnt', default=0, type=int) - parser.add_argument('--stream', default=True, type=bool) + parser.add_argument('--stream', default=False, type=bool) parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载") parser.add_argument('--model_mode', default=1, type=int, help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型") @@ -154,7 +154,7 @@ def main(): max_new_tokens=args.max_seq_len, temperature=args.temperature, top_p=args.top_p, - stream=True, + stream=args.stream, pad_token_id=tokenizer.pad_token_id ) diff --git a/model/model.py b/model/model.py index 6316a04..2a0ba39 100644 --- a/model/model.py +++ b/model/model.py @@ -4,7 +4,7 @@ import inspect import time from .LMConfig import LMConfig -from typing import Any, Optional, Tuple, List +from typing import Any, Optional, Tuple, List, Union import numpy as np import torch import torch.nn.functional as F @@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel): input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: bool = False, + logits_to_keep: Union[int, torch.Tensor] = 0, **args): past_key_values = past_key_values or [None] * len(self.layers) start_pos = args.get('start_pos', 0) @@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel): use_cache=use_cache ) past_kvs.append(past_kv) - logits = self.output(self.norm(h)) + + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) @@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel): @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, - stream=False, rp=1., use_cache=True, pad_token_id=0, **args): + stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args): # 流式生成 if stream: return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) @@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel): generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) - out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) - tokens_list = [tokens[:, -1:] for tokens in out] - gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad - full_sequence = torch.cat([non_pad, gen], dim=-1) - generated.append(full_sequence) + for _ in range(num_return_sequences): + out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) + tokens_list = [tokens[:, -1:] for tokens in out] + gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad + full_sequence = torch.cat([non_pad, gen], dim=-1) + generated.append(full_sequence) + max_length = max(seq.size(1) for seq in generated) generated = [ torch.cat( @@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel): dim=-1) for seq in generated ] - return torch.cat(generated, dim=0) + output = torch.cat(generated, dim=0) + res = output.view(input_ids.size(0), num_return_sequences, -1) + res = res.squeeze(0) if input_ids.size(0) == 1 else res + res = res.squeeze(1) if num_return_sequences == 1 else res + return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): start, first_seq, past_kvs = input_ids.shape[1], True, None