diff --git a/model/model.py b/model/model.py index 848cf8a..8fc5bda 100644 --- a/model/model.py +++ b/model/model.py @@ -12,7 +12,7 @@ from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast - +# RMSNorm 类定义了一个用于归一化输入张量的模块。 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -25,7 +25,7 @@ class RMSNorm(torch.nn.Module): def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) - +# precompute_pos_cis 函数用于预计算位置编码。 def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore @@ -33,7 +33,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis - +# apply_rotary_emb 函数用于应用旋转位置编码。 def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim @@ -49,7 +49,7 @@ def apply_rotary_emb(xq, xk, pos_cis): xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) - +# repeat_kv 函数用于重复键值对。 def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape @@ -88,13 +88,15 @@ class Attention(nn.Module): x: torch.Tensor, pos_cis: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache=False): - bsz, seq_len, _ = x.shape - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) - xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + use_cache=False, + db_value=None): + bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。 + xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。 + xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 + xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 + # 应用旋转位置编码 xq, xk = apply_rotary_emb(xq, xk, pos_cis) # kv_cache实现 if past_key_value is not None: @@ -102,11 +104,40 @@ class Attention(nn.Module): xv = torch.cat([past_key_value[1], xv], dim=1) past_kv = (xk, xv) if use_cache else None + # 重复键值对 xq, xk, xv = ( xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) + + # 如果提供了db_value,根据头的数量调整它的形状并与xv合并 + if db_value is not None: + # 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D] + if db_value.ndim == 4: # [B, N, H, D] + db_value = db_value.transpose(1, 2) # -> [B, H, N, D] + + # 检查是否需要调整D维度 + if db_value.shape[-1] != xv.shape[-1]: + # 如果db_value的维度与xv不同,可以添加一个投影层 + # 或者在这里使用简单的调整方法 + # 这里我们简单地通过均值池化或重复来调整维度 + if db_value.shape[-1] > xv.shape[-1]: + # 降维 + factor = db_value.shape[-1] // xv.shape[-1] + db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1]) + db_value = db_value.mean(dim=3) + else: + # 升维 + factor = xv.shape[-1] // db_value.shape[-1] + db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor) + db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1]) + + # 将db_value与xv相加或融合 + # 这里我们简单地将它们相加,但你也可以使用其他融合方法 + xv = xv + db_value + + # 使用Flash Attention if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( @@ -259,7 +290,7 @@ class MOEFeedForward(nn.Module): class MiniMindBlock(nn.Module): - def __init__(self, layer_id: int, config: LMConfig): + def __init__(self, layer_id: int, config: LMConfig, weight_down_embed=None): super().__init__() self.n_heads = config.n_heads self.dim = config.dim @@ -270,13 +301,86 @@ class MiniMindBlock(nn.Module): self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) + + # Product Key 相关参数 + self.weight_down_embed = weight_down_embed + # 假设num_experts是已定义的总专家数量的平方根 + self.num_keys = int(math.sqrt(self.weight_down_embed.num_embeddings)) if weight_down_embed is not None else 0 + + # 查询生成的参数 + self.dim_key = config.dim // 2 # 一般用特征维度的一半 + + # 创建查询生成模块 + if weight_down_embed is not None: + self.to_queries = nn.Sequential( + nn.Linear(config.dim, self.dim_key * self.n_heads * 2, bias=False), + nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange + ) + + # 存储Product Keys + self.keys = nn.Parameter(torch.randn(self.n_heads, self.num_keys, 2, self.dim_key) * 0.02) + + # 超参数 + self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys + self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 def forward(self, x, pos_cis, past_key_value=None, use_cache=False): + db_value = None + + # 如果有weight_down_embed,使用Product Key机制 + if self.weight_down_embed is not None: + # 1. 生成queries + queries = self.to_queries(x) # [b, n, 2, h, d] + queries = queries.permute(2, 0, 1, 3, 4) # [2, b, n, h, d] + + # 2. 计算queries与keys的相似度 + sim = torch.einsum('p b n h d, h k p d -> p b n h k', queries, self.keys) + + # 3. 在两个子空间分别做top-k + scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] + scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] + indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] + + # 4. 组合两个子空间的分数和索引 + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) + + all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) + + # 5. 最终top-k选择 + scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) + indices = all_indices.gather(-1, pk_indices) + + # 6. 从embedding中获取专家值 + # [b, n, h, k] -> [b, n, h, k, 1] -> [b, n, h, k, d] + indices_expanded = indices.unsqueeze(-1).expand(-1, -1, -1, -1, self.weight_down_embed.embedding_dim) + + # 将索引从3D展平为1D以便gather操作 + batch_size, seq_len = x.shape[0], x.shape[1] + flat_indices = indices.view(-1) + + # 从embedding中获取值 + db_values = self.weight_down_embed(flat_indices) + + # 重塑回原始形状 + db_value = db_values.view(batch_size, seq_len, self.n_heads, + self.num_experts_per_head_topk, -1) + + # 使用分数加权 + db_value = db_value * F.relu(scores.unsqueeze(-1)) + + # 合并多个专家的输出(如果每个头有多个专家) + if self.num_experts_per_head_topk > 1: + db_value = db_value.sum(dim=3) # [b, n, h, d] + + # 注意力计算 h_attn, past_kv = self.attention( self.attention_norm(x), pos_cis, past_key_value=past_key_value, - use_cache=use_cache + use_cache=use_cache, + db_value=db_value ) h = x + h_attn out = h + self.feed_forward(self.ffn_norm(h)) @@ -292,7 +396,20 @@ class MiniMindLM(PreTrainedModel): self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) - self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) + + # 修改专家数量和知识维度,确保能开方 + self.num_experts = 1000 * 1000 # 1M专家,确保是完全平方数 + # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 + self.head_dim = params.dim // params.n_heads + self.knowledge_dim = self.head_dim + + # 定义weight_down_embed,用于存储专家知识 + self.weight_down_embed = nn.Embedding(self.num_experts, self.knowledge_dim) + # 初始化embedding权重 + nn.init.normal_(self.weight_down_embed.weight, std=0.02) + + # 将self.weight_down_embed传递给每个MiniMindBlock + self.layers = nn.ModuleList([MiniMindBlock(l, params, self.weight_down_embed) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight