Minimind/train_full_sft.py

215 lines
9.4 KiB
Python
Raw Normal View History

2025-02-09 23:49:47 +08:00
import os
2025-04-24 15:58:39 +08:00
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
2025-02-09 23:49:47 +08:00
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
2025-04-24 15:58:39 +08:00
warnings.filterwarnings('ignore')
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
# 日志记录函数,用于打印训练信息。
2025-02-09 23:49:47 +08:00
def Logger(content):
if not ddp or dist.get_rank() == 0:
print(content)
2025-04-24 15:58:39 +08:00
# 学习率计算函数,用于计算当前学习率。
2025-02-09 23:49:47 +08:00
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
2025-04-24 15:58:39 +08:00
# 训练一个epoch的函数用于训练模型。
2025-02-09 23:49:47 +08:00
def train_epoch(epoch, wandb):
2025-04-24 15:58:39 +08:00
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失函数,用于计算损失。
2025-02-09 23:49:47 +08:00
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
2025-04-24 15:58:39 +08:00
# 将数据移动到指定设备。
2025-02-09 23:49:47 +08:00
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
2025-04-24 15:58:39 +08:00
# 计算当前学习率。
2025-02-09 23:49:47 +08:00
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
2025-04-24 15:58:39 +08:00
# 更新学习率。
2025-02-09 23:49:47 +08:00
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
2025-04-24 15:58:39 +08:00
res = model(X) #获取输出
2025-02-09 23:49:47 +08:00
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
2025-04-24 15:58:39 +08:00
).view(Y.size()) #计算损失
# 计算损失
2025-02-09 23:49:47 +08:00
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
2025-04-24 15:58:39 +08:00
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。
2025-02-09 23:49:47 +08:00
if (step + 1) % args.accumulation_steps == 0:
2025-04-24 15:58:39 +08:00
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
optimizer.zero_grad(set_to_none=True) #清空梯度。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
# 如果达到日志记录间隔,则记录日志。
2025-02-09 23:49:47 +08:00
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item(),
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, ckp)
model.train()
2025-04-24 15:58:39 +08:00
# 初始化模型函数,用于初始化模型。
2025-02-09 23:49:47 +08:00
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
return model, tokenizer
2025-04-24 15:58:39 +08:00
# 初始化分布式模式函数,用于初始化分布式模式。
2025-02-09 23:49:47 +08:00
def init_distributed_mode():
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Full SFT")
parser.add_argument("--out_dir", type=str, default="out")
2025-04-24 15:58:39 +08:00
parser.add_argument("--epochs", type=int, default=3)
2025-02-11 23:52:40 +08:00
parser.add_argument("--batch_size", type=int, default=32)
2025-02-09 23:49:47 +08:00
parser.add_argument("--learning_rate", type=float, default=5e-5)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
2025-04-24 15:58:39 +08:00
parser.add_argument("--use_wandb", default=True, action="store_true")
2025-02-09 23:49:47 +08:00
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
2025-04-24 15:58:39 +08:00
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
2025-02-09 23:49:47 +08:00
parser.add_argument('--use_moe', default=False, type=bool)
2025-04-24 15:58:39 +08:00
parser.add_argument("--data_path", type=str, default="./dataset/sft_1024.jsonl")
2025-02-09 23:49:47 +08:00
args = parser.parse_args()
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * lm_config.max_seq_len
device_type = "cuda" if "cuda" in args.device else "cpu"
args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
2025-04-04 11:39:41 +08:00
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
2025-04-24 15:58:39 +08:00
# 如果使用分布式模式,则初始化分布式模式。
2025-02-09 23:49:47 +08:00
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
2025-04-04 11:39:41 +08:00
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
# 如果使用WandB则初始化WandB。
2025-02-09 23:49:47 +08:00
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
2025-04-24 15:58:39 +08:00
# 初始化模型。
2025-02-09 23:49:47 +08:00
model, tokenizer = init_model(lm_config)
2025-04-24 15:58:39 +08:00
# 初始化数据集。
2025-02-09 23:49:47 +08:00
train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
2025-04-24 15:58:39 +08:00
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) #创建一个梯度缩放器(GradScaler),用于混合精度训练。当模型使用半精度格式(float16或bfloat16)训练时启用,它帮助防止梯度下溢并提高训练效率。
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) # 创建AdamW优化器实例负责更新模型参数。它接收模型的所有参数和指定的学习率作为输入。AdamW是Adam优化器的变体增加了权重衰减的正则化。
2025-02-09 23:49:47 +08:00
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)