修复了loss为nan的错误

This commit is contained in:
iomgaa 2025-05-11 15:45:17 +00:00
parent 0c8c6e5d1a
commit eba4311ac5
4 changed files with 1372 additions and 14 deletions

2
.gitignore vendored
View File

@ -2,4 +2,4 @@
/dataset
/out
wandb/
**/*.log
**/*.lognohup.out

View File

@ -2,6 +2,8 @@ import math
import struct
import inspect
import time
import sys
sys.path.append('/mnt/lzn/Minimind')
from .LMConfig import LMConfig
from typing import Any, Optional, Tuple, List, Union
@ -94,22 +96,25 @@ class Attention(nn.Module):
x: torch.Tensor,
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False,
use_cache=True,
db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
# print(xk, xv)
# 重复键值对
xq, xk, xv = (
xq.transpose(1, 2),
@ -373,7 +378,7 @@ class MiniMindBlock(nn.Module):
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False):
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
# import pdb;pdb.set_trace()
# db_value = None
@ -425,7 +430,7 @@ class MiniMindBlock(nn.Module):
use_cache=use_cache,
db_value=db_value
)
# print(past_kv)
h_attn = self.cross_att(h_attn, db_value)
# 残差连接
@ -485,7 +490,7 @@ class ExtractDB(nn.Module):
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 最终top-k选择
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)# no gradient
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
return flat_indices
@ -557,7 +562,7 @@ class MiniMindLM(PreTrainedModel):
def forward(self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
use_cache: bool = True,
logits_to_keep: Union[int, torch.Tensor] = 0,
**args):
past_key_values = past_key_values or [None] * len(self.layers)
@ -578,16 +583,17 @@ class MiniMindLM(PreTrainedModel):
# 正常模式,使用数据库查询
index = self.extract_db.q_to_k(h)
db_value = self.extract_db.get_data(index)
h, past_kv = layer(
h, db_value, pos_cis,
past_key_value=past_key_values[l],
use_cache=use_cache
)
past_kvs.append(past_kv)
h_list.append(h.unsqueeze(0))
# print(past_kvs)
h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3)
# 只在非禁用数据库模式下执行数据库更新逻辑
@ -595,7 +601,9 @@ class MiniMindLM(PreTrainedModel):
# 使用detach()分离计算图,避免多次反向传播
h_tensor_detached = h_tensor.detach()
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim)
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
@ -604,7 +612,7 @@ class MiniMindLM(PreTrainedModel):
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, z_v)
# import pdb;pdb.set_trace()
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))

1346
nohup.out Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,9 @@
import os
# 设置环境变量
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
os.environ["WANDB_MODE"] = "offline"
import sys
sys.path.append('/mnt/lzn/Minimind')
import platform
import argparse
import time
@ -119,6 +122,7 @@ def train_epoch(epoch, wandb):
except Exception as e:
print(f"Error occurred: {str(e)}")
moe_path = '_moe' if lm_config.use_moe else ''
save_path = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}_nanERROR.pth'
if os.path.exists(save_path):
os.remove(save_path)