Minimind/fix_logits_to_keep_issue.py
2025-08-01 15:54:21 +08:00

247 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
修复logits_to_keep参数导致的loss计算错误
验证问题并提供解决方案
"""
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model.LMConfig import LMConfig
from model.model_original import MiniMindLM
def demonstrate_logits_to_keep_issue():
"""
演示logits_to_keep参数导致的问题
"""
print("🔍 验证logits_to_keep参数问题")
print("="*60)
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 加载模型
config = LMConfig(
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
)
model = MiniMindLM(config)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 加载测试数据
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
sample = json.loads(f.readline().strip())
text = sample['text']
tokens = tokenizer.encode(text, add_special_tokens=False)
input_tokens = tokens[:100]
target_tokens = tokens[100:130] # 30个目标token
print(f"测试样本: {len(tokens)} tokens")
print(f"输入: {len(input_tokens)} tokens")
print(f"目标: {len(target_tokens)} tokens")
with torch.no_grad():
full_input = torch.tensor([tokens[:130]], dtype=torch.long).to(device)
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
print(f"\n🔬 详细对比不同方法:")
# 方法1: 标准forward (正确方法)
outputs1 = model(full_input)
logits1 = outputs1.logits
correct_logits = logits1[0, 99:129, :].contiguous() # 取position 99-128
loss1 = F.cross_entropy(correct_logits, target_labels, reduction='mean')
print(f"1. 标准forward (正确):")
print(f" 完整logits形状: {logits1.shape}")
print(f" 用于计算的logits形状: {correct_logits.shape}")
print(f" Loss: {loss1.item():.4f}")
# 方法2: 使用logits_to_keep=30 (错误方法)
outputs2 = model(full_input, logits_to_keep=30)
logits2 = outputs2.logits
incorrect_logits = logits2[0, -30:, :].contiguous() # 最后30个
loss2 = F.cross_entropy(incorrect_logits, target_labels, reduction='mean')
print(f"\n2. logits_to_keep=30 (eval_model.py方法):")
print(f" 部分logits形状: {logits2.shape}")
print(f" 用于计算的logits形状: {incorrect_logits.shape}")
print(f" Loss: {loss2.item():.4f}")
# 方法3: 修复后的方法不使用logits_to_keep
# 这就是方法1但为了清晰显示修复方案
print(f"\n3. 修复方法 (不使用logits_to_keep):")
print(f" 使用完整forward然后选择正确的logits切片")
print(f" 这与方法1相同Loss: {loss1.item():.4f}")
# 分析差异
print(f"\n📊 数值分析:")
print(f" Loss差异: {abs(loss2.item() - loss1.item()):.4f}")
print(f" Loss增幅: {(loss2.item() / loss1.item() - 1) * 100:.1f}%")
# 检查logits的微小差异如何被放大
logits_diff = torch.abs(correct_logits - incorrect_logits).max()
print(f" 最大logits差异: {logits_diff.item():.8f}")
# 计算softmax概率的差异
prob1 = F.softmax(correct_logits, dim=-1)
prob2 = F.softmax(incorrect_logits, dim=-1)
prob_diff = torch.abs(prob1 - prob2).max()
print(f" 最大概率差异: {prob_diff.item():.8f}")
print(f"\n💡 结论:")
print(f" 虽然logits差异很小({logits_diff.item():.8f})")
print(f" 但在交叉熵损失中被显著放大导致loss增加{(loss2.item() / loss1.item() - 1) * 100:.1f}%")
def create_fixed_eval_model():
"""
创建修复后的eval_model.py
"""
print(f"\n🔧 创建修复后的评估脚本")
print("="*60)
# 读取原始eval_model.py
with open('eval_model.py', 'r', encoding='utf-8') as f:
content = f.read()
# 修复关键部分移除logits_to_keep的使用
fixed_content = content.replace(
""" # 计算loss使用forward方法
# 准备用于loss计算的输入
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
outputs = model(loss_input_ids, logits_to_keep=predict_length)
# 计算loss
logits = outputs.logits
loss = None
if logits is not None:
# 重塑logits和目标
shift_logits = logits[0, -predict_length:, :].contiguous()
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 计算交叉熵损失
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
loss = loss.item()""",
""" # 计算loss使用forward方法
# 准备用于loss计算的输入
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
outputs = model(loss_input_ids) # 移除logits_to_keep参数
# 计算loss
logits = outputs.logits
loss = None
if logits is not None:
# 重塑logits和目标 - 修复:使用正确的位置切片
shift_logits = logits[0, input_length:input_length + predict_length, :].contiguous()
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 计算交叉熵损失
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
loss = loss.item()"""
)
# 保存修复后的文件
with open('eval_model_fixed.py', 'w', encoding='utf-8') as f:
f.write(fixed_content)
print(f"✅ 创建了修复版本eval_model_fixed.py")
print(f"主要修复:")
print(f" 1. 移除 logits_to_keep 参数")
print(f" 2. 使用正确的位置切片: [input_length:input_length + predict_length]")
print(f" 3. 而不是错误的 [-predict_length:]")
def test_fixed_evaluation():
"""
测试修复后的评估方法
"""
print(f"\n🧪 测试修复后的评估方法")
print("="*60)
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 加载模型
config = LMConfig(
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
)
model = MiniMindLM(config)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 测试多个样本
total_loss_old = 0
total_loss_fixed = 0
valid_samples = 0
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 10: # 测试前10个样本
break
sample = json.loads(line.strip())
text = sample['text']
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) < 130:
continue
input_length = 100
predict_length = 30
input_tokens = tokens[:input_length]
target_tokens = tokens[input_length:input_length + predict_length]
with torch.no_grad():
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 原始错误方法
outputs_old = model(full_input, logits_to_keep=predict_length)
logits_old = outputs_old.logits
shift_logits_old = logits_old[0, -predict_length:, :].contiguous()
loss_old = F.cross_entropy(shift_logits_old, target_labels, reduction='mean')
# 修复后方法
outputs_fixed = model(full_input)
logits_fixed = outputs_fixed.logits
shift_logits_fixed = logits_fixed[0, input_length:input_length + predict_length, :].contiguous()
loss_fixed = F.cross_entropy(shift_logits_fixed, target_labels, reduction='mean')
total_loss_old += loss_old.item()
total_loss_fixed += loss_fixed.item()
valid_samples += 1
print(f"样本{i+1}: 原始{loss_old.item():.4f} -> 修复{loss_fixed.item():.4f}")
avg_loss_old = total_loss_old / valid_samples
avg_loss_fixed = total_loss_fixed / valid_samples
print(f"\n📊 测试结果总结:")
print(f" 测试样本数: {valid_samples}")
print(f" 原始方法平均loss: {avg_loss_old:.4f}")
print(f" 修复方法平均loss: {avg_loss_fixed:.4f}")
print(f" 差异: {abs(avg_loss_old - avg_loss_fixed):.4f}")
print(f" 修复后loss更接近训练时的教师强制loss (~2.4)")
if __name__ == "__main__":
demonstrate_logits_to_keep_issue()
create_fixed_eval_model()
test_fixed_evaluation()