diff --git a/model/model.py b/model/model.py index 2a0ba39..06f791d 100644 --- a/model/model.py +++ b/model/model.py @@ -325,6 +325,7 @@ class MiniMindLM(PreTrainedModel): slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) + self.OUT.__setitem__('last_hidden_state', h) self.OUT.__setitem__('logits', logits) self.OUT.__setitem__('aux_loss', aux_loss) self.OUT.__setitem__('past_key_values', past_kvs)