update generate

This commit is contained in:
jingyaogong 2025-04-05 15:53:55 +08:00
parent 399d526fbd
commit 9e67798397

View File

@ -357,9 +357,7 @@ class MiniMindLM(PreTrainedModel):
for seq in generated
]
output = torch.cat(generated, dim=0)
res = output.view(input_ids.size(0), num_return_sequences, -1)
res = res.squeeze(0) if input_ids.size(0) == 1 else res
res = res.squeeze(1) if num_return_sequences == 1 else res
res = output.view(input_ids.size(0) * num_return_sequences, -1)
return res
def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):