diff --git a/model/LMConfig.py b/model/LMConfig.py index 1312e97..9f913f6 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -1,5 +1,5 @@ from transformers import PretrainedConfig -from typing import List +from typing import List, Optional, Union class LMConfig(PretrainedConfig): @@ -12,17 +12,22 @@ class LMConfig(PretrainedConfig): n_heads: int = 32, n_kv_heads: int = 8, vocab_size: int = 6400, - hidden_dim: int = None, + hidden_dim: Optional[int] = None, multiple_of: int = 64, norm_eps: float = 1e-5, max_seq_len: int = 8192, - rope_theta: int = 1e6, + rope_theta: float = 1e6, dropout: float = 0.0, flash_attn: bool = True, #################################################### # DB related configurations #################################################### disable_db: bool = False, # 特殊模式:禁用数据库功能 + db_intelligent_balance: bool = True, # 是否启用智能负载均衡 + db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤) + db_balance_strength: float = 0.3, # 平衡权重的基础值 + db_momentum: float = 0.9, # 使用频率统计的动量 + db_adaptive_weights: bool = True, # 是否启用动态权重调整 #################################################### # Here are the specific configurations of MOE # When use_moe is false, the following is invalid @@ -57,6 +62,11 @@ class LMConfig(PretrainedConfig): # DB related configurations #################################################### self.disable_db = disable_db # 设置是否禁用数据库 + self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡 + self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤) + self.db_balance_strength = db_balance_strength # 平衡权重的基础值 + self.db_momentum = db_momentum # 使用频率统计的动量 + self.db_adaptive_weights = db_adaptive_weights # 是否启用动态权重调整 #################################################### # Here are the specific configurations of MOE # When use_moe is false, the following is invalid diff --git a/model/model.py b/model/model.py index 814674d..e442c87 100644 --- a/model/model.py +++ b/model/model.py @@ -188,11 +188,6 @@ class Attention(nn.Module): # 应用旋转位置编码(使用实数版本) xq, xk = apply_rotary_emb_real(xq, xk, pos_cis) - # kv_cache实现 REMOVED - # if past_key_value is not None: - # xk = torch.cat([past_key_value[0], xk], dim=1) - # xv = torch.cat([past_key_value[1], xv], dim=1) - # past_kv = (xk, xv) if use_cache else None # 重复键值对 xq, xk, xv = ( @@ -440,66 +435,7 @@ class MiniMindBlock(nn.Module): self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config) - # 假设num_experts是已定义的总专家数量的平方根 - - - # 查询生成的参数 - - - # 创建查询生成模块 - # if weight_down_embed is not None: - # self.to_queries = nn.Sequential( - # nn.Linear(config.dim, self.dim_key * 2, bias=False), - # # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange - # ) - - # # 超参数 - # self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys - # self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数 - def forward(self, x, db_value, pos_cis): - # import pdb;pdb.set_trace() - # db_value = None - - # # 如果有weight_down_embed,使用Product Key机制 - # if self.weight_down_embed is not None: - # # 1. 生成queries - # batch_size, seq_len, dim = x.shape - - # # collapse sequence dimension by averaging - # x_flat = x.mean(dim=1) # [batch_size, dim] - # queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] - # queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] - # queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - - # # 2. 计算queries与keys的相似度 - # sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - - # # 3. 在两个子空间分别做top-k - # scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] - # scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] - # indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - - # # 4. 组合两个子空间的分数和索引 - # all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) - # all_scores = all_scores.view(*all_scores.shape[:-2], -1) - - # all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) - # all_indices = all_indices.view(*all_indices.shape[:-2], -1) - - # # 5. 最终top-k选择 - # scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) - # indices = all_indices.gather(-1, pk_indices) - - # # 6. 从embedding中获取专家值 - - # # 从embedding中获取值 - # flat_indices = indices.view(-1) # 将索引展平为一维张量 - # db_values = self.weight_down_embed(flat_indices) - - # # 重塑回原始形状 - # db_value = db_values.view(batch_size, -1, dim) - # 注意力计算 h_attn = self.attention( @@ -518,7 +454,7 @@ class MiniMindBlock(nn.Module): return out class ExtractDB(nn.Module): - def __init__(self,params): + def __init__(self, params, tok_embeddings=None): # 修改专家数量和知识维度,确保能开方 super().__init__() self.batch_size = None @@ -529,12 +465,27 @@ class ExtractDB(nn.Module): self.head_dim = params.dim // params.n_heads self.knowledge_length = params.knowledge_length - # 使用register_buffer代替nn.Parameter,避免梯度问题 - # self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02) - self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long)) - - + # 智能负载均衡相关参数 + self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True) + self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7) + self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3) + self.momentum = getattr(params, 'db_momentum', 0.9) + self.adaptive_weights = getattr(params, 'db_adaptive_weights', True) + + # 动态权重调整参数 + self.current_relevance_weight = 0.8 # 开始时更重视相关性 + self.current_balance_weight = 0.2 + self.weight_update_frequency = 100 # 每100步调整一次权重 + self.step_counter = 0 + + # 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动 + self.register_buffer('usage_counts', torch.zeros(self.knowledge_num)) + self.register_buffer('total_queries', torch.tensor(0.0)) + # 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度 + self.register_buffer('weight_down_embed', + torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) + ) self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0 self.product_key_topk = min(16, self.num_keys) @@ -543,7 +494,153 @@ class ExtractDB(nn.Module): self.to_queries = nn.Sequential( nn.Linear(params.dim, self.dim_key * 2, bias=False), ) - + + # 存储token embeddings的引用,用于计算真实的语义相关性 + self.tok_embeddings = tok_embeddings + + def update_usage_statistics(self, selected_indices): + """更新数据库条目的使用统计""" + if not self.training or not self.enable_intelligent_balance: + return + + with torch.no_grad(): + # 统计当前batch中每个条目的使用次数 + batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device) + unique_indices, counts = torch.unique(selected_indices, return_counts=True) + batch_usage[unique_indices] = counts.float() + + # 使用简单的tensor操作来更新统计 + current_usage = self.usage_counts.clone() + current_total = self.total_queries.clone() + + new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage + new_total = current_total + selected_indices.numel() + + # 直接替换buffer内容 + self.usage_counts.copy_(new_usage) + self.total_queries.copy_(new_total) + + def update_dynamic_weights(self): + """动态调整相关性和平衡权重""" + if not self.adaptive_weights or not self.training: + return + + self.step_counter += 1 + + # 每隔一定步数调整权重 + if self.step_counter % self.weight_update_frequency == 0: + with torch.no_grad(): + if self.total_queries > 0: + # 计算使用分布的方差(不平衡程度) + usage_rates = self.usage_counts / self.total_queries + usage_variance = usage_rates.var().item() + + # 根据不平衡程度调整权重 + if usage_variance > 0.01: # 高度不平衡 + self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1) + self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1) + elif usage_variance < 0.001: # 已经很平衡 + self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1) + self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1) + + # 确保权重和为1 + total_weight = self.current_relevance_weight + self.current_balance_weight + self.current_relevance_weight /= total_weight + self.current_balance_weight /= total_weight + + def intelligent_selection(self, query, all_scores, all_indices): + """智能分层选择策略""" + if not self.enable_intelligent_balance or not self.training: + # 如果禁用智能平衡或在推理模式,使用原始分数 + return all_scores + + with torch.no_grad(): + batch_size = all_scores.size(0) + device = all_scores.device + dtype = all_scores.dtype + + # 更新动态权重 + self.update_dynamic_weights() + + # 对每个batch进行分层选择 + enhanced_scores = all_scores.clone() + + # 预先计算query的特征表示(取平均) + query_features = query.mean(dim=1) # [batch_size, dim] + + for batch_idx in range(batch_size): + indices = all_indices[batch_idx] # 当前batch的候选条目 + scores = all_scores[batch_idx] # 当前batch的原始分数 + + # 第一层:基于value内容计算真正的相关性 + # 1. 获取候选条目的value tokens(只获取当前需要的) + candidate_tokens = self.weight_down_embed[indices] # [num_candidates, knowledge_length] + + # 2. 高效计算:直接使用embedding层,避免中间变量 + # 将tokens reshape为一维,批量计算embeddings,然后reshape回来 + num_candidates, knowledge_length = candidate_tokens.shape + flat_tokens = candidate_tokens.view(-1) # [num_candidates * knowledge_length] + + # 批量计算所有token的embeddings + flat_embeddings = self.tok_embeddings(flat_tokens) # [num_candidates * knowledge_length, dim] + + # Reshape回原始形状并进行mean pooling + candidate_embeddings = flat_embeddings.view(num_candidates, knowledge_length, -1) + candidate_features = candidate_embeddings.mean(dim=1) # [num_candidates, dim] + + # 3. 计算query与候选条目的相似度 + query_feature = query_features[batch_idx] # [dim] + similarity_scores = F.cosine_similarity( + query_feature.unsqueeze(0), candidate_features, dim=1 + ) # [num_candidates] + + # 4. 将相似度分数归一化为概率分布 + relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype) + + # 相关性阈值:选择概率大于某个阈值的候选项 + # 动态阈值:如果所有候选项的相似度都很平均,降低阈值 + mean_prob = relevance_probs.mean() + adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5) + relevant_mask = relevance_probs > adaptive_threshold + + if relevant_mask.sum() == 0: + # 如果没有足够相关的,选择相似度最高的top-k + top_k = min(5, len(indices)) + _, top_indices = similarity_scores.topk(top_k) + relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool) + relevant_mask[top_indices] = True + + # 第二层:在相关候选中应用平衡策略 + if relevant_mask.sum() > 1: + # 计算平衡分数(使用频率低的分数高) + relevant_indices = indices[relevant_mask] + relevant_usage = self.usage_counts[relevant_indices] + + # 平衡分数:使用频率的倒数(加1避免除零) + balance_scores = 1.0 / (relevant_usage + 1.0) + balance_scores = balance_scores / (balance_scores.sum() + 1e-8) + + # 相关性分数(基于真实的语义相似度) + relevant_rel_scores = relevance_probs[relevant_mask] + relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8) + + # 综合分数:动态权重组合 + combined_scores = (self.current_relevance_weight * relevant_rel_scores + + self.current_balance_weight * balance_scores.to(dtype)) + + # 确保数据类型一致 + adjustment = self.base_balance_strength * combined_scores.to(dtype) + + # 将综合分数应用到enhanced_scores + enhanced_scores[batch_idx, relevant_mask] = ( + scores[relevant_mask] + adjustment + ) + + # 清理中间变量,释放显存 + del candidate_tokens, flat_tokens, flat_embeddings, candidate_embeddings, candidate_features + + return enhanced_scores.to(device) + def q_to_k(self,x): # 1. 生成queries self.batch_size, seq_len, dim = x.shape @@ -570,10 +667,17 @@ class ExtractDB(nn.Module): all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) all_indices = all_indices.view(*all_indices.shape[:-2], -1) - # 5. 最终top-k选择 - scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1) + # 5. 应用智能分层选择策略 + enhanced_scores = self.intelligent_selection(x, all_scores, all_indices) + + # 6. 基于增强后的分数进行最终top-k选择 + scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1) indices = all_indices.gather(-1, pk_indices) flat_indices = indices.view(-1) + + # 7. 更新使用统计 + self.update_usage_statistics(flat_indices) + return flat_indices def get_data(self, index): @@ -599,10 +703,13 @@ class MiniMindLM(PreTrainedModel): self.params = params or LMConfig() super().__init__(self.params) self.vocab_size, self.n_layers = params.vocab_size, params.n_layers + + # 先创建token embeddings self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) - # 移除旧的weight_down_embed声明 - self.extract_db = ExtractDB(self.params) + + # 创建ExtractDB,传入tok_embeddings引用 + self.extract_db = ExtractDB(self.params, self.tok_embeddings) # 将self.weight_down_embed传递给每个MiniMindBlock self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) @@ -652,20 +759,13 @@ class MiniMindLM(PreTrainedModel): h_list = [] for l, layer in enumerate(self.layers): - # 禁用数据库模式,使用固定值替代数据库查询 - if self.params.disable_db: - # 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4 - batch_size = h.size(0) - db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4, - dtype=h.dtype, device=h.device) - else: - # 正常模式,使用数据库查询 - # import pdb;pdb.set_trace() - index = self.extract_db.q_to_k(h) + # 正常模式,使用数据库查询 + # import pdb;pdb.set_trace() + index = self.extract_db.q_to_k(h) - token_idx = self.extract_db.get_data(index) #这里是index + token_idx = self.extract_db.get_data(index) #这里是index - db_value =self.tok_embeddings(token_idx) + db_value =self.tok_embeddings(token_idx) h = layer( h, db_value, pos_cis_real