From 7d726c5b20c73d480ae8d2f8223f0aa48a8de11f Mon Sep 17 00:00:00 2001 From: iomgaa Date: Fri, 30 May 2025 14:57:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=9B=B4=E6=8E=A5=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E5=8C=B9=E9=85=8D=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E6=95=B0=E6=8D=AE=E5=BA=93=E6=A3=80=E7=B4=A2=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=EF=BC=8C=E6=94=AF=E6=8C=81=E5=AE=9E=E6=97=B6=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E5=92=8C=E5=8A=A8=E6=80=81=E8=B4=9F=E8=BD=BD=E5=9D=87?= =?UTF-8?q?=E8=A1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/LMConfig.py | 4 + model/model.py | 346 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 277 insertions(+), 73 deletions(-) diff --git a/model/LMConfig.py b/model/LMConfig.py index 9f913f6..e58ed37 100644 --- a/model/LMConfig.py +++ b/model/LMConfig.py @@ -23,6 +23,8 @@ class LMConfig(PretrainedConfig): # DB related configurations #################################################### disable_db: bool = False, # 特殊模式:禁用数据库功能 + use_direct_semantic: bool = False, # 是否使用直接语义匹配(替代Product Key) + realtime_steps: int = 2000, # 前多少步使用实时计算(后续使用渐进式缓存) db_intelligent_balance: bool = True, # 是否启用智能负载均衡 db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤) db_balance_strength: float = 0.3, # 平衡权重的基础值 @@ -62,6 +64,8 @@ class LMConfig(PretrainedConfig): # DB related configurations #################################################### self.disable_db = disable_db # 设置是否禁用数据库 + self.use_direct_semantic = use_direct_semantic # 是否使用直接语义匹配(替代Product Key) + self.realtime_steps = realtime_steps # 前多少步使用实时计算(后续使用渐进式缓存) self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡 self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤) self.db_balance_strength = db_balance_strength # 平衡权重的基础值 diff --git a/model/model.py b/model/model.py index e442c87..9079ab1 100644 --- a/model/model.py +++ b/model/model.py @@ -551,7 +551,6 @@ class ExtractDB(nn.Module): def intelligent_selection(self, query, all_scores, all_indices): """智能分层选择策略""" if not self.enable_intelligent_balance or not self.training: - # 如果禁用智能平衡或在推理模式,使用原始分数 return all_scores with torch.no_grad(): @@ -564,121 +563,115 @@ class ExtractDB(nn.Module): # 对每个batch进行分层选择 enhanced_scores = all_scores.clone() - - # 预先计算query的特征表示(取平均) query_features = query.mean(dim=1) # [batch_size, dim] + # 预先计算所有候选条目的嵌入(批量优化) + all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0) + unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True) + + # 批量计算唯一候选条目的嵌入 + candidate_tokens = self.weight_down_embed[unique_indices] + flat_tokens = candidate_tokens.view(-1) + flat_embeddings = self.tok_embeddings(flat_tokens) + unique_candidate_features = flat_embeddings.view( + len(unique_indices), self.knowledge_length, -1 + ).mean(dim=1) # [num_unique_candidates, dim] + + # 归一化候选特征(优化相似度计算) + normalized_candidates = F.normalize(unique_candidate_features, dim=-1) + normalized_queries = F.normalize(query_features, dim=-1) + for batch_idx in range(batch_size): - indices = all_indices[batch_idx] # 当前batch的候选条目 - scores = all_scores[batch_idx] # 当前batch的原始分数 + indices = all_indices[batch_idx] + scores = all_scores[batch_idx] - # 第一层:基于value内容计算真正的相关性 - # 1. 获取候选条目的value tokens(只获取当前需要的) - candidate_tokens = self.weight_down_embed[indices] # [num_candidates, knowledge_length] + # 获取当前batch候选条目对应的特征索引 + start_idx = batch_idx * len(indices) + end_idx = start_idx + len(indices) + batch_inverse_indices = inverse_indices[start_idx:end_idx] - # 2. 高效计算:直接使用embedding层,避免中间变量 - # 将tokens reshape为一维,批量计算embeddings,然后reshape回来 - num_candidates, knowledge_length = candidate_tokens.shape - flat_tokens = candidate_tokens.view(-1) # [num_candidates * knowledge_length] + # 使用预计算的归一化特征进行优化相似度计算 + batch_candidate_features = normalized_candidates[batch_inverse_indices] + query_feature = normalized_queries[batch_idx] - # 批量计算所有token的embeddings - flat_embeddings = self.tok_embeddings(flat_tokens) # [num_candidates * knowledge_length, dim] + # 使用矩阵乘法计算余弦相似度 + similarity_scores = torch.mv(batch_candidate_features, query_feature) - # Reshape回原始形状并进行mean pooling - candidate_embeddings = flat_embeddings.view(num_candidates, knowledge_length, -1) - candidate_features = candidate_embeddings.mean(dim=1) # [num_candidates, dim] - - # 3. 计算query与候选条目的相似度 - query_feature = query_features[batch_idx] # [dim] - similarity_scores = F.cosine_similarity( - query_feature.unsqueeze(0), candidate_features, dim=1 - ) # [num_candidates] - - # 4. 将相似度分数归一化为概率分布 + # 应用相关性阈值过滤 relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype) - - # 相关性阈值:选择概率大于某个阈值的候选项 - # 动态阈值:如果所有候选项的相似度都很平均,降低阈值 mean_prob = relevance_probs.mean() adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5) relevant_mask = relevance_probs > adaptive_threshold if relevant_mask.sum() == 0: - # 如果没有足够相关的,选择相似度最高的top-k + # 如果没有相关候选,选择相似度最高的 top_k = min(5, len(indices)) _, top_indices = similarity_scores.topk(top_k) relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool) relevant_mask[top_indices] = True - # 第二层:在相关候选中应用平衡策略 + # 在相关候选中应用负载均衡 if relevant_mask.sum() > 1: - # 计算平衡分数(使用频率低的分数高) relevant_indices = indices[relevant_mask] relevant_usage = self.usage_counts[relevant_indices] - # 平衡分数:使用频率的倒数(加1避免除零) + # 计算平衡分数 balance_scores = 1.0 / (relevant_usage + 1.0) balance_scores = balance_scores / (balance_scores.sum() + 1e-8) - # 相关性分数(基于真实的语义相似度) + # 相关性分数 relevant_rel_scores = relevance_probs[relevant_mask] relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8) - # 综合分数:动态权重组合 + # 综合分数 combined_scores = (self.current_relevance_weight * relevant_rel_scores + self.current_balance_weight * balance_scores.to(dtype)) - # 确保数据类型一致 + # 应用调整 adjustment = self.base_balance_strength * combined_scores.to(dtype) - - # 将综合分数应用到enhanced_scores - enhanced_scores[batch_idx, relevant_mask] = ( - scores[relevant_mask] + adjustment - ) - - # 清理中间变量,释放显存 - del candidate_tokens, flat_tokens, flat_embeddings, candidate_embeddings, candidate_features + enhanced_scores[batch_idx, relevant_mask] = scores[relevant_mask] + adjustment return enhanced_scores.to(device) def q_to_k(self,x): # 1. 生成queries - self.batch_size, seq_len, dim = x.shape + self.batch_size, seq_len, dim = x.shape - # collapse sequence dimension by averaging - x_flat = x.mean(dim=1) # [batch_size, dim] + # collapse sequence dimension by averaging + x_flat = x.mean(dim=1) # [batch_size, dim] - queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] - queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] - queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] + queries = self.to_queries(x_flat) # [batch_size, 2*dim_key] + queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key] + queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key] - # 2. 计算queries与keys的相似度 - sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) + # 2. 计算queries与keys的相似度 + sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys) - # 3. 在两个子空间分别做top-k - scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] - scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] - indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] + # 3. 在两个子空间分别做top-k + scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)] + scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0] + indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1] - # 4. 组合两个子空间的分数和索引 - all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) - all_scores = all_scores.view(*all_scores.shape[:-2], -1) + # 4. 组合两个子空间的分数和索引 + all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) + all_scores = all_scores.view(*all_scores.shape[:-2], -1) - all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) - all_indices = all_indices.view(*all_indices.shape[:-2], -1) + all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2) + all_indices = all_indices.view(*all_indices.shape[:-2], -1) - # 5. 应用智能分层选择策略 - enhanced_scores = self.intelligent_selection(x, all_scores, all_indices) + # 5. 应用智能分层选择策略 + enhanced_scores = self.intelligent_selection(x, all_scores, all_indices) - # 6. 基于增强后的分数进行最终top-k选择 - scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1) - indices = all_indices.gather(-1, pk_indices) - flat_indices = indices.view(-1) - - # 7. 更新使用统计 - self.update_usage_statistics(flat_indices) - - return flat_indices + # 6. 基于增强后的分数进行最终top-k选择 + scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1) + indices = all_indices.gather(-1, pk_indices) + + flat_indices = indices.view(-1) + + # 7. 更新使用统计 + self.update_usage_statistics(flat_indices) + + return flat_indices def get_data(self, index): # 直接从GPU获取embedding @@ -708,8 +701,13 @@ class MiniMindLM(PreTrainedModel): self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) - # 创建ExtractDB,传入tok_embeddings引用 - self.extract_db = ExtractDB(self.params, self.tok_embeddings) + # 根据配置选择ExtractDB版本 + # use_direct_semantic = getattr(params, 'use_direct_semantic', False) + # if use_direct_semantic: + # self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings) + # else: + # self.extract_db = ExtractDB(self.params, self.tok_embeddings) + self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings) # 将self.weight_down_embed传递给每个MiniMindBlock self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)]) @@ -878,3 +876,205 @@ class MiniMindLM(PreTrainedModel): yield input_ids[:, start:] if input_ids_next.item() == eos_token_id: break + +class ExtractDB_DirectSemantic(nn.Module): + """直接语义匹配的数据库检索模块,完全移除Product Key""" + def __init__(self, params, tok_embeddings=None): + super().__init__() + self.batch_size = None + self.dim = params.dim + self.knowledge_num = params.knowledge_num + self.knowledge_length = params.knowledge_length + self.tok_embeddings = tok_embeddings + self.num_experts_per_head_topk = 1 + + # 训练步数管理 + self.current_step = 0 + self.realtime_threshold = getattr(params, 'realtime_steps', 800) # 前800步实时计算 + + # 渐进式缓存策略参数 + self.knowledge_update_rate = 0.01 # 每步更新1%的知识 + self.knowledge_per_step = max(1, int(self.knowledge_num * self.knowledge_update_rate)) + self.update_cycle = 100 # 100步循环 + + # 知识库存储 + self.register_buffer('weight_down_embed', + torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long) + ) + + # 嵌入缓存 + self.knowledge_embeddings_cache = None + self.cache_update_mask = torch.zeros(self.knowledge_num, dtype=torch.bool) # 跟踪哪些已更新 + + # 归一化缓存(用于优化相似度计算) + self.normalized_knowledge_cache = None + self.normalization_valid = False + + # 负载均衡组件 + self.register_buffer('usage_counts', torch.zeros(self.knowledge_num)) + self.register_buffer('total_queries', torch.tensor(0.0)) + self.momentum = getattr(params, 'db_momentum', 0.9) + self.balance_strength = getattr(params, 'db_balance_strength', 0.1) + + def should_use_realtime_computation(self): + """判断是否应该使用实时计算""" + return self.current_step < self.realtime_threshold + + def get_knowledge_indices_to_update(self): + """获取本步需要更新的知识条目索引""" + if self.should_use_realtime_computation(): + # 前800步:全部实时计算 + return torch.arange(self.knowledge_num) + + # 后续步数:循环更新策略 + cycle_position = self.current_step % self.update_cycle + start_idx = (cycle_position * self.knowledge_per_step) % self.knowledge_num + end_idx = min(start_idx + self.knowledge_per_step, self.knowledge_num) + + return torch.arange(start_idx, end_idx) + + def update_knowledge_embeddings(self, force_all=False): + """智能更新知识嵌入缓存""" + if force_all or self.should_use_realtime_computation(): + # 全量更新 + indices_to_update = torch.arange(self.knowledge_num) + else: + # 渐进式更新 + indices_to_update = self.get_knowledge_indices_to_update() + + if len(indices_to_update) == 0: + return + + # 初始化缓存 + if self.knowledge_embeddings_cache is None: + # 获取tok_embeddings的dtype,确保类型一致 + dummy_input = torch.zeros(1, dtype=torch.long, device=self.weight_down_embed.device) + dummy_embedding = self.tok_embeddings(dummy_input) + embedding_dtype = dummy_embedding.dtype + + self.knowledge_embeddings_cache = torch.zeros( + self.knowledge_num, self.dim, + device=self.weight_down_embed.device, + dtype=embedding_dtype # 使用与tok_embeddings相同的dtype + ) + + with torch.no_grad(): + # 只更新指定的知识条目 + tokens_to_update = self.weight_down_embed[indices_to_update] # [num_update, knowledge_length] + flat_tokens = tokens_to_update.view(-1) # [num_update * knowledge_length] + + # 批量计算嵌入 + flat_embeddings = self.tok_embeddings(flat_tokens) # [num_update * knowledge_length, dim] + + # 重塑并平均池化 + updated_embeddings = flat_embeddings.view( + len(indices_to_update), self.knowledge_length, -1 + ).mean(dim=1) # [num_update, dim] + + # 更新缓存 - 现在类型应该匹配了 + self.knowledge_embeddings_cache[indices_to_update] = updated_embeddings + self.cache_update_mask[indices_to_update] = True + + # 使归一化缓存失效 + self.normalization_valid = False + + def get_normalized_knowledge_embeddings(self): + """获取归一化的知识嵌入(用于优化相似度计算)""" + if not self.normalization_valid or self.normalized_knowledge_cache is None: + if self.knowledge_embeddings_cache is None: + self.update_knowledge_embeddings(force_all=True) + + self.normalized_knowledge_cache = F.normalize( + self.knowledge_embeddings_cache, dim=-1 + ) + self.normalization_valid = True + + return self.normalized_knowledge_cache + + def optimized_similarity_computation(self, query_features): + """优化的相似度计算""" + # 归一化查询特征 + normalized_query = F.normalize(query_features, dim=-1) # [batch_size, dim] + + # 获取归一化的知识嵌入 + normalized_knowledge = self.get_normalized_knowledge_embeddings() # [knowledge_num, dim] + + # 使用矩阵乘法计算余弦相似度 + similarities = torch.mm(normalized_query, normalized_knowledge.t()) # [batch_size, knowledge_num] + + return similarities + + def apply_load_balancing(self, similarities): + """应用负载均衡策略""" + if not self.training or self.total_queries == 0: + return similarities + + # 计算使用频率 + usage_rates = self.usage_counts / (self.total_queries + 1e-8) + + # 创建平衡偏置(低频率条目获得正偏置) + max_usage = usage_rates.max() + balance_bias = self.balance_strength * (max_usage - usage_rates + 1e-8).log() + + # 应用偏置 + balanced_similarities = similarities + balance_bias.unsqueeze(0) + + return balanced_similarities + + def update_usage_statistics(self, selected_indices): + """更新使用统计""" + if not self.training: + return + + with torch.no_grad(): + # 统计当前batch中每个条目的使用次数 + batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device) + unique_indices, counts = torch.unique(selected_indices, return_counts=True) + batch_usage[unique_indices] = counts.float() + + # 更新统计 + self.usage_counts.copy_( + self.momentum * self.usage_counts + (1 - self.momentum) * batch_usage + ) + self.total_queries.copy_(self.total_queries + selected_indices.numel()) + + def q_to_k(self, x): + """直接语义检索的主方法""" + self.current_step += 1 + batch_size, seq_len, dim = x.shape + + # 智能更新知识嵌入缓存 + self.update_knowledge_embeddings() + + # 计算查询特征(序列平均) + query_features = x.mean(dim=1) # [batch_size, dim] + + # 优化的相似度计算 + similarities = self.optimized_similarity_computation(query_features) + + # 应用负载均衡 + balanced_similarities = self.apply_load_balancing(similarities) + + # 选择top-k + _, indices = balanced_similarities.topk(self.num_experts_per_head_topk, dim=-1) + flat_indices = indices.view(-1) + + # 更新使用统计 + self.update_usage_statistics(flat_indices) + + return flat_indices + + def get_data(self, index): + """获取数据,与原版本兼容""" + return self.weight_down_embed[index] + + @torch.no_grad() + def updata_value(self, k, v): + """更新数据,与原版本兼容""" + v_reshaped = v.view(v.size(0), -1) + v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype) + self.weight_down_embed[k] = v_reshaped + + # 标记相关缓存需要更新 + self.cache_update_mask[k] = False + self.normalization_valid = False