diff --git a/model/model.py b/model/model.py index 31aca85..2873a77 100644 --- a/model/model.py +++ b/model/model.py @@ -94,7 +94,7 @@ class Attention(nn.Module): x: torch.Tensor, pos_cis: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache=False, + use_cache=True, db_value=None): bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。 @@ -373,7 +373,7 @@ class MiniMindBlock(nn.Module): # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 - def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False): + def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True): # import pdb;pdb.set_trace() # db_value = None