Minimind/dataset_decoder.py

145 lines
6.2 KiB
Python
Raw Permalink Normal View History

2025-05-22 11:32:15 +08:00
import os
import argparse
import torch
from transformers import AutoTokenizer
from model.model import MiniMindLM, ExtractDB
from model.LMConfig import LMConfig
def decode_dataset(model_path, output_path, device="cuda"):
"""
Decode the weight_down_embed buffer in the model to readable text
Args:
model_path: Path to the model checkpoint
output_path: Path to save the decoded text
device: Device to load the model on
"""
print(f"Loading tokenizer from ./model/minimind_tokenizer")
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
print(f"Setting up model configuration")
# Create model configuration matching the training parameters
lm_config = LMConfig(
dim=1024,
n_layers=32,
max_seq_len=1024,
use_flash_attn=True,
knowledge_num=16384, # From the script parameters
knowledge_length=64 # From the script parameters
)
print(f"Initializing model")
model = MiniMindLM(lm_config).to(device)
print(f"Loading model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
# Get model parameters
model_state = dict(model.named_parameters())
model_state.update(dict(model.named_buffers()))
# Find parameters with matching names but different shapes
shape_mismatch = {}
for name, param in model_state.items():
if name in state_dict and param.shape != state_dict[name].shape:
shape_mismatch[name] = (param.shape, state_dict[name].shape)
# Find parameters in model but not in state_dict and vice versa
model_only = set(model_state.keys()) - set(state_dict.keys())
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
# Create filtered state_dict with only compatible parameters
filtered_state_dict = {}
for name, param in state_dict.items():
if name in model_state and param.shape == model_state[name].shape:
filtered_state_dict[name] = param
# Print parameter differences
if shape_mismatch:
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
for name, (model_shape, state_shape) in shape_mismatch.items():
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
if model_only:
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
for name in sorted(model_only):
print(f" {name}: {model_state[name].shape}")
# 特殊处理pos_cis_real参数
if name == "pos_cis_real":
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
if state_dict_only:
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
for name in sorted(state_dict_only):
print(f" {name}: {state_dict[name].shape}")
# 如果checkpoint中有output.weight但模型中没有尝试加载到tok_embeddings
if name == "output.weight" and "tok_embeddings.weight" in model_state:
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
# Load only the compatible parameters
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
model.load_state_dict(filtered_state_dict, strict=False)
# 检查extract_db和weight_down_embed是否存在
if not hasattr(model, "extract_db"):
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
return
print("Accessing weight_down_embed buffer")
# Get the weight_down_embed buffer from the model
try:
weight_down_embed = model.extract_db.weight_down_embed
print(f"Successfully accessed weight_down_embed buffer")
except Exception as e:
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
print(f"Model structure: {model.__class__.__name__}")
print(f"ExtractDB attributes: {dir(model.extract_db)}")
return
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Decoding knowledge and writing to {output_path}")
knowledge_num, knowledge_length = weight_down_embed.shape
with open(output_path, 'w', encoding='utf-8') as f:
for i in range(knowledge_num):
try:
# Get token IDs for this knowledge entry
token_ids = weight_down_embed[i].cpu().tolist()
# Decode tokens to text
text = tokenizer.decode(token_ids, skip_special_tokens=True)
# Write to file
f.write(f"Knowledge_{i}: {text}\n")
# Print progress periodically
if (i + 1) % 100 == 0:
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
except Exception as e:
print(f"Error decoding knowledge entry {i}: {e}")
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
print(f"Decoding completed. Output saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
2025-06-03 07:36:34 +00:00
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
2025-05-22 11:32:15 +08:00
help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to load the model on")
args = parser.parse_args()
decode_dataset(args.model_path, args.output_path, args.device)