Minimind/test_real_rope.py

98 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()