From 5351ae8a6ae3a5ae8fe0761aa8acce15a2a79cab Mon Sep 17 00:00:00 2001 From: iomgaa Date: Sun, 11 May 2025 11:58:13 +0800 Subject: [PATCH] =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E5=B0=BA=E5=AF=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/model.py | 315 +++---------------------------------------------- 1 file changed, 15 insertions(+), 300 deletions(-) diff --git a/model/model.py b/model/model.py index 31aca85..6e2dcb7 100644 --- a/model/model.py +++ b/model/model.py @@ -11,14 +11,8 @@ import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast -from torch import nn, einsum -from einops import rearrange, repeat - -def exists(val): - return val is not None -# RMSNorm 类定义了一个用于归一化输入张量的模块。 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -31,7 +25,7 @@ class RMSNorm(torch.nn.Module): def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) -# precompute_pos_cis 函数用于预计算位置编码。 + def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore @@ -39,7 +33,7 @@ def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis -# apply_rotary_emb 函数用于应用旋转位置编码。 + def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim @@ -55,7 +49,7 @@ def apply_rotary_emb(xq, xk, pos_cis): xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) -# repeat_kv 函数用于重复键值对。 + def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape @@ -94,15 +88,13 @@ class Attention(nn.Module): x: torch.Tensor, pos_cis: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache=False, - db_value=None): - bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 - xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。 - xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。 - xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 - xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 + use_cache=False): + bsz, seq_len, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) - # 应用旋转位置编码 xq, xk = apply_rotary_emb(xq, xk, pos_cis) # kv_cache实现 if past_key_value is not None: @@ -110,40 +102,11 @@ class Attention(nn.Module): xv = torch.cat([past_key_value[1], xv], dim=1) past_kv = (xk, xv) if use_cache else None - # 重复键值对 xq, xk, xv = ( xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) - - # 如果提供了db_value,根据头的数量调整它的形状并与xv合并 - if db_value is not None: - # 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D] - if db_value.ndim == 4: # [B, N, H, D] - db_value = db_value.transpose(1, 2) # -> [B, H, N, D] - - # 检查是否需要调整D维度 - if db_value.shape[-1] != xv.shape[-1]: - # 如果db_value的维度与xv不同,可以添加一个投影层 - # 或者在这里使用简单的调整方法 - # 这里我们简单地通过均值池化或重复来调整维度 - if db_value.shape[-1] > xv.shape[-1]: - # 降维 - factor = db_value.shape[-1] // xv.shape[-1] - db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1]) - db_value = db_value.mean(dim=3) - else: - # 升维 - factor = xv.shape[-1] // db_value.shape[-1] - db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor) - db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1]) - - # 将db_value与xv相加或融合 - # 这里我们简单地将它们相加,但你也可以使用其他融合方法 - xv = xv + db_value - - # 使用Flash Attention if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( @@ -164,53 +127,6 @@ class Attention(nn.Module): return output, past_kv - - -class CrossAttention(nn.Module): - def __init__( - self, - config - ): - super().__init__() - self.config = config - self.num_heads = 8 - self.head_dim = self.config.dim // self.num_heads - self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) - self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) - self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) - - self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) - - def forward(self, x, db, context_mask=None, pos_emb=None): - batch_size = x.size(0) - - # 分离多头 - q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - - if pos_emb is not None: - pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - q = q + pos_emb - k = k + pos_emb - v = v + pos_emb - - attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) - - if context_mask is not None: - expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) - attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) - - attn_weights = F.softmax(attn_scores, dim=-1) - - context = torch.matmul(attn_weights, v) - - context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) - - context = self.to_out(context) - - return context - class FeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() @@ -349,162 +265,23 @@ class MiniMindBlock(nn.Module): self.dim = config.dim self.head_dim = config.dim // config.n_heads self.attention = Attention(config) - self.cross_att = CrossAttention(config) - + self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) - - # 假设num_experts是已定义的总专家数量的平方根 - - - # 查询生成的参数 - - - # 创建查询生成模块 - # if weight_down_embed is not None: - # self.to_queries = nn.Sequential( - # nn.Linear(config.dim, self.dim_key * 2, bias=False), - # # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange - # ) - - # # 超参数 - # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys - # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 - def forward(self, x, db_value, pos_cis, past_key_value=None, use_cache=False): - # import pdb;pdb.set_trace() - # db_value = None - - # # 如果有weight_down_embed,使用Product Key机制 - # if self.weight_down_embed is not None: - # # 1. 生成queries - # batch_size, seq_len, dim = x.shape - - # # collapse sequence dimension by averaging - # x_flat = x.mean(dim=1) # [batch_size, dim] - # queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] - # queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] - # queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - - # # 2. 计算queries与keys的相似度 - # sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - - # # 3. 在两个子空间分别做top-k - # scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] - # scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] - # indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - - # # 4. 组合两个子空间的分数和索引 - # all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) - # all_scores = all_scores.view(*all_scores.shape[:-2], -1) - - # all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) - # all_indices = all_indices.view(*all_indices.shape[:-2], -1) - - # # 5. 最终top-k选择 - # scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) - # indices = all_indices.gather(-1, pk_indices) - - # # 6. 从embedding中获取专家值 - - # # 从embedding中获取值 - # flat_indices = indices.view(-1) # 将索引展平为一维张量 - # db_values = self.weight_down_embed(flat_indices) - - # # 重塑回原始形状 - # db_value = db_values.view(batch_size, -1, dim) - - - # 注意力计算 + def forward(self, x, pos_cis, past_key_value=None, use_cache=False): h_attn, past_kv = self.attention( self.attention_norm(x), pos_cis, past_key_value=past_key_value, - use_cache=use_cache, - db_value=db_value + use_cache=use_cache ) - - h_attn = self.cross_att(h_attn, db_value) - - # 残差连接 h = x + h_attn - - # 前馈神经网络 out = h + self.feed_forward(self.ffn_norm(h)) return out, past_kv -class ExtractDB(nn.Module): - def __init__(self,params): - # 修改专家数量和知识维度,确保能开方 - super().__init__() - self.batch_size = None - self.dim = params.dim - self.dim_key = self.dim // 2 - self.num_experts = 10 * 10 # 100专家,确保是完全平方数 - # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 - self.head_dim = params.dim // params.n_heads - self.knowledge_dim = 8*params.dim - - # 使用register_buffer代替nn.Parameter,避免梯度问题 - self.register_buffer('weight_down_embed', torch.randn(self.num_experts, self.knowledge_dim) * 0.02) - - self.num_keys = int(math.sqrt(self.num_experts)) if self.num_experts > 0 else 0 - self.product_key_topk = min(16, self.num_keys) - self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) - self.num_experts_per_head_topk = 1 - self.to_queries = nn.Sequential( - nn.Linear(params.dim, self.dim_key * 2, bias=False), - ) - - def q_to_k(self,x): - # 1. 生成queries - self.batch_size, seq_len, dim = x.shape - - # collapse sequence dimension by averaging - x_flat = x.mean(dim=1) # [batch_size, dim] - - queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] - queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] - queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - - # 2. 计算queries与keys的相似度 - sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - - # 3. 在两个子空间分别做top-k - scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] - scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] - indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - - # 4. 组合两个子空间的分数和索引 - all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) - all_scores = all_scores.view(*all_scores.shape[:-2], -1) - - all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) - all_indices = all_indices.view(*all_indices.shape[:-2], -1) - - # 5. 最终top-k选择 - scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) - indices = all_indices.gather(-1, pk_indices) - flat_indices = indices.view(-1) - return flat_indices - - def get_data(self, index): - # 直接从GPU获取embedding - db_values = self.weight_down_embed[index] - db_value = db_values.view(self.batch_size, -1, self.dim) - return db_value - - @torch.no_grad() - def updata_value(self, k, v): - # 直接更新buffer上的值 (不需要梯度) - v_reshaped = v.view(v.size(0), -1) - # 确保数据类型匹配 - v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) - self.weight_down_embed[k] = v_reshaped - - class MiniMindLM(PreTrainedModel): config_class = LMConfig @@ -515,44 +292,14 @@ class MiniMindLM(PreTrainedModel): self.vocab_size, self.n_layers = params.vocab_size, params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) - # 移除旧的weight_down_embed声明 - self.extract_db = ExtractDB(self.params) - - # 将self.weight_down_embed传递给每个MiniMindBlock self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.tok_embeddings.weight = self.output.weight - - # Calculate input dimension - input_dim = (self.params.max_seq_len-1)*self.params.n_layers - # Use a bottleneck architecture to reduce parameters - bottleneck_dim = 256 # Significantly smaller bottleneck dimension - - # Factorized shared downsampling using two smaller convolutions - self.shared_downsample = nn.Sequential( - # First reduce input dimension to bottleneck - nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'), - nn.ReLU(), # Non-linearity to improve representation capacity - # Then expand to target dimension - nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same') - ) - - # Specific layers for v path - self.downsample_v_specific = nn.Sequential( - nn.Conv1d(128*8, 128, kernel_size=1, padding='same'), - nn.Conv1d(128, 8, kernel_size=1, padding='same') - ) - - # Specific layers for q path - self.downsample_q_specific = nn.Sequential( - nn.Conv1d(128*8, 512, kernel_size=1, padding='same') - ) self.register_buffer("pos_cis", precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.OUT = CausalLMOutputWithPast() - self.params = params def forward(self, input_ids: Optional[torch.Tensor] = None, @@ -565,46 +312,14 @@ class MiniMindLM(PreTrainedModel): h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] past_kvs = [] - h_list = [] - for l, layer in enumerate(self.layers): - # 禁用数据库模式,使用固定值替代数据库查询 - if self.params.disable_db: - # 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4 - batch_size = h.size(0) - db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4, - dtype=h.dtype, device=h.device) - else: - # 正常模式,使用数据库查询 - index = self.extract_db.q_to_k(h) - db_value = self.extract_db.get_data(index) - h, past_kv = layer( - h, db_value, pos_cis, + h, pos_cis, past_key_value=past_key_values[l], use_cache=use_cache ) - past_kvs.append(past_kv) - h_list.append(h.unsqueeze(0)) - - h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3) - - # 只在非禁用数据库模式下执行数据库更新逻辑 - if not self.params.disable_db: - # 使用detach()分离计算图,避免多次反向传播 - h_tensor_detached = h_tensor.detach() - h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim) - - # 数据库更新逻辑与主计算图分离 - with torch.no_grad(): - # Compute shared downsampling layer once - shared_features = self.shared_downsample(h_tensor_detached) - z_v = self.downsample_v_specific(shared_features) - z_q = self.downsample_q_specific(shared_features) - z_k = self.extract_db.q_to_k(z_q) - self.extract_db.updata_value(z_k, z_v) - + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) @@ -667,4 +382,4 @@ class MiniMindLM(PreTrainedModel): input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: - break + break \ No newline at end of file