#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
测试实数版本的位置编码
"""

import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM

def test_pos_encoding_equivalence():
    """测试复数版本和实数版本的位置编码是否等价"""
    print("测试位置编码等价性...")

    # 参数设置
    dim = 64
    seq_len = 10

    # 生成复数版本的位置编码
    pos_cis = precompute_pos_cis(dim=dim, end=seq_len)

    # 生成实数版本的位置编码
    pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)

    # 创建随机查询和键
    batch_size = 2
    n_heads = 4
    head_dim = dim

    xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
    xk = torch.randn(batch_size, seq_len, n_heads, head_dim)

    # 应用复数版本的旋转位置编码
    xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)

    # 应用实数版本的旋转位置编码
    xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)

    # 计算差异
    q_diff = torch.abs(xq_complex - xq_real).mean().item()
    k_diff = torch.abs(xk_complex - xk_real).mean().item()

    print(f"查询差异: {q_diff:.6f}")
    print(f"键差异: {k_diff:.6f}")

    # 检查差异是否在可接受范围内
    tolerance = 1e-5
    if q_diff < tolerance and k_diff < tolerance:
        print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
    else:
        print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")

def test_model_forward():
    """测试模型前向传播"""
    print("\n测试模型前向传播...")

    # 创建模型配置
    config = LMConfig(
        dim=128,
        n_layers=2,
        n_heads=4,
        n_kv_heads=4,  # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
        vocab_size=1000,
        max_seq_len=128,
        disable_db=True  # 禁用数据库功能,避免额外的复杂性
    )

    # 创建模型
    try:
        model = MiniMindLM(config)
        print(f"✅ 模型初始化成功")
    except Exception as e:
        print(f"❌ 模型初始化失败: {str(e)}")
        return

    # 创建输入
    batch_size = 2
    seq_len = 10
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))

    # 前向传播
    try:
        with torch.no_grad():
            outputs = model(input_ids)
        print(f"✅ 模型前向传播成功")
        print(f"输出形状: {outputs.logits.shape}")
    except Exception as e:
        print(f"❌ 模型前向传播失败: {str(e)}")

if __name__ == "__main__":
    # 测试位置编码等价性
    test_pos_encoding_equivalence()

    # 测试模型前向传播
    test_model_forward()