update inference

This commit is contained in:
jingyaogong 2025-04-05 12:03:04 +08:00
parent 7fcc46b39a
commit ed01c5d84a
2 changed files with 20 additions and 11 deletions

View File

@ -120,7 +120,7 @@ def main():
# history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文 # history_cnt需要设为偶数即【用户问题, 模型回答】为1组设置为0时即当前query不携带历史上文
# 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置 # 模型未经过外推微调时在更长的上下文的chat_template时难免出现性能的明显退化因此需要注意此处设置
parser.add_argument('--history_cnt', default=0, type=int) parser.add_argument('--history_cnt', default=0, type=int)
parser.add_argument('--stream', default=True, type=bool) parser.add_argument('--stream', default=False, type=bool)
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载") parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重1: transformers加载")
parser.add_argument('--model_mode', default=1, type=int, parser.add_argument('--model_mode', default=1, type=int,
help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型") help="0: 预训练模型1: SFT-Chat模型2: RLHF-Chat模型3: Reason模型")
@ -154,7 +154,7 @@ def main():
max_new_tokens=args.max_seq_len, max_new_tokens=args.max_seq_len,
temperature=args.temperature, temperature=args.temperature,
top_p=args.top_p, top_p=args.top_p,
stream=True, stream=args.stream,
pad_token_id=tokenizer.pad_token_id pad_token_id=tokenizer.pad_token_id
) )

View File

@ -4,7 +4,7 @@ import inspect
import time import time
from .LMConfig import LMConfig from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List from typing import Any, Optional, Tuple, List, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel):
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False, use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args): **args):
past_key_values = past_key_values or [None] * len(self.layers) past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0) start_pos = args.get('start_pos', 0)
@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel):
use_cache=use_cache use_cache=use_cache
) )
past_kvs.append(past_kv) past_kvs.append(past_kv)
logits = self.output(self.norm(h))
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('aux_loss', aux_loss)
@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel):
@torch.inference_mode() @torch.inference_mode()
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
stream=False, rp=1., use_cache=True, pad_token_id=0, **args): stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
# 流式生成 # 流式生成
if stream: if stream:
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel):
generated = [] generated = []
for i in range(input_ids.size(0)): for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args) for _ in range(num_return_sequences):
tokens_list = [tokens[:, -1:] for tokens in out] out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad tokens_list = [tokens[:, -1:] for tokens in out]
full_sequence = torch.cat([non_pad, gen], dim=-1) gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
generated.append(full_sequence) full_sequence = torch.cat([non_pad, gen], dim=-1)
generated.append(full_sequence)
max_length = max(seq.size(1) for seq in generated) max_length = max(seq.size(1) for seq in generated)
generated = [ generated = [
torch.cat( torch.cat(
@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel):
dim=-1) dim=-1)
for seq in generated for seq in generated
] ]
return torch.cat(generated, dim=0) output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0), num_return_sequences, -1)
res = res.squeeze(0) if input_ids.size(0) == 1 else res
res = res.squeeze(1) if num_return_sequences == 1 else res
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None start, first_seq, past_kvs = input_ids.shape[1], True, None