From 8dd7cfaf728e8ec63507804ec371cec0fb51fcf7 Mon Sep 17 00:00:00 2001
From: iomgaa <iomgaaycz@gmail.com>
Date: Sun, 11 May 2025 23:57:34 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E4=BA=86loss=E4=B8=BAnan?=
 =?UTF-8?q?=E7=9A=84=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 model/model.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/model/model.py b/model/model.py
index 31aca85..2873a77 100644
--- a/model/model.py
+++ b/model/model.py
@@ -94,7 +94,7 @@ class Attention(nn.Module):
                 x: torch.Tensor,
                 pos_cis: torch.Tensor,
                 past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-                use_cache=False,
+                use_cache=True,
                 db_value=None):
         bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度
         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。
@@ -373,7 +373,7 @@ class MiniMindBlock(nn.Module):
         #     self.product_key_topk = min(16, self.num_keys)  # 确保不超过num_keys
         #     self.num_experts_per_head_topk = 1  # 最终每个头选取的专家数
 
-    def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False):
+    def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=True):
         # import pdb;pdb.set_trace()
         # db_value = None