'update'
This commit is contained in:
parent
0859f54a88
commit
a5a39d8c9b
@ -173,31 +173,42 @@ class CrossAttention(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = 8
|
||||
self.head_dim = 768 // self.num_heads
|
||||
self.to_q = nn.Linear(768, 768, bias=False)
|
||||
self.to_k = nn.Linear(768, 768, bias=False)
|
||||
self.to_v = nn.Linear(768, 768, bias=False)
|
||||
|
||||
|
||||
self.to_out = nn.Linear(768, 768, bias=False)
|
||||
|
||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||
# db = db.permute(0, 2, 1)
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(db)
|
||||
v = self.to_v(db)
|
||||
batch_size = x.size(0)
|
||||
|
||||
# 分离多头
|
||||
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if pos_emb is not None:
|
||||
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q = q + pos_emb
|
||||
k = k + pos_emb
|
||||
v = v + pos_emb
|
||||
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
|
||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||
|
||||
if context_mask is not None:
|
||||
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10)
|
||||
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768)
|
||||
|
||||
context = self.to_out(context)
|
||||
|
||||
return context
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
@ -37,73 +37,85 @@ def get_lr(current_step, total_steps, lr):
|
||||
|
||||
|
||||
def train_epoch(epoch, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失(Cross-Entropy Loss);当 reduction='none' 时,nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
start_time = time.time()
|
||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||
# 将数据加载到设备上
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
try:
|
||||
# 将数据加载到设备上
|
||||
X = X.to(args.device)
|
||||
Y = Y.to(args.device)
|
||||
loss_mask = loss_mask.to(args.device)
|
||||
|
||||
# 更新学习率
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
# 更新学习率
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
with ctx:
|
||||
res = model(X) #获取输出
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())#计算损失
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss
|
||||
# 为了批次堆叠进行的处理,真正的batch size为num gpu*batch size per gpu*accumulation steps
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
with ctx:
|
||||
res = model(X)
|
||||
loss = loss_fct(
|
||||
res.logits.view(-1, res.logits.size(-1)),
|
||||
Y.view(-1)
|
||||
).view(Y.size())
|
||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||
loss += res.aux_loss
|
||||
loss = loss / args.accumulation_steps
|
||||
|
||||
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16)计算时出现数值不稳定的问题。
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
# 如果达到堆叠数目就进行处理
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度,使其范数不超过args.grad_clip。
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
|
||||
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# 打印日志
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
# 打印日志
|
||||
if step % args.log_interval == 0:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
iter_per_epoch,
|
||||
loss.item() * args.accumulation_steps,
|
||||
optimizer.param_groups[-1]['lr'],
|
||||
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
|
||||
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
|
||||
wandb.log({"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": optimizer.param_groups[-1]['lr'],
|
||||
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
|
||||
|
||||
# 保存模型
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
# 保存模型
|
||||
if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
|
||||
model.eval()
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict() #获取模型参数
|
||||
else:
|
||||
state_dict = model.state_dict() #获取模型参数
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict() #获取模型参数
|
||||
else:
|
||||
state_dict = model.state_dict() #获取模型参数
|
||||
|
||||
torch.save(state_dict, ckp) #只保存参数
|
||||
model.train()
|
||||
torch.save(state_dict, ckp) #只保存参数
|
||||
model.train()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {str(e)}")
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"NaN gradient in parameter: {name}")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and torch.isnan(param.grad).any():
|
||||
print(f"Parameter {name} values: {param.data}")
|
||||
print(f"Parameter {name} gradients: {param.grad}")
|
||||
|
||||
raise ValueError("NaN gradient detected")
|
||||
|
||||
|
||||
def init_model(lm_config):
|
||||
@ -208,7 +220,8 @@ if __name__ == "__main__":
|
||||
if ddp:
|
||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||
|
||||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
iter_per_epoch = len(train_loader)
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, wandb)
|
||||
|
Loading…
x
Reference in New Issue
Block a user