update generate
This commit is contained in:
parent
399d526fbd
commit
9e67798397
@ -357,9 +357,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
for seq in generated
|
for seq in generated
|
||||||
]
|
]
|
||||||
output = torch.cat(generated, dim=0)
|
output = torch.cat(generated, dim=0)
|
||||||
res = output.view(input_ids.size(0), num_return_sequences, -1)
|
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||||
res = res.squeeze(0) if input_ids.size(0) == 1 else res
|
|
||||||
res = res.squeeze(1) if num_return_sequences == 1 else res
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user