update inference
This commit is contained in:
parent
7fcc46b39a
commit
ed01c5d84a
@ -120,7 +120,7 @@ def main():
|
||||
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
||||
parser.add_argument('--history_cnt', default=0, type=int)
|
||||
parser.add_argument('--stream', default=True, type=bool)
|
||||
parser.add_argument('--stream', default=False, type=bool)
|
||||
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||
parser.add_argument('--model_mode', default=1, type=int,
|
||||
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
||||
@ -154,7 +154,7 @@ def main():
|
||||
max_new_tokens=args.max_seq_len,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
stream=True,
|
||||
stream=args.stream,
|
||||
pad_token_id=tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import inspect
|
||||
import time
|
||||
|
||||
from .LMConfig import LMConfig
|
||||
from typing import Any, Optional, Tuple, List
|
||||
from typing import Any, Optional, Tuple, List, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**args):
|
||||
past_key_values = past_key_values or [None] * len(self.layers)
|
||||
start_pos = args.get('start_pos', 0)
|
||||
@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
use_cache=use_cache
|
||||
)
|
||||
past_kvs.append(past_kv)
|
||||
logits = self.output(self.norm(h))
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
self.OUT.__setitem__('logits', logits)
|
||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||
@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
||||
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||
# 流式生成
|
||||
if stream:
|
||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel):
|
||||
generated = []
|
||||
for i in range(input_ids.size(0)):
|
||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||
for _ in range(num_return_sequences):
|
||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||
generated.append(full_sequence)
|
||||
|
||||
max_length = max(seq.size(1) for seq in generated)
|
||||
generated = [
|
||||
torch.cat(
|
||||
@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel):
|
||||
dim=-1)
|
||||
for seq in generated
|
||||
]
|
||||
return torch.cat(generated, dim=0)
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0), num_return_sequences, -1)
|
||||
res = res.squeeze(0) if input_ids.size(0) == 1 else res
|
||||
res = res.squeeze(1) if num_return_sequences == 1 else res
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||
|
Loading…
x
Reference in New Issue
Block a user