位置编码从复数变为两次实数计算

This commit is contained in:
Jax922 2025-05-13 10:50:10 +08:00
parent 7cf4228401
commit 7ba51b8571
3 changed files with 201 additions and 13 deletions

@ -31,7 +31,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
# precompute_pos_cis 函数用于预计算位置编码
# precompute_pos_cis 函数用于预计算位置编码(复数版本)
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
@ -39,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return pos_cis
# apply_rotary_emb 函数用于应用旋转位置编码
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)
def apply_rotary_emb(xq, xk, pos_cis):
def unite_shape(pos_cis, x):
ndim = x.ndim
@ -55,6 +55,92 @@ def apply_rotary_emb(xq, xk, pos_cis):
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
"""使用实数张量实现位置编码,避免使用复数张量
这个函数与precompute_pos_cis完全等价但使用实数张量而非复数张量
原始函数生成形状为[seq_len, dim//2]的复数张量其中实部全为1虚部为旋转角度
这个函数生成形状为[seq_len, dim]的实数张量其中偶数索引是cos(角度)奇数索引是sin(角度)
"""
# 确保dim是偶数
if dim % 2 != 0:
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
# 复制原始函数的频率计算逻辑
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
# 计算cos和sin值
# 在复数版本中pos_cis = torch.polar(torch.ones_like(freqs), freqs)
# 等价于 cos(freqs) + i*sin(freqs)
cos = torch.cos(freqs)
sin = torch.sin(freqs)
# 创建实数张量交错排列cos和sin
pos_emb = torch.zeros((end, dim), device=freqs.device)
pos_emb[:, 0::2] = cos # 偶数索引放cos
pos_emb[:, 1::2] = sin # 奇数索引放sin
return pos_emb
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
def apply_rotary_emb_real(xq, xk, pos_emb):
"""使用实数张量实现旋转位置编码,避免使用复数张量
这个函数与apply_rotary_emb完全等价但使用实数张量而非复数张量
原始函数将输入张量转换为复数形式与位置编码相乘然后再转回实数形式
这个函数直接使用实数运算实现相同的旋转操作
"""
# 获取形状信息
bsz, seq_len, n_heads, head_dim = xq.shape
# 确保pos_emb形状正确
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
# 截取需要的位置编码长度
pos_emb = pos_emb[:seq_len]
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
# 将head_dim分成两半
half_head_dim = head_dim // 2
# 提取cos和sin值偶数索引是cos奇数索引是sin
cos = pos_emb[..., 0::2]
sin = pos_emb[..., 1::2]
# 将xq和xk重新排列以便进行旋转操作
# 原始复数版本中xq和xk被重塑为复数张量其中实部和虚部交错排列
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
# 分离偶数和奇数索引
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# 应用旋转(等价于复数乘法)
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
# 其中a是偶数索引b是奇数索引
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
xk_out_even = xk_even * cos - xk_odd * sin
xk_out_odd = xk_even * sin + xk_odd * cos
# 重新组合偶数和奇数索引
xq_out = torch.zeros_like(xq)
xk_out = torch.zeros_like(xk)
xq_out[..., 0::2] = xq_out_even
xq_out[..., 1::2] = xq_out_odd
xk_out[..., 0::2] = xk_out_even
xk_out[..., 1::2] = xk_out_odd
return xq_out.type_as(xq), xk_out.type_as(xk)
# repeat_kv 函数用于重复键值对。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
@ -102,8 +188,8 @@ class Attention(nn.Module):
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
# 应用旋转位置编码
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
# 应用旋转位置编码(使用实数版本)
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
# kv_cache实现
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
@ -548,8 +634,9 @@ class MiniMindLM(PreTrainedModel):
self.downsample_q_specific = nn.Sequential(
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
)
self.register_buffer("pos_cis",
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
self.register_buffer("pos_cis_real",
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
persistent=False)
self.params = params
@ -562,7 +649,7 @@ class MiniMindLM(PreTrainedModel):
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = args.get('start_pos', 0)
h = self.dropout(self.tok_embeddings(input_ids))
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
past_kvs = []
h_list = []
@ -579,7 +666,7 @@ class MiniMindLM(PreTrainedModel):
db_value = self.extract_db.get_data(index)
h, past_kv = layer(
h, db_value, pos_cis,
h, db_value, pos_cis_real,
past_key_value=past_key_values[l],
use_cache=use_cache
)

97
test_real_rope.py Normal file

@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
测试实数版本的位置编码
"""
import torch
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
from model.LMConfig import LMConfig
from model.model import MiniMindLM
def test_pos_encoding_equivalence():
"""测试复数版本和实数版本的位置编码是否等价"""
print("测试位置编码等价性...")
# 参数设置
dim = 64
seq_len = 10
# 生成复数版本的位置编码
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
# 生成实数版本的位置编码
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
# 创建随机查询和键
batch_size = 2
n_heads = 4
head_dim = dim
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
# 应用复数版本的旋转位置编码
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
# 应用实数版本的旋转位置编码
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
# 计算差异
q_diff = torch.abs(xq_complex - xq_real).mean().item()
k_diff = torch.abs(xk_complex - xk_real).mean().item()
print(f"查询差异: {q_diff:.6f}")
print(f"键差异: {k_diff:.6f}")
# 检查差异是否在可接受范围内
tolerance = 1e-5
if q_diff < tolerance and k_diff < tolerance:
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
else:
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
def test_model_forward():
"""测试模型前向传播"""
print("\n测试模型前向传播...")
# 创建模型配置
config = LMConfig(
dim=128,
n_layers=2,
n_heads=4,
n_kv_heads=4, # 确保n_kv_heads被设置且n_heads能被n_kv_heads整除
vocab_size=1000,
max_seq_len=128,
disable_db=True # 禁用数据库功能,避免额外的复杂性
)
# 创建模型
try:
model = MiniMindLM(config)
print(f"✅ 模型初始化成功")
except Exception as e:
print(f"❌ 模型初始化失败: {str(e)}")
return
# 创建输入
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
# 前向传播
try:
with torch.no_grad():
outputs = model(input_ids)
print(f"✅ 模型前向传播成功")
print(f"输出形状: {outputs.logits.shape}")
except Exception as e:
print(f"❌ 模型前向传播失败: {str(e)}")
if __name__ == "__main__":
# 测试位置编码等价性
test_pos_encoding_equivalence()
# 测试模型前向传播
test_model_forward()

@ -331,11 +331,15 @@ def main():
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)
# 处理pos_cis复数张量问题
# 方法1将pos_cis转换为实数张量两个实数张量表示实部和虚部
# 这里我们采用方法2告诉accelerate忽略pos_cis
# 在DeepSpeed模式下我们需要设置DeepSpeed的参数
if hasattr(model, "pos_cis"):
# 处理位置编码张量问题
# 我们已经将复数版本的pos_cis替换为实数版本的pos_cis_real
# 但为了安全起见,我们仍然将其设置为不参与分布式训练
if hasattr(model, "pos_cis_real"):
Logger(f'检测到pos_cis_real实数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
# 兼容旧版本检查是否仍有pos_cis
elif hasattr(model, "pos_cis"):
Logger(f'检测到pos_cis复数张量将其设置为不参与分布式训练', accelerator)
# 设置模型的_ddp_params_and_buffers_to_ignore属性
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}