This commit is contained in:
Gary 2025-05-07 22:03:41 +08:00
parent 0859f54a88
commit a5a39d8c9b
2 changed files with 87 additions and 63 deletions

View File

@ -173,31 +173,42 @@ class CrossAttention(nn.Module):
):
super().__init__()
self.config = config
self.num_heads = 8
self.head_dim = 768 // self.num_heads
self.to_q = nn.Linear(768, 768, bias=False)
self.to_k = nn.Linear(768, 768, bias=False)
self.to_v = nn.Linear(768, 768, bias=False)
self.to_out = nn.Linear(768, 768, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None):
# db = db.permute(0, 2, 1)
q = self.to_q(x)
k = self.to_k(db)
v = self.to_v(db)
batch_size = x.size(0)
# 分离多头
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb
k = k + pos_emb
v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None:
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10)
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768)
context = self.to_out(context)
return context
class FeedForward(nn.Module):

View File

@ -37,73 +37,85 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失Cross-Entropy Loss当 reduction='none' 时nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
# 将数据加载到设备上
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
try:
# 将数据加载到设备上
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 更新学习率
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
with ctx:
res = model(X) #获取输出
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())#计算损失
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss
# 为了批次堆叠进行的处理真正的batch size为num gpu*batch size per gpu*accumulation steps
loss += res.aux_loss
loss = loss / args.accumulation_steps
with ctx:
res = model(X)
loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)),
Y.view(-1)
).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum()
loss += res.aux_loss
loss = loss / args.accumulation_steps
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。
scaler.scale(loss).backward()
# 如果达到堆叠数目就进行处理
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。
optimizer.zero_grad(set_to_none=True)
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
# 打印日志
if step % args.log_interval == 0:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,
iter_per_epoch,
loss.item() * args.accumulation_steps,
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
# 保存模型
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
model.eval()
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict() #获取模型参数
else:
state_dict = model.state_dict() #获取模型参数
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict() #获取模型参数
else:
state_dict = model.state_dict() #获取模型参数
torch.save(state_dict, ckp) #只保存参数
model.train()
torch.save(state_dict, ckp) #只保存参数
model.train()
except Exception as e:
print(f"Error occurred: {str(e)}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
def init_model(lm_config):
@ -208,7 +220,8 @@ if __name__ == "__main__":
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)