Minimind/train_pretrain.py

215 lines
10 KiB
Python
Raw Normal View History

2025-02-09 23:49:47 +08:00
import os
2025-04-24 15:58:39 +08:00
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
2025-02-09 23:49:47 +08:00
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler
from contextlib import nullcontext
from transformers import AutoTokenizer
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
def Logger(content):
2025-04-24 15:58:39 +08:00
# 如果没有使用ddp或者ddp的主设备那么就打印
2025-02-09 23:49:47 +08:00
if not ddp or dist.get_rank() == 0:
print(content)
def get_lr(current_step, total_steps, lr):
2025-04-24 15:58:39 +08:00
# 更新学习率
# \text{get\_lr}(c, t, l) = \frac{l}{10} + 0.5 \cdot l \cdot \left(1 + \cos\left(\frac{\pi \cdot c}{t}\right)\right)
2025-02-09 23:49:47 +08:00
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def train_epoch(epoch, wandb):
2025-04-24 15:58:39 +08:00
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失Cross-Entropy Loss当 reduction='none' 时nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。
2025-02-09 23:49:47 +08:00
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
2025-04-24 15:58:39 +08:00
# 将数据加载到设备上
2025-02-09 23:49:47 +08:00
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
2025-04-24 15:58:39 +08:00
# 更新学习率
2025-02-09 23:49:47 +08:00
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
2025-04-24 15:58:39 +08:00
res = model(X) #获取输出
2025-02-09 23:49:47 +08:00
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
2025-04-24 15:58:39 +08:00
).view(Y.size())#计算损失
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss
# 为了批次堆叠进行的处理真正的batch size为num gpu*batch size per gpu*accumulation steps
loss += res.aux_loss
2025-02-09 23:49:47 +08:00
loss = loss / args.accumulation_steps
2025-04-24 15:58:39 +08:00
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
# 如果达到堆叠数目就进行处理
2025-02-09 23:49:47 +08:00
if (step + 1) % args.accumulation_steps == 0:
2025-04-24 15:58:39 +08:00
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
# 打印日志
2025-02-09 23:49:47 +08:00
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
2025-04-24 15:58:39 +08:00
# 保存模型
2025-02-09 23:49:47 +08:00
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
2025-04-24 15:58:39 +08:00
state_dict = model.module.state_dict() #获取模型参数
2025-02-09 23:49:47 +08:00
else:
2025-04-24 15:58:39 +08:00
state_dict = model.state_dict() #获取模型参数
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
torch.save(state_dict, ckp) #只保存参数
2025-02-09 23:49:47 +08:00
model.train()
def init_model(lm_config):
2025-04-24 15:58:39 +08:00
# 加载tokenizer
2025-02-09 23:49:47 +08:00
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
2025-04-24 15:58:39 +08:00
# 加载模型
2025-02-09 23:49:47 +08:00
model = MiniMindLM(lm_config).to(args.device)
2025-04-24 15:58:39 +08:00
# 打印模型参数
2025-02-09 23:49:47 +08:00
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def init_distributed_mode():
2025-04-24 15:58:39 +08:00
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
2025-02-09 23:49:47 +08:00
2025-04-24 15:58:39 +08:00
dist.init_process_group(backend="nccl") #初始化分布式进程组使用NCCL后端NVIDIA Collective Communications Library这是NVIDIA GPU之间通信的优化库。
ddp_rank = int(os.environ["RANK"]) #从环境变量获取当前进程的全局编号。
ddp_local_rank = int(os.environ["LOCAL_RANK"]) #从环境变量获取当前进程的本地编号。
ddp_world_size = int(os.environ["WORLD_SIZE"]) #从环境变量获取当前进程组中的进程总数。
DEVICE = f"cuda:{ddp_local_rank}" #根据本地编号选择GPU设备。
torch.cuda.set_device(DEVICE) #设置当前进程的GPU设备。
2025-02-09 23:49:47 +08:00
# torchrun --nproc_per_node 2 1-pretrain.py
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind Pretraining")
parser.add_argument("--out_dir", type=str, default="out")
# 若要以最快速度实现zero则epochs设置为1轮否则应当利用有限的数据训练2~6个epochs。
2025-04-24 15:58:39 +08:00
parser.add_argument("--epochs", type=int, default=3)
2025-02-11 23:52:40 +08:00
parser.add_argument("--batch_size", type=int, default=32)
2025-02-09 23:49:47 +08:00
parser.add_argument("--learning_rate", type=float, default=5e-4)
2025-04-24 15:58:39 +08:00
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
2025-02-09 23:49:47 +08:00
parser.add_argument("--dtype", type=str, default="bfloat16")
2025-04-25 16:29:28 +08:00
parser.add_argument("--use_wandb", default=False, action="store_true")
2025-02-09 23:49:47 +08:00
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
2025-04-24 15:58:39 +08:00
parser.add_argument("--num_workers", type=int, default=8)
2025-02-09 23:49:47 +08:00
parser.add_argument("--ddp", action="store_true")
2025-04-24 15:58:39 +08:00
parser.add_argument("--accumulation_steps", type=int, default=8) #梯度累积步数,用于控制梯度更新频率。
parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。
parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
2025-04-25 16:29:28 +08:00
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
2025-04-24 15:58:39 +08:00
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
2025-02-09 23:49:47 +08:00
args = parser.parse_args()
2025-04-24 15:58:39 +08:00
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe) #创建LMConfig对象用于控制模型配置。
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
os.makedirs(args.out_dir, exist_ok=True) #创建输出目录。
tokens_per_iter = args.batch_size * lm_config.max_seq_len #计算每个迭代步骤的token数量。
print(f"tokens_per_iter: {tokens_per_iter}")
device_type = "cuda" if "cuda" in args.device else "cpu" #确定设备类型。
2025-02-09 23:49:47 +08:00
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
2025-04-04 11:39:41 +08:00
base_seed = 1337
torch.manual_seed(base_seed)
torch.cuda.manual_seed(base_seed)
2025-02-09 23:49:47 +08:00
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
2025-04-04 11:39:41 +08:00
rank = dist.get_rank()
torch.manual_seed(base_seed + rank)
# 同时设置 CUDA 的随机种子
torch.cuda.manual_seed(base_seed + rank)
2025-02-09 23:49:47 +08:00
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
else:
wandb = None
model, tokenizer = init_model(lm_config)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)