218 lines
8.0 KiB
Python
218 lines
8.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
最终修复eval_model.py中的位置索引错误
|
||
"""
|
||
|
||
import json
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from transformers import AutoTokenizer
|
||
from model.LMConfig import LMConfig
|
||
from model.model_original import MiniMindLM
|
||
|
||
|
||
def demonstrate_correct_fix():
|
||
"""
|
||
演示正确的修复方法
|
||
"""
|
||
print("🔧 演示正确的修复方法")
|
||
print("="*60)
|
||
|
||
device = 'cuda'
|
||
model_path = 'out/experiment_1_4_0/pretrain_512.pth'
|
||
|
||
# 加载模型
|
||
config = LMConfig(
|
||
dim=512, n_layers=8, n_heads=32, vocab_size=6400, max_seq_len=512,
|
||
dropout=0.0, norm_eps=1e-5, rope_theta=1e6, use_moe=False
|
||
)
|
||
|
||
model = MiniMindLM(config)
|
||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||
|
||
state_dict = torch.load(model_path, map_location=device)
|
||
model.load_state_dict(state_dict, strict=False)
|
||
model.to(device)
|
||
model.eval()
|
||
|
||
# 测试多个样本以验证修复效果
|
||
total_loss_wrong = 0
|
||
total_loss_correct = 0
|
||
valid_samples = 0
|
||
|
||
print("测试样本的loss对比:")
|
||
print("样本 | 错误方法 | 正确方法 | 差异")
|
||
print("-" * 45)
|
||
|
||
with open('dataset/stable/eval_data_from_train.json', 'r', encoding='utf-8') as f:
|
||
for i, line in enumerate(f):
|
||
if i >= 10: # 测试前10个样本
|
||
break
|
||
|
||
sample = json.loads(line.strip())
|
||
text = sample['text']
|
||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||
|
||
if len(tokens) < 130:
|
||
continue
|
||
|
||
input_length = 100
|
||
predict_length = 30
|
||
target_tokens = tokens[input_length:input_length + predict_length]
|
||
|
||
with torch.no_grad():
|
||
full_input = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
target_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 获取完整logits
|
||
outputs = model(full_input)
|
||
logits = outputs.logits
|
||
|
||
# 错误方法 (eval_model.py原来的方法)
|
||
wrong_slice = logits[0, -predict_length:, :].contiguous() # 取最后30个
|
||
loss_wrong = F.cross_entropy(wrong_slice, target_labels, reduction='mean')
|
||
|
||
# 正确方法
|
||
correct_slice = logits[0, input_length-1:input_length+predict_length-1, :].contiguous() # 取99:129
|
||
loss_correct = F.cross_entropy(correct_slice, target_labels, reduction='mean')
|
||
|
||
total_loss_wrong += loss_wrong.item()
|
||
total_loss_correct += loss_correct.item()
|
||
valid_samples += 1
|
||
|
||
diff = loss_wrong.item() - loss_correct.item()
|
||
print(f"{i+1:2} | {loss_wrong.item():8.4f} | {loss_correct.item():8.4f} | {diff:+6.4f}")
|
||
|
||
avg_loss_wrong = total_loss_wrong / valid_samples
|
||
avg_loss_correct = total_loss_correct / valid_samples
|
||
improvement = avg_loss_wrong - avg_loss_correct
|
||
|
||
print("-" * 45)
|
||
print(f"平均 | {avg_loss_wrong:8.4f} | {avg_loss_correct:8.4f} | {improvement:+6.4f}")
|
||
|
||
print(f"\n📊 修复效果:")
|
||
print(f" 错误方法平均loss: {avg_loss_wrong:.4f}")
|
||
print(f" 正确方法平均loss: {avg_loss_correct:.4f}")
|
||
print(f" 改进幅度: {improvement:.4f} ({improvement/avg_loss_wrong*100:.1f}%)")
|
||
print(f" 正确方法更接近训练时的教师强制loss (~2.4)")
|
||
|
||
|
||
def create_final_fixed_eval_model():
|
||
"""
|
||
创建最终修复版的eval_model.py
|
||
"""
|
||
print(f"\n🔧 创建最终修复版的eval_model.py")
|
||
print("="*60)
|
||
|
||
# 读取原始eval_model.py
|
||
with open('eval_model.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 修复evaluate_sample函数中的关键部分
|
||
old_loss_calculation = ''' # 计算loss(使用forward方法)
|
||
# 准备用于loss计算的输入
|
||
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
outputs = model(loss_input_ids, logits_to_keep=predict_length)
|
||
|
||
# 计算loss
|
||
logits = outputs.logits
|
||
loss = None
|
||
if logits is not None:
|
||
# 重塑logits和目标
|
||
shift_logits = logits[0, -predict_length:, :].contiguous()
|
||
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 计算交叉熵损失
|
||
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
|
||
loss = loss.item()'''
|
||
|
||
new_loss_calculation = ''' # 计算loss(使用forward方法)
|
||
# 准备用于loss计算的输入
|
||
loss_input_ids = torch.tensor([tokens[:input_length + predict_length]], dtype=torch.long).to(device)
|
||
outputs = model(loss_input_ids) # 移除logits_to_keep参数
|
||
|
||
# 计算loss
|
||
logits = outputs.logits
|
||
loss = None
|
||
if logits is not None:
|
||
# 重塑logits和目标 - 修复:使用正确的位置切片
|
||
# 在Transformer中,position i的logits预测position i+1的token
|
||
# 要预测position input_length到input_length+predict_length-1的token
|
||
# 需要使用position input_length-1到input_length+predict_length-2的logits
|
||
shift_logits = logits[0, input_length-1:input_length+predict_length-1, :].contiguous()
|
||
shift_labels = torch.tensor(target_tokens, dtype=torch.long).to(device)
|
||
|
||
# 计算交叉熵损失
|
||
loss = F.cross_entropy(shift_logits, shift_labels, reduction='mean')
|
||
loss = loss.item()'''
|
||
|
||
# 替换内容
|
||
fixed_content = content.replace(old_loss_calculation, new_loss_calculation)
|
||
|
||
# 保存修复后的文件
|
||
with open('eval_model_final_fixed.py', 'w', encoding='utf-8') as f:
|
||
f.write(fixed_content)
|
||
|
||
print(f"✅ 创建了最终修复版本:eval_model_final_fixed.py")
|
||
print(f"主要修复:")
|
||
print(f" 1. 移除 logits_to_keep 参数(避免计算差异)")
|
||
print(f" 2. 使用正确的位置切片: [input_length-1:input_length+predict_length-1]")
|
||
print(f" 3. 这考虑了Transformer中position i预测position i+1的特性")
|
||
|
||
# 直接修复原文件
|
||
with open('eval_model.py', 'w', encoding='utf-8') as f:
|
||
f.write(fixed_content)
|
||
|
||
print(f"✅ 同时直接修复了原文件:eval_model.py")
|
||
|
||
|
||
def test_final_fix():
|
||
"""
|
||
测试最终修复版本
|
||
"""
|
||
print(f"\n🧪 测试最终修复版本")
|
||
print("="*60)
|
||
|
||
import subprocess
|
||
|
||
# 运行修复后的eval_model.py,使用较少样本快速测试
|
||
cmd = [
|
||
'.venv/bin/python', 'eval_model.py',
|
||
'--model_path', 'out/experiment_1_4_0/pretrain_512.pth',
|
||
'--model_type', 'model_original',
|
||
'--num_samples', '5',
|
||
'--input_length', '100',
|
||
'--predict_length', '30'
|
||
]
|
||
|
||
print("运行命令:")
|
||
print(" ".join(cmd))
|
||
print("\n运行结果:")
|
||
|
||
try:
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
|
||
|
||
# 提取关键信息
|
||
output_lines = result.stdout.split('\n')
|
||
for line in output_lines:
|
||
if 'Loss:' in line or '平均Loss:' in line or '总体统计:' in line or '有效样本数:' in line:
|
||
print(line)
|
||
|
||
if result.returncode == 0:
|
||
print("\n✅ 修复后的eval_model.py运行成功!")
|
||
else:
|
||
print(f"\n❌ 运行失败,错误码: {result.returncode}")
|
||
if result.stderr:
|
||
print("错误信息:")
|
||
print(result.stderr[:500])
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print("❌ 运行超时")
|
||
except Exception as e:
|
||
print(f"❌ 运行出错: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
demonstrate_correct_fix()
|
||
create_final_fixed_eval_model()
|
||
test_final_fix() |