import os import platform import argparse import time import math import warnings import pandas as pd import torch import torch.nn.functional as F import torch.distributed as dist from contextlib import nullcontext from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, AutoModelForCausalLM from model.model import MiniMindLM from model.LMConfig import LMConfig from model.dataset import DPODataset warnings.filterwarnings('ignore') def Logger(content): if not ddp or dist.get_rank() == 0: print(content) def get_lr(current_step, total_steps, lr): return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) def logits_to_probs(logits, labels): # logits shape: (batch_size, seq_len, vocab_size) # labels shape: (batch_size, seq_len) # probs shape: (batch_size, seq_len) log_probs = F.log_softmax(logits, dim=2) probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) return probs def dpo_loss(ref_probs, probs, mask, beta): # ref_probs 和 probs 都是 shape: (batch_size, seq_len) # https://github.com/jingyaogong/minimind/issues/298 seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1) ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze() probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze() # 将 chosen 和 rejected 数据分开 batch_size = ref_probs.shape[0] chosen_ref_probs = ref_probs[:batch_size // 2] reject_ref_probs = ref_probs[batch_size // 2:] chosen_probs = probs[:batch_size // 2] reject_probs = probs[batch_size // 2:] pi_logratios = chosen_probs - reject_probs ref_logratios = chosen_ref_probs - reject_ref_probs logits = pi_logratios - ref_logratios loss = -F.logsigmoid(beta * logits) return loss.mean() def train_epoch(epoch, wandb): start_time = time.time() for step, batch in enumerate(train_loader): x_chosen = batch['x_chosen'].to(args.device) x_rejected = batch['x_rejected'].to(args.device) y_chosen = batch['y_chosen'].to(args.device) y_rejected = batch['y_rejected'].to(args.device) mask_chosen = batch['mask_chosen'].to(args.device) mask_rejected = batch['mask_rejected'].to(args.device) x = torch.cat([x_chosen, x_rejected], dim=0) y = torch.cat([y_chosen, y_rejected], dim=0) mask = torch.cat([mask_chosen, mask_rejected], dim=0) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with ctx: with torch.no_grad(): ref_outputs = ref_model(x) ref_logits = ref_outputs.logits ref_probs = logits_to_probs(ref_logits, y) ref_probs = ref_probs * mask outputs = model(x) logits = outputs.logits probs = logits_to_probs(logits, y) probs = probs * mask loss = dpo_loss(ref_probs, probs, mask, beta=0.1) loss = loss / args.accumulation_steps scaler.scale(loss).backward() if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0: spend_time = time.time() - start_time Logger( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format( epoch + 1, args.epochs, step, iter_per_epoch, loss.item(), optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({"loss": loss, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() moe_path = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/rlhf_{lm_config.dim}{moe_path}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, ckp) model.train() def init_model(lm_config): tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') model = MiniMindLM(lm_config) moe_path = '_moe' if lm_config.use_moe else '' ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth' state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) # 初始化参考模型 ref_model = MiniMindLM(lm_config) ref_model.load_state_dict(state_dict, strict=False) ref_model.eval() ref_model.requires_grad_(False) Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') model = model.to(args.device) ref_model = ref_model.to(args.device) return model, ref_model, tokenizer def init_distributed_mode(): if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind RLHF") parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=2) parser.add_argument("--batch_size", type=int, default=8) # sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏 parser.add_argument("--learning_rate", type=float, default=1e-8) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_wandb", action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-RLHF-SFT") parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--ddp", action="store_true") parser.add_argument("--accumulation_steps", type=int, default=1) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--save_interval", type=int, default=100) parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=1024, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument("--data_path", type=str, default="./dataset/dpo.jsonl") args = parser.parse_args() lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * lm_config.max_seq_len device_type = "cuda" if "cuda" in args.device else "cpu" args.wandb_run_name = f"MiniMind-Full-DPO-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp_local_rank, DEVICE = 0, "cuda:0" base_seed = 1337 torch.manual_seed(base_seed) torch.cuda.manual_seed(base_seed) if ddp: init_distributed_mode() args.device = torch.device(DEVICE) rank = dist.get_rank() torch.manual_seed(base_seed + rank) # 同时设置 CUDA 的随机种子 torch.cuda.manual_seed(base_seed + rank) if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb wandb.init(project=args.wandb_project, name=args.wandb_run_name) else: wandb = None model, ref_model, tokenizer = init_model(lm_config) train_ds = DPODataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers, sampler=train_sampler ) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) for epoch in range(args.epochs): train_epoch(epoch, wandb)