import math import struct import inspect import time from .LMConfig import LMConfig from typing import Any, Optional, Tuple, List, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from torch import nn, einsum from einops import rearrange, repeat def exists(val): return val is not None # RMSNorm 类定义了一个用于归一化输入张量的模块。 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): return self.weight * self._norm(x.float()).type_as(x) # precompute_pos_cis 函数用于预计算位置编码(复数版本)。 def precompute_pos_cis(dim: int, end: int = int(32 * 1024), theta: float = 1e6): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore pos_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return pos_cis # apply_rotary_emb 函数用于应用旋转位置编码(复数版本)。 def apply_rotary_emb(xq, xk, pos_cis): def unite_shape(pos_cis, x): ndim = x.ndim assert 0 <= 1 < ndim assert pos_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return pos_cis.view(*shape) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) pos_cis = unite_shape(pos_cis, xq_) xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) # precompute_pos_cis_real 函数用于预计算位置编码(实数版本)。 def precompute_pos_cis_real(dim: int, end: int = int(32 * 1024), theta: float = 1e6): """使用实数张量实现位置编码,避免使用复数张量 这个函数与precompute_pos_cis完全等价,但使用实数张量而非复数张量。 原始函数生成形状为[seq_len, dim//2]的复数张量,其中实部全为1,虚部为旋转角度。 这个函数生成形状为[seq_len, dim]的实数张量,其中偶数索引是cos(角度),奇数索引是sin(角度)。 """ # 确保dim是偶数 if dim % 2 != 0: raise ValueError(f"维度必须是偶数,但得到了 {dim}") # 复制原始函数的频率计算逻辑 freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() # 计算cos和sin值 # 在复数版本中,pos_cis = torch.polar(torch.ones_like(freqs), freqs) # 等价于 cos(freqs) + i*sin(freqs) cos = torch.cos(freqs) sin = torch.sin(freqs) # 创建实数张量,交错排列cos和sin pos_emb = torch.zeros((end, dim), device=freqs.device) pos_emb[:, 0::2] = cos # 偶数索引放cos pos_emb[:, 1::2] = sin # 奇数索引放sin return pos_emb # apply_rotary_emb_real 函数用于应用旋转位置编码(实数版本)。 def apply_rotary_emb_real(xq, xk, pos_emb): """使用实数张量实现旋转位置编码,避免使用复数张量 这个函数与apply_rotary_emb完全等价,但使用实数张量而非复数张量。 原始函数将输入张量转换为复数形式,与位置编码相乘,然后再转回实数形式。 这个函数直接使用实数运算实现相同的旋转操作。 """ # 获取形状信息 bsz, seq_len, n_heads, head_dim = xq.shape # 确保pos_emb形状正确 assert pos_emb.shape[0] >= seq_len, f"位置编码长度 {pos_emb.shape[0]} 小于序列长度 {seq_len}" assert pos_emb.shape[1] == head_dim, f"位置编码维度 {pos_emb.shape[1]} 与头维度 {head_dim} 不匹配" # 截取需要的位置编码长度 pos_emb = pos_emb[:seq_len] # 将pos_emb调整为广播形状 [1, seq_len, 1, head_dim] pos_emb = pos_emb.unsqueeze(0).unsqueeze(2) # 将head_dim分成两半 half_head_dim = head_dim // 2 # 提取cos和sin值(偶数索引是cos,奇数索引是sin) cos = pos_emb[..., 0::2] sin = pos_emb[..., 1::2] # 将xq和xk重新排列,以便进行旋转操作 # 原始复数版本中,xq和xk被重塑为复数张量,其中实部和虚部交错排列 # 在实数版本中,我们需要将偶数索引和奇数索引分开处理 # 分离偶数和奇数索引 xq_even = xq[..., 0::2] # 偶数索引,对应复数的实部 xq_odd = xq[..., 1::2] # 奇数索引,对应复数的虚部 xk_even = xk[..., 0::2] xk_odd = xk[..., 1::2] # 应用旋转(等价于复数乘法) # (a + bi)(cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i # 其中a是偶数索引,b是奇数索引 xq_out_even = xq_even * cos - xq_odd * sin # 新的偶数索引(实部) xq_out_odd = xq_even * sin + xq_odd * cos # 新的奇数索引(虚部) xk_out_even = xk_even * cos - xk_odd * sin xk_out_odd = xk_even * sin + xk_odd * cos # 重新组合偶数和奇数索引 xq_out = torch.zeros_like(xq) xk_out = torch.zeros_like(xk) xq_out[..., 0::2] = xq_out_even xq_out[..., 1::2] = xq_out_odd xk_out[..., 0::2] = xk_out_even xk_out[..., 1::2] = xk_out_odd return xq_out.type_as(xq), xk_out.type_as(xk) # repeat_kv 函数用于重复键值对。 def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: LMConfig): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 self.n_local_heads = args.n_heads self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask, persistent=False) def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, db_value=None): bsz, seq_len, _ = x.shape #bsz: 批量大小, seq_len: 序列长度, _: 隐藏维度 xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) #将输入张量x分别通过线性层wq, wk, wv进行变换,得到查询、键和值。 xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim) #将变换后的张量xq重塑为形状为(bsz, seq_len, n_local_heads, head_dim)的形状。 xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xk重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim) #将变换后的张量xv重塑为形状为(bsz, seq_len, n_local_kv_heads, head_dim)的形状。 # 应用旋转位置编码(使用实数版本) xq, xk = apply_rotary_emb_real(xq, xk, pos_cis) # 重复键值对 xq, xk, xv = ( xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2) ) # 如果提供了db_value,根据头的数量调整它的形状并与xv合并 if db_value is not None: # 确保db_value的形状与xv兼容,假设db_value形状为[B, N, H, D] if db_value.ndim == 4: # [B, N, H, D] db_value = db_value.transpose(1, 2) # -> [B, H, N, D] # 检查是否需要调整D维度 if db_value.shape[-1] != xv.shape[-1]: # 如果db_value的维度与xv不同,可以添加一个投影层 # 或者在这里使用简单的调整方法 # 这里我们简单地通过均值池化或重复来调整维度 if db_value.shape[-1] > xv.shape[-1]: # 降维 factor = db_value.shape[-1] // xv.shape[-1] db_value = db_value.view(bsz, self.n_local_heads, seq_len, factor, xv.shape[-1]) db_value = db_value.mean(dim=3) else: # 升维 factor = xv.shape[-1] // db_value.shape[-1] db_value = db_value.unsqueeze(-1).repeat(1, 1, 1, 1, factor) db_value = db_value.view(bsz, self.n_local_heads, seq_len, xv.shape[-1]) # 将db_value与xv相加或融合 # 这里我们简单地将它们相加,但你也可以使用其他融合方法 xv = xv + db_value # 使用Flash Attention if self.flash and seq_len != 1: dropout_p = self.dropout if self.training else 0.0 output = F.scaled_dot_product_attention( xq, xk, xv, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) scores += self.mask[:, :, :seq_len, :seq_len] scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = scores @ xv output = output.transpose(1, 2).reshape(bsz, seq_len, -1) output = self.resid_dropout(self.wo(output)) return output class CrossAttention(nn.Module): def __init__( self, config ): super().__init__() self.config = config self.num_heads = 8 self.head_dim = self.config.dim // self.num_heads self.to_q = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_k = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_v = nn.Linear(self.config.dim, self.config.dim, bias=False) self.to_out = nn.Linear(self.config.dim, self.config.dim, bias=False) def forward(self, x, db, context_mask=None, pos_emb=None): batch_size = x.size(0) # 分离多头 q = self.to_q(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = self.to_k(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = self.to_v(db).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) if pos_emb is not None: pos_emb = pos_emb.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) q = q + pos_emb k = k + pos_emb v = v + pos_emb attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if context_mask is not None: expanded_mask = context_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) attn_scores = attn_scores.masked_fill(expanded_mask == 0, -1e10) attn_weights = F.softmax(attn_scores, dim=-1) context = torch.matmul(attn_weights, v) context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config.dim) context = self.to_out(context) return context class FeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() if config.hidden_dim is None: hidden_dim = 4 * config.dim hidden_dim = int(2 * hidden_dim / 3) config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of) self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class MoEGate(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.dim self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == 'softmax': scores = logits.softmax(dim=-1) else: raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}') topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_( seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = 0 return topk_idx, topk_weight, aux_loss class MOEFeedForward(nn.Module): def __init__(self, config: LMConfig): super().__init__() self.config = config self.experts = nn.ModuleList([ FeedForward(config) for _ in range(config.n_routed_experts) ]) self.gate = MoEGate(config) if config.n_shared_experts is not None: self.shared_experts = FeedForward(config) def forward(self, x): identity = x orig_shape = x.shape bsz, seq_len, _ = x.shape # 使用门控机制选择专家 topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0) y = torch.empty_like(x, dtype=torch.float16) for i, expert in enumerate(self.experts): y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype) # 确保类型一致 y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) else: y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) self.aux_loss = aux_loss return y @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.config.num_experts_per_tok # 当tokens_per_expert = [6, 15, 20, 26],tokens_per_expert.shape[0]即为专家数量(此时为4) # 且token_idxs = [3, 7, 19, 21, 24, 25, 4, 5, 6, 10, 11, 12...] 时 # 意味token_idxs[:6] -> [3, 7, 19, 21, 24, 25]这6个位置属于专家0处理的token(每个token有可能被多个专家处理,这取决于num_experts_per_tok) # 接下来9个位置token_idxs[6:15] -> [4, 5, 6, 10, 11, 12...]属于专家1处理的token...依此类推 for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: continue expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens).to(expert_cache.dtype) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out) return expert_cache class MiniMindBlock(nn.Module): def __init__(self, layer_id: int, config: LMConfig): super().__init__() self.n_heads = config.n_heads self.dim = config.dim self.head_dim = config.dim // config.n_heads self.attention = Attention(config) self.cross_att = CrossAttention(config) self.layer_id = layer_id self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) def forward(self, x, db_value, pos_cis): # 注意力计算 h_attn = self.attention( self.attention_norm(x), pos_cis, db_value=db_value ) h_attn = self.cross_att(h_attn, db_value) # 残差连接 h = x + h_attn # 前馈神经网络 out = h + self.feed_forward(self.ffn_norm(h)) return out class ExtractDB(nn.Module): def __init__(self, params, tok_embeddings=None): # 修改专家数量和知识维度,确保能开方 super().__init__() self.batch_size = None self.dim = params.dim self.dim_key = self.dim // 2 self.knowledge_num = params.knowledge_num # 100专家,确保是完全平方数 # 将knowledge_dim设置为与head_dim相同,以便在attention中直接使用 self.head_dim = params.dim // params.n_heads self.knowledge_length = params.knowledge_length # 智能负载均衡相关参数 self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True) self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7) self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3) self.momentum = getattr(params, 'db_momentum', 0.9) self.adaptive_weights = getattr(params, 'db_adaptive_weights', True) # 动态权重调整参数 self.current_relevance_weight = 0.8 # 开始时更重视相关性 self.current_balance_weight = 0.2 self.weight_update_frequency = 100 # 每100步调整一次权重 self.step_counter = 0 # 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 self.register_buffer('usage_counts', torch.zeros(self.knowledge_num)) self.register_buffer('total_queries', torch.tensor(0.0)) # 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度 self.register_buffer('weight_down_embed', torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) ) self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0 self.product_key_topk = min(16, self.num_keys) self.keys = nn.Parameter(torch.randn(self.num_keys, 2, self.dim_key) * 0.02) self.num_experts_per_head_topk = 1 self.to_queries = nn.Sequential( nn.Linear(params.dim, self.dim_key * 2, bias=False), ) # 存储token embeddings的引用,用于计算真实的语义相关性 self.tok_embeddings = tok_embeddings def update_usage_statistics(self, selected_indices): """更新数据库条目的使用统计""" if not self.training or not self.enable_intelligent_balance: return with torch.no_grad(): # 统计当前batch中每个条目的使用次数 batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device) unique_indices, counts = torch.unique(selected_indices, return_counts=True) batch_usage[unique_indices] = counts.float() # 使用简单的tensor操作来更新统计 current_usage = self.usage_counts.clone() current_total = self.total_queries.clone() new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage new_total = current_total + selected_indices.numel() # 直接替换buffer内容 self.usage_counts.copy_(new_usage) self.total_queries.copy_(new_total) def update_dynamic_weights(self): """动态调整相关性和平衡权重""" if not self.adaptive_weights or not self.training: return self.step_counter += 1 # 每隔一定步数调整权重 if self.step_counter % self.weight_update_frequency == 0: with torch.no_grad(): if self.total_queries > 0: # 计算使用分布的方差(不平衡程度) usage_rates = self.usage_counts / self.total_queries usage_variance = usage_rates.var().item() # 根据不平衡程度调整权重 if usage_variance > 0.01: # 高度不平衡 self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1) self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1) elif usage_variance < 0.001: # 已经很平衡 self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1) self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1) # 确保权重和为1 total_weight = self.current_relevance_weight + self.current_balance_weight self.current_relevance_weight /= total_weight self.current_balance_weight /= total_weight def intelligent_selection(self, query, all_scores, all_indices): """智能分层选择策略""" if not self.enable_intelligent_balance or not self.training: return all_scores with torch.no_grad(): batch_size = all_scores.size(0) device = all_scores.device dtype = all_scores.dtype # 更新动态权重 self.update_dynamic_weights() # 对每个batch进行分层选择 enhanced_scores = all_scores.clone() query_features = query.mean(dim=1) # [batch_size, dim] # 预先计算所有候选条目的嵌入(批量优化) all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0) unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True) # 批量计算唯一候选条目的嵌入 candidate_tokens = self.weight_down_embed[unique_indices] flat_tokens = candidate_tokens.view(-1) flat_embeddings = self.tok_embeddings(flat_tokens) unique_candidate_features = flat_embeddings.view( len(unique_indices), self.knowledge_length, -1 ).mean(dim=1) # [num_unique_candidates, dim] # 归一化候选特征(优化相似度计算) normalized_candidates = F.normalize(unique_candidate_features, dim=-1) normalized_queries = F.normalize(query_features, dim=-1) for batch_idx in range(batch_size): indices = all_indices[batch_idx] scores = all_scores[batch_idx] # 获取当前batch候选条目对应的特征索引 start_idx = batch_idx * len(indices) end_idx = start_idx + len(indices) batch_inverse_indices = inverse_indices[start_idx:end_idx] # 使用预计算的归一化特征进行优化相似度计算 batch_candidate_features = normalized_candidates[batch_inverse_indices] query_feature = normalized_queries[batch_idx] # 使用矩阵乘法计算余弦相似度 similarity_scores = torch.mv(batch_candidate_features, query_feature) # 应用相关性阈值过滤 relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype) mean_prob = relevance_probs.mean() adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5) relevant_mask = relevance_probs > adaptive_threshold if relevant_mask.sum() == 0: # 如果没有相关候选,选择相似度最高的 top_k = min(5, len(indices)) _, top_indices = similarity_scores.topk(top_k) relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool) relevant_mask[top_indices] = True # 在相关候选中应用负载均衡 if relevant_mask.sum() > 1: relevant_indices = indices[relevant_mask] relevant_usage = self.usage_counts[relevant_indices] # 计算平衡分数 balance_scores = 1.0 / (relevant_usage + 1.0) balance_scores = balance_scores / (balance_scores.sum() + 1e-8) # 相关性分数 relevant_rel_scores = relevance_probs[relevant_mask] relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8) # 综合分数 combined_scores = (self.current_relevance_weight * relevant_rel_scores + self.current_balance_weight * balance_scores.to(dtype)) # 应用调整 adjustment = self.base_balance_strength * combined_scores.to(dtype) enhanced_scores[batch_idx, relevant_mask] = scores[relevant_mask] + adjustment return enhanced_scores.to(device) def q_to_k(self,x): # 1. 生成queries self.batch_size, seq_len, dim = x.shape # collapse sequence dimension by averaging x_flat = x.mean(dim=1) # [batch_size, dim] queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] # 2. 计算queries与keys的相似度 sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) # 3. 在两个子空间分别做top-k scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] # 4. 组合两个子空间的分数和索引 all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) all_scores = all_scores.view(*all_scores.shape[:-2], -1) all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) all_indices = all_indices.view(*all_indices.shape[:-2], -1) # 5. 应用智能分层选择策略 enhanced_scores = self.intelligent_selection(x, all_scores, all_indices) # 6. 基于增强后的分数进行最终top-k选择 scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1) indices = all_indices.gather(-1, pk_indices) flat_indices = indices.view(-1) # 7. 更新使用统计 self.update_usage_statistics(flat_indices) return flat_indices def get_data(self, index): # 直接从GPU获取embedding db_values = self.weight_down_embed[index]#变成token了所以是1,后续再过emb # db_value = db_values.view(self.batch_size,-1) return db_values @torch.no_grad() def updata_value(self, k, v):#要加一个从向量返回index的过程 # 直接更新buffer上的值 (不需要梯度) v_reshaped = v.view(v.size(0), -1) # 确保数据类型匹配 v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) self.weight_down_embed[k] = v_reshaped class MiniMindLM(PreTrainedModel): config_class = LMConfig def __init__(self, params: LMConfig = None): self.params = params or LMConfig() super().__init__(self.params) self.vocab_size, self.n_layers = params.vocab_size, params.n_layers # 先创建token embeddings self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) # 根据配置选择ExtractDB版本 # use_direct_semantic = getattr(params, 'use_direct_semantic', False) # if use_direct_semantic: # self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings) # else: # self.extract_db = ExtractDB(self.params, self.tok_embeddings) self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings) # 将self.weight_down_embed传递给每个MiniMindBlock self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) self.database_output = nn.Linear(params.dim, params.knowledge_length, bias=False) self.tok_embeddings.weight = self.output.weight self.database_output.weight = self.output.weight # Calculate input dimension input_dim = (self.params.max_seq_len-1)*self.params.n_layers # Use a bottleneck architecture to reduce parameters bottleneck_dim = 256 # Significantly smaller bottleneck dimension # Factorized shared downsampling using two smaller convolutions self.shared_downsample = nn.Sequential( # First reduce input dimension to bottleneck nn.Conv1d(input_dim, bottleneck_dim, kernel_size=1, padding='same'), nn.ReLU(), # Non-linearity to improve representation capacity # Then expand to target dimension nn.Conv1d(bottleneck_dim, 128*8, kernel_size=1, padding='same') ) # Specific layers for v path self.downsample_v_specific = nn.Sequential( nn.Conv1d(128*8, 128, kernel_size=1, padding='same'), nn.Conv1d(128, self.params.knowledge_length, kernel_size=1, padding='same') ) # Specific layers for q path self.downsample_q_specific = nn.Sequential( nn.Conv1d(128*8, 512, kernel_size=1, padding='same') ) # 使用实数版本的位置编码,避免复数张量可能导致的段错误 self.register_buffer("pos_cis_real", precompute_pos_cis_real(dim=params.dim // params.n_heads, theta=params.rope_theta), persistent=False) self.params = params def forward(self, input_ids: Optional[torch.Tensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **args): start_pos = args.get('start_pos', 0) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis_real = self.pos_cis_real[start_pos:start_pos + input_ids.size(1)] h_list = [] for l, layer in enumerate(self.layers): # 正常模式,使用数据库查询 # import pdb;pdb.set_trace() index = self.extract_db.q_to_k(h) token_idx = self.extract_db.get_data(index) #这里是index db_value =self.tok_embeddings(token_idx) h = layer( h, db_value, pos_cis_real ) h_list.append(h.unsqueeze(0)) h_tensor = torch.cat(h_list, dim=0).permute(1, 0, 2, 3) # 只在非禁用数据库模式下执行数据库更新逻辑 if not self.params.disable_db: # 使用detach()分离计算图,避免多次反向传播 h_tensor_detached = h_tensor.detach() h_tensor_detached = h_tensor_detached.reshape(h_tensor_detached.shape[0], -1, self.params.dim) # 数据库更新逻辑与主计算图分离 with torch.no_grad(): # Compute shared downsampling layer once shared_features = self.shared_downsample(h_tensor_detached) # Get features from v path - now we output embedding-dimension vectors z_v_features = self.downsample_v_specific(shared_features) batch_z, seq_len, dim_z = z_v_features.shape # Reshape to batch_size * knowledge_length, dim z_v_flat = z_v_features.reshape(-1, dim_z) # Direct token prediction - like the main language model head token_logits = self.database_output(z_v_flat) # [batch_z * seq_len, vocab_size] # Get token indices directly from logits token_indices_flat = torch.argmax(token_logits, dim=-1) token_indices = token_indices_flat.reshape(batch_z, -1) # Process query path as before z_q = self.downsample_q_specific(shared_features) z_k = self.extract_db.q_to_k(z_q) # self.extract_db.updata_value(z_k, token_indices) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) # 进一步简化,只保留必要的参数 output = CausalLMOutputWithPast( logits=logits, ) output.hidden_states = h output.aux_loss = aux_loss # 尝试添加其他属性(如果支持的话) # try: # output.hidden_states = h # except: # pass return output @torch.inference_mode() def generate(self, input_ids, eos_token_id=2, max_new_tokens=1024, temperature=0.75, top_p=0.90, stream=False, rp=1., pad_token_id=0, num_return_sequences=1, **args): # 流式生成 if stream: return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) # 直接生成 generated = [] for i in range(input_ids.size(0)): non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0) for _ in range(num_return_sequences): out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, **args) tokens_list = [tokens[:, -1:] for tokens in out] gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad full_sequence = torch.cat([non_pad, gen], dim=-1) generated.append(full_sequence) max_length = max(seq.size(1) for seq in generated) generated = [ torch.cat( [seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)], dim=-1) for seq in generated ] output = torch.cat(generated, dim=0) res = output.view(input_ids.size(0) * num_return_sequences, -1) return res def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, **args): start, first_seq = input_ids.shape[1], True while input_ids.shape[1] < max_new_tokens - 1: if first_seq: out, first_seq = self(input_ids, **args), False else: out = self(input_ids[:, -1:], start_pos=input_ids.shape[1] - 1, **args) logits = out.logits[:, -1, :] logits[:, list(set(input_ids.tolist()[0]))] /= rp logits /= (temperature + 1e-9) if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float('Inf') input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat((input_ids, input_ids_next), dim=1) yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break class ExtractDB_DirectSemantic(nn.Module): """直接语义匹配的数据库检索模块,完全移除Product Key""" def __init__(self, params, tok_embeddings=None): super().__init__() self.batch_size = None self.dim = params.dim self.knowledge_num = params.knowledge_num self.knowledge_length = params.knowledge_length self.tok_embeddings = tok_embeddings self.num_experts_per_head_topk = 1 # 训练步数管理 self.current_step = 0 self.realtime_threshold = getattr(params, 'realtime_steps', 800) # 前800步实时计算 # 渐进式缓存策略参数 self.knowledge_update_rate = 0.01 # 每步更新1%的知识 self.knowledge_per_step = max(1, int(self.knowledge_num * self.knowledge_update_rate)) self.update_cycle = 100 # 100步循环 # 知识库存储 self.register_buffer('weight_down_embed', torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) ) # 嵌入缓存 self.knowledge_embeddings_cache = None self.cache_update_mask = torch.zeros(self.knowledge_num, dtype=torch.bool) # 跟踪哪些已更新 # 归一化缓存(用于优化相似度计算) self.normalized_knowledge_cache = None self.normalization_valid = False # 负载均衡组件 self.register_buffer('usage_counts', torch.zeros(self.knowledge_num)) self.register_buffer('total_queries', torch.tensor(0.0)) self.momentum = getattr(params, 'db_momentum', 0.9) self.balance_strength = getattr(params, 'db_balance_strength', 0.1) def should_use_realtime_computation(self): """判断是否应该使用实时计算""" return self.current_step < self.realtime_threshold def get_knowledge_indices_to_update(self): """获取本步需要更新的知识条目索引""" if self.should_use_realtime_computation(): # 前800步:全部实时计算 return torch.arange(self.knowledge_num) # 后续步数:循环更新策略 cycle_position = self.current_step % self.update_cycle start_idx = (cycle_position * self.knowledge_per_step) % self.knowledge_num end_idx = min(start_idx + self.knowledge_per_step, self.knowledge_num) return torch.arange(start_idx, end_idx) def update_knowledge_embeddings(self, force_all=False): """智能更新知识嵌入缓存""" if force_all or self.should_use_realtime_computation(): # 全量更新 indices_to_update = torch.arange(self.knowledge_num) else: # 渐进式更新 indices_to_update = self.get_knowledge_indices_to_update() if len(indices_to_update) == 0: return # 初始化缓存 if self.knowledge_embeddings_cache is None: # 获取tok_embeddings的dtype,确保类型一致 dummy_input = torch.zeros(1, dtype=torch.long, device=self.weight_down_embed.device) dummy_embedding = self.tok_embeddings(dummy_input) embedding_dtype = dummy_embedding.dtype self.knowledge_embeddings_cache = torch.zeros( self.knowledge_num, self.dim, device=self.weight_down_embed.device, dtype=embedding_dtype # 使用与tok_embeddings相同的dtype ) with torch.no_grad(): # 只更新指定的知识条目 tokens_to_update = self.weight_down_embed[indices_to_update] # [num_update, knowledge_length] flat_tokens = tokens_to_update.view(-1) # [num_update * knowledge_length] # 批量计算嵌入 flat_embeddings = self.tok_embeddings(flat_tokens) # [num_update * knowledge_length, dim] # 重塑并平均池化 updated_embeddings = flat_embeddings.view( len(indices_to_update), self.knowledge_length, -1 ).mean(dim=1) # [num_update, dim] # 更新缓存 - 现在类型应该匹配了 self.knowledge_embeddings_cache[indices_to_update] = updated_embeddings self.cache_update_mask[indices_to_update] = True # 使归一化缓存失效 self.normalization_valid = False def get_normalized_knowledge_embeddings(self): """获取归一化的知识嵌入(用于优化相似度计算)""" if not self.normalization_valid or self.normalized_knowledge_cache is None: if self.knowledge_embeddings_cache is None: self.update_knowledge_embeddings(force_all=True) self.normalized_knowledge_cache = F.normalize( self.knowledge_embeddings_cache, dim=-1 ) self.normalization_valid = True return self.normalized_knowledge_cache def optimized_similarity_computation(self, query_features): """优化的相似度计算""" # 归一化查询特征 normalized_query = F.normalize(query_features, dim=-1) # [batch_size, dim] # 获取归一化的知识嵌入 normalized_knowledge = self.get_normalized_knowledge_embeddings() # [knowledge_num, dim] # 使用矩阵乘法计算余弦相似度 similarities = torch.mm(normalized_query, normalized_knowledge.t()) # [batch_size, knowledge_num] return similarities def apply_load_balancing(self, similarities): """应用负载均衡策略""" if not self.training or self.total_queries == 0: return similarities # 计算使用频率 usage_rates = self.usage_counts / (self.total_queries + 1e-8) # 创建平衡偏置(低频率条目获得正偏置) max_usage = usage_rates.max() balance_bias = self.balance_strength * (max_usage - usage_rates + 1e-8).log() # 应用偏置 balanced_similarities = similarities + balance_bias.unsqueeze(0) return balanced_similarities def update_usage_statistics(self, selected_indices): """更新使用统计""" if not self.training: return with torch.no_grad(): # 统计当前batch中每个条目的使用次数 batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device) unique_indices, counts = torch.unique(selected_indices, return_counts=True) batch_usage[unique_indices] = counts.float() # 更新统计 self.usage_counts.copy_( self.momentum * self.usage_counts + (1 - self.momentum) * batch_usage ) self.total_queries.copy_(self.total_queries + selected_indices.numel()) def q_to_k(self, x): """直接语义检索的主方法""" self.current_step += 1 batch_size, seq_len, dim = x.shape # 智能更新知识嵌入缓存 self.update_knowledge_embeddings() # 计算查询特征(序列平均) query_features = x.mean(dim=1) # [batch_size, dim] # 优化的相似度计算 similarities = self.optimized_similarity_computation(query_features) # 应用负载均衡 balanced_similarities = self.apply_load_balancing(similarities) # 选择top-k _, indices = balanced_similarities.topk(self.num_experts_per_head_topk, dim=-1) flat_indices = indices.view(-1) # 更新使用统计 self.update_usage_statistics(flat_indices) return flat_indices def get_data(self, index): """获取数据,与原版本兼容""" return self.weight_down_embed[index] @torch.no_grad() def updata_value(self, k, v): """更新数据,与原版本兼容""" v_reshaped = v.view(v.size(0), -1) v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) self.weight_down_embed[k] = v_reshaped # 标记相关缓存需要更新 self.cache_update_mask[k] = False self.normalization_valid = False