update inference
This commit is contained in:
parent
7fcc46b39a
commit
ed01c5d84a
@ -120,7 +120,7 @@ def main():
|
|||||||
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
# history_cnt需要设为偶数,即【用户问题, 模型回答】为1组;设置为0时,即当前query不携带历史上文
|
||||||
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
# 模型未经过外推微调时,在更长的上下文的chat_template时难免出现性能的明显退化,因此需要注意此处设置
|
||||||
parser.add_argument('--history_cnt', default=0, type=int)
|
parser.add_argument('--history_cnt', default=0, type=int)
|
||||||
parser.add_argument('--stream', default=True, type=bool)
|
parser.add_argument('--stream', default=False, type=bool)
|
||||||
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
parser.add_argument('--load', default=0, type=int, help="0: 原生torch权重,1: transformers加载")
|
||||||
parser.add_argument('--model_mode', default=1, type=int,
|
parser.add_argument('--model_mode', default=1, type=int,
|
||||||
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
help="0: 预训练模型,1: SFT-Chat模型,2: RLHF-Chat模型,3: Reason模型")
|
||||||
@ -154,7 +154,7 @@ def main():
|
|||||||
max_new_tokens=args.max_seq_len,
|
max_new_tokens=args.max_seq_len,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
stream=True,
|
stream=args.stream,
|
||||||
pad_token_id=tokenizer.pad_token_id
|
pad_token_id=tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import inspect
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from .LMConfig import LMConfig
|
from .LMConfig import LMConfig
|
||||||
from typing import Any, Optional, Tuple, List
|
from typing import Any, Optional, Tuple, List, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -307,6 +307,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
**args):
|
**args):
|
||||||
past_key_values = past_key_values or [None] * len(self.layers)
|
past_key_values = past_key_values or [None] * len(self.layers)
|
||||||
start_pos = args.get('start_pos', 0)
|
start_pos = args.get('start_pos', 0)
|
||||||
@ -320,7 +321,9 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
use_cache=use_cache
|
use_cache=use_cache
|
||||||
)
|
)
|
||||||
past_kvs.append(past_kv)
|
past_kvs.append(past_kv)
|
||||||
logits = self.output(self.norm(h))
|
|
||||||
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||||
self.OUT.__setitem__('logits', logits)
|
self.OUT.__setitem__('logits', logits)
|
||||||
self.OUT.__setitem__('aux_loss', aux_loss)
|
self.OUT.__setitem__('aux_loss', aux_loss)
|
||||||
@ -329,7 +332,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
||||||
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
stream=False, rp=1., use_cache=True, pad_token_id=0, num_return_sequences=1, **args):
|
||||||
# 流式生成
|
# 流式生成
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||||
@ -338,11 +341,13 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
generated = []
|
generated = []
|
||||||
for i in range(input_ids.size(0)):
|
for i in range(input_ids.size(0)):
|
||||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||||
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
for _ in range(num_return_sequences):
|
||||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
|
||||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||||
generated.append(full_sequence)
|
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||||
|
generated.append(full_sequence)
|
||||||
|
|
||||||
max_length = max(seq.size(1) for seq in generated)
|
max_length = max(seq.size(1) for seq in generated)
|
||||||
generated = [
|
generated = [
|
||||||
torch.cat(
|
torch.cat(
|
||||||
@ -350,7 +355,11 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
dim=-1)
|
dim=-1)
|
||||||
for seq in generated
|
for seq in generated
|
||||||
]
|
]
|
||||||
return torch.cat(generated, dim=0)
|
output = torch.cat(generated, dim=0)
|
||||||
|
res = output.view(input_ids.size(0), num_return_sequences, -1)
|
||||||
|
res = res.squeeze(0) if input_ids.size(0) == 1 else res
|
||||||
|
res = res.squeeze(1) if num_return_sequences == 1 else res
|
||||||
|
return res
|
||||||
|
|
||||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user