Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
47b120be3a |
@ -132,7 +132,7 @@ def decode_dataset(model_path, output_path, device="cuda"):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
||||||
parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth",
|
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
|
||||||
help="Path to the model checkpoint")
|
help="Path to the model checkpoint")
|
||||||
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
||||||
help="Path to save the decoded text file")
|
help="Path to save the decoded text file")
|
||||||
|
33
loss.py
Normal file
33
loss.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import re
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
log_file = 'out/train.log'
|
||||||
|
steps_per_epoch = 58880 # 你需要根据实际日志设置
|
||||||
|
|
||||||
|
with open(log_file, 'r', encoding='utf-8') as f:
|
||||||
|
log_text = f.read()
|
||||||
|
|
||||||
|
# 提取 epoch, step, loss
|
||||||
|
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
|
||||||
|
matches = pattern.findall(log_text)
|
||||||
|
|
||||||
|
global_steps = []
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for epoch, step, loss in matches:
|
||||||
|
epoch = int(epoch)
|
||||||
|
step = int(step)
|
||||||
|
global_step = (epoch - 1) * steps_per_epoch + step
|
||||||
|
global_steps.append(global_step)
|
||||||
|
losses.append(float(loss))
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(global_steps, losses, label='Loss')
|
||||||
|
plt.xlabel('Global Step')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.title('Training Loss Curve')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('out/loss_curve.png')
|
||||||
|
plt.show()
|
@ -550,7 +550,7 @@ class ExtractDB(nn.Module):
|
|||||||
|
|
||||||
# collapse sequence dimension by averaging
|
# collapse sequence dimension by averaging
|
||||||
x_flat = x.mean(dim=1) # [batch_size, dim]
|
x_flat = x.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
||||||
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
||||||
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
||||||
@ -590,6 +590,20 @@ class ExtractDB(nn.Module):
|
|||||||
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||||
self.weight_down_embed[k] = v_reshaped
|
self.weight_down_embed[k] = v_reshaped
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_keys_with_zq(self, flat_indices, z_q):
|
||||||
|
"""
|
||||||
|
flat_indices: [batch],q_to_k输出的检索到的key的全局索引(0~knowledge_num-1)
|
||||||
|
z_q: [batch, 2, dim_key],每个样本的两个子空间query
|
||||||
|
"""
|
||||||
|
num_keys = self.num_keys
|
||||||
|
idx_x = flat_indices // num_keys # [batch]
|
||||||
|
idx_y = flat_indices % num_keys # [batch]
|
||||||
|
|
||||||
|
# 对于每个样本,把keys的两个子空间分别替换为z_q的对应部分
|
||||||
|
for i in range(flat_indices.size(0)):
|
||||||
|
self.keys.data[idx_x[i], 0, :] = z_q[i, 0, :].to(self.keys.dtype)
|
||||||
|
self.keys.data[idx_y[i], 1, :] = z_q[i, 1, :].to(self.keys.dtype)
|
||||||
|
|
||||||
|
|
||||||
class MiniMindLM(PreTrainedModel):
|
class MiniMindLM(PreTrainedModel):
|
||||||
@ -634,13 +648,16 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
|
|
||||||
# Specific layers for q path
|
# Specific layers for q path
|
||||||
self.downsample_q_specific = nn.Sequential(
|
self.downsample_q_specific = nn.Sequential(
|
||||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
nn.Conv1d(128*8, self.params.dim, kernel_size=1, padding='same')
|
||||||
)
|
)
|
||||||
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
||||||
self.register_buffer("pos_cis_real",
|
self.register_buffer("pos_cis_real",
|
||||||
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.value_update_schedule = 0.9 # 前%冻结
|
||||||
|
self.global_step = 0 # 当前步数
|
||||||
|
self.total_steps = None # 总步数,训练脚本里赋值
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
@ -683,27 +700,29 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
|
|
||||||
# 数据库更新逻辑与主计算图分离
|
# 数据库更新逻辑与主计算图分离
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
||||||
# Compute shared downsampling layer once
|
# Compute shared downsampling layer once
|
||||||
shared_features = self.shared_downsample(h_tensor_detached)
|
shared_features = self.shared_downsample(h_tensor_detached)
|
||||||
|
|
||||||
# Get features from v path - now we output embedding-dimension vectors
|
# Get features from v path
|
||||||
z_v_features = self.downsample_v_specific(shared_features)
|
z_v_features = self.downsample_v_specific(shared_features)
|
||||||
batch_z, seq_len, dim_z = z_v_features.shape
|
batch_z, seq_len, dim_z = z_v_features.shape
|
||||||
|
|
||||||
# Reshape to batch_size * knowledge_length, dim
|
|
||||||
z_v_flat = z_v_features.reshape(-1, dim_z)
|
z_v_flat = z_v_features.reshape(-1, dim_z)
|
||||||
|
token_logits = self.database_output(z_v_flat)
|
||||||
# Direct token prediction - like the main language model head
|
|
||||||
token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size]
|
|
||||||
# Get token indices directly from logits
|
|
||||||
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
token_indices_flat = torch.argmax(token_logits, dim=-1)
|
||||||
token_indices = token_indices_flat.reshape(batch_z, -1)
|
token_indices = token_indices_flat.reshape(batch_z, -1)
|
||||||
|
|
||||||
# Process query path as before
|
# Process query path
|
||||||
z_q = self.downsample_q_specific(shared_features)
|
z_q_input = self.downsample_q_specific(shared_features) # [batch, dim, seq_len]
|
||||||
z_k = self.extract_db.q_to_k(z_q)
|
z_q_input = z_q_input.permute(0, 2, 1) # [batch, seq_len, dim]
|
||||||
# self.extract_db.updata_value(z_k, token_indices)
|
z_k = self.extract_db.q_to_k(z_q_input) # [batch]
|
||||||
|
z_q_pooled = z_q_input.mean(dim=1) # [batch, dim]
|
||||||
|
z_q_vec = self.extract_db.to_queries(z_q_pooled) # [batch, 2*dim_key]
|
||||||
|
z_q_vec = z_q_vec.view(z_q_vec.size(0), 2, self.extract_db.dim_key) # [batch, 2, dim_key]
|
||||||
|
|
||||||
|
progress = self.global_step / self.total_steps if self.total_steps else 0
|
||||||
|
if progress >= self.value_update_schedule:
|
||||||
|
self.extract_db.updata_value(z_k, token_indices)
|
||||||
|
self.extract_db.update_keys_with_zq(z_k, z_q_vec)
|
||||||
|
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
@ -778,3 +797,5 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
yield input_ids[:, start:]
|
yield input_ids[:, start:]
|
||||||
if input_ids_next.item() == eos_token_id:
|
if input_ids_next.item() == eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
0
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file → Executable file
0
run_file/DynamicKV-LLM_Mini_Minimind.sh
Normal file → Executable file
Loading…
x
Reference in New Issue
Block a user