This commit is contained in:
Gary 2025-05-07 22:03:41 +08:00
parent 0859f54a88
commit a5a39d8c9b
2 changed files with 87 additions and 63 deletions

View File

@ -173,31 +173,42 @@ class CrossAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.num_heads = 8
self.head_dim = 768 // self.num_heads
self.to_q = nn.Linear(768, 768, bias=False) self.to_q = nn.Linear(768, 768, bias=False)
self.to_k = nn.Linear(768, 768, bias=False) self.to_k = nn.Linear(768, 768, bias=False)
self.to_v = nn.Linear(768, 768, bias=False) self.to_v = nn.Linear(768, 768, bias=False)
self.to_out = nn.Linear(768, 768, bias=False)
def forward(self, x, db, context_mask=None, pos_emb=None): def forward(self, x, db, context_mask=None, pos_emb=None):
# db = db.permute(0, 2, 1) batch_size = x.size(0)
q = self.to_q(x) # 分离多头
k = self.to_k(db) q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db) k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if pos_emb is not None: if pos_emb is not None:
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
q = q + pos_emb q = q + pos_emb
k = k + pos_emb k = k + pos_emb
v = v + pos_emb v = v + pos_emb
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1)) attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if context_mask is not None: if context_mask is not None:
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10) expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = F.softmax(attn_scores, dim=-1)
context = torch.matmul(attn_weights, v) context = torch.matmul(attn_weights, v)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768)
context = self.to_out(context)
return context return context
class FeedForward(nn.Module): class FeedForward(nn.Module):

View File

@ -37,9 +37,10 @@ def get_lr(current_step, total_steps, lr):
def train_epoch(epoch, wandb): def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失Cross-Entropy Loss当 reduction='none' 时nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。 loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time() start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader): for step, (X, Y, loss_mask) in enumerate(train_loader):
try:
# 将数据加载到设备上 # 将数据加载到设备上
X = X.to(args.device) X = X.to(args.device)
Y = Y.to(args.device) Y = Y.to(args.device)
@ -51,27 +52,25 @@ def train_epoch(epoch, wandb):
param_group['lr'] = lr param_group['lr'] = lr
with ctx: with ctx:
res = model(X) #获取输出 res = model(X)
loss = loss_fct( loss = loss_fct(
res.logits.view(-1, res.logits.size(-1)), res.logits.view(-1, res.logits.size(-1)),
Y.view(-1) Y.view(-1)
).view(Y.size())#计算损失 ).view(Y.size())
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss loss = (loss * loss_mask).sum() / loss_mask.sum()
# 为了批次堆叠进行的处理真正的batch size为num gpu*batch size per gpu*accumulation steps
loss += res.aux_loss loss += res.aux_loss
loss = loss / args.accumulation_steps loss = loss / args.accumulation_steps
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16计算时出现数值不稳定的问题。 scaler.scale(loss).backward()
# 如果达到堆叠数目就进行处理
if (step + 1) % args.accumulation_steps == 0: if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。 scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度使其范数不超过args.grad_clip。 torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。 scaler.step(optimizer)
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。 scaler.update()
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。 optimizer.zero_grad(set_to_none=True)
# 打印日志 # 打印日志
if step % args.log_interval == 0: if step % args.log_interval == 0:
@ -105,6 +104,19 @@ def train_epoch(epoch, wandb):
torch.save(state_dict, ckp) #只保存参数 torch.save(state_dict, ckp) #只保存参数
model.train() model.train()
except Exception as e:
print(f"Error occurred: {str(e)}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"NaN gradient in parameter: {name}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f"Parameter {name} values: {param.data}")
print(f"Parameter {name} gradients: {param.grad}")
raise ValueError("NaN gradient detected")
def init_model(lm_config): def init_model(lm_config):
# 加载tokenizer # 加载tokenizer
@ -209,6 +221,7 @@ if __name__ == "__main__":
model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
torch.autograd.set_detect_anomaly(True)
iter_per_epoch = len(train_loader) iter_per_epoch = len(train_loader)
for epoch in range(args.epochs): for epoch in range(args.epochs):
train_epoch(epoch, wandb) train_epoch(epoch, wandb)