2025-05-22 11:32:15 +08:00
|
|
|
|
import os
|
|
|
|
|
import argparse
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
from model.model import MiniMindLM, ExtractDB
|
|
|
|
|
from model.LMConfig import LMConfig
|
|
|
|
|
|
|
|
|
|
def decode_dataset(model_path, output_path, device="cuda"):
|
|
|
|
|
"""
|
|
|
|
|
Decode the weight_down_embed buffer in the model to readable text
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_path: Path to the model checkpoint
|
|
|
|
|
output_path: Path to save the decoded text
|
|
|
|
|
device: Device to load the model on
|
|
|
|
|
"""
|
|
|
|
|
print(f"Loading tokenizer from ./model/minimind_tokenizer")
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
|
|
|
|
|
|
|
|
|
print(f"Setting up model configuration")
|
|
|
|
|
# Create model configuration matching the training parameters
|
|
|
|
|
lm_config = LMConfig(
|
|
|
|
|
dim=1024,
|
|
|
|
|
n_layers=32,
|
|
|
|
|
max_seq_len=1024,
|
|
|
|
|
use_flash_attn=True,
|
|
|
|
|
knowledge_num=16384, # From the script parameters
|
|
|
|
|
knowledge_length=64 # From the script parameters
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print(f"Initializing model")
|
|
|
|
|
model = MiniMindLM(lm_config).to(device)
|
|
|
|
|
|
|
|
|
|
print(f"Loading model weights from {model_path}")
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
# Get model parameters
|
|
|
|
|
model_state = dict(model.named_parameters())
|
|
|
|
|
model_state.update(dict(model.named_buffers()))
|
|
|
|
|
|
|
|
|
|
# Find parameters with matching names but different shapes
|
|
|
|
|
shape_mismatch = {}
|
|
|
|
|
for name, param in model_state.items():
|
|
|
|
|
if name in state_dict and param.shape != state_dict[name].shape:
|
|
|
|
|
shape_mismatch[name] = (param.shape, state_dict[name].shape)
|
|
|
|
|
|
|
|
|
|
# Find parameters in model but not in state_dict and vice versa
|
|
|
|
|
model_only = set(model_state.keys()) - set(state_dict.keys())
|
|
|
|
|
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
|
|
|
|
|
|
|
|
|
|
# Create filtered state_dict with only compatible parameters
|
|
|
|
|
filtered_state_dict = {}
|
|
|
|
|
for name, param in state_dict.items():
|
|
|
|
|
if name in model_state and param.shape == model_state[name].shape:
|
|
|
|
|
filtered_state_dict[name] = param
|
|
|
|
|
|
|
|
|
|
# Print parameter differences
|
|
|
|
|
if shape_mismatch:
|
|
|
|
|
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
|
|
|
|
|
for name, (model_shape, state_shape) in shape_mismatch.items():
|
|
|
|
|
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
|
|
|
|
|
|
|
|
|
|
if model_only:
|
|
|
|
|
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
|
|
|
|
|
for name in sorted(model_only):
|
|
|
|
|
print(f" {name}: {model_state[name].shape}")
|
|
|
|
|
|
|
|
|
|
# 特殊处理pos_cis_real参数
|
|
|
|
|
if name == "pos_cis_real":
|
|
|
|
|
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
|
|
|
|
|
|
|
|
|
|
if state_dict_only:
|
|
|
|
|
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
|
|
|
|
|
for name in sorted(state_dict_only):
|
|
|
|
|
print(f" {name}: {state_dict[name].shape}")
|
|
|
|
|
|
|
|
|
|
# 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings
|
|
|
|
|
if name == "output.weight" and "tok_embeddings.weight" in model_state:
|
|
|
|
|
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
|
|
|
|
|
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
|
|
|
|
|
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
|
|
|
|
|
|
|
|
|
|
# Load only the compatible parameters
|
|
|
|
|
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
|
|
|
|
|
model.load_state_dict(filtered_state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
# 检查extract_db和weight_down_embed是否存在
|
|
|
|
|
if not hasattr(model, "extract_db"):
|
|
|
|
|
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
print("Accessing weight_down_embed buffer")
|
|
|
|
|
# Get the weight_down_embed buffer from the model
|
|
|
|
|
try:
|
|
|
|
|
weight_down_embed = model.extract_db.weight_down_embed
|
|
|
|
|
print(f"Successfully accessed weight_down_embed buffer")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
|
|
|
|
|
print(f"Model structure: {model.__class__.__name__}")
|
|
|
|
|
print(f"ExtractDB attributes: {dir(model.extract_db)}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
|
|
|
|
|
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
|
|
|
|
|
|
|
|
|
|
# Create output directory if it doesn't exist
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
print(f"Decoding knowledge and writing to {output_path}")
|
|
|
|
|
knowledge_num, knowledge_length = weight_down_embed.shape
|
|
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
for i in range(knowledge_num):
|
|
|
|
|
try:
|
|
|
|
|
# Get token IDs for this knowledge entry
|
|
|
|
|
token_ids = weight_down_embed[i].cpu().tolist()
|
|
|
|
|
|
|
|
|
|
# Decode tokens to text
|
|
|
|
|
text = tokenizer.decode(token_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
# Write to file
|
|
|
|
|
f.write(f"Knowledge_{i}: {text}\n")
|
|
|
|
|
|
|
|
|
|
# Print progress periodically
|
|
|
|
|
if (i + 1) % 100 == 0:
|
|
|
|
|
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"Error decoding knowledge entry {i}: {e}")
|
|
|
|
|
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
|
|
|
|
|
|
|
|
|
|
print(f"Decoding completed. Output saved to {output_path}")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
|
2025-06-03 07:36:34 +00:00
|
|
|
|
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
|
2025-05-22 11:32:15 +08:00
|
|
|
|
help="Path to the model checkpoint")
|
|
|
|
|
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
|
|
|
|
|
help="Path to save the decoded text file")
|
|
|
|
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
|
|
|
|
|
help="Device to load the model on")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
decode_dataset(args.model_path, args.output_path, args.device)
|