update model (fix loss bug)

This commit is contained in:
gongjy 2024-09-29 16:58:48 +08:00
parent 4ef9c41563
commit a87f628400

View File

@ -369,7 +369,8 @@ class Transformer(PreTrainedModel):
if targets is not None:
logits = self.output(h)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
ignore_index=0, reduction='none')
else:
logits = self.output(h[:, [-1], :])
self.last_loss = None