修正了loss为nan的错误
This commit is contained in:
parent
cb286d26d1
commit
8dd7cfaf72
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
||||
x: torch.Tensor,
|
||||
pos_cis: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache=False,
|
||||
use_cache=True,
|
||||
db_value=None):
|
||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||
@ -373,7 +373,7 @@ class MiniMindBlock(nn.Module):
|
||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False):
|
||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
||||
# import pdb;pdb.set_trace()
|
||||
# db_value = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user