Compare commits

...

1 Commits
master ... cy

Author SHA1 Message Date
cy
47b120be3a 更新了key和value的更新方式 2025-06-03 07:36:34 +00:00
4 changed files with 69 additions and 15 deletions

View File

@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file")

33
loss.py Normal file
View File

@ -0,0 +1,33 @@
import re
import matplotlib.pyplot as plt
log_file = 'out/train.log'
steps_per_epoch = 58880 # 你需要根据实际日志设置
with open(log_file, 'r', encoding='utf-8') as f:
log_text = f.read()
# 提取 epoch, step, loss
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
matches = pattern.findall(log_text)
global_steps = []
losses = []
for epoch, step, loss in matches:
epoch = int(epoch)
step = int(step)
global_step = (epoch - 1) * steps_per_epoch + step
global_steps.append(global_step)
losses.append(float(loss))
plt.figure(figsize=(12, 6))
plt.plot(global_steps, losses, label='Loss')
plt.xlabel('Global Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('out/loss_curve.png')
plt.show()

View File

@ -590,6 +590,20 @@ class ExtractDB(nn.Module):
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
@torch.no_grad()
def update_keys_with_zq(self, flat_indices, z_q):
"""
flat_indices: [batch]q_to_k输出的检索到的key的全局索引0~knowledge_num-1
z_q: [batch, 2, dim_key]每个样本的两个子空间query
"""
num_keys = self.num_keys
idx_x = flat_indices // num_keys # [batch]
idx_y = flat_indices % num_keys # [batch]
# 对于每个样本把keys的两个子空间分别替换为z_q的对应部分
for i in range(flat_indices.size(0)):
self.keys.data[idx_x[i], 0, :] = z_q[i, 0, :].to(self.keys.dtype)
self.keys.data[idx_y[i], 1, :] = z_q[i, 1, :].to(self.keys.dtype)
class MiniMindLM(PreTrainedModel):
@ -634,13 +648,16 @@ class MiniMindLM(PreTrainedModel):
# Specific layers for q path
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
nn.Conv1d(128*8, self.params.dim, kernel_size=1, padding='same')
)
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.params = params
self.value_update_schedule = 0.9 # 前%冻结
self.global_step = 0 # 当前步数
self.total_steps = None # 总步数,训练脚本里赋值
def forward(self,
input_ids: Optional[torch.Tensor] = None,
@ -683,27 +700,29 @@ class MiniMindLM(PreTrainedModel):
# 数据库更新逻辑与主计算图分离
with torch.no_grad():
# Compute shared downsampling layer once
shared_features = self.shared_downsample(h_tensor_detached)
# Get features from v path - now we output embedding-dimension vectors
# Get features from v path
z_v_features = self.downsample_v_specific(shared_features)
batch_z, seq_len, dim_z = z_v_features.shape
# Reshape to batch_size * knowledge_length, dim
z_v_flat = z_v_features.reshape(-1, dim_z)
# Direct token prediction - like the main language model head
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
# Get token indices directly from logits
token_logits = self.database_output(z_v_flat)
token_indices_flat = torch.argmax(token_logits, dim=-1)
token_indices = token_indices_flat.reshape(batch_z, -1)
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
# self.extract_db.updata_value(z_k, token_indices)
# Process query path
z_q_input = self.downsample_q_specific(shared_features) # [batch, dim, seq_len]
z_q_input = z_q_input.permute(0, 2, 1) # [batch, seq_len, dim]
z_k = self.extract_db.q_to_k(z_q_input) # [batch]
z_q_pooled = z_q_input.mean(dim=1) # [batch, dim]
z_q_vec = self.extract_db.to_queries(z_q_pooled) # [batch, 2*dim_key]
z_q_vec = z_q_vec.view(z_q_vec.size(0), 2, self.extract_db.dim_key) # [batch, 2, dim_key]
progress = self.global_step / self.total_steps if self.total_steps else 0
if progress >= self.value_update_schedule:
self.extract_db.updata_value(z_k, token_indices)
self.extract_db.update_keys_with_zq(z_k, z_q_vec)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])
@ -778,3 +797,5 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break

0
run_file/DynamicKV-LLM_Mini_Minimind.sh Normal file → Executable file
View File