add hidden state

This commit is contained in:
jingyaogong 2025-04-05 14:39:56 +08:00
parent 885661f47d
commit 399d526fbd

View File

@ -325,6 +325,7 @@ class MiniMindLM(PreTrainedModel):
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
self.OUT.__setitem__('last_hidden_state', h)
self.OUT.__setitem__('logits', logits)
self.OUT.__setitem__('aux_loss', aux_loss)
self.OUT.__setitem__('past_key_values', past_kvs)