From 47b120be3a818bfde6f79866e52cd37382057209 Mon Sep 17 00:00:00 2001 From: cy <1433747532@qq.com> Date: Tue, 3 Jun 2025 07:36:34 +0000 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BA=86key=E5=92=8Cvalue?= =?UTF-8?q?=E7=9A=84=E6=9B=B4=E6=96=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dataset_decoder.py | 2 +- loss.py | 33 +++++++++++++++++ model/model.py | 49 ++++++++++++++++++------- run_file/DynamicKV-LLM_Mini_Minimind.sh | 0 4 files changed, 69 insertions(+), 15 deletions(-) create mode 100644 loss.py mode change 100644 => 100755 run_file/DynamicKV-LLM_Mini_Minimind.sh diff --git a/dataset_decoder.py b/dataset_decoder.py index cb5be61..3a95ae3 100644 --- a/dataset_decoder.py +++ b/dataset_decoder.py @@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database") - parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth", + parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth", help="Path to the model checkpoint") parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt", help="Path to save the decoded text file") diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..06fb4ca --- /dev/null +++ b/loss.py @@ -0,0 +1,33 @@ +import re +import matplotlib.pyplot as plt + +log_file = 'out/train.log' +steps_per_epoch = 58880 # 你需要根据实际日志设置 + +with open(log_file, 'r', encoding='utf-8') as f: + log_text = f.read() + +# 提取 epoch, step, loss +pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE) +matches = pattern.findall(log_text) + +global_steps = [] +losses = [] + +for epoch, step, loss in matches: + epoch = int(epoch) + step = int(step) + global_step = (epoch - 1) * steps_per_epoch + step + global_steps.append(global_step) + losses.append(float(loss)) + +plt.figure(figsize=(12, 6)) +plt.plot(global_steps, losses, label='Loss') +plt.xlabel('Global Step') +plt.ylabel('Loss') +plt.title('Training Loss Curve') +plt.legend() +plt.grid(True) +plt.tight_layout() +plt.savefig('out/loss_curve.png') +plt.show() \ No newline at end of file diff --git a/model/model.py b/model/model.py index 814674d..eda9197 100644 --- a/model/model.py +++ b/model/model.py @@ -550,7 +550,7 @@ class ExtractDB(nn.Module): # collapse sequence dimension by averaging x_flat = x.mean(dim=1) # [batch_size, dim] - + queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] @@ -590,6 +590,20 @@ class ExtractDB(nn.Module): v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) self.weight_down_embed[k] = v_reshaped + @torch.no_grad() + def update_keys_with_zq(self, flat_indices, z_q): + """ + flat_indices: [batch],q_to_k输出的检索到的key的全局索引(0~knowledge_num-1) + z_q: [batch, 2, dim_key],每个样本的两个子空间query + """ + num_keys = self.num_keys + idx_x = flat_indices // num_keys # [batch] + idx_y = flat_indices % num_keys # [batch] + + # 对于每个样本,把keys的两个子空间分别替换为z_q的对应部分 + for i in range(flat_indices.size(0)): + self.keys.data[idx_x[i], 0, :] = z_q[i, 0, :].to(self.keys.dtype) + self.keys.data[idx_y[i], 1, :] = z_q[i, 1, :].to(self.keys.dtype) class MiniMindLM(PreTrainedModel): @@ -634,13 +648,16 @@ class MiniMindLM(PreTrainedModel): # Specific layers for q path self.downsample_q_specific = nn.Sequential( - nn.Conv1d(128*8, 512, kernel_size=1, padding='same') + nn.Conv1d(128*8, self.params.dim, kernel_size=1, padding='same') ) # 使用实数版本的位置编码,避免复数张量可能导致的段错误 self.register_buffer("pos_cis_real", precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.params = params + self.value_update_schedule = 0.9 # 前%冻结 + self.global_step = 0 # 当前步数 + self.total_steps = None # 总步数,训练脚本里赋值 def forward(self, input_ids: Optional[torch.Tensor] = None, @@ -683,27 +700,29 @@ class MiniMindLM(PreTrainedModel): # 数据库更新逻辑与主计算图分离 with torch.no_grad(): - # Compute shared downsampling layer once shared_features = self.shared_downsample(h_tensor_detached) - # Get features from v path - now we output embedding-dimension vectors + # Get features from v path z_v_features = self.downsample_v_specific(shared_features) batch_z, seq_len, dim_z = z_v_features.shape - - # Reshape to batch_size * knowledge_length, dim z_v_flat = z_v_features.reshape(-1, dim_z) - - # Direct token prediction - like the main language model head - token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size] - # Get token indices directly from logits + token_logits = self.database_output(z_v_flat) token_indices_flat = torch.argmax(token_logits, dim=-1) token_indices = token_indices_flat.reshape(batch_z, -1) - # Process query path as before - z_q = self.downsample_q_specific(shared_features) - z_k = self.extract_db.q_to_k(z_q) - # self.extract_db.updata_value(z_k, token_indices) + # Process query path + z_q_input = self.downsample_q_specific(shared_features) # [batch, dim, seq_len] + z_q_input = z_q_input.permute(0, 2, 1) # [batch, seq_len, dim] + z_k = self.extract_db.q_to_k(z_q_input) # [batch] + z_q_pooled = z_q_input.mean(dim=1) # [batch, dim] + z_q_vec = self.extract_db.to_queries(z_q_pooled) # [batch, 2*dim_key] + z_q_vec = z_q_vec.view(z_q_vec.size(0), 2, self.extract_db.dim_key) # [batch, 2, dim_key] + + progress = self.global_step / self.total_steps if self.total_steps else 0 + if progress >= self.value_update_schedule: + self.extract_db.updata_value(z_k, token_indices) + self.extract_db.update_keys_with_zq(z_k, z_q_vec) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) @@ -778,3 +797,5 @@ class MiniMindLM(PreTrainedModel): yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break + + diff --git a/run_file/DynamicKV-LLM_Mini_Minimind.sh b/run_file/DynamicKV-LLM_Mini_Minimind.sh old mode 100644 new mode 100755