Minimind/test_real_rope.py

98 lines
2.8 KiB
Python
Raw Normal View History

2025-05-14 00:01:40 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()