Minimind/debug_model.py
2025-08-01 15:54:21 +08:00

101 lines
3.2 KiB
Python

#!/usr/bin/env python3
"""
调试模型生成过程
"""
import torch
from transformers import AutoTokenizer
from model.model_original import MiniMindLM
from model.LMConfig import LMConfig
def debug_generation():
# 加载模型和tokenizer
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 配置
config = LMConfig(
dim=512,
n_layers=8,
n_heads=32,
vocab_size=6400,
max_seq_len=512
)
# 初始化模型
model = MiniMindLM(config)
# 加载权重
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
# 测试文本
text = "The quick brown fox"
input_tokens = tokenizer.encode(text, add_special_tokens=False)
print(f"输入文本: {text}")
print(f"输入tokens: {input_tokens}")
print(f"解码回来: {tokenizer.decode(input_tokens)}")
# 转为tensor
input_ids = torch.tensor([input_tokens], dtype=torch.long).to(device)
print(f"输入张量形状: {input_ids.shape}")
# 手动生成一步
with torch.no_grad():
# 前向传播
outputs = model(input_ids)
logits = outputs.logits
print(f"输出logits形状: {logits.shape}")
# 获取最后一个位置的logits
next_token_logits = logits[0, -1, :]
print(f"下一个token的logits形状: {next_token_logits.shape}")
# 应用温度
next_token_logits = next_token_logits / 1.0
# 获取概率分布
probs = torch.softmax(next_token_logits, dim=-1)
# 找出top-5的token
top_probs, top_indices = torch.topk(probs, 10)
print(f"\nTop 10 候选tokens:")
for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
token_text = tokenizer.decode([idx.item()], skip_special_tokens=True)
print(f" {i+1}. Token {idx.item()}: '{token_text}' (prob: {prob.item():.4f})")
# 贪婪采样
next_token = torch.argmax(next_token_logits, dim=-1)
print(f"\n贪婪采样选择的token: {next_token.item()}")
print(f"对应文本: '{tokenizer.decode([next_token.item()], skip_special_tokens=True)}'")
# 使用generate方法
print(f"\n使用generate方法:")
with torch.no_grad():
generated = model.generate(
input_ids,
max_new_tokens=5,
temperature=1.0,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
print(f"生成的完整序列长度: {generated[0].shape}")
print(f"生成的tokens: {generated[0].tolist()}")
# 提取新生成的部分
if len(generated[0]) > len(input_tokens):
new_tokens = generated[0][len(input_tokens):].tolist()
print(f"新生成的tokens: {new_tokens}")
print(f"新生成的文本: '{tokenizer.decode(new_tokens, skip_special_tokens=True)}'")
else:
print("没有生成新的tokens")
if __name__ == "__main__":
debug_generation()