update dpo_loss
This commit is contained in:
parent
4f95e23a98
commit
278ec760a1
13
train_dpo.py
13
train_dpo.py
@ -40,11 +40,12 @@ def logits_to_probs(logits, labels):
|
||||
return probs
|
||||
|
||||
|
||||
def dpo_loss(ref_probs, probs, beta):
|
||||
def dpo_loss(ref_probs, probs, mask, beta):
|
||||
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
|
||||
# 计算每个样本的平均概率
|
||||
ref_probs = ref_probs.mean(dim=1)
|
||||
probs = probs.mean(dim=1)
|
||||
# https://github.com/jingyaogong/minimind/issues/298
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
|
||||
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
|
||||
# 将 chosen 和 rejected 数据分开
|
||||
batch_size = ref_probs.shape[0]
|
||||
@ -87,7 +88,7 @@ def train_epoch(epoch, wandb):
|
||||
logits = outputs.logits
|
||||
probs = logits_to_probs(logits, y)
|
||||
probs = probs * mask
|
||||
loss = dpo_loss(ref_probs, probs, beta=0.1)
|
||||
loss = dpo_loss(ref_probs, probs, mask, beta=0.1)
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
@ -183,7 +184,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--dim', default=512, type=int)
|
||||
parser.add_argument('--n_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=3000, type=int)
|
||||
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user