From 9e6779839751b4646892633293d2adcb1cc97455 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sat, 5 Apr 2025 15:53:55 +0800 Subject: [PATCH] update generate --- model/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/model/model.py b/model/model.py index 06f791d..d53b57a 100644 --- a/model/model.py +++ b/model/model.py @@ -357,9 +357,7 @@ class MiniMindLM(PreTrainedModel): for seq in generated ] output = torch.cat(generated, dim=0) - res = output.view(input_ids.size(0), num_return_sequences, -1) - res = res.squeeze(0) if input_ids.size(0) == 1 else res - res = res.squeeze(1) if num_return_sequences == 1 else res + res = output.view(input_ids.size(0) * num_return_sequences, -1) return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):