247 lines
9.5 KiB
Python
247 lines
9.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
修复logits_to_keep参数导致的loss计算错误
|
||
验证问题并提供解决方案
|
||
"""
|
||
|
||
import json
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from transformers import AutoTokenizer
|
||
from model.LMConfig import LMConfig
|
||
from model.model_original import MiniMindLM
|
||
|
||
|
||
def demonstrate_logits_to_keep_issue():
|
||
"""
|
||
演示logits_to_keep参数导致的问题
|
||
"""
|
||
print("🔍 验证logits_to_keep参数问题")
|
||
print("="*60)
|
||
|
||
device = 'cuda'
|
||
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
|
||
|
||
# 加载模型
|
||
config = LMConfig(
|
||
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
|
||
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
|
||
)
|
||
|
||
model = MiniMindLM(config)
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 加载测试数据
|
||
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
|
||
sample = json.loads(f.readline().strip())
|
||
|
||
text = sample['text']
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
|
||
input_tokens = tokens[:100]
|
||
target_tokens = tokens[100:130] # 30个目标token
|
||
|
||
print(f"测试样本: {len(tokens)} tokens")
|
||
print(f"输入: {len(input_tokens)} tokens")
|
||
print(f"目标: {len(target_tokens)} tokens")
|
||
|
||
with torch.no_grad():
|
||
full_input = torch.tensor([tokens[:130]], dtype=torch.long).to(device)
|
||
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
print(f"\n🔬 详细对比不同方法:")
|
||
|
||
# 方法1: 标准forward (正确方法)
|
||
outputs1 = model(full_input)
|
||
logits1 = outputs1.logits
|
||
correct_logits = logits1[0, 99:129, :].contiguous() # 取position 99-128
|
||
loss1 = F.cross_entropy(correct_logits, target_labels, reduction='mean')
|
||
|
||
print(f"1. 标准forward (正确):")
|
||
print(f" 完整logits形状: {logits1.shape}")
|
||
print(f" 用于计算的logits形状: {correct_logits.shape}")
|
||
print(f" Loss: {loss1.item():.4f}")
|
||
|
||
# 方法2: 使用logits_to_keep=30 (错误方法)
|
||
outputs2 = model(full_input, logits_to_keep=30)
|
||
logits2 = outputs2.logits
|
||
incorrect_logits = logits2[0, -30:, :].contiguous() # 最后30个
|
||
loss2 = F.cross_entropy(incorrect_logits, target_labels, reduction='mean')
|
||
|
||
print(f"\n2. logits_to_keep=30 (eval_model.py方法):")
|
||
print(f" 部分logits形状: {logits2.shape}")
|
||
print(f" 用于计算的logits形状: {incorrect_logits.shape}")
|
||
print(f" Loss: {loss2.item():.4f}")
|
||
|
||
# 方法3: 修复后的方法(不使用logits_to_keep)
|
||
# 这就是方法1,但为了清晰显示修复方案
|
||
print(f"\n3. 修复方法 (不使用logits_to_keep):")
|
||
print(f" 使用完整forward,然后选择正确的logits切片")
|
||
print(f" 这与方法1相同,Loss: {loss1.item():.4f}")
|
||
|
||
# 分析差异
|
||
print(f"\n📊 数值分析:")
|
||
print(f" Loss差异: {abs(loss2.item() - loss1.item()):.4f}")
|
||
print(f" Loss增幅: {(loss2.item() / loss1.item() - 1) * 100:.1f}%")
|
||
|
||
# 检查logits的微小差异如何被放大
|
||
logits_diff = torch.abs(correct_logits - incorrect_logits).max()
|
||
print(f" 最大logits差异: {logits_diff.item():.8f}")
|
||
|
||
# 计算softmax概率的差异
|
||
prob1 = F.softmax(correct_logits, dim=-1)
|
||
prob2 = F.softmax(incorrect_logits, dim=-1)
|
||
prob_diff = torch.abs(prob1 - prob2).max()
|
||
print(f" 最大概率差异: {prob_diff.item():.8f}")
|
||
|
||
print(f"\n💡 结论:")
|
||
print(f" 虽然logits差异很小({logits_diff.item():.8f}),")
|
||
print(f" 但在交叉熵损失中被显著放大,导致loss增加{(loss2.item() / loss1.item() - 1) * 100:.1f}%")
|
||
|
||
|
||
def create_fixed_eval_model():
|
||
"""
|
||
创建修复后的eval_model.py
|
||
"""
|
||
print(f"\n🔧 创建修复后的评估脚本")
|
||
print("="*60)
|
||
|
||
# 读取原始eval_model.py
|
||
with open('eval_model.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复关键部分:移除logits_to_keep的使用
|
||
fixed_content = content.replace(
|
||
""" # 计算loss(使用forward方法)
|
||
# 准备用于loss计算的输入
|
||
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
outputs = model(loss_input_ids, logits_to_keep=predict_length)
|
||
|
||
# 计算loss
|
||
logits = outputs.logits
|
||
loss = None
|
||
if logits is not None:
|
||
# 重塑logits和目标
|
||
shift_logits = logits[0, -predict_length:, :].contiguous()
|
||
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 计算交叉熵损失
|
||
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
|
||
loss = loss.item()""",
|
||
""" # 计算loss(使用forward方法)
|
||
# 准备用于loss计算的输入
|
||
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
outputs = model(loss_input_ids) # 移除logits_to_keep参数
|
||
|
||
# 计算loss
|
||
logits = outputs.logits
|
||
loss = None
|
||
if logits is not None:
|
||
# 重塑logits和目标 - 修复:使用正确的位置切片
|
||
shift_logits = logits[0, input_length:input_length + predict_length, :].contiguous()
|
||
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 计算交叉熵损失
|
||
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
|
||
loss = loss.item()"""
|
||
)
|
||
|
||
# 保存修复后的文件
|
||
with open('eval_model_fixed.py', 'w', encoding='utf-8') as f:
|
||
f.write(fixed_content)
|
||
|
||
print(f"✅ 创建了修复版本:eval_model_fixed.py")
|
||
print(f"主要修复:")
|
||
print(f" 1. 移除 logits_to_keep 参数")
|
||
print(f" 2. 使用正确的位置切片: [input_length:input_length + predict_length]")
|
||
print(f" 3. 而不是错误的 [-predict_length:]")
|
||
|
||
|
||
def test_fixed_evaluation():
|
||
"""
|
||
测试修复后的评估方法
|
||
"""
|
||
print(f"\n🧪 测试修复后的评估方法")
|
||
print("="*60)
|
||
|
||
device = 'cuda'
|
||
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
|
||
|
||
# 加载模型
|
||
config = LMConfig(
|
||
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
|
||
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
|
||
)
|
||
|
||
model = MiniMindLM(config)
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 测试多个样本
|
||
total_loss_old = 0
|
||
total_loss_fixed = 0
|
||
valid_samples = 0
|
||
|
||
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
|
||
for i, line in enumerate(f):
|
||
if i >= 10: # 测试前10个样本
|
||
break
|
||
|
||
sample = json.loads(line.strip())
|
||
text = sample['text']
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
|
||
if len(tokens) < 130:
|
||
continue
|
||
|
||
input_length = 100
|
||
predict_length = 30
|
||
input_tokens = tokens[:input_length]
|
||
target_tokens = tokens[input_length:input_length + predict_length]
|
||
|
||
with torch.no_grad():
|
||
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 原始错误方法
|
||
outputs_old = model(full_input, logits_to_keep=predict_length)
|
||
logits_old = outputs_old.logits
|
||
shift_logits_old = logits_old[0, -predict_length:, :].contiguous()
|
||
loss_old = F.cross_entropy(shift_logits_old, target_labels, reduction='mean')
|
||
|
||
# 修复后方法
|
||
outputs_fixed = model(full_input)
|
||
logits_fixed = outputs_fixed.logits
|
||
shift_logits_fixed = logits_fixed[0, input_length:input_length + predict_length, :].contiguous()
|
||
loss_fixed = F.cross_entropy(shift_logits_fixed, target_labels, reduction='mean')
|
||
|
||
total_loss_old += loss_old.item()
|
||
total_loss_fixed += loss_fixed.item()
|
||
valid_samples += 1
|
||
|
||
print(f"样本{i+1}: 原始{loss_old.item():.4f} -> 修复{loss_fixed.item():.4f}")
|
||
|
||
avg_loss_old = total_loss_old / valid_samples
|
||
avg_loss_fixed = total_loss_fixed / valid_samples
|
||
|
||
print(f"\n📊 测试结果总结:")
|
||
print(f" 测试样本数: {valid_samples}")
|
||
print(f" 原始方法平均loss: {avg_loss_old:.4f}")
|
||
print(f" 修复方法平均loss: {avg_loss_fixed:.4f}")
|
||
print(f" 差异: {abs(avg_loss_old - avg_loss_fixed):.4f}")
|
||
print(f" 修复后loss更接近训练时的教师强制loss (~2.4)")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
demonstrate_logits_to_keep_issue()
|
||
create_fixed_eval_model()
|
||
test_fixed_evaluation() |