diff --git a/train_dpo.py b/train_dpo.py index dc9054a..e0b67af 100644 --- a/train_dpo.py +++ b/train_dpo.py @@ -40,11 +40,12 @@ def logits_to_probs(logits, labels): return probs -def dpo_loss(ref_probs, probs, beta): +def dpo_loss(ref_probs, probs, mask, beta): # ref_probs 和 probs 都是 shape: (batch_size, seq_len) - # 计算每个样本的平均概率 - ref_probs = ref_probs.mean(dim=1) - probs = probs.mean(dim=1) + # https://github.com/jingyaogong/minimind/issues/298 + seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1) + ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze() + probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze() # 将 chosen 和 rejected 数据分开 batch_size = ref_probs.shape[0] @@ -87,7 +88,7 @@ def train_epoch(epoch, wandb): logits = outputs.logits probs = logits_to_probs(logits, y) probs = probs * mask - loss = dpo_loss(ref_probs, probs, beta=0.1) + loss = dpo_loss(ref_probs, probs, mask, beta=0.1) loss = loss / args.accumulation_steps scaler.scale(loss).backward() @@ -183,7 +184,7 @@ if __name__ == "__main__": parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) - parser.add_argument('--max_seq_len', default=3000, type=int) + parser.add_argument('--max_seq_len', default=1024, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl")