From 83f5cfe6ca6b9d4b4df5486b3ff0af36f9ec9afc Mon Sep 17 00:00:00 2001 From: Jax922 <1322037892@qq.com> Date: Mon, 12 May 2025 19:11:04 +0800 Subject: [PATCH] update --- train_pretrain.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/train_pretrain.py b/train_pretrain.py index 4eeabbd..7397984 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -13,6 +13,7 @@ from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader, DistributedSampler +# 移除通信分析工具导入 from contextlib import nullcontext from typing import Optional @@ -54,6 +55,8 @@ def train_epoch(epoch, wandb): optimizer_start = torch.cuda.Event(enable_timing=True) optimizer_end = torch.cuda.Event(enable_timing=True) + # 移除CUDA图优化代码 + # 预取数据 prefetch_factor = 2 # 预取的批次数 data_iter = iter(train_loader) @@ -100,6 +103,7 @@ def train_epoch(epoch, wandb): if args.profile and (not ddp or dist.get_rank() == 0): forward_start.record() + # 常规前向传播 with ctx: res = model(X) loss = loss_fct( @@ -123,6 +127,9 @@ def train_epoch(epoch, wandb): # 如果出错,不添加辅助损失 loss = loss / args.accumulation_steps + # 反向传播 + scaler.scale(loss).backward() + if args.profile and (not ddp or dist.get_rank() == 0): forward_end.record() backward_start.record() @@ -139,9 +146,6 @@ def train_epoch(epoch, wandb): Logger(f"loss.dtype: {loss.dtype}") Logger("-------------------------") - # 反向传播 - scaler.scale(loss).backward() - if args.profile and (not ddp or dist.get_rank() == 0): backward_end.record() @@ -245,6 +249,8 @@ def train_epoch(epoch, wandb): wandb.log(log_dict) + # 移除通信分析代码 + # 保存模型 if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): model.eval() @@ -303,6 +309,9 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None): return model, tokenizer +# 移除通信分析函数 + + def init_distributed_mode(): if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。 global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。 @@ -344,7 +353,8 @@ if __name__ == "__main__": parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") # 性能分析相关参数 parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") - parser.add_argument("--profile_interval", type=int, default=100, help="性能分析打印间隔(步数)") + parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") + parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") args = parser.parse_args() print(args) @@ -354,7 +364,8 @@ if __name__ == "__main__": n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe, - disable_db=args.disable_db # 添加禁用数据库参数 + disable_db=args.disable_db, # 添加禁用数据库参数 + flash_attn=args.use_flash_attn # 添加FlashAttention支持 ) #创建LMConfig对象,用于控制模型配置。 args.save_dir = os.path.join(args.out_dir) #创建保存目录。 os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。