import os import argparse import torch from transformers import AutoTokenizer from model.model import MiniMindLM, ExtractDB from model.LMConfig import LMConfig def decode_dataset(model_path, output_path, device="cuda"): """ Decode the weight_down_embed buffer in the model to readable text Args: model_path: Path to the model checkpoint output_path: Path to save the decoded text device: Device to load the model on """ print(f"Loading tokenizer from ./model/minimind_tokenizer") tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') print(f"Setting up model configuration") # Create model configuration matching the training parameters lm_config = LMConfig( dim=1024, n_layers=32, max_seq_len=1024, use_flash_attn=True, knowledge_num=16384, # From the script parameters knowledge_length=64 # From the script parameters ) print(f"Initializing model") model = MiniMindLM(lm_config).to(device) print(f"Loading model weights from {model_path}") state_dict = torch.load(model_path, map_location=device) # Get model parameters model_state = dict(model.named_parameters()) model_state.update(dict(model.named_buffers())) # Find parameters with matching names but different shapes shape_mismatch = {} for name, param in model_state.items(): if name in state_dict and param.shape != state_dict[name].shape: shape_mismatch[name] = (param.shape, state_dict[name].shape) # Find parameters in model but not in state_dict and vice versa model_only = set(model_state.keys()) - set(state_dict.keys()) state_dict_only = set(state_dict.keys()) - set(model_state.keys()) # Create filtered state_dict with only compatible parameters filtered_state_dict = {} for name, param in state_dict.items(): if name in model_state and param.shape == model_state[name].shape: filtered_state_dict[name] = param # Print parameter differences if shape_mismatch: print(f"Parameters with shape mismatches: {len(shape_mismatch)}") for name, (model_shape, state_shape) in shape_mismatch.items(): print(f" {name}: model={model_shape}, checkpoint={state_shape}") if model_only: print(f"Parameters in model but not in checkpoint: {len(model_only)}") for name in sorted(model_only): print(f" {name}: {model_state[name].shape}") # 特殊处理pos_cis_real参数 if name == "pos_cis_real": print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.") if state_dict_only: print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}") for name in sorted(state_dict_only): print(f" {name}: {state_dict[name].shape}") # 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings if name == "output.weight" and "tok_embeddings.weight" in model_state: print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight") if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape: filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"] # Load only the compatible parameters print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters") model.load_state_dict(filtered_state_dict, strict=False) # 检查extract_db和weight_down_embed是否存在 if not hasattr(model, "extract_db"): print("ERROR: Model does not have extract_db attribute. This is required for decoding.") return print("Accessing weight_down_embed buffer") # Get the weight_down_embed buffer from the model try: weight_down_embed = model.extract_db.weight_down_embed print(f"Successfully accessed weight_down_embed buffer") except Exception as e: print(f"ERROR: Failed to access weight_down_embed buffer: {e}") print(f"Model structure: {model.__class__.__name__}") print(f"ExtractDB attributes: {dir(model.extract_db)}") return print(f"Shape of weight_down_embed: {weight_down_embed.shape}") print(f"Data type of weight_down_embed: {weight_down_embed.dtype}") # Create output directory if it doesn't exist os.makedirs(os.path.dirname(output_path), exist_ok=True) print(f"Decoding knowledge and writing to {output_path}") knowledge_num, knowledge_length = weight_down_embed.shape with open(output_path, 'w', encoding='utf-8') as f: for i in range(knowledge_num): try: # Get token IDs for this knowledge entry token_ids = weight_down_embed[i].cpu().tolist() # Decode tokens to text text = tokenizer.decode(token_ids, skip_special_tokens=True) # Write to file f.write(f"Knowledge_{i}: {text}\n") # Print progress periodically if (i + 1) % 100 == 0: print(f"Decoded {i + 1}/{knowledge_num} knowledge entries") except Exception as e: print(f"Error decoding knowledge entry {i}: {e}") f.write(f"Knowledge_{i}: [ERROR DECODING]\n") print(f"Decoding completed. Output saved to {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database") parser.add_argument("--model_path", type=str, default="out/pretrain_512.pth", help="Path to the model checkpoint") parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt", help="Path to save the decoded text file") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to load the model on") args = parser.parse_args() decode_dataset(args.model_path, args.output_path, args.device)