193 lines
7.8 KiB
Python
193 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
深入分析位置切片的问题
|
||
验证logits_to_keep和位置索引的正确性
|
||
"""
|
||
|
||
import json
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from transformers import AutoTokenizer
|
||
from model.LMConfig import LMConfig
|
||
from model.model_original import MiniMindLM
|
||
|
||
|
||
def analyze_position_indexing():
|
||
"""
|
||
分析位置索引的正确性
|
||
"""
|
||
print("🔍 分析位置索引和切片逻辑")
|
||
print("="*60)
|
||
|
||
device = 'cuda'
|
||
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
|
||
|
||
# 加载模型
|
||
config = LMConfig(
|
||
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
|
||
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
|
||
)
|
||
|
||
model = MiniMindLM(config)
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 加载测试数据
|
||
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
|
||
sample = json.loads(f.readline().strip())
|
||
|
||
text = sample['text']
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
|
||
input_length = 100
|
||
predict_length = 30
|
||
input_tokens = tokens[:input_length]
|
||
target_tokens = tokens[input_length:input_length + predict_length]
|
||
|
||
print(f"输入长度: {input_length}")
|
||
print(f"预测长度: {predict_length}")
|
||
print(f"总序列长度: {input_length + predict_length}")
|
||
print(f"输入token位置: 0 到 {input_length-1}")
|
||
print(f"目标token位置: {input_length} 到 {input_length + predict_length - 1}")
|
||
|
||
with torch.no_grad():
|
||
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
print(f"\n🔬 详细分析不同切片方法:")
|
||
|
||
# 方法1: 标准forward
|
||
outputs1 = model(full_input)
|
||
logits1 = outputs1.logits
|
||
print(f"\n1. 标准forward:")
|
||
print(f" 输入形状: {full_input.shape}")
|
||
print(f" 输出logits形状: {logits1.shape}")
|
||
|
||
# 在transformer中,position i的logits预测position i+1的token
|
||
# 所以要预测position 100-129的token,需要position 99-128的logits
|
||
correct_slice = logits1[0, input_length-1:input_length+predict_length-1, :].contiguous()
|
||
loss1 = F.cross_entropy(correct_slice, target_labels, reduction='mean')
|
||
print(f" 正确切片 [{input_length-1}:{input_length+predict_length-1}]: {correct_slice.shape}")
|
||
print(f" Loss: {loss1.item():.4f}")
|
||
|
||
# 方法2: logits_to_keep
|
||
outputs2 = model(full_input, logits_to_keep=predict_length)
|
||
logits2 = outputs2.logits
|
||
print(f"\n2. logits_to_keep={predict_length}:")
|
||
print(f" 输出logits形状: {logits2.shape}")
|
||
|
||
# 当logits_to_keep=30时,返回最后30个位置的logits
|
||
# 这应该对应position 100-129,但实际是哪些位置?
|
||
keep_slice = logits2[0, -predict_length:, :].contiguous()
|
||
loss2 = F.cross_entropy(keep_slice, target_labels, reduction='mean')
|
||
print(f" logits_to_keep切片 [-{predict_length}:]: {keep_slice.shape}")
|
||
print(f" Loss: {loss2.item():.4f}")
|
||
|
||
# 检查这两个切片是否相同
|
||
print(f"\n🔍 切片对比:")
|
||
if torch.allclose(correct_slice, keep_slice, rtol=1e-6):
|
||
print(f" ✅ 两个切片完全相同")
|
||
else:
|
||
diff = torch.abs(correct_slice - keep_slice).max()
|
||
print(f" ❌ 切片不同,最大差异: {diff.item():.8f}")
|
||
|
||
# 检查具体哪些位置不同
|
||
diff_mask = ~torch.isclose(correct_slice, keep_slice, rtol=1e-6)
|
||
diff_positions = torch.where(diff_mask.any(dim=-1))[0]
|
||
print(f" 不同的位置: {diff_positions.tolist()}")
|
||
|
||
# 方法3: 验证eval_model.py中的逻辑
|
||
print(f"\n3. eval_model.py的逻辑:")
|
||
# eval_model.py使用的是logits[0, -predict_length:, :]
|
||
eval_slice = logits1[0, -predict_length:, :].contiguous()
|
||
loss3 = F.cross_entropy(eval_slice, target_labels, reduction='mean')
|
||
print(f" eval_model.py切片 [-{predict_length}:]: {eval_slice.shape}")
|
||
print(f" 这对应logits中的位置: {logits1.shape[1] - predict_length} 到 {logits1.shape[1] - 1}")
|
||
print(f" Loss: {loss3.item():.4f}")
|
||
|
||
# 检查eval_model.py的切片是否正确
|
||
if torch.allclose(correct_slice, eval_slice, rtol=1e-6):
|
||
print(f" ✅ eval_model.py切片正确")
|
||
else:
|
||
diff = torch.abs(correct_slice - eval_slice).max()
|
||
print(f" ❌ eval_model.py切片错误,最大差异: {diff.item():.8f}")
|
||
|
||
|
||
def compare_different_sequence_lengths():
|
||
"""
|
||
比较不同序列长度下的行为
|
||
"""
|
||
print(f"\n🧪 测试不同序列长度")
|
||
print("="*60)
|
||
|
||
device = 'cuda'
|
||
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
|
||
|
||
# 加载模型
|
||
config = LMConfig(
|
||
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
|
||
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
|
||
)
|
||
|
||
model = MiniMindLM(config)
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 创建测试序列
|
||
test_tokens = list(range(200)) # 简单的数字序列
|
||
|
||
test_configs = [
|
||
(50, 20), # 50输入,20预测
|
||
(100, 30), # 100输入,30预测
|
||
(150, 40), # 150输入,40预测
|
||
]
|
||
|
||
for input_len, predict_len in test_configs:
|
||
print(f"\n测试配置: 输入{input_len}, 预测{predict_len}")
|
||
|
||
sequence = test_tokens[:input_len + predict_len]
|
||
input_ids = torch.tensor([sequence], dtype=torch.long).to(device)
|
||
target_labels = torch.tensor(sequence[input_len:], dtype=torch.long).to(device)
|
||
|
||
with torch.no_grad():
|
||
# 标准方法
|
||
outputs_std = model(input_ids)
|
||
logits_std = outputs_std.logits
|
||
slice_std = logits_std[0, input_len-1:input_len+predict_len-1, :].contiguous()
|
||
loss_std = F.cross_entropy(slice_std, target_labels, reduction='mean')
|
||
|
||
# logits_to_keep方法
|
||
outputs_keep = model(input_ids, logits_to_keep=predict_len)
|
||
logits_keep = outputs_keep.logits
|
||
slice_keep = logits_keep[0, -predict_len:, :].contiguous()
|
||
loss_keep = F.cross_entropy(slice_keep, target_labels, reduction='mean')
|
||
|
||
# eval_model.py方法
|
||
slice_eval = logits_std[0, -predict_len:, :].contiguous()
|
||
loss_eval = F.cross_entropy(slice_eval, target_labels, reduction='mean')
|
||
|
||
print(f" 标准方法loss: {loss_std.item():.4f}")
|
||
print(f" logits_to_keep loss: {loss_keep.item():.4f}")
|
||
print(f" eval_model.py loss: {loss_eval.item():.4f}")
|
||
|
||
# 检查是否相同
|
||
std_vs_keep = torch.allclose(slice_std, slice_keep, rtol=1e-6)
|
||
std_vs_eval = torch.allclose(slice_std, slice_eval, rtol=1e-6)
|
||
keep_vs_eval = torch.allclose(slice_keep, slice_eval, rtol=1e-6)
|
||
|
||
print(f" 标准 vs logits_to_keep: {'✅' if std_vs_keep else '❌'}")
|
||
print(f" 标准 vs eval_model.py: {'✅' if std_vs_eval else '❌'}")
|
||
print(f" logits_to_keep vs eval_model.py: {'✅' if keep_vs_eval else '❌'}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
analyze_position_indexing()
|
||
compare_different_sequence_lengths() |