From 0859f54a8879005f31dfdf6709055023575598e0 Mon Sep 17 00:00:00 2001 From: iomgaa Date: Fri, 25 Apr 2025 16:49:05 +0800 Subject: [PATCH] =?UTF-8?q?DynamicKV-LLM=201.0.0=20=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E4=BA=86=E6=A0=B8=E5=BF=83=E6=9E=B6=E6=9E=84=EF=BC=8C=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8F=AF=E4=BB=A5=E6=AD=A3=E5=B8=B8=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/model.py | 46 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/model/model.py b/model/model.py index 610af59..182f58b 100644 --- a/model/model.py +++ b/model/model.py @@ -431,15 +431,13 @@ class ExtractDB(nn.Module): self.batch_size = None self.dim = params.dim self.dim_key = self.dim // 2 - self.num_experts = 10 * 10 # 1M专家,确保是完全平方数 + self.num_experts = 10 * 10 # 100专家,确保是完全平方数 # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 self.head_dim = params.dim // params.n_heads self.knowledge_dim = 8*params.dim - # 使用CPU上的普通tensor替代nn.Embedding - self.register_buffer('weight_down_embed_cpu', torch.randn(self.num_experts, self.knowledge_dim, - dtype=torch.float32, - device='cpu') * 0.02, - persistent=True) + + # 使用register_buffer代替nn.Parameter,避免梯度问题 + self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02) self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0 self.product_key_topk = min(16, self.num_keys) @@ -482,20 +480,18 @@ class ExtractDB(nn.Module): return flat_indices def get_data(self, index): - # 将需要的embedding从CPU移到当前设备上 - device = index.device - # 根据索引获取对应的embedding - db_values = self.weight_down_embed_cpu[index.cpu()].to(device) + # 直接从GPU获取embedding + db_values = self.weight_down_embed[index] db_value = db_values.view(self.batch_size, -1, self.dim) return db_value + @torch.no_grad() def updata_value(self, k, v): - # 更新CPU上的张量值 - k_cpu = k.cpu() - v_cpu = v.view(v.size(0), -1).cpu() - - # 直接更新内存中的值 - self.weight_down_embed_cpu[k_cpu] = v_cpu + # 直接更新buffer上的值 (不需要梯度) + v_reshaped = v.view(v.size(0), -1) + # 确保数据类型匹配 + v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) + self.weight_down_embed[k] = v_reshaped @@ -554,17 +550,19 @@ class MiniMindLM(PreTrainedModel): past_kvs.append(past_kv) h_list.append(h.unsqueeze(0)) + + # 使用detach()分离计算图,避免多次反向传播 h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3) - h_tensor = h_tensor.reshape(h_tensor.shape[0],-1,768) - z_v = self.downsample_v(h_tensor) - z_q = self.downsample_q(h_tensor) + h_tensor_detached = h_tensor.detach() + h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0],-1,768) - z_k = self.extract_db.q_to_k(z_q) - self.extract_db.updata_value(z_k,z_v) + # 数据库更新逻辑与主计算图分离 + with torch.no_grad(): + z_v = self.downsample_v(h_tensor_detached) + z_q = self.downsample_q(h_tensor_detached) + z_k = self.extract_db.q_to_k(z_q) + self.extract_db.updata_value(z_k,z_v) - #更新数据库 - # q,v = f(h_list) - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))