#!/usr/bin/env python3 """ 调试模型生成过程 """ import torch from transformers import AutoTokenizer from model.model_original import MiniMindLM from model.LMConfig import LMConfig def debug_generation(): # 加载模型和tokenizer device = 'cuda' model_path = 'out/experiment_1_4_0/pretrain_512.pth' # 配置 config = LMConfig( dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512 ) # 初始化模型 model = MiniMindLM(config) # 加载权重 state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # 加载tokenizer tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') # 测试文本 text = "The quick brown fox" input_tokens = tokenizer.encode(text, add_special_tokens=False) print(f"输入文本: {text}") print(f"输入tokens: {input_tokens}") print(f"解码回来: {tokenizer.decode(input_tokens)}") # 转为tensor input_ids = torch.tensor([input_tokens], dtype=torch.long).to(device) print(f"输入张量形状: {input_ids.shape}") # 手动生成一步 with torch.no_grad(): # 前向传播 outputs = model(input_ids) logits = outputs.logits print(f"输出logits形状: {logits.shape}") # 获取最后一个位置的logits next_token_logits = logits[0, -1, :] print(f"下一个token的logits形状: {next_token_logits.shape}") # 应用温度 next_token_logits = next_token_logits / 1.0 # 获取概率分布 probs = torch.softmax(next_token_logits, dim=-1) # 找出top-5的token top_probs, top_indices = torch.topk(probs, 10) print(f"\nTop 10 候选tokens:") for i, (prob, idx) in enumerate(zip(top_probs, top_indices)): token_text = tokenizer.decode([idx.item()], skip_special_tokens=True) print(f" {i+1}. Token {idx.item()}: '{token_text}' (prob: {prob.item():.4f})") # 贪婪采样 next_token = torch.argmax(next_token_logits, dim=-1) print(f"\n贪婪采样选择的token: {next_token.item()}") print(f"对应文本: '{tokenizer.decode([next_token.item()], skip_special_tokens=True)}'") # 使用generate方法 print(f"\n使用generate方法:") with torch.no_grad(): generated = model.generate( input_ids, max_new_tokens=5, temperature=1.0, top_p=0.95, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id ) print(f"生成的完整序列长度: {generated[0].shape}") print(f"生成的tokens: {generated[0].tolist()}") # 提取新生成的部分 if len(generated[0]) > len(input_tokens): new_tokens = generated[0][len(input_tokens):].tolist() print(f"新生成的tokens: {new_tokens}") print(f"新生成的文本: '{tokenizer.decode(new_tokens, skip_special_tokens=True)}'") else: print("没有生成新的tokens") if __name__ == "__main__": debug_generation()