update
This commit is contained in:
parent
803d1f1b72
commit
83f5cfe6ca
@ -13,6 +13,7 @@ from torch import optim, nn
|
|||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
# 移除通信分析工具导入
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -54,6 +55,8 @@ def train_epoch(epoch, wandb):
|
|||||||
optimizer_start = torch.cuda.Event(enable_timing=True)
|
optimizer_start = torch.cuda.Event(enable_timing=True)
|
||||||
optimizer_end = torch.cuda.Event(enable_timing=True)
|
optimizer_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
# 移除CUDA图优化代码
|
||||||
|
|
||||||
# 预取数据
|
# 预取数据
|
||||||
prefetch_factor = 2 # 预取的批次数
|
prefetch_factor = 2 # 预取的批次数
|
||||||
data_iter = iter(train_loader)
|
data_iter = iter(train_loader)
|
||||||
@ -100,6 +103,7 @@ def train_epoch(epoch, wandb):
|
|||||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
forward_start.record()
|
forward_start.record()
|
||||||
|
|
||||||
|
# 常规前向传播
|
||||||
with ctx:
|
with ctx:
|
||||||
res = model(X)
|
res = model(X)
|
||||||
loss = loss_fct(
|
loss = loss_fct(
|
||||||
@ -123,6 +127,9 @@ def train_epoch(epoch, wandb):
|
|||||||
# 如果出错,不添加辅助损失
|
# 如果出错,不添加辅助损失
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
|
# 反向传播
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
forward_end.record()
|
forward_end.record()
|
||||||
backward_start.record()
|
backward_start.record()
|
||||||
@ -139,9 +146,6 @@ def train_epoch(epoch, wandb):
|
|||||||
Logger(f"loss.dtype: {loss.dtype}")
|
Logger(f"loss.dtype: {loss.dtype}")
|
||||||
Logger("-------------------------")
|
Logger("-------------------------")
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
scaler.scale(loss).backward()
|
|
||||||
|
|
||||||
if args.profile and (not ddp or dist.get_rank() == 0):
|
if args.profile and (not ddp or dist.get_rank() == 0):
|
||||||
backward_end.record()
|
backward_end.record()
|
||||||
|
|
||||||
@ -245,6 +249,8 @@ def train_epoch(epoch, wandb):
|
|||||||
|
|
||||||
wandb.log(log_dict)
|
wandb.log(log_dict)
|
||||||
|
|
||||||
|
# 移除通信分析代码
|
||||||
|
|
||||||
# 保存模型
|
# 保存模型
|
||||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -303,6 +309,9 @@ def init_model(lm_config, pretrained_embedding_path: Optional[str] = None):
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# 移除通信分析函数
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_mode():
|
def init_distributed_mode():
|
||||||
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
if not ddp: return #如果没有启用分布式数据并行(DDP),直接返回,不执行任何操作。
|
||||||
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
global ddp_local_rank, DEVICE #声明这两个变量为全局变量,以便在函数外部也能访问它们。
|
||||||
@ -344,7 +353,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||||
# 性能分析相关参数
|
# 性能分析相关参数
|
||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=100, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
@ -354,7 +364,8 @@ if __name__ == "__main__":
|
|||||||
n_layers=args.n_layers,
|
n_layers=args.n_layers,
|
||||||
max_seq_len=args.max_seq_len,
|
max_seq_len=args.max_seq_len,
|
||||||
use_moe=args.use_moe,
|
use_moe=args.use_moe,
|
||||||
disable_db=args.disable_db # 添加禁用数据库参数
|
disable_db=args.disable_db, # 添加禁用数据库参数
|
||||||
|
flash_attn=args.use_flash_attn # 添加FlashAttention支持
|
||||||
) #创建LMConfig对象,用于控制模型配置。
|
) #创建LMConfig对象,用于控制模型配置。
|
||||||
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
args.save_dir = os.path.join(args.out_dir) #创建保存目录。
|
||||||
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
os.makedirs(args.save_dir, exist_ok=True) #创建保存目录。
|
||||||
|
Loading…
x
Reference in New Issue
Block a user