This commit is contained in:
Yu Chengzhang 2025-06-23 23:05:47 +08:00
parent 44cd7b4d72
commit 5f19adcffa

View File

@ -228,15 +228,15 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
best_loss = float('10000') best_loss = float('10000')
# 添加CUDA事件来分析性能 (只在主进程进行) # 添加CUDA事件来分析性能 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
data_start = torch.cuda.Event(enable_timing=True) # data_start = torch.cuda.Event(enable_timing=True)
data_end = torch.cuda.Event(enable_timing=True) # data_end = torch.cuda.Event(enable_timing=True)
forward_start = torch.cuda.Event(enable_timing=True) # forward_start = torch.cuda.Event(enable_timing=True)
forward_end = torch.cuda.Event(enable_timing=True) # forward_end = torch.cuda.Event(enable_timing=True)
backward_start = torch.cuda.Event(enable_timing=True) # backward_start = torch.cuda.Event(enable_timing=True)
backward_end = torch.cuda.Event(enable_timing=True) # backward_end = torch.cuda.Event(enable_timing=True)
optimizer_start = torch.cuda.Event(enable_timing=True) # optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True) # optimizer_end = torch.cuda.Event(enable_timing=True)
# 预取数据 # 预取数据
prefetch_factor = 2 # 预取的批次数 prefetch_factor = 2 # 预取的批次数
@ -257,8 +257,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
for step in range(total_steps_in_epoch): for step in range(total_steps_in_epoch):
try: try:
# 计时数据加载 (只在主进程进行) # 计时数据加载 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
data_start.record() # data_start.record()
# 使用预取的数据 # 使用预取的数据
if prefetch_batches: if prefetch_batches:
@ -276,16 +276,16 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
pass pass
# 计时数据加载结束 (只在主进程进行) # 计时数据加载结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
data_end.record() # data_end.record()
# 更新学习率 # 更新学习率
if scheduler is not None: if scheduler is not None:
scheduler.step() scheduler.step()
# 计时前向传播 (只在主进程进行) # 计时前向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
forward_start.record() # forward_start.record()
# 前向传播 # 前向传播
with ctx: with ctx:
@ -311,24 +311,24 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
# 计时前向传播结束 (只在主进程进行) # 计时前向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
forward_end.record() # forward_end.record()
# 计时反向传播 (只在主进程进行) # 计时反向传播 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
backward_start.record() # backward_start.record()
# 反向传播 # 反向传播
# 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪 # 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
accelerator.backward(loss) accelerator.backward(loss)
# 计时反向传播结束 (只在主进程进行) # 计时反向传播结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
backward_end.record() # backward_end.record()
# 计时优化器步骤 (只在主进程进行) # 计时优化器步骤 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
optimizer_start.record() # optimizer_start.record()
# 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪 # 优化器步骤 - 当使用DeepSpeed时它会自动处理梯度累积和梯度裁剪
# 只有在达到累积步数时才会执行优化器步骤 # 只有在达到累积步数时才会执行优化器步骤
@ -340,8 +340,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
optimizer.zero_grad() optimizer.zero_grad()
# 计时优化器步骤结束 (只在主进程进行) # 计时优化器步骤结束 (只在主进程进行)
if args.profile and accelerator.is_main_process: # if args.profile and accelerator.is_main_process:
optimizer_end.record() # optimizer_end.record()
# 打印训练信息 (只在主进程进行) # 打印训练信息 (只在主进程进行)
if (step + 1) % args.log_interval == 0 and accelerator.is_main_process: if (step + 1) % args.log_interval == 0 and accelerator.is_main_process:
@ -419,7 +419,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
# 保存模型 (只在主进程进行) # 保存模型 (只在主进程进行)
loss_total = loss.item() * args.accumulation_steps loss_total = loss.item() * args.accumulation_steps
if best_loss > loss_total and accelerator.is_main_process: if epoch > 1 or best_loss > loss_total and accelerator.is_main_process:
best_loss = loss_total best_loss = loss_total
# 使用函数开始处定义的moe_path变量 # 使用函数开始处定义的moe_path变量
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth' ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'