Minimind/final_fix_eval_model.py
2025-08-01 15:54:21 +08:00

218 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
最终修复eval_model.py中的位置索引错误
"""
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from model.LMConfig import LMConfig
from model.model_original import MiniMindLM
def demonstrate_correct_fix():
"""
演示正确的修复方法
"""
print("🔧 演示正确的修复方法")
print("="*60)
device = 'cuda'
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
# 加载模型
config = LMConfig(
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
)
model = MiniMindLM(config)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.to(device)
model.eval()
# 测试多个样本以验证修复效果
total_loss_wrong = 0
total_loss_correct = 0
valid_samples = 0
print("测试样本的loss对比:")
print("样本 | 错误方法 | 正确方法 | 差异")
print("-" * 45)
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if i >= 10: # 测试前10个样本
break
sample = json.loads(line.strip())
text = sample['text']
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) < 130:
continue
input_length = 100
predict_length = 30
target_tokens = tokens[input_length:input_length + predict_length]
with torch.no_grad():
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 获取完整logits
outputs = model(full_input)
logits = outputs.logits
# 错误方法 (eval_model.py原来的方法)
wrong_slice = logits[0, -predict_length:, :].contiguous() # 取最后30个
loss_wrong = F.cross_entropy(wrong_slice, target_labels, reduction='mean')
# 正确方法
correct_slice = logits[0, input_length-1:input_length+predict_length-1, :].contiguous() # 取99:129
loss_correct = F.cross_entropy(correct_slice, target_labels, reduction='mean')
total_loss_wrong += loss_wrong.item()
total_loss_correct += loss_correct.item()
valid_samples += 1
diff = loss_wrong.item() - loss_correct.item()
print(f"{i+1:2} | {loss_wrong.item():8.4f} | {loss_correct.item():8.4f} | {diff:+6.4f}")
avg_loss_wrong = total_loss_wrong / valid_samples
avg_loss_correct = total_loss_correct / valid_samples
improvement = avg_loss_wrong - avg_loss_correct
print("-" * 45)
print(f"平均 | {avg_loss_wrong:8.4f} | {avg_loss_correct:8.4f} | {improvement:+6.4f}")
print(f"\n📊 修复效果:")
print(f" 错误方法平均loss: {avg_loss_wrong:.4f}")
print(f" 正确方法平均loss: {avg_loss_correct:.4f}")
print(f" 改进幅度: {improvement:.4f} ({improvement/avg_loss_wrong*100:.1f}%)")
print(f" 正确方法更接近训练时的教师强制loss (~2.4)")
def create_final_fixed_eval_model():
"""
创建最终修复版的eval_model.py
"""
print(f"\n🔧 创建最终修复版的eval_model.py")
print("="*60)
# 读取原始eval_model.py
with open('eval_model.py', 'r', encoding='utf-8') as f:
content = f.read()
# 修复evaluate_sample函数中的关键部分
old_loss_calculation = ''' # 计算loss使用forward方法
# 准备用于loss计算的输入
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
outputs = model(loss_input_ids, logits_to_keep=predict_length)
# 计算loss
logits = outputs.logits
loss = None
if logits is not None:
# 重塑logits和目标
shift_logits = logits[0, -predict_length:, :].contiguous()
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 计算交叉熵损失
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
loss = loss.item()'''
new_loss_calculation = ''' # 计算loss使用forward方法
# 准备用于loss计算的输入
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
outputs = model(loss_input_ids) # 移除logits_to_keep参数
# 计算loss
logits = outputs.logits
loss = None
if logits is not None:
# 重塑logits和目标 - 修复:使用正确的位置切片
# 在Transformer中position i的logits预测position i+1的token
# 要预测position input_length到input_length+predict_length-1的token
# 需要使用position input_length-1到input_length+predict_length-2的logits
shift_logits = logits[0, input_length-1:input_length+predict_length-1, :].contiguous()
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
# 计算交叉熵损失
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
loss = loss.item()'''
# 替换内容
fixed_content = content.replace(old_loss_calculation, new_loss_calculation)
# 保存修复后的文件
with open('eval_model_final_fixed.py', 'w', encoding='utf-8') as f:
f.write(fixed_content)
print(f"✅ 创建了最终修复版本eval_model_final_fixed.py")
print(f"主要修复:")
print(f" 1. 移除 logits_to_keep 参数(避免计算差异)")
print(f" 2. 使用正确的位置切片: [input_length-1:input_length+predict_length-1]")
print(f" 3. 这考虑了Transformer中position i预测position i+1的特性")
# 直接修复原文件
with open('eval_model.py', 'w', encoding='utf-8') as f:
f.write(fixed_content)
print(f"✅ 同时直接修复了原文件eval_model.py")
def test_final_fix():
"""
测试最终修复版本
"""
print(f"\n🧪 测试最终修复版本")
print("="*60)
import subprocess
# 运行修复后的eval_model.py使用较少样本快速测试
cmd = [
'.venv/bin/python', 'eval_model.py',
'--model_path', 'out/experiment_1_4_0/pretrain_512.pth',
'--model_type', 'model_original',
'--num_samples', '5',
'--input_length', '100',
'--predict_length', '30'
]
print("运行命令:")
print(" ".join(cmd))
print("\n运行结果:")
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
# 提取关键信息
output_lines = result.stdout.split('\n')
for line in output_lines:
if 'Loss:' in line or '平均Loss:' in line or '总体统计:' in line or '有效样本数:' in line:
print(line)
if result.returncode == 0:
print("\n✅ 修复后的eval_model.py运行成功")
else:
print(f"\n❌ 运行失败,错误码: {result.returncode}")
if result.stderr:
print("错误信息:")
print(result.stderr[:500])
except subprocess.TimeoutExpired:
print("❌ 运行超时")
except Exception as e:
print(f"❌ 运行出错: {e}")
if __name__ == "__main__":
demonstrate_correct_fix()
create_final_fixed_eval_model()
test_final_fix()