Minimind/dataset_decoder.py

145 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
import torch
from transformers import AutoTokenizer
from model.model import MiniMindLM, ExtractDB
from model.LMConfig import LMConfig
def decode_dataset(model_path, output_path, device="cuda"):
"""
Decode the weight_down_embed buffer in the model to readable text
Args:
model_path: Path to the model checkpoint
output_path: Path to save the decoded text
device: Device to load the model on
"""
print(f"Loading tokenizer from ./model/minimind_tokenizer")
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
print(f"Setting up model configuration")
# Create model configuration matching the training parameters
lm_config = LMConfig(
dim=1024,
n_layers=32,
max_seq_len=1024,
use_flash_attn=True,
knowledge_num=16384, # From the script parameters
knowledge_length=64 # From the script parameters
)
print(f"Initializing model")
model = MiniMindLM(lm_config).to(device)
print(f"Loading model weights from {model_path}")
state_dict = torch.load(model_path, map_location=device)
# Get model parameters
model_state = dict(model.named_parameters())
model_state.update(dict(model.named_buffers()))
# Find parameters with matching names but different shapes
shape_mismatch = {}
for name, param in model_state.items():
if name in state_dict and param.shape != state_dict[name].shape:
shape_mismatch[name] = (param.shape, state_dict[name].shape)
# Find parameters in model but not in state_dict and vice versa
model_only = set(model_state.keys()) - set(state_dict.keys())
state_dict_only = set(state_dict.keys()) - set(model_state.keys())
# Create filtered state_dict with only compatible parameters
filtered_state_dict = {}
for name, param in state_dict.items():
if name in model_state and param.shape == model_state[name].shape:
filtered_state_dict[name] = param
# Print parameter differences
if shape_mismatch:
print(f"Parameters with shape mismatches: {len(shape_mismatch)}")
for name, (model_shape, state_shape) in shape_mismatch.items():
print(f" {name}: model={model_shape}, checkpoint={state_shape}")
if model_only:
print(f"Parameters in model but not in checkpoint: {len(model_only)}")
for name in sorted(model_only):
print(f" {name}: {model_state[name].shape}")
# 特殊处理pos_cis_real参数
if name == "pos_cis_real":
print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.")
if state_dict_only:
print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}")
for name in sorted(state_dict_only):
print(f" {name}: {state_dict[name].shape}")
# 如果checkpoint中有output.weight但模型中没有尝试加载到tok_embeddings
if name == "output.weight" and "tok_embeddings.weight" in model_state:
print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight")
if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape:
filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"]
# Load only the compatible parameters
print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters")
model.load_state_dict(filtered_state_dict, strict=False)
# 检查extract_db和weight_down_embed是否存在
if not hasattr(model, "extract_db"):
print("ERROR: Model does not have extract_db attribute. This is required for decoding.")
return
print("Accessing weight_down_embed buffer")
# Get the weight_down_embed buffer from the model
try:
weight_down_embed = model.extract_db.weight_down_embed
print(f"Successfully accessed weight_down_embed buffer")
except Exception as e:
print(f"ERROR: Failed to access weight_down_embed buffer: {e}")
print(f"Model structure: {model.__class__.__name__}")
print(f"ExtractDB attributes: {dir(model.extract_db)}")
return
print(f"Shape of weight_down_embed: {weight_down_embed.shape}")
print(f"Data type of weight_down_embed: {weight_down_embed.dtype}")
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"Decoding knowledge and writing to {output_path}")
knowledge_num, knowledge_length = weight_down_embed.shape
with open(output_path, 'w', encoding='utf-8') as f:
for i in range(knowledge_num):
try:
# Get token IDs for this knowledge entry
token_ids = weight_down_embed[i].cpu().tolist()
# Decode tokens to text
text = tokenizer.decode(token_ids, skip_special_tokens=True)
# Write to file
f.write(f"Knowledge_{i}: {text}\n")
# Print progress periodically
if (i + 1) % 100 == 0:
print(f"Decoded {i + 1}/{knowledge_num} knowledge entries")
except Exception as e:
print(f"Error decoding knowledge entry {i}: {e}")
f.write(f"Knowledge_{i}: [ERROR DECODING]\n")
print(f"Decoding completed. Output saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database")
parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth",
help="Path to the model checkpoint")
parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt",
help="Path to save the decoded text file")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to load the model on")
args = parser.parse_args()
decode_dataset(args.model_path, args.output_path, args.device)