diff --git a/model/model.py b/model/model.py index 54d5599..a98cf20 100644 --- a/model/model.py +++ b/model/model.py @@ -31,7 +31,7 @@ class RMSNorm(torch.nn.Module): def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) -# precompute_pos_cis 函数用于预计算位置编码。 +# precompute_pos_cis 函数用于预计算位置编码(复数版本)。 def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore @@ -39,7 +39,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis -# apply_rotary_emb 函数用于应用旋转位置编码。 +# apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。 def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim @@ -55,6 +55,92 @@ def apply_rotary_emb(xq, xk, pos_cis): xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) +# precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。 +def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6): + """使用实数张量实现位置编码,避免使用复数张量 + + 这个函数与precompute_pos_cis完全等价,但使用实数张量而非复数张量。 + 原始函数生成形状为[seq_len, dim//2]的复数张量,其中实部全为1,虚部为旋转角度。 + 这个函数生成形状为[seq_len, dim]的实数张量,其中偶数索引是cos(角度),奇数索引是sin(角度)。 + """ + # 确保dim是偶数 + if dim % 2 != 0: + raise ValueError(f"维度必须是偶数,但得到了 {dim}") + + # 复制原始函数的频率计算逻辑 + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + + # 计算cos和sin值 + # 在复数版本中,pos_cis = torch.polar(torch.ones_like(freqs), freqs) + # 等价于 cos(freqs) + i*sin(freqs) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + + # 创建实数张量,交错排列cos和sin + pos_emb = torch.zeros((end, dim), device=freqs.device) + pos_emb[:, 0::2] = cos # 偶数索引放cos + pos_emb[:, 1::2] = sin # 奇数索引放sin + + return pos_emb + +# apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。 +def apply_rotary_emb_real(xq, xk, pos_emb): + """使用实数张量实现旋转位置编码,避免使用复数张量 + + 这个函数与apply_rotary_emb完全等价,但使用实数张量而非复数张量。 + 原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。 + 这个函数直接使用实数运算实现相同的旋转操作。 + """ + # 获取形状信息 + bsz, seq_len, n_heads, head_dim = xq.shape + + # 确保pos_emb形状正确 + assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}" + assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配" + + # 截取需要的位置编码长度 + pos_emb = pos_emb[:seq_len] + + # 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim] + pos_emb = pos_emb.unsqueeze(0).unsqueeze(2) + + # 将head_dim分成两半 + half_head_dim = head_dim // 2 + + # 提取cos和sin值(偶数索引是cos,奇数索引是sin) + cos = pos_emb[..., 0::2] + sin = pos_emb[..., 1::2] + + # 将xq和xk重新排列,以便进行旋转操作 + # 原始复数版本中,xq和xk被重塑为复数张量,其中实部和虚部交错排列 + # 在实数版本中,我们需要将偶数索引和奇数索引分开处理 + + # 分离偶数和奇数索引 + xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部 + xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部 + xk_even = xk[..., 0::2] + xk_odd = xk[..., 1::2] + + # 应用旋转(等价于复数乘法) + # (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i + # 其中a是偶数索引,b是奇数索引 + xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部) + xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部) + xk_out_even = xk_even * cos - xk_odd * sin + xk_out_odd = xk_even * sin + xk_odd * cos + + # 重新组合偶数和奇数索引 + xq_out = torch.zeros_like(xq) + xk_out = torch.zeros_like(xk) + xq_out[..., 0::2] = xq_out_even + xq_out[..., 1::2] = xq_out_odd + xk_out[..., 0::2] = xk_out_even + xk_out[..., 1::2] = xk_out_odd + + return xq_out.type_as(xq), xk_out.type_as(xk) + # repeat_kv 函数用于重复键值对。 def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" @@ -102,8 +188,8 @@ class Attention(nn.Module): xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 - # 应用旋转位置编码 - xq, xk = apply_rotary_emb(xq, xk, pos_cis) + # 应用旋转位置编码(使用实数版本) + xq, xk = apply_rotary_emb_real(xq, xk, pos_cis) # kv_cache实现 if past_key_value is not None: xk = torch.cat([past_key_value[0], xk], dim=1) @@ -548,8 +634,9 @@ class MiniMindLM(PreTrainedModel): self.downsample_q_specific = nn.Sequential( nn.Conv1d(128*8, 512, kernel_size=1, padding='same') ) - self.register_buffer("pos_cis", - precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), + # 使用实数版本的位置编码,避免复数张量可能导致的段错误 + self.register_buffer("pos_cis_real", + precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.params = params @@ -562,7 +649,7 @@ class MiniMindLM(PreTrainedModel): past_key_values = past_key_values or [None] * len(self.layers) start_pos = args.get('start_pos', 0) h = self.dropout(self.tok_embeddings(input_ids)) - pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] + pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)] past_kvs = [] h_list = [] @@ -579,7 +666,7 @@ class MiniMindLM(PreTrainedModel): db_value = self.extract_db.get_data(index) h, past_kv = layer( - h, db_value, pos_cis, + h, db_value, pos_cis_real, past_key_value=past_key_values[l], use_cache=use_cache ) diff --git a/test_real_rope.py b/test_real_rope.py new file mode 100644 index 0000000..fe65292 --- /dev/null +++ b/test_real_rope.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +测试实数版本的位置编码 +""" + +import torch +from model.model import precompute_pos_cis, precompute_pos_cis_real, apply_rotary_emb, apply_rotary_emb_real +from model.LMConfig import LMConfig +from model.model import MiniMindLM + +def test_pos_encoding_equivalence(): + """测试复数版本和实数版本的位置编码是否等价""" + print("测试位置编码等价性...") + + # 参数设置 + dim = 64 + seq_len = 10 + + # 生成复数版本的位置编码 + pos_cis = precompute_pos_cis(dim=dim, end=seq_len) + + # 生成实数版本的位置编码 + pos_cis_real = precompute_pos_cis_real(dim=dim, end=seq_len) + + # 创建随机查询和键 + batch_size = 2 + n_heads = 4 + head_dim = dim + + xq = torch.randn(batch_size, seq_len, n_heads, head_dim) + xk = torch.randn(batch_size, seq_len, n_heads, head_dim) + + # 应用复数版本的旋转位置编码 + xq_complex, xk_complex = apply_rotary_emb(xq, xk, pos_cis) + + # 应用实数版本的旋转位置编码 + xq_real, xk_real = apply_rotary_emb_real(xq, xk, pos_cis_real) + + # 计算差异 + q_diff = torch.abs(xq_complex - xq_real).mean().item() + k_diff = torch.abs(xk_complex - xk_real).mean().item() + + print(f"查询差异: {q_diff:.6f}") + print(f"键差异: {k_diff:.6f}") + + # 检查差异是否在可接受范围内 + tolerance = 1e-5 + if q_diff < tolerance and k_diff < tolerance: + print("✅ 测试通过: 复数版本和实数版本的位置编码在数值上等价") + else: + print("❌ 测试失败: 复数版本和实数版本的位置编码存在显著差异") + +def test_model_forward(): + """测试模型前向传播""" + print("\n测试模型前向传播...") + + # 创建模型配置 + config = LMConfig( + dim=128, + n_layers=2, + n_heads=4, + n_kv_heads=4, # 确保n_kv_heads被设置,且n_heads能被n_kv_heads整除 + vocab_size=1000, + max_seq_len=128, + disable_db=True # 禁用数据库功能,避免额外的复杂性 + ) + + # 创建模型 + try: + model = MiniMindLM(config) + print(f"✅ 模型初始化成功") + except Exception as e: + print(f"❌ 模型初始化失败: {str(e)}") + return + + # 创建输入 + batch_size = 2 + seq_len = 10 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) + + # 前向传播 + try: + with torch.no_grad(): + outputs = model(input_ids) + print(f"✅ 模型前向传播成功") + print(f"输出形状: {outputs.logits.shape}") + except Exception as e: + print(f"❌ 模型前向传播失败: {str(e)}") + +if __name__ == "__main__": + # 测试位置编码等价性 + test_pos_encoding_equivalence() + + # 测试模型前向传播 + test_model_forward() diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index e72f46f..d675d12 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -331,11 +331,15 @@ def main(): # 将accelerator传递给init_model函数中的Logger调用 Logger(f'模型初始化完成', accelerator) - # 处理pos_cis复数张量问题 - # 方法1:将pos_cis转换为实数张量(两个实数张量表示实部和虚部) - # 这里我们采用方法2:告诉accelerate忽略pos_cis - # 在DeepSpeed模式下,我们需要设置DeepSpeed的参数 - if hasattr(model, "pos_cis"): + # 处理位置编码张量问题 + # 我们已经将复数版本的pos_cis替换为实数版本的pos_cis_real + # 但为了安全起见,我们仍然将其设置为不参与分布式训练 + if hasattr(model, "pos_cis_real"): + Logger(f'检测到pos_cis_real实数张量,将其设置为不参与分布式训练', accelerator) + # 设置模型的_ddp_params_and_buffers_to_ignore属性 + model._ddp_params_and_buffers_to_ignore = {"pos_cis_real"} + # 兼容旧版本,检查是否仍有pos_cis + elif hasattr(model, "pos_cis"): Logger(f'检测到pos_cis复数张量,将其设置为不参与分布式训练', accelerator) # 设置模型的_ddp_params_and_buffers_to_ignore属性 model._ddp_params_and_buffers_to_ignore = {"pos_cis"}