DynamicKV-LLM 1.0.0 完成了核心架构,模型可以正常训练
This commit is contained in:
parent
e3120f5e62
commit
0859f54a88
@ -431,15 +431,13 @@ class ExtractDB(nn.Module):
|
|||||||
self.batch_size = None
|
self.batch_size = None
|
||||||
self.dim = params.dim
|
self.dim = params.dim
|
||||||
self.dim_key = self.dim // 2
|
self.dim_key = self.dim // 2
|
||||||
self.num_experts = 10 * 10 # 1M专家,确保是完全平方数
|
self.num_experts = 10 * 10 # 100专家,确保是完全平方数
|
||||||
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
# 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用
|
||||||
self.head_dim = params.dim // params.n_heads
|
self.head_dim = params.dim // params.n_heads
|
||||||
self.knowledge_dim = 8*params.dim
|
self.knowledge_dim = 8*params.dim
|
||||||
# 使用CPU上的普通tensor替代nn.Embedding
|
|
||||||
self.register_buffer('weight_down_embed_cpu', torch.randn(self.num_experts, self.knowledge_dim,
|
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
||||||
dtype=torch.float32,
|
self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02)
|
||||||
device='cpu') * 0.02,
|
|
||||||
persistent=True)
|
|
||||||
|
|
||||||
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0
|
||||||
self.product_key_topk = min(16, self.num_keys)
|
self.product_key_topk = min(16, self.num_keys)
|
||||||
@ -482,20 +480,18 @@ class ExtractDB(nn.Module):
|
|||||||
return flat_indices
|
return flat_indices
|
||||||
|
|
||||||
def get_data(self, index):
|
def get_data(self, index):
|
||||||
# 将需要的embedding从CPU移到当前设备上
|
# 直接从GPU获取embedding
|
||||||
device = index.device
|
db_values = self.weight_down_embed[index]
|
||||||
# 根据索引获取对应的embedding
|
|
||||||
db_values = self.weight_down_embed_cpu[index.cpu()].to(device)
|
|
||||||
db_value = db_values.view(self.batch_size, -1, self.dim)
|
db_value = db_values.view(self.batch_size, -1, self.dim)
|
||||||
return db_value
|
return db_value
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def updata_value(self, k, v):
|
def updata_value(self, k, v):
|
||||||
# 更新CPU上的张量值
|
# 直接更新buffer上的值 (不需要梯度)
|
||||||
k_cpu = k.cpu()
|
v_reshaped = v.view(v.size(0), -1)
|
||||||
v_cpu = v.view(v.size(0), -1).cpu()
|
# 确保数据类型匹配
|
||||||
|
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||||
# 直接更新内存中的值
|
self.weight_down_embed[k] = v_reshaped
|
||||||
self.weight_down_embed_cpu[k_cpu] = v_cpu
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -554,17 +550,19 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
|
|
||||||
past_kvs.append(past_kv)
|
past_kvs.append(past_kv)
|
||||||
h_list.append(h.unsqueeze(0))
|
h_list.append(h.unsqueeze(0))
|
||||||
h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3)
|
|
||||||
h_tensor = h_tensor.reshape(h_tensor.shape[0],-1,768)
|
|
||||||
z_v = self.downsample_v(h_tensor)
|
|
||||||
z_q = self.downsample_q(h_tensor)
|
|
||||||
|
|
||||||
|
# 使用detach()分离计算图,避免多次反向传播
|
||||||
|
h_tensor = torch.cat(h_list,dim=0).permute(1,0,2,3)
|
||||||
|
h_tensor_detached = h_tensor.detach()
|
||||||
|
h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0],-1,768)
|
||||||
|
|
||||||
|
# 数据库更新逻辑与主计算图分离
|
||||||
|
with torch.no_grad():
|
||||||
|
z_v = self.downsample_v(h_tensor_detached)
|
||||||
|
z_q = self.downsample_q(h_tensor_detached)
|
||||||
z_k = self.extract_db.q_to_k(z_q)
|
z_k = self.extract_db.q_to_k(z_q)
|
||||||
self.extract_db.updata_value(z_k,z_v)
|
self.extract_db.updata_value(z_k,z_v)
|
||||||
|
|
||||||
#更新数据库
|
|
||||||
# q,v = f(h_list)
|
|
||||||
|
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user