update generate args
This commit is contained in:
parent
5b65bc767e
commit
19b388cd87
@ -329,13 +329,15 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
|
||||||
# 流式生成
|
# 流式生成
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
|
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache,
|
||||||
|
**args)
|
||||||
|
|
||||||
# 直接生成
|
# 直接生成
|
||||||
generated = []
|
generated = []
|
||||||
for i in range(input_ids.size(0)):
|
for i in range(input_ids.size(0)):
|
||||||
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
||||||
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
|
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache,
|
||||||
|
**args)
|
||||||
tokens_list = [tokens[:, -1:] for tokens in out]
|
tokens_list = [tokens[:, -1:] for tokens in out]
|
||||||
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
||||||
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
||||||
@ -353,10 +355,10 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
||||||
while input_ids.shape[1] < max_new_tokens - 1:
|
while input_ids.shape[1] < max_new_tokens - 1:
|
||||||
if first_seq or not use_cache:
|
if first_seq or not use_cache:
|
||||||
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
|
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
|
||||||
else:
|
else:
|
||||||
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
||||||
start_pos=input_ids.shape[1] - 1)
|
start_pos=input_ids.shape[1] - 1, **args)
|
||||||
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
||||||
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
||||||
logits /= (temperature + 1e-9)
|
logits /= (temperature + 1e-9)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user