DynamicKV-LLM 1.0.0 完成了核心架构,模型可以正常训练
This commit is contained in:
parent
e3120f5e62
commit
0859f54a88
@ -431,15 +431,13 @@ class ExtractDB(nn.Module):
|
||||
self.batch_size = None
|
||||
self.dim = params.dim
|
||||
self.dim_key = self.dim // 2
|
||||
self.num_experts = 10 * 10 # 1M专家,确保是完全平方数
|
||||
self.num_experts = 10 * 10 # 100专家,确保是完全平方数
|
||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||
self.head_dim = params.dim // params.n_heads
|
||||
self.knowledge_dim = 8*params.dim
|
||||
# 使用CPU上的普通tensor替代nn.Embedding
|
||||
self.register_buffer('weight_down_embed_cpu', torch.randn(self.num_experts, self.knowledge_dim,
|
||||
dtype=torch.float32,
|
||||
device='cpu') * 0.02,
|
||||
persistent=True)
|
||||
|
||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
|
||||
|
||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
||||
self.product_key_topk = min(16, self.num_keys)
|
||||
@ -482,20 +480,18 @@ class ExtractDB(nn.Module):
|
||||
return flat_indices
|
||||
|
||||
def get_data(self, index):
|
||||
# 将需要的embedding从CPU移到当前设备上
|
||||
device = index.device
|
||||
# 根据索引获取对应的embedding
|
||||
db_values = self.weight_down_embed_cpu[index.cpu()].to(device)
|
||||
# 直接从GPU获取embedding
|
||||
db_values = self.weight_down_embed[index]
|
||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||
return db_value
|
||||
|
||||
@torch.no_grad()
|
||||
def updata_value(self, k, v):
|
||||
# 更新CPU上的张量值
|
||||
k_cpu = k.cpu()
|
||||
v_cpu = v.view(v.size(0), -1).cpu()
|
||||
|
||||
# 直接更新内存中的值
|
||||
self.weight_down_embed_cpu[k_cpu] = v_cpu
|
||||
# 直接更新buffer上的值 (不需要梯度)
|
||||
v_reshaped = v.view(v.size(0), -1)
|
||||
# 确保数据类型匹配
|
||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||
self.weight_down_embed[k] = v_reshaped
|
||||
|
||||
|
||||
|
||||
@ -554,17 +550,19 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
past_kvs.append(past_kv)
|
||||
h_list.append(h.unsqueeze(0))
|
||||
h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3)
|
||||
h_tensor = h_tensor.reshape(h_tensor.shape[0],-1,768)
|
||||
z_v = self.downsample_v(h_tensor)
|
||||
z_q = self.downsample_q(h_tensor)
|
||||
|
||||
# 使用detach()分离计算图,避免多次反向传播
|
||||
h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3)
|
||||
h_tensor_detached = h_tensor.detach()
|
||||
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0],-1,768)
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
z_v = self.downsample_v(h_tensor_detached)
|
||||
z_q = self.downsample_q(h_tensor_detached)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k,z_v)
|
||||
|
||||
#更新数据库
|
||||
# q,v = f(h_list)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||
|
Loading…
x
Reference in New Issue
Block a user