修正了loss为nan的错误
This commit is contained in:
parent
cb286d26d1
commit
8dd7cfaf72
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
pos_cis: torch.Tensor,
|
pos_cis: torch.Tensor,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
use_cache=False,
|
use_cache=True,
|
||||||
db_value=None):
|
db_value=None):
|
||||||
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
|
||||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
|
||||||
@ -373,7 +373,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
||||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
||||||
|
|
||||||
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False):
|
def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
|
||||||
# import pdb;pdb.set_trace()
|
# import pdb;pdb.set_trace()
|
||||||
# db_value = None
|
# db_value = None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user