修正了loss为nan的错误

This commit is contained in:
iomgaa 2025-05-11 23:57:34 +08:00
parent cb286d26d1
commit 8dd7cfaf72

View File

@ -94,7 +94,7 @@ class Attention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
pos_cis: torch.Tensor, pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache=False, use_cache=True,
db_value=None): db_value=None):
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换得到查询、键和值。
@ -373,7 +373,7 @@ class MiniMindBlock(nn.Module):
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False): def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
# import pdb;pdb.set_trace() # import pdb;pdb.set_trace()
# db_value = None # db_value = None