This commit is contained in:
Jax922 2025-05-12 19:11:04 +08:00
parent 803d1f1b72
commit 83f5cfe6ca

View File

@ -13,6 +13,7 @@ from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
# 移除通信分析工具导入
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
@ -54,6 +55,8 @@ def train_epoch(epoch, wandb):
optimizer_start = torch.cuda.Event(enable_timing=True) optimizer_start = torch.cuda.Event(enable_timing=True)
optimizer_end = torch.cuda.Event(enable_timing=True) optimizer_end = torch.cuda.Event(enable_timing=True)
# 移除CUDA图优化代码
# 预取数据 # 预取数据
prefetch_factor = 2 # 预取的批次数 prefetch_factor = 2 # 预取的批次数
data_iter = iter(train_loader) data_iter = iter(train_loader)
@ -100,6 +103,7 @@ def train_epoch(epoch, wandb):
if args.profile and (not ddp or dist.get_rank() == 0): if args.profile and (not ddp or dist.get_rank() == 0):
forward_start.record() forward_start.record()
# 常规前向传播
with ctx: with ctx:
res = model(X) res = model(X)
loss = loss_fct( loss = loss_fct(
@ -123,6 +127,9 @@ def train_epoch(epoch, wandb):
# 如果出错,不添加辅助损失 # 如果出错,不添加辅助损失
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
# 反向传播
scaler.scale(loss).backward()
if args.profile and (not ddp or dist.get_rank() == 0): if args.profile and (not ddp or dist.get_rank() == 0):
forward_end.record() forward_end.record()
backward_start.record() backward_start.record()
@ -139,9 +146,6 @@ def train_epoch(epoch, wandb):
Logger(f"loss.dtype: {loss.dtype}") Logger(f"loss.dtype: {loss.dtype}")
Logger("-------------------------") Logger("-------------------------")
# 反向传播
scaler.scale(loss).backward()
if args.profile and (not ddp or dist.get_rank() == 0): if args.profile and (not ddp or dist.get_rank() == 0):
backward_end.record() backward_end.record()
@ -245,6 +249,8 @@ def train_epoch(epoch, wandb):
wandb.log(log_dict) wandb.log(log_dict)
# 移除通信分析代码
# 保存模型 # 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval() model.eval()
@ -303,6 +309,9 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
return model, tokenizer return model, tokenizer
# 移除通信分析函数
def init_distributed_mode(): def init_distributed_mode():
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。 if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。 global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
@ -344,7 +353,8 @@ if __name__ == "__main__":
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
# 性能分析相关参数 # 性能分析相关参数
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
parser.add_argument("--profile_interval", type=int, default=100, help="性能分析打印间隔(步数)") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
@ -354,7 +364,8 @@ if __name__ == "__main__":
n_layers=args.n_layers, n_layers=args.n_layers,
max_seq_len=args.max_seq_len, max_seq_len=args.max_seq_len,
use_moe=args.use_moe, use_moe=args.use_moe,
disable_db=args.disable_db # 添加禁用数据库参数 disable_db=args.disable_db, # 添加禁用数据库参数
flash_attn=args.use_flash_attn # 添加FlashAttention支持
) #创建LMConfig对象用于控制模型配置。 ) #创建LMConfig对象用于控制模型配置。
args.save_dir = os.path.join(args.out_dir) #创建保存目录。 args.save_dir = os.path.join(args.out_dir) #创建保存目录。
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。 os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。