update dpo_loss

This commit is contained in:
jingyaogong 2025-04-01 17:32:50 +08:00
parent 4f95e23a98
commit 278ec760a1

View File

@ -40,11 +40,12 @@ def logits_to_probs(logits, labels):
return probs return probs
def dpo_loss(ref_probs, probs, beta): def dpo_loss(ref_probs, probs, mask, beta):
# ref_probs 和 probs 都是 shape: (batch_size, seq_len) # ref_probs 和 probs 都是 shape: (batch_size, seq_len)
# 计算每个样本的平均概率 # https://github.com/jingyaogong/minimind/issues/298
ref_probs = ref_probs.mean(dim=1) seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
probs = probs.mean(dim=1) ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开 # 将 chosen 和 rejected 数据分开
batch_size = ref_probs.shape[0] batch_size = ref_probs.shape[0]
@ -87,7 +88,7 @@ def train_epoch(epoch, wandb):
logits = outputs.logits logits = outputs.logits
probs = logits_to_probs(logits, y) probs = logits_to_probs(logits, y)
probs = probs * mask probs = probs * mask
loss = dpo_loss(ref_probs, probs, beta=0.1) loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
scaler.scale(loss).backward() scaler.scale(loss).backward()
@ -183,7 +184,7 @@ if __name__ == "__main__":
parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--dim', default=512, type=int) parser.add_argument('--dim', default=512, type=int)
parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--n_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=3000, type=int) parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl") parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl")