update generate
This commit is contained in:
parent
399d526fbd
commit
9e67798397
@ -357,9 +357,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
for seq in generated
|
||||
]
|
||||
output = torch.cat(generated, dim=0)
|
||||
res = output.view(input_ids.size(0), num_return_sequences, -1)
|
||||
res = res.squeeze(0) if input_ids.size(0) == 1 else res
|
||||
res = res.squeeze(1) if num_return_sequences == 1 else res
|
||||
res = output.view(input_ids.size(0) * num_return_sequences, -1)
|
||||
return res
|
||||
|
||||
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
||||
|
Loading…
x
Reference in New Issue
Block a user