diff --git a/dataset_decoder.py b/dataset_decoder.py new file mode 100644 index 0000000..cb5be61 --- /dev/null +++ b/dataset_decoder.py @@ -0,0 +1,144 @@ +import os +import argparse +import torch +from transformers import AutoTokenizer +from model.model import MiniMindLM, ExtractDB +from model.LMConfig import LMConfig + +def decode_dataset(model_path, output_path, device="cuda"): + """ + Decode the weight_down_embed buffer in the model to readable text + + Args: + model_path: Path to the model checkpoint + output_path: Path to save the decoded text + device: Device to load the model on + """ + print(f"Loading tokenizer from ./model/minimind_tokenizer") + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') + + print(f"Setting up model configuration") + # Create model configuration matching the training parameters + lm_config = LMConfig( + dim=1024, + n_layers=32, + max_seq_len=1024, + use_flash_attn=True, + knowledge_num=16384, # From the script parameters + knowledge_length=64 # From the script parameters + ) + + print(f"Initializing model") + model = MiniMindLM(lm_config).to(device) + + print(f"Loading model weights from {model_path}") + state_dict = torch.load(model_path, map_location=device) + + # Get model parameters + model_state = dict(model.named_parameters()) + model_state.update(dict(model.named_buffers())) + + # Find parameters with matching names but different shapes + shape_mismatch = {} + for name, param in model_state.items(): + if name in state_dict and param.shape != state_dict[name].shape: + shape_mismatch[name] = (param.shape, state_dict[name].shape) + + # Find parameters in model but not in state_dict and vice versa + model_only = set(model_state.keys()) - set(state_dict.keys()) + state_dict_only = set(state_dict.keys()) - set(model_state.keys()) + + # Create filtered state_dict with only compatible parameters + filtered_state_dict = {} + for name, param in state_dict.items(): + if name in model_state and param.shape == model_state[name].shape: + filtered_state_dict[name] = param + + # Print parameter differences + if shape_mismatch: + print(f"Parameters with shape mismatches: {len(shape_mismatch)}") + for name, (model_shape, state_shape) in shape_mismatch.items(): + print(f" {name}: model={model_shape}, checkpoint={state_shape}") + + if model_only: + print(f"Parameters in model but not in checkpoint: {len(model_only)}") + for name in sorted(model_only): + print(f" {name}: {model_state[name].shape}") + + # 特殊处理pos_cis_real参数 + if name == "pos_cis_real": + print(f"Detected pos_cis_real parameter. This is a position encoding that will be initialized automatically.") + + if state_dict_only: + print(f"Parameters in checkpoint but not in model: {len(state_dict_only)}") + for name in sorted(state_dict_only): + print(f" {name}: {state_dict[name].shape}") + + # 如果checkpoint中有output.weight但模型中没有,尝试加载到tok_embeddings + if name == "output.weight" and "tok_embeddings.weight" in model_state: + print(f"Found output.weight in checkpoint but not in model. Will try to map it to tok_embeddings.weight") + if model_state["tok_embeddings.weight"].shape == state_dict["output.weight"].shape: + filtered_state_dict["tok_embeddings.weight"] = state_dict["output.weight"] + + # Load only the compatible parameters + print(f"Loading {len(filtered_state_dict)}/{len(state_dict)} parameters") + model.load_state_dict(filtered_state_dict, strict=False) + + # 检查extract_db和weight_down_embed是否存在 + if not hasattr(model, "extract_db"): + print("ERROR: Model does not have extract_db attribute. This is required for decoding.") + return + + print("Accessing weight_down_embed buffer") + # Get the weight_down_embed buffer from the model + try: + weight_down_embed = model.extract_db.weight_down_embed + print(f"Successfully accessed weight_down_embed buffer") + except Exception as e: + print(f"ERROR: Failed to access weight_down_embed buffer: {e}") + print(f"Model structure: {model.__class__.__name__}") + print(f"ExtractDB attributes: {dir(model.extract_db)}") + return + + print(f"Shape of weight_down_embed: {weight_down_embed.shape}") + print(f"Data type of weight_down_embed: {weight_down_embed.dtype}") + + # Create output directory if it doesn't exist + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + print(f"Decoding knowledge and writing to {output_path}") + knowledge_num, knowledge_length = weight_down_embed.shape + + with open(output_path, 'w', encoding='utf-8') as f: + for i in range(knowledge_num): + try: + # Get token IDs for this knowledge entry + token_ids = weight_down_embed[i].cpu().tolist() + + # Decode tokens to text + text = tokenizer.decode(token_ids, skip_special_tokens=True) + + # Write to file + f.write(f"Knowledge_{i}: {text}\n") + + # Print progress periodically + if (i + 1) % 100 == 0: + print(f"Decoded {i + 1}/{knowledge_num} knowledge entries") + except Exception as e: + print(f"Error decoding knowledge entry {i}: {e}") + f.write(f"Knowledge_{i}: [ERROR DECODING]\n") + + print(f"Decoding completed. Output saved to {output_path}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Decode MiniMind model's knowledge database") + parser.add_argument("--model_path", type=str, default="out/pretrain_1024.pth", + help="Path to the model checkpoint") + parser.add_argument("--output_path", type=str, default="out/knowledge_db.txt", + help="Path to save the decoded text file") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to load the model on") + + args = parser.parse_args() + + decode_dataset(args.model_path, args.output_path, args.device)