Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
47b120be3a |
@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
||||
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
|
||||
help="Path to the model checkpoint")
|
||||
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||
help="Path to save the decoded text file")
|
||||
|
33
loss.py
Normal file
33
loss.py
Normal file
@ -0,0 +1,33 @@
|
||||
import re
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
log_file = 'out/train.log'
|
||||
steps_per_epoch = 58880 # 你需要根据实际日志设置
|
||||
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
log_text = f.read()
|
||||
|
||||
# 提取 epoch, step, loss
|
||||
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
|
||||
matches = pattern.findall(log_text)
|
||||
|
||||
global_steps = []
|
||||
losses = []
|
||||
|
||||
for epoch, step, loss in matches:
|
||||
epoch = int(epoch)
|
||||
step = int(step)
|
||||
global_step = (epoch - 1) * steps_per_epoch + step
|
||||
global_steps.append(global_step)
|
||||
losses.append(float(loss))
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(global_steps, losses, label='Loss')
|
||||
plt.xlabel('Global Step')
|
||||
plt.ylabel('Loss')
|
||||
plt.title('Training Loss Curve')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.tight_layout()
|
||||
plt.savefig('out/loss_curve.png')
|
||||
plt.show()
|
@ -590,6 +590,20 @@ class ExtractDB(nn.Module):
|
||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||
self.weight_down_embed[k] = v_reshaped
|
||||
|
||||
@torch.no_grad()
|
||||
def update_keys_with_zq(self, flat_indices, z_q):
|
||||
"""
|
||||
flat_indices: [batch],q_to_k输出的检索到的key的全局索引(0~knowledge_num-1)
|
||||
z_q: [batch, 2, dim_key],每个样本的两个子空间query
|
||||
"""
|
||||
num_keys = self.num_keys
|
||||
idx_x = flat_indices // num_keys # [batch]
|
||||
idx_y = flat_indices % num_keys # [batch]
|
||||
|
||||
# 对于每个样本,把keys的两个子空间分别替换为z_q的对应部分
|
||||
for i in range(flat_indices.size(0)):
|
||||
self.keys.data[idx_x[i], 0, :] = z_q[i, 0, :].to(self.keys.dtype)
|
||||
self.keys.data[idx_y[i], 1, :] = z_q[i, 1, :].to(self.keys.dtype)
|
||||
|
||||
|
||||
class MiniMindLM(PreTrainedModel):
|
||||
@ -634,13 +648,16 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# Specific layers for q path
|
||||
self.downsample_q_specific = nn.Sequential(
|
||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||
nn.Conv1d(128*8, self.params.dim, kernel_size=1, padding='same')
|
||||
)
|
||||
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
||||
self.register_buffer("pos_cis_real",
|
||||
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||
persistent=False)
|
||||
self.params = params
|
||||
self.value_update_schedule = 0.9 # 前%冻结
|
||||
self.global_step = 0 # 当前步数
|
||||
self.total_steps = None # 总步数,训练脚本里赋值
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -683,27 +700,29 @@ class MiniMindLM(PreTrainedModel):
|
||||
|
||||
# 数据库更新逻辑与主计算图分离
|
||||
with torch.no_grad():
|
||||
|
||||
# Compute shared downsampling layer once
|
||||
shared_features = self.shared_downsample(h_tensor_detached)
|
||||
|
||||
# Get features from v path - now we output embedding-dimension vectors
|
||||
# Get features from v path
|
||||
z_v_features = self.downsample_v_specific(shared_features)
|
||||
batch_z, seq_len, dim_z = z_v_features.shape
|
||||
|
||||
# Reshape to batch_size * knowledge_length, dim
|
||||
z_v_flat = z_v_features.reshape(-1, dim_z)
|
||||
|
||||
# Direct token prediction - like the main language model head
|
||||
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
|
||||
# Get token indices directly from logits
|
||||
token_logits = self.database_output(z_v_flat)
|
||||
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
||||
token_indices = token_indices_flat.reshape(batch_z, -1)
|
||||
|
||||
# Process query path as before
|
||||
z_q = self.downsample_q_specific(shared_features)
|
||||
z_k = self.extract_db.q_to_k(z_q)
|
||||
# self.extract_db.updata_value(z_k, token_indices)
|
||||
# Process query path
|
||||
z_q_input = self.downsample_q_specific(shared_features) # [batch, dim, seq_len]
|
||||
z_q_input = z_q_input.permute(0, 2, 1) # [batch, seq_len, dim]
|
||||
z_k = self.extract_db.q_to_k(z_q_input) # [batch]
|
||||
z_q_pooled = z_q_input.mean(dim=1) # [batch, dim]
|
||||
z_q_vec = self.extract_db.to_queries(z_q_pooled) # [batch, 2*dim_key]
|
||||
z_q_vec = z_q_vec.view(z_q_vec.size(0), 2, self.extract_db.dim_key) # [batch, 2, dim_key]
|
||||
|
||||
progress = self.global_step / self.total_steps if self.total_steps else 0
|
||||
if progress >= self.value_update_schedule:
|
||||
self.extract_db.updata_value(z_k, token_indices)
|
||||
self.extract_db.update_keys_with_zq(z_k, z_q_vec)
|
||||
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||
@ -778,3 +797,5 @@ class MiniMindLM(PreTrainedModel):
|
||||
yield input_ids[:, start:]
|
||||
if input_ids_next.item() == eos_token_id:
|
||||
break
|
||||
|
||||
|
||||
|
0
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file → Executable file
0
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file → Executable file
Loading…
x
Reference in New Issue
Block a user