'update'
This commit is contained in:
parent
0859f54a88
commit
a5a39d8c9b
@ -173,31 +173,42 @@ class CrossAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.num_heads = 8
|
||||||
|
self.head_dim = 768 // self.num_heads
|
||||||
self.to_q = nn.Linear(768, 768, bias=False)
|
self.to_q = nn.Linear(768, 768, bias=False)
|
||||||
self.to_k = nn.Linear(768, 768, bias=False)
|
self.to_k = nn.Linear(768, 768, bias=False)
|
||||||
self.to_v = nn.Linear(768, 768, bias=False)
|
self.to_v = nn.Linear(768, 768, bias=False)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(768, 768, bias=False)
|
||||||
|
|
||||||
def forward(self, x, db, context_mask=None, pos_emb=None):
|
def forward(self, x, db, context_mask=None, pos_emb=None):
|
||||||
# db = db.permute(0, 2, 1)
|
batch_size = x.size(0)
|
||||||
|
|
||||||
q = self.to_q(x)
|
# 分离多头
|
||||||
k = self.to_k(db)
|
q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
v = self.to_v(db)
|
k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
if pos_emb is not None:
|
if pos_emb is not None:
|
||||||
|
pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
q = q + pos_emb
|
q = q + pos_emb
|
||||||
k = k + pos_emb
|
k = k + pos_emb
|
||||||
v = v + pos_emb
|
v = v + pos_emb
|
||||||
|
|
||||||
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(k.size(-1))
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if context_mask is not None:
|
if context_mask is not None:
|
||||||
attn_scores = attn_scores.masked_fill(context_mask == 0, -1e10)
|
expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
||||||
|
attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10)
|
||||||
|
|
||||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
context = torch.matmul(attn_weights, v)
|
context = torch.matmul(attn_weights, v)
|
||||||
|
|
||||||
|
context = context.transpose(1, 2).contiguous().view(batch_size, -1, 768)
|
||||||
|
|
||||||
|
context = self.to_out(context)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
|
@ -37,9 +37,10 @@ def get_lr(current_step, total_steps, lr):
|
|||||||
|
|
||||||
|
|
||||||
def train_epoch(epoch, wandb):
|
def train_epoch(epoch, wandb):
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none') #交叉熵损失(Cross-Entropy Loss);当 reduction='none' 时,nn.CrossEntropyLoss 不会对损失进行任何汇总操作,而是返回每个样本的单独损失值。
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
for step, (X, Y, loss_mask) in enumerate(train_loader):
|
||||||
|
try:
|
||||||
# 将数据加载到设备上
|
# 将数据加载到设备上
|
||||||
X = X.to(args.device)
|
X = X.to(args.device)
|
||||||
Y = Y.to(args.device)
|
Y = Y.to(args.device)
|
||||||
@ -51,27 +52,25 @@ def train_epoch(epoch, wandb):
|
|||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
with ctx:
|
with ctx:
|
||||||
res = model(X) #获取输出
|
res = model(X)
|
||||||
loss = loss_fct(
|
loss = loss_fct(
|
||||||
res.logits.view(-1, res.logits.size(-1)),
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
).view(Y.size())#计算损失
|
).view(Y.size())
|
||||||
loss = (loss * loss_mask).sum() / loss_mask.sum() #计算总的loss
|
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
# 为了批次堆叠进行的处理,真正的batch size为num gpu*batch size per gpu*accumulation steps
|
|
||||||
loss += res.aux_loss
|
loss += res.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward() #用于处理混合精度训练。它的作用是自动缩放损失值,以防止在使用低精度(如 FP16)计算时出现数值不稳定的问题。
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# 如果达到堆叠数目就进行处理
|
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
scaler.unscale_(optimizer) #PyTorch 自动混合精度(AMP)训练的一部分。它"反缩放"之前为防止在混合精度训练中出现下溢而缩放的梯度。
|
scaler.unscale_(optimizer)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #应用梯度裁剪以防止梯度爆炸。它会缩放梯度,使其范数不超过args.grad_clip。
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||||
|
|
||||||
scaler.step(optimizer) #使用优化器更新模型权重,但由缩放器控制以适应混合精度训练。
|
scaler.step(optimizer)
|
||||||
scaler.update() #根据本次迭代是否有梯度溢出来更新下一次迭代的缩放因子。
|
scaler.update()
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True) #为下一次迭代清零所有梯度。set_to_none=True参数通过将梯度设置为None而不是零来提高内存效率。
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# 打印日志
|
# 打印日志
|
||||||
if step % args.log_interval == 0:
|
if step % args.log_interval == 0:
|
||||||
@ -105,6 +104,19 @@ def train_epoch(epoch, wandb):
|
|||||||
torch.save(state_dict, ckp) #只保存参数
|
torch.save(state_dict, ckp) #只保存参数
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {str(e)}")
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.grad is not None and torch.isnan(param.grad).any():
|
||||||
|
print(f"NaN gradient in parameter: {name}")
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.grad is not None and torch.isnan(param.grad).any():
|
||||||
|
print(f"Parameter {name} values: {param.data}")
|
||||||
|
print(f"Parameter {name} gradients: {param.grad}")
|
||||||
|
|
||||||
|
raise ValueError("NaN gradient detected")
|
||||||
|
|
||||||
|
|
||||||
def init_model(lm_config):
|
def init_model(lm_config):
|
||||||
# 加载tokenizer
|
# 加载tokenizer
|
||||||
@ -209,6 +221,7 @@ if __name__ == "__main__":
|
|||||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||||
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
|
||||||
|
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
iter_per_epoch = len(train_loader)
|
iter_per_epoch = len(train_loader)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
train_epoch(epoch, wandb)
|
train_epoch(epoch, wandb)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user