Minimind/analyze_position_slicing.py
2025-08-01 15:54:21 +08:00

193 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
深入分析位置切片的问题
验证logits_to_keep和位置索引的正确性
"""
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model.LMConfig import LMConfig
from model.model_original import MiniMindLM
def analyze_position_indexing():
"""
分析位置索引的正确性
"""
print("🔍 分析位置索引和切片逻辑")
print("="*60)
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 加载模型
config = LMConfig(
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
)
model = MiniMindLM(config)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 加载测试数据
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
sample = json.loads(f.readline().strip())
text = sample['text']
tokens = tokenizer.encode(text, add_special_tokens=False)
input_length = 100
predict_length = 30
input_tokens = tokens[:input_length]
target_tokens = tokens[input_length:input_length + predict_length]
print(f"输入长度: {input_length}")
print(f"预测长度: {predict_length}")
print(f"总序列长度: {input_length + predict_length}")
print(f"输入token位置: 0 到 {input_length-1}")
print(f"目标token位置: {input_length}{input_length + predict_length - 1}")
with torch.no_grad():
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
print(f"\n🔬 详细分析不同切片方法:")
# 方法1: 标准forward
outputs1 = model(full_input)
logits1 = outputs1.logits
print(f"\n1. 标准forward:")
print(f" 输入形状: {full_input.shape}")
print(f" 输出logits形状: {logits1.shape}")
# 在transformer中position i的logits预测position i+1的token
# 所以要预测position 100-129的token需要position 99-128的logits
correct_slice = logits1[0, input_length-1:input_length+predict_length-1, :].contiguous()
loss1 = F.cross_entropy(correct_slice, target_labels, reduction='mean')
print(f" 正确切片 [{input_length-1}:{input_length+predict_length-1}]: {correct_slice.shape}")
print(f" Loss: {loss1.item():.4f}")
# 方法2: logits_to_keep
outputs2 = model(full_input, logits_to_keep=predict_length)
logits2 = outputs2.logits
print(f"\n2. logits_to_keep={predict_length}:")
print(f" 输出logits形状: {logits2.shape}")
# 当logits_to_keep=30时返回最后30个位置的logits
# 这应该对应position 100-129但实际是哪些位置
keep_slice = logits2[0, -predict_length:, :].contiguous()
loss2 = F.cross_entropy(keep_slice, target_labels, reduction='mean')
print(f" logits_to_keep切片 [-{predict_length}:]: {keep_slice.shape}")
print(f" Loss: {loss2.item():.4f}")
# 检查这两个切片是否相同
print(f"\n🔍 切片对比:")
if torch.allclose(correct_slice, keep_slice, rtol=1e-6):
print(f" ✅ 两个切片完全相同")
else:
diff = torch.abs(correct_slice - keep_slice).max()
print(f" ❌ 切片不同,最大差异: {diff.item():.8f}")
# 检查具体哪些位置不同
diff_mask = ~torch.isclose(correct_slice, keep_slice, rtol=1e-6)
diff_positions = torch.where(diff_mask.any(dim=-1))[0]
print(f" 不同的位置: {diff_positions.tolist()}")
# 方法3: 验证eval_model.py中的逻辑
print(f"\n3. eval_model.py的逻辑:")
# eval_model.py使用的是logits[0, -predict_length:, :]
eval_slice = logits1[0, -predict_length:, :].contiguous()
loss3 = F.cross_entropy(eval_slice, target_labels, reduction='mean')
print(f" eval_model.py切片 [-{predict_length}:]: {eval_slice.shape}")
print(f" 这对应logits中的位置: {logits1.shape[1] - predict_length}{logits1.shape[1] - 1}")
print(f" Loss: {loss3.item():.4f}")
# 检查eval_model.py的切片是否正确
if torch.allclose(correct_slice, eval_slice, rtol=1e-6):
print(f" ✅ eval_model.py切片正确")
else:
diff = torch.abs(correct_slice - eval_slice).max()
print(f" ❌ eval_model.py切片错误最大差异: {diff.item():.8f}")
def compare_different_sequence_lengths():
"""
比较不同序列长度下的行为
"""
print(f"\n🧪 测试不同序列长度")
print("="*60)
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 加载模型
config = LMConfig(
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
)
model = MiniMindLM(config)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 创建测试序列
test_tokens = list(range(200)) # 简单的数字序列
test_configs = [
(50, 20), # 50输入20预测
(100, 30), # 100输入30预测
(150, 40), # 150输入40预测
]
for input_len, predict_len in test_configs:
print(f"\n测试配置: 输入{input_len}, 预测{predict_len}")
sequence = test_tokens[:input_len + predict_len]
input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
target_labels = torch.tensor(sequence[input_len:], dtype=torch.long).to(device)
with torch.no_grad():
# 标准方法
outputs_std = model(input_ids)
logits_std = outputs_std.logits
slice_std = logits_std[0, input_len-1:input_len+predict_len-1, :].contiguous()
loss_std = F.cross_entropy(slice_std, target_labels, reduction='mean')
# logits_to_keep方法
outputs_keep = model(input_ids, logits_to_keep=predict_len)
logits_keep = outputs_keep.logits
slice_keep = logits_keep[0, -predict_len:, :].contiguous()
loss_keep = F.cross_entropy(slice_keep, target_labels, reduction='mean')
# eval_model.py方法
slice_eval = logits_std[0, -predict_len:, :].contiguous()
loss_eval = F.cross_entropy(slice_eval, target_labels, reduction='mean')
print(f" 标准方法loss: {loss_std.item():.4f}")
print(f" logits_to_keep loss: {loss_keep.item():.4f}")
print(f" eval_model.py loss: {loss_eval.item():.4f}")
# 检查是否相同
std_vs_keep = torch.allclose(slice_std, slice_keep, rtol=1e-6)
std_vs_eval = torch.allclose(slice_std, slice_eval, rtol=1e-6)
keep_vs_eval = torch.allclose(slice_keep, slice_eval, rtol=1e-6)
print(f" 标准 vs logits_to_keep: {'' if std_vs_keep else ''}")
print(f" 标准 vs eval_model.py: {'' if std_vs_eval else ''}")
print(f" logits_to_keep vs eval_model.py: {'' if keep_vs_eval else ''}")
if __name__ == "__main__":
analyze_position_indexing()
compare_different_sequence_lengths()