98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
![]() |
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
测试实数版本的位置编码
|
|||
|
"""
|
|||
|
|
|||
|
import torch
|
|||
|
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
|||
|
from model.LMConfig import LMConfig
|
|||
|
from model.model import MiniMindLM
|
|||
|
|
|||
|
def test_pos_encoding_equivalence():
|
|||
|
"""测试复数版本和实数版本的位置编码是否等价"""
|
|||
|
print("测试位置编码等价性...")
|
|||
|
|
|||
|
# 参数设置
|
|||
|
dim = 64
|
|||
|
seq_len = 10
|
|||
|
|
|||
|
# 生成复数版本的位置编码
|
|||
|
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
|||
|
|
|||
|
# 生成实数版本的位置编码
|
|||
|
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
|||
|
|
|||
|
# 创建随机查询和键
|
|||
|
batch_size = 2
|
|||
|
n_heads = 4
|
|||
|
head_dim = dim
|
|||
|
|
|||
|
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
|||
|
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
|||
|
|
|||
|
# 应用复数版本的旋转位置编码
|
|||
|
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
|||
|
|
|||
|
# 应用实数版本的旋转位置编码
|
|||
|
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
|||
|
|
|||
|
# 计算差异
|
|||
|
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
|||
|
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
|||
|
|
|||
|
print(f"查询差异: {q_diff:.6f}")
|
|||
|
print(f"键差异: {k_diff:.6f}")
|
|||
|
|
|||
|
# 检查差异是否在可接受范围内
|
|||
|
tolerance = 1e-5
|
|||
|
if q_diff < tolerance and k_diff < tolerance:
|
|||
|
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
|||
|
else:
|
|||
|
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
|||
|
|
|||
|
def test_model_forward():
|
|||
|
"""测试模型前向传播"""
|
|||
|
print("\n测试模型前向传播...")
|
|||
|
|
|||
|
# 创建模型配置
|
|||
|
config = LMConfig(
|
|||
|
dim=128,
|
|||
|
n_layers=2,
|
|||
|
n_heads=4,
|
|||
|
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
|||
|
vocab_size=1000,
|
|||
|
max_seq_len=128,
|
|||
|
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
|||
|
)
|
|||
|
|
|||
|
# 创建模型
|
|||
|
try:
|
|||
|
model = MiniMindLM(config)
|
|||
|
print(f"✅ 模型初始化成功")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 模型初始化失败: {str(e)}")
|
|||
|
return
|
|||
|
|
|||
|
# 创建输入
|
|||
|
batch_size = 2
|
|||
|
seq_len = 10
|
|||
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
|||
|
|
|||
|
# 前向传播
|
|||
|
try:
|
|||
|
with torch.no_grad():
|
|||
|
outputs = model(input_ids)
|
|||
|
print(f"✅ 模型前向传播成功")
|
|||
|
print(f"输出形状: {outputs.logits.shape}")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 模型前向传播失败: {str(e)}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 测试位置编码等价性
|
|||
|
test_pos_encoding_equivalence()
|
|||
|
|
|||
|
# 测试模型前向传播
|
|||
|
test_model_forward()
|