#!/usr/bin/env python # -*- coding: utf-8 -*- """ 测试实数版本的位置编码 """ import torch from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real from model.LMConfig import LMConfig from model.model import MiniMindLM def test_pos_encoding_equivalence(): """测试复数版本和实数版本的位置编码是否等价""" print("测试位置编码等价性...") # 参数设置 dim = 64 seq_len = 10 # 生成复数版本的位置编码 pos_cis = precompute_pos_cis(dim=dim, end=seq_len) # 生成实数版本的位置编码 pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len) # 创建随机查询和键 batch_size = 2 n_heads = 4 head_dim = dim xq = torch.randn(batch_size, seq_len, n_heads, head_dim) xk = torch.randn(batch_size, seq_len, n_heads, head_dim) # 应用复数版本的旋转位置编码 xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis) # 应用实数版本的旋转位置编码 xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real) # 计算差异 q_diff = torch.abs(xq_complex - xq_real).mean().item() k_diff = torch.abs(xk_complex - xk_real).mean().item() print(f"查询差异: {q_diff:.6f}") print(f"键差异: {k_diff:.6f}") # 检查差异是否在可接受范围内 tolerance = 1e-5 if q_diff < tolerance and k_diff < tolerance: print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价") else: print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异") def test_model_forward(): """测试模型前向传播""" print("\n测试模型前向传播...") # 创建模型配置 config = LMConfig( dim=128, n_layers=2, n_heads=4, n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除 vocab_size=1000, max_seq_len=128, disable_db=True # 禁用数据库功能,避免额外的复杂性 ) # 创建模型 try: model = MiniMindLM(config) print(f"✅ 模型初始化成功") except Exception as e: print(f"❌ 模型初始化失败: {str(e)}") return # 创建输入 batch_size = 2 seq_len = 10 input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) # 前向传播 try: with torch.no_grad(): outputs = model(input_ids) print(f"✅ 模型前向传播成功") print(f"输出形状: {outputs.logits.shape}") except Exception as e: print(f"❌ 模型前向传播失败: {str(e)}") if __name__ == "__main__": # 测试位置编码等价性 test_pos_encoding_equivalence() # 测试模型前向传播 test_model_forward()