位置编码从复数变为两次实数计算
This commit is contained in:
parent
7cf4228401
commit
7ba51b8571
103
model/model.py
103
model/model.py
@ -31,7 +31,7 @@ class RMSNorm(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.weight * self._norm(x.float()).type_as(x)
|
return self.weight * self._norm(x.float()).type_as(x)
|
||||||
|
|
||||||
# precompute_pos_cis 函数用于预计算位置编码。
|
# precompute_pos_cis 函数用于预计算位置编码(复数版本)。
|
||||||
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
@ -39,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
|||||||
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||||
return pos_cis
|
return pos_cis
|
||||||
|
|
||||||
# apply_rotary_emb 函数用于应用旋转位置编码。
|
# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。
|
||||||
def apply_rotary_emb(xq, xk, pos_cis):
|
def apply_rotary_emb(xq, xk, pos_cis):
|
||||||
def unite_shape(pos_cis, x):
|
def unite_shape(pos_cis, x):
|
||||||
ndim = x.ndim
|
ndim = x.ndim
|
||||||
@ -55,6 +55,92 @@ def apply_rotary_emb(xq, xk, pos_cis):
|
|||||||
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
|
||||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
|
# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。
|
||||||
|
def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6):
|
||||||
|
"""使用实数张量实现位置编码,避免使用复数张量
|
||||||
|
|
||||||
|
这个函数与precompute_pos_cis完全等价,但使用实数张量而非复数张量。
|
||||||
|
原始函数生成形状为[seq_len, dim//2]的复数张量,其中实部全为1,虚部为旋转角度。
|
||||||
|
这个函数生成形状为[seq_len, dim]的实数张量,其中偶数索引是cos(角度),奇数索引是sin(角度)。
|
||||||
|
"""
|
||||||
|
# 确保dim是偶数
|
||||||
|
if dim % 2 != 0:
|
||||||
|
raise ValueError(f"维度必须是偶数,但得到了 {dim}")
|
||||||
|
|
||||||
|
# 复制原始函数的频率计算逻辑
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device)
|
||||||
|
freqs = torch.outer(t, freqs).float()
|
||||||
|
|
||||||
|
# 计算cos和sin值
|
||||||
|
# 在复数版本中,pos_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
# 等价于 cos(freqs) + i*sin(freqs)
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
|
||||||
|
# 创建实数张量,交错排列cos和sin
|
||||||
|
pos_emb = torch.zeros((end, dim), device=freqs.device)
|
||||||
|
pos_emb[:, 0::2] = cos # 偶数索引放cos
|
||||||
|
pos_emb[:, 1::2] = sin # 奇数索引放sin
|
||||||
|
|
||||||
|
return pos_emb
|
||||||
|
|
||||||
|
# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。
|
||||||
|
def apply_rotary_emb_real(xq, xk, pos_emb):
|
||||||
|
"""使用实数张量实现旋转位置编码,避免使用复数张量
|
||||||
|
|
||||||
|
这个函数与apply_rotary_emb完全等价,但使用实数张量而非复数张量。
|
||||||
|
原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。
|
||||||
|
这个函数直接使用实数运算实现相同的旋转操作。
|
||||||
|
"""
|
||||||
|
# 获取形状信息
|
||||||
|
bsz, seq_len, n_heads, head_dim = xq.shape
|
||||||
|
|
||||||
|
# 确保pos_emb形状正确
|
||||||
|
assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}"
|
||||||
|
assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配"
|
||||||
|
|
||||||
|
# 截取需要的位置编码长度
|
||||||
|
pos_emb = pos_emb[:seq_len]
|
||||||
|
|
||||||
|
# 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim]
|
||||||
|
pos_emb = pos_emb.unsqueeze(0).unsqueeze(2)
|
||||||
|
|
||||||
|
# 将head_dim分成两半
|
||||||
|
half_head_dim = head_dim // 2
|
||||||
|
|
||||||
|
# 提取cos和sin值(偶数索引是cos,奇数索引是sin)
|
||||||
|
cos = pos_emb[..., 0::2]
|
||||||
|
sin = pos_emb[..., 1::2]
|
||||||
|
|
||||||
|
# 将xq和xk重新排列,以便进行旋转操作
|
||||||
|
# 原始复数版本中,xq和xk被重塑为复数张量,其中实部和虚部交错排列
|
||||||
|
# 在实数版本中,我们需要将偶数索引和奇数索引分开处理
|
||||||
|
|
||||||
|
# 分离偶数和奇数索引
|
||||||
|
xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部
|
||||||
|
xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部
|
||||||
|
xk_even = xk[..., 0::2]
|
||||||
|
xk_odd = xk[..., 1::2]
|
||||||
|
|
||||||
|
# 应用旋转(等价于复数乘法)
|
||||||
|
# (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||||
|
# 其中a是偶数索引,b是奇数索引
|
||||||
|
xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部)
|
||||||
|
xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部)
|
||||||
|
xk_out_even = xk_even * cos - xk_odd * sin
|
||||||
|
xk_out_odd = xk_even * sin + xk_odd * cos
|
||||||
|
|
||||||
|
# 重新组合偶数和奇数索引
|
||||||
|
xq_out = torch.zeros_like(xq)
|
||||||
|
xk_out = torch.zeros_like(xk)
|
||||||
|
xq_out[..., 0::2] = xq_out_even
|
||||||
|
xq_out[..., 1::2] = xq_out_odd
|
||||||
|
xk_out[..., 0::2] = xk_out_even
|
||||||
|
xk_out[..., 1::2] = xk_out_odd
|
||||||
|
|
||||||
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||||
|
|
||||||
# repeat_kv 函数用于重复键值对。
|
# repeat_kv 函数用于重复键值对。
|
||||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
||||||
@ -102,8 +188,8 @@ class Attention(nn.Module):
|
|||||||
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||||
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。
|
||||||
|
|
||||||
# 应用旋转位置编码
|
# 应用旋转位置编码(使用实数版本)
|
||||||
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
|
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||||
# kv_cache实现
|
# kv_cache实现
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
xk = torch.cat([past_key_value[0], xk], dim=1)
|
xk = torch.cat([past_key_value[0], xk], dim=1)
|
||||||
@ -548,8 +634,9 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.downsample_q_specific = nn.Sequential(
|
self.downsample_q_specific = nn.Sequential(
|
||||||
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
nn.Conv1d(128*8, 512, kernel_size=1, padding='same')
|
||||||
)
|
)
|
||||||
self.register_buffer("pos_cis",
|
# 使用实数版本的位置编码,避免复数张量可能导致的段错误
|
||||||
precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
self.register_buffer("pos_cis_real",
|
||||||
|
precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
@ -562,7 +649,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
past_key_values = past_key_values or [None] * len(self.layers)
|
past_key_values = past_key_values or [None] * len(self.layers)
|
||||||
start_pos = args.get('start_pos', 0)
|
start_pos = args.get('start_pos', 0)
|
||||||
h = self.dropout(self.tok_embeddings(input_ids))
|
h = self.dropout(self.tok_embeddings(input_ids))
|
||||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)]
|
||||||
past_kvs = []
|
past_kvs = []
|
||||||
h_list = []
|
h_list = []
|
||||||
|
|
||||||
@ -579,7 +666,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
db_value = self.extract_db.get_data(index)
|
db_value = self.extract_db.get_data(index)
|
||||||
|
|
||||||
h, past_kv = layer(
|
h, past_kv = layer(
|
||||||
h, db_value, pos_cis,
|
h, db_value, pos_cis_real,
|
||||||
past_key_value=past_key_values[l],
|
past_key_value=past_key_values[l],
|
||||||
use_cache=use_cache
|
use_cache=use_cache
|
||||||
)
|
)
|
||||||
|
97
test_real_rope.py
Normal file
97
test_real_rope.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
测试实数版本的位置编码
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real
|
||||||
|
from model.LMConfig import LMConfig
|
||||||
|
from model.model import MiniMindLM
|
||||||
|
|
||||||
|
def test_pos_encoding_equivalence():
|
||||||
|
"""测试复数版本和实数版本的位置编码是否等价"""
|
||||||
|
print("测试位置编码等价性...")
|
||||||
|
|
||||||
|
# 参数设置
|
||||||
|
dim = 64
|
||||||
|
seq_len = 10
|
||||||
|
|
||||||
|
# 生成复数版本的位置编码
|
||||||
|
pos_cis = precompute_pos_cis(dim=dim, end=seq_len)
|
||||||
|
|
||||||
|
# 生成实数版本的位置编码
|
||||||
|
pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len)
|
||||||
|
|
||||||
|
# 创建随机查询和键
|
||||||
|
batch_size = 2
|
||||||
|
n_heads = 4
|
||||||
|
head_dim = dim
|
||||||
|
|
||||||
|
xq = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||||
|
xk = torch.randn(batch_size, seq_len, n_heads, head_dim)
|
||||||
|
|
||||||
|
# 应用复数版本的旋转位置编码
|
||||||
|
xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis)
|
||||||
|
|
||||||
|
# 应用实数版本的旋转位置编码
|
||||||
|
xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real)
|
||||||
|
|
||||||
|
# 计算差异
|
||||||
|
q_diff = torch.abs(xq_complex - xq_real).mean().item()
|
||||||
|
k_diff = torch.abs(xk_complex - xk_real).mean().item()
|
||||||
|
|
||||||
|
print(f"查询差异: {q_diff:.6f}")
|
||||||
|
print(f"键差异: {k_diff:.6f}")
|
||||||
|
|
||||||
|
# 检查差异是否在可接受范围内
|
||||||
|
tolerance = 1e-5
|
||||||
|
if q_diff < tolerance and k_diff < tolerance:
|
||||||
|
print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价")
|
||||||
|
else:
|
||||||
|
print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异")
|
||||||
|
|
||||||
|
def test_model_forward():
|
||||||
|
"""测试模型前向传播"""
|
||||||
|
print("\n测试模型前向传播...")
|
||||||
|
|
||||||
|
# 创建模型配置
|
||||||
|
config = LMConfig(
|
||||||
|
dim=128,
|
||||||
|
n_layers=2,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除
|
||||||
|
vocab_size=1000,
|
||||||
|
max_seq_len=128,
|
||||||
|
disable_db=True # 禁用数据库功能,避免额外的复杂性
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
try:
|
||||||
|
model = MiniMindLM(config)
|
||||||
|
print(f"✅ 模型初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 模型初始化失败: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建输入
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 10
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_ids)
|
||||||
|
print(f"✅ 模型前向传播成功")
|
||||||
|
print(f"输出形状: {outputs.logits.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 模型前向传播失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试位置编码等价性
|
||||||
|
test_pos_encoding_equivalence()
|
||||||
|
|
||||||
|
# 测试模型前向传播
|
||||||
|
test_model_forward()
|
@ -331,11 +331,15 @@ def main():
|
|||||||
# 将accelerator传递给init_model函数中的Logger调用
|
# 将accelerator传递给init_model函数中的Logger调用
|
||||||
Logger(f'模型初始化完成', accelerator)
|
Logger(f'模型初始化完成', accelerator)
|
||||||
|
|
||||||
# 处理pos_cis复数张量问题
|
# 处理位置编码张量问题
|
||||||
# 方法1:将pos_cis转换为实数张量(两个实数张量表示实部和虚部)
|
# 我们已经将复数版本的pos_cis替换为实数版本的pos_cis_real
|
||||||
# 这里我们采用方法2:告诉accelerate忽略pos_cis
|
# 但为了安全起见,我们仍然将其设置为不参与分布式训练
|
||||||
# 在DeepSpeed模式下,我们需要设置DeepSpeed的参数
|
if hasattr(model, "pos_cis_real"):
|
||||||
if hasattr(model, "pos_cis"):
|
Logger(f'检测到pos_cis_real实数张量,将其设置为不参与分布式训练', accelerator)
|
||||||
|
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||||
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"}
|
||||||
|
# 兼容旧版本,检查是否仍有pos_cis
|
||||||
|
elif hasattr(model, "pos_cis"):
|
||||||
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator)
|
||||||
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
# 设置模型的_ddp_params_and_buffers_to_ignore属性
|
||||||
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user