update generate args

This commit is contained in:
gongjy 2025-02-15 23:56:09 +08:00
parent 19b388cd87
commit 844e79148c

View File

@ -329,15 +329,13 @@ class MiniMindLM(PreTrainedModel):
stream=False, rp=1., use_cache=True, pad_token_id=0, **args): stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
# 流式生成 # 流式生成
if stream: if stream:
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
**args)
# 直接生成 # 直接生成
generated = [] generated = []
for i in range(input_ids.size(0)): for i in range(input_ids.size(0)):
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
**args)
tokens_list = [tokens[:, -1:] for tokens in out] tokens_list = [tokens[:, -1:] for tokens in out]
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
full_sequence = torch.cat([non_pad, gen], dim=-1) full_sequence = torch.cat([non_pad, gen], dim=-1)
@ -351,7 +349,7 @@ class MiniMindLM(PreTrainedModel):
] ]
return torch.cat(generated, dim=0) return torch.cat(generated, dim=0)
def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args): def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
start, first_seq, past_kvs = input_ids.shape[1], True, None start, first_seq, past_kvs = input_ids.shape[1], True, None
while input_ids.shape[1] < max_new_tokens - 1: while input_ids.shape[1] < max_new_tokens - 1:
if first_seq or not use_cache: if first_seq or not use_cache: