import os import platform import argparse import time import math import warnings import pandas as pd import torch import torch.nn.functional as F import torch.distributed as dist from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler from transformers import AutoTokenizer, AutoModelForCausalLM from model.model import Transformer from model.LMConfig import LMConfig from model.dataset import SFTDataset warnings.filterwarnings('ignore') def Logger(content): if not ddp or dist.get_rank() == 0: print(content) def get_lr(it, all): warmup_iters = args.warmup_iters lr_decay_iters = all min_lr = args.learning_rate / 10 if it < warmup_iters: return args.learning_rate * it / warmup_iters if it > lr_decay_iters: return min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (args.learning_rate - min_lr) def train_epoch(epoch, wandb): start_time = time.time() for step, (X, Y, loss_mask) in enumerate(train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch) for param_group in optimizer.param_groups: param_group['lr'] = lr with ctx: logits = model(X, Y).logits loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0, reduction='none') loss_mask = loss_mask.view(-1) loss = torch.sum(loss * loss_mask) / loss_mask.sum() scaler.scale(loss).backward() if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if step % args.log_interval == 0: spend_time = time.time() - start_time Logger( 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format( epoch, args.epochs, step, iter_per_epoch, loss.item(), optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({"loss": loss, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() moe_path = '_moe' if lm_config.use_moe else '' ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth' if isinstance(model, torch.nn.parallel.DistributedDataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() torch.save(state_dict, ckp) model.train() def init_model(): tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') model_from = 1 # 1从权重,2用transformers def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) if model_from == 1: model = Transformer(lm_config) moe_path = '_moe' if lm_config.use_moe else '' ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth' state_dict = torch.load(ckp, map_location=args.device) unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) else: model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True) Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万') model = model.to(args.device) return model, tokenizer def init_distributed_mode(): if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--epochs", type=int, default=19, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="Weights & Biases project name") parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading") parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel") parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold") parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations") parser.add_argument("--log_interval", type=int, default=100, help="Logging interval") parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval") parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') args = parser.parse_args() lm_config = LMConfig() max_seq_len = lm_config.max_seq_len args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * max_seq_len torch.manual_seed(1337) device_type = "cuda" if "cuda" in args.device else "cpu" args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp_local_rank, DEVICE = 0, "cuda:0" if ddp: init_distributed_mode() args.device = torch.device(DEVICE) if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb wandb.init(project=args.wandb_project, name=args.wandb_run_name) else: wandb = None model, tokenizer = init_model() df = pd.read_csv('./dataset/sft_data_single.csv') df = df.sample(frac=1.0) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None train_loader = DataLoader( train_ds, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers, sampler=train_sampler ) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: Logger("compiling the model... (takes a ~minute)") unoptimized_model = model model = torch.compile(model) if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) for epoch in range(args.epochs): train_epoch(epoch, wandb)