From a87f6284007a2c4062016041c612fd0a86c2fafd Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sun, 29 Sep 2024 16:58:48 +0800 Subject: [PATCH] update model (fix loss bug) --- model/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/model/model.py b/model/model.py index e858ec7..20535fb 100644 --- a/model/model.py +++ b/model/model.py @@ -369,7 +369,8 @@ class Transformer(PreTrainedModel): if targets is not None: logits = self.output(h) - self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), + ignore_index=0, reduction='none') else: logits = self.output(h[:, [-1], :]) self.last_loss = None