From a5a39d8c9b894755c4d9a39e7c4c1660257d244c Mon Sep 17 00:00:00 2001 From: Gary <1601978618@qq.com> Date: Wed, 7 May 2025 22:03:41 +0800 Subject: [PATCH] 'update' --- model/model.py | 27 +++++++--- train_pretrain.py | 123 +++++++++++++++++++++++++--------------------- 2 files changed, 87 insertions(+), 63 deletions(-) diff --git a/model/model.py b/model/model.py index 182f58b..3dc48b2 100644 --- a/model/model.py +++ b/model/model.py @@ -173,31 +173,42 @@ class CrossAttention(nn.Module): ): super().__init__() self.config = config + self.num_heads = 8 + self.head_dim = 768 // self.num_heads self.to_q = nn.Linear(768, 768, bias=False) self.to_k = nn.Linear(768, 768, bias=False) self.to_v = nn.Linear(768, 768, bias=False) - + + self.to_out = nn.Linear(768, 768, bias=False) def forward(self, x, db, context_mask=None, pos_emb=None): - # db = db.permute(0, 2, 1) - - q = self.to_q(x) - k = self.to_k(db) - v = self.to_v(db) + batch_size = x.size(0) + + # 分离多头 + q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) if pos_emb is not None: + pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) q = q + pos_emb k = k + pos_emb v = v + pos_emb - attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1)) + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if context_mask is not None: - attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10) + expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) attn_weights = F.softmax(attn_scores, dim=-1) + context = torch.matmul(attn_weights, v) + context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768) + + context = self.to_out(context) + return context class FeedForward(nn.Module): diff --git a/train_pretrain.py b/train_pretrain.py index b286b1e..d619933 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -37,73 +37,85 @@ def get_lr(current_step, total_steps, lr): def train_epoch(epoch, wandb): - loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失(Cross-Entropy Loss);当 reduction='none' 时,nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。 + loss_fct = nn.CrossEntropyLoss(reduction='none') start_time = time.time() for step, (X, Y, loss_mask) in enumerate(train_loader): - # 将数据加载到设备上 - X = X.to(args.device) - Y = Y.to(args.device) - loss_mask = loss_mask.to(args.device) + try: + # 将数据加载到设备上 + X = X.to(args.device) + Y = Y.to(args.device) + loss_mask = loss_mask.to(args.device) - # 更新学习率 - lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) - for param_group in optimizer.param_groups: - param_group['lr'] = lr + # 更新学习率 + lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) + for param_group in optimizer.param_groups: + param_group['lr'] = lr - with ctx: - res = model(X) #获取输出 - loss = loss_fct( - res.logits.view(-1, res.logits.size(-1)), - Y.view(-1) - ).view(Y.size())#计算损失 - loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss - # 为了批次堆叠进行的处理,真正的batch size为num gpu*batch size per gpu*accumulation steps - loss += res.aux_loss - loss = loss / args.accumulation_steps + with ctx: + res = model(X) + loss = loss_fct( + res.logits.view(-1, res.logits.size(-1)), + Y.view(-1) + ).view(Y.size()) + loss = (loss * loss_mask).sum() / loss_mask.sum() + loss += res.aux_loss + loss = loss / args.accumulation_steps - scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16)计算时出现数值不稳定的问题。 + scaler.scale(loss).backward() - # 如果达到堆叠数目就进行处理 - if (step + 1) % args.accumulation_steps == 0: - scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。 - torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度,使其范数不超过args.grad_clip。 + if (step + 1) % args.accumulation_steps == 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。 - scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。 + scaler.step(optimizer) + scaler.update() - optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。 + optimizer.zero_grad(set_to_none=True) - # 打印日志 - if step % args.log_interval == 0: - spend_time = time.time() - start_time - Logger( - 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format( - epoch + 1, - args.epochs, - step, - iter_per_epoch, - loss.item() * args.accumulation_steps, - optimizer.param_groups[-1]['lr'], - spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) + # 打印日志 + if step % args.log_interval == 0: + spend_time = time.time() - start_time + Logger( + 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format( + epoch + 1, + args.epochs, + step, + iter_per_epoch, + loss.item() * args.accumulation_steps, + optimizer.param_groups[-1]['lr'], + spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) - if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss.item() * args.accumulation_steps, - "lr": optimizer.param_groups[-1]['lr'], - "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) + if (wandb is not None) and (not ddp or dist.get_rank() == 0): + wandb.log({"loss": loss.item() * args.accumulation_steps, + "lr": optimizer.param_groups[-1]['lr'], + "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) - # 保存模型 - if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): - model.eval() - moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' + # 保存模型 + if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): + model.eval() + moe_path = '_moe' if lm_config.use_moe else '' + ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth' - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - state_dict = model.module.state_dict() #获取模型参数 - else: - state_dict = model.state_dict() #获取模型参数 + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() #获取模型参数 + else: + state_dict = model.state_dict() #获取模型参数 - torch.save(state_dict, ckp) #只保存参数 - model.train() + torch.save(state_dict, ckp) #只保存参数 + model.train() + + except Exception as e: + print(f"Error occurred: {str(e)}") + for name, param in model.named_parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + print(f"NaN gradient in parameter: {name}") + + for name, param in model.named_parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + print(f"Parameter {name} values: {param.data}") + print(f"Parameter {name} gradients: {param.grad}") + + raise ValueError("NaN gradient detected") def init_model(lm_config): @@ -208,7 +220,8 @@ if __name__ == "__main__": if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) - + + torch.autograd.set_detect_anomaly(True) iter_per_epoch = len(train_loader) for epoch in range(args.epochs): train_epoch(epoch, wandb)