98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
测试实数版本的位置编码
|
||
"""
|
||
|
||
import torch
|
||
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||
from model.LMConfig import LMConfig
|
||
from model.model import MiniMindLM
|
||
|
||
def test_pos_encoding_equivalence():
|
||
"""测试复数版本和实数版本的位置编码是否等价"""
|
||
print("测试位置编码等价性...")
|
||
|
||
# 参数设置
|
||
dim = 64
|
||
seq_len = 10
|
||
|
||
# 生成复数版本的位置编码
|
||
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||
|
||
# 生成实数版本的位置编码
|
||
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||
|
||
# 创建随机查询和键
|
||
batch_size = 2
|
||
n_heads = 4
|
||
head_dim = dim
|
||
|
||
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||
|
||
# 应用复数版本的旋转位置编码
|
||
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||
|
||
# 应用实数版本的旋转位置编码
|
||
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||
|
||
# 计算差异
|
||
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||
|
||
print(f"查询差异: {q_diff:.6f}")
|
||
print(f"键差异: {k_diff:.6f}")
|
||
|
||
# 检查差异是否在可接受范围内
|
||
tolerance = 1e-5
|
||
if q_diff < tolerance and k_diff < tolerance:
|
||
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||
else:
|
||
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||
|
||
def test_model_forward():
|
||
"""测试模型前向传播"""
|
||
print("\n测试模型前向传播...")
|
||
|
||
# 创建模型配置
|
||
config = LMConfig(
|
||
dim=128,
|
||
n_layers=2,
|
||
n_heads=4,
|
||
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||
vocab_size=1000,
|
||
max_seq_len=128,
|
||
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||
)
|
||
|
||
# 创建模型
|
||
try:
|
||
model = MiniMindLM(config)
|
||
print(f"✅ 模型初始化成功")
|
||
except Exception as e:
|
||
print(f"❌ 模型初始化失败: {str(e)}")
|
||
return
|
||
|
||
# 创建输入
|
||
batch_size = 2
|
||
seq_len = 10
|
||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||
|
||
# 前向传播
|
||
try:
|
||
with torch.no_grad():
|
||
outputs = model(input_ids)
|
||
print(f"✅ 模型前向传播成功")
|
||
print(f"输出形状: {outputs.logits.shape}")
|
||
except Exception as e:
|
||
print(f"❌ 模型前向传播失败: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
# 测试位置编码等价性
|
||
test_pos_encoding_equivalence()
|
||
|
||
# 测试模型前向传播
|
||
test_model_forward()
|