Minimind/train_pretrain_accelerate.py

347 lines
14 KiB
Python
Raw Normal View History

import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
import time
import math
import warnings
import pandas as pd
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
from contextlib import nullcontext
from typing import Optional
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
warnings.filterwarnings('ignore')
# 日志记录函数
def Logger(msg, accelerator=None):
# 如果没有提供accelerator则只在主进程打印
if accelerator is None or accelerator.is_main_process:
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {msg}")
# 获取学习率函数
def get_lr(it, num_iters, learning_rate):
# 余弦学习率衰减
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
pretrained_embeddings = torch.load(pretrained_embedding_path)
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
# 在函数开始处定义moe_path避免在异常处理中引用未定义变量
moe_path = '_moe' if args.use_moe else ''
# 添加CUDA事件来分析性能
if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据
prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader)
prefetch_batches = []
# 预取初始批次
for _ in range(min(prefetch_factor, len(train_loader))):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
break
for step in range(len(train_loader)):
try:
# 计时数据加载
if args.profile and accelerator.is_main_process:
data_start.record()
# 使用预取的数据
if prefetch_batches:
X, Y, loss_mask = prefetch_batches.pop(0)
else:
# 如果预取队列为空,直接加载
X, Y, loss_mask = next(data_iter)
# 异步预取下一批数据
if step + prefetch_factor < len(train_loader):
try:
batch = next(data_iter)
prefetch_batches.append(batch)
except StopIteration:
pass
if args.profile and accelerator.is_main_process:
data_end.record()
# 更新学习率
if scheduler is not None:
scheduler.step()
# 计时前向传播
if args.profile and accelerator.is_main_process:
forward_start.record()
# 前向传播
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
# 添加辅助损失,如果存在的话
try:
aux_loss = sum(l.feed_forward.aux_loss for l in model.module.layers
if hasattr(l.feed_forward, 'aux_loss'))
loss += aux_loss
except Exception as e:
Logger(f"Warning: Could not add auxiliary loss: {e}")
# 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps
if args.profile and accelerator.is_main_process:
forward_end.record()
# 计时反向传播
if args.profile and accelerator.is_main_process:
backward_start.record()
# 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss)
if args.profile and accelerator.is_main_process:
backward_end.record()
# 计时优化器步骤
if args.profile and accelerator.is_main_process:
optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤
# 注意当使用DeepSpeed时它会自动处理梯度累积所以我们不需要检查step % accumulation_steps
optimizer.step()
# 当使用DeepSpeed时zero_grad()会在step()之后自动调用
# 但为了安全起见,我们仍然显式调用它
optimizer.zero_grad()
if args.profile and accelerator.is_main_process:
optimizer_end.record()
# 打印训练信息
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
# 计算性能指标
if args.profile:
torch.cuda.synchronize()
data_time = data_start.elapsed_time(data_end) if step > 0 else 0
forward_time = forward_start.elapsed_time(forward_end)
backward_time = backward_start.elapsed_time(backward_end)
optimizer_time = optimizer_start.elapsed_time(optimizer_end) if (step + 1) % args.accumulation_steps == 0 else 0
total_time = data_time + forward_time + backward_time + optimizer_time
# 打印性能分析
if (step + 1) % (args.log_interval * args.profile_interval) == 0:
Logger(f"性能分析 - 数据加载: {data_time:.2f}ms ({data_time/total_time*100:.1f}%), "
f"前向传播: {forward_time:.2f}ms ({forward_time/total_time*100:.1f}%), "
f"反向传播: {backward_time:.2f}ms ({backward_time/total_time*100:.1f}%), "
f"优化器: {optimizer_time:.2f}ms ({optimizer_time/total_time*100:.1f}%)", accelerator)
# 计算当前学习率
current_lr = optimizer.param_groups[0]['lr']
# 计算训练速度
elapsed_time = time.time() - start_time
tokens_per_sec = (step + 1) * args.batch_size * args.max_seq_len / elapsed_time
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{len(train_loader)}, "
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
f"LR: {current_lr:.6f}, "
f"Speed: {tokens_per_sec:.2f} tokens/sec", accelerator)
# 保存模型
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
# 获取解包后的模型
unwrapped_model = accelerator.unwrap_model(model)
# 保存模型参数
accelerator.save(unwrapped_model.state_dict(), ckp)
Logger(f"Model saved to {ckp}", accelerator)
except Exception as e:
Logger(f"Error in training step: {e}", accelerator)
import traceback
Logger(traceback.format_exc(), accelerator)
def main():
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
parser.add_argument("--out_dir", type=str, default="out")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--learning_rate", type=float, default=2e-4)
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=48)
parser.add_argument("--accumulation_steps", type=int, default=32)
parser.add_argument("--grad_clip", type=float, default=1.0)
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=10000)
parser.add_argument('--dim', default=1024, type=int)
parser.add_argument('--n_layers', default=32, type=int)
parser.add_argument('--max_seq_len', default=1024, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能使用固定值1e-4替代")
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
args = parser.parse_args()
# 初始化accelerator
# 设置ddp_kwargs以处理未使用的参数
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# 创建DeepSpeedPlugin对象
ds_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=args.accumulation_steps,
gradient_clipping=args.grad_clip,
zero_stage=2, # 使用ZeRO-2优化
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
offload_param_device="none", # 不将参数卸载到CPU
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs],
deepspeed_plugin=ds_plugin,
mixed_precision="bf16" if args.dtype == "bfloat16" else "fp16" if args.dtype == "float16" else "no"
)
# 设置随机种子
set_seed(1337 + accelerator.process_index)
# 配置模型
lm_config = LMConfig(
dim=args.dim,
n_layers=args.n_layers,
max_seq_len=args.max_seq_len,
use_moe=args.use_moe,
disable_db=args.disable_db,
flash_attn=args.use_flash_attn
)
# 创建保存目录
args.save_dir = os.path.join(args.out_dir)
if accelerator.is_main_process:
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
# 计算每次迭代的token数量
tokens_per_iter = args.batch_size * lm_config.max_seq_len
Logger(f"tokens_per_iter: {tokens_per_iter}", accelerator)
# 设置数据类型
pt_dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
# 设置wandb运行名称
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
# 设置自动混合精度上下文
ctx = nullcontext() if accelerator.device.type == "cpu" else torch.cuda.amp.autocast(dtype=pt_dtype)
# 初始化模型和tokenizer
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
# 处理pos_cis复数张量问题
# 方法1将pos_cis转换为实数张量两个实数张量表示实部和虚部
# 这里我们采用方法2告诉accelerate忽略pos_cis
# 在DeepSpeed模式下我们需要设置DeepSpeed的参数
if hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
# 创建数据集和数据加载器
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=True if args.num_workers > 0 else False,
prefetch_factor=2 if args.num_workers > 0 else None
)
# 创建优化器
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
# 创建学习率调度器
total_steps = len(train_loader) * args.epochs
warmup_steps = args.warmup_iters if args.warmup_iters > 0 else int(0.1 * total_steps)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# 准备训练
model, optimizer, train_loader, scheduler = accelerator.prepare(
model, optimizer, train_loader, scheduler
)
# 初始化wandb
if args.use_wandb and accelerator.is_main_process:
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
else:
wandb = None
# 训练循环
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx)
# 关闭wandb
if args.use_wandb and accelerator.is_main_process:
wandb.finish()
if __name__ == "__main__":
main()