import os import platform import argparse import random import time import math import warnings import torch.distributed as dist from contextlib import nullcontext from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, AutoModelForCausalLM from model.model import MiniMindLM from model.LMConfig import LMConfig from model.dataset import SFTDataset from model.model_lora import * warnings.filterwarnings('ignore') # Logger function def Logger(content): if not ddp or dist.get_rank() == 0: print(content) def get_lr(current_step, total_steps, lr): return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) # 代码和full_sft「几乎」一致 def train_epoch(epoch, wandb): loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() for step, (X, Y, loss_mask) in enumerate(train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr with ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) loss = (loss * loss_mask).sum() / loss_mask.sum() loss += res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward() if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0: spend_time = time.time() - start_time Logger( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format( epoch + 1, args.epochs, step, iter_per_epoch, loss.item(), optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({"loss": loss, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() # 【区别1】只保存lora权重即可 save_lora(model, f'{args.save_dir}/lora/{args.lora_name}_{lm_config.dim}.pth') model.train() def init_model(lm_config): tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') model = MiniMindLM(lm_config) moe_path = '_moe' if lm_config.use_moe else '' ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth' state_dict = torch.load(ckp, map_location=args.device) model.load_state_dict(state_dict, strict=False) return model.to(args.device), tokenizer def init_distributed_mode(): if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind SFT with LoRA") parser.add_argument("--out_dir", type=str, default="out") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--use_wandb", action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA-SFT") parser.add_argument("--num_workers", type=int, default=1) parser.add_argument("--ddp", action="store_true") parser.add_argument("--accumulation_steps", type=int, default=1) parser.add_argument("--grad_clip", type=float, default=1.0) parser.add_argument("--warmup_iters", type=int, default=0) parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--save_interval", type=int, default=1) parser.add_argument('--local_rank', type=int, default=-1) parser.add_argument('--dim', default=512, type=int) parser.add_argument('--n_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument("--data_path", type=str, default="./dataset/lora_identity.jsonl") parser.add_argument("--lora_name", type=str, default="lora_identity", help="根据任务保存成lora_(英文/医学/心理...)") args = parser.parse_args() lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * lm_config.max_seq_len device_type = "cuda" if "cuda" in args.device else "cpu" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp_local_rank, DEVICE = 0, "cuda:0" base_seed = 1337 torch.manual_seed(base_seed) torch.cuda.manual_seed(base_seed) if ddp: init_distributed_mode() args.device = torch.device(DEVICE) rank = dist.get_rank() torch.manual_seed(base_seed + rank) # 同时设置 CUDA 的随机种子 torch.cuda.manual_seed(base_seed + rank) args.wandb_run_name = f"MiniMind-Lora-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb wandb.init(project=args.wandb_project, name=args.wandb_run_name) else: wandb = None model, tokenizer = init_model(lm_config) apply_lora(model) total_params = sum(p.numel() for p in model.parameters()) # 总参数数量 lora_params_count = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name) # LoRA 参数数量 if not ddp or dist.get_rank() == 0: print(f"LLM 总参数量: {total_params}") print(f"LoRA 参数量: {lora_params_count}") print(f"LoRA 参数占比: {lora_params_count / total_params * 100:.2f}%") for name, param in model.named_parameters(): if 'lora' not in name: param.requires_grad = False lora_params = [] for name, param in model.named_parameters(): if 'lora' in name: lora_params.append(param) # 只对 LoRA 参数进行优化 optimizer = optim.AdamW(lora_params, lr=args.learning_rate) train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers, sampler=train_sampler ) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) iter_per_epoch = len(train_loader) for epoch in range(args.epochs): train_epoch(epoch, wandb)