DynamicKV-LLM Pretrain v1.1.2
This commit is contained in:
parent
00d3c24e03
commit
45da3b383b
@ -608,7 +608,9 @@ class MiniMindLM(PreTrainedModel):
|
||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
||||
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
||||
self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False)
|
||||
self.tok_embeddings.weight = self.output.weight
|
||||
self.database_output.weight = self.output.weight
|
||||
|
||||
# Calculate input dimension
|
||||
input_dim = (self.params.max_seq_len-1)*self.params.n_layers
|
||||
@ -629,7 +631,7 @@ class MiniMindLM(PreTrainedModel):
|
||||
nn.Conv1d(128*8, 128, kernel_size=1, padding='same'),
|
||||
nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same')
|
||||
)
|
||||
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
@ -684,23 +686,22 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
z_v = self.downsample_v_specific(shared_features)#这里需要从emb返回为token
|
||||
batch_z, seq_len, dim_z = z_v.shape
|
||||
embedding_weights = self.tok_embeddings.weight.detach()#[vocab_size, dim]
|
||||
z_v_flat = z_v.reshape(-1, z_v.shape[-1]) # [batch_size_z * knowledge_len, dim]
|
||||
|
||||
# 余弦相似度版本
|
||||
# z_v_flat_norm = F.normalize(z_v_flat, p=2, dim=1)
|
||||
# embedding_weights_norm = F.normalize(embedding_weights, p=2, dim=1)
|
||||
# similarity_scores = torch.matmul(z_v_flat_norm, embedding_weights_norm.transpose(0, 1))
|
||||
# token_indices_flat = torch.argmax(similarity_scores, dim=1) # [batch_size_z * seq_len]
|
||||
# token_indices = token_indices_flat.reshape(batch_size, -1)
|
||||
|
||||
distances = torch.cdist(z_v_flat, embedding_weights, p=2)
|
||||
token_indices_flat = torch.argmin(distances, dim=1)
|
||||
token_indices = token_indices_flat.reshape(batch_z,-1)
|
||||
# Get features from v path - now we output embedding-dimension vectors
|
||||
z_v_features = self.downsample_v_specific(shared_features)
|
||||
batch_z, seq_len, dim_z = z_v_features.shape
|
||||
|
||||
# Reshape to batch_size * knowledge_length, dim
|
||||
z_v_flat = z_v_features.reshape(-1, dim_z)
|
||||
|
||||
# Direct token prediction - like the main language model head
|
||||
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
|
||||
# Get token indices directly from logits
|
||||
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
||||
token_indices = token_indices_flat.reshape(batch_z, -1)
|
||||
|
||||
# Process query path as before
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
# import pdb;pdb.set_trace()
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
self.extract_db.updata_value(z_k, token_indices)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user