DynamicKV-LLM Pretrain v1.1.2

This commit is contained in:
Jax922 2025-05-23 01:18:08 +08:00
parent 00d3c24e03
commit 45da3b383b

View File

@ -608,7 +608,9 @@ class MiniMindLM(PreTrainedModel):
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
self.tok_embeddings.weight = self.output.weight
self.database_output.weight = self.output.weight
# Calculate input dimension
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
@ -684,23 +686,22 @@ class MiniMindLM(PreTrainedModel):
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
z_v = self.downsample_v_specific(shared_features)#这里需要从emb返回为token
batch_z, seq_len, dim_z = z_v.shape
embedding_weights = self.tok_embeddings.weight.detach()#[vocab_size, dim]
z_v_flat = z_v.reshape(-1, z_v.shape[-1]) # [batch_size_z * knowledge_len, dim]
# 余弦相似度版本
# z_v_flat_norm = F.normalize(z_v_flat, p=2, dim=1)
# embedding_weights_norm = F.normalize(embedding_weights, p=2, dim=1)
# similarity_scores = torch.matmul(z_v_flat_norm, embedding_weights_norm.transpose(0, 1))
# token_indices_flat = torch.argmax(similarity_scores, dim=1) # [batch_size_z * seq_len]
# token_indices = token_indices_flat.reshape(batch_size, -1)
# Get features from v path - now we output embedding-dimension vectors
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
distances = torch.cdist(z_v_flat, embedding_weights, p=2)
token_indices_flat = torch.argmin(distances, dim=1)
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
# import pdb;pdb.set_trace()
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, token_indices)