update inference

This commit is contained in:
jingyaogong 2025-04-05 12:03:04 +08:00
parent 7fcc46b39a
commit ed01c5d84a
2 changed files with 20 additions and 11 deletions

View File

@ -120,7 +120,7 @@ def main():
# history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文
# 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置
parser.add_argument('--history_cnt', default=0, type=int)
parser.add_argument('--stream', default=True, type=bool)
parser.add_argument('--stream', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载")
parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型")
@ -154,7 +154,7 @@ def main():
max_new_tokens=args.max_seq_len,
temperature=args.temperature,
top_p=args.top_p,
stream=True,
stream=args.stream,
pad_token_id=tokenizer.pad_token_id
)

View File

@ -4,7 +4,7 @@ import inspect
import time
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List
from typing import Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn.functional as F
@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel):
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel):
use_cache=use_cache
)
past_kvs.append(past_kv)
logits = self.output(self.norm(h))
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel):
@torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成
if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel):
generated = []
for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
for _ in range(num_return_sequences):
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated)
generated = [
torch.cat(
@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel):
dim=-1)
for seq in generated
]
return torch.cat(generated, dim=0)
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0), num_return_sequences, -1)
res = res.squeeze(0) if input_ids.size(0) == 1 else res
res = res.squeeze(1) if num_return_sequences == 1 else res
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None