diff --git a/model/model.py b/model/model.py index 06f791d..d53b57a 100644 --- a/model/model.py +++ b/model/model.py @@ -357,9 +357,7 @@ class MiniMindLM(PreTrainedModel): for seq in generated ] output = torch.cat(generated, dim=0) - res = output.view(input_ids.size(0), num_return_sequences, -1) - res = res.squeeze(0) if input_ids.size(0) == 1 else res - res = res.squeeze(1) if num_return_sequences == 1 else res + res = output.view(input_ids.size(0) * num_return_sequences, -1) return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):