添加直接语义匹配功能,优化数据库检索模块,支持实时计算和动态负载均衡

This commit is contained in:
iomgaa 2025-05-30 14:57:32 +08:00
parent 0b53e1b951
commit 7d726c5b20
2 changed files with 277 additions and 73 deletions

View File

@ -23,6 +23,8 @@ class LMConfig(PretrainedConfig):
# DB related configurations
####################################################
disable_db: bool = False, # 特殊模式:禁用数据库功能
use_direct_semantic: bool = False, # 是否使用直接语义匹配替代Product Key
realtime_steps: int = 2000, # 前多少步使用实时计算(后续使用渐进式缓存)
db_intelligent_balance: bool = True, # 是否启用智能负载均衡
db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤)
db_balance_strength: float = 0.3, # 平衡权重的基础值
@ -62,6 +64,8 @@ class LMConfig(PretrainedConfig):
# DB related configurations
####################################################
self.disable_db = disable_db # 设置是否禁用数据库
self.use_direct_semantic = use_direct_semantic # 是否使用直接语义匹配替代Product Key
self.realtime_steps = realtime_steps # 前多少步使用实时计算(后续使用渐进式缓存)
self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡
self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤)
self.db_balance_strength = db_balance_strength # 平衡权重的基础值

View File

@ -551,7 +551,6 @@ class ExtractDB(nn.Module):
def intelligent_selection(self, query, all_scores, all_indices):
"""智能分层选择策略"""
if not self.enable_intelligent_balance or not self.training:
# 如果禁用智能平衡或在推理模式,使用原始分数
return all_scores
with torch.no_grad():
@ -564,121 +563,115 @@ class ExtractDB(nn.Module):
# 对每个batch进行分层选择
enhanced_scores = all_scores.clone()
# 预先计算query的特征表示取平均
query_features = query.mean(dim=1) # [batch_size, dim]
# 预先计算所有候选条目的嵌入(批量优化)
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
# 批量计算唯一候选条目的嵌入
candidate_tokens = self.weight_down_embed[unique_indices]
flat_tokens = candidate_tokens.view(-1)
flat_embeddings = self.tok_embeddings(flat_tokens)
unique_candidate_features = flat_embeddings.view(
len(unique_indices), self.knowledge_length, -1
).mean(dim=1) # [num_unique_candidates, dim]
# 归一化候选特征(优化相似度计算)
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
normalized_queries = F.normalize(query_features, dim=-1)
for batch_idx in range(batch_size):
indices = all_indices[batch_idx] # 当前batch的候选条目
scores = all_scores[batch_idx] # 当前batch的原始分数
indices = all_indices[batch_idx]
scores = all_scores[batch_idx]
# 第一层基于value内容计算真正的相关性
# 1. 获取候选条目的value tokens只获取当前需要的
candidate_tokens = self.weight_down_embed[indices] # [num_candidates, knowledge_length]
# 获取当前batch候选条目对应的特征索引
start_idx = batch_idx * len(indices)
end_idx = start_idx + len(indices)
batch_inverse_indices = inverse_indices[start_idx:end_idx]
# 2. 高效计算直接使用embedding层避免中间变量
# 将tokens reshape为一维批量计算embeddings然后reshape回来
num_candidates, knowledge_length = candidate_tokens.shape
flat_tokens = candidate_tokens.view(-1) # [num_candidates * knowledge_length]
# 使用预计算的归一化特征进行优化相似度计算
batch_candidate_features = normalized_candidates[batch_inverse_indices]
query_feature = normalized_queries[batch_idx]
# 批量计算所有token的embeddings
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_candidates * knowledge_length, dim]
# 使用矩阵乘法计算余弦相似度
similarity_scores = torch.mv(batch_candidate_features, query_feature)
# Reshape回原始形状并进行mean pooling
candidate_embeddings = flat_embeddings.view(num_candidates, knowledge_length, -1)
candidate_features = candidate_embeddings.mean(dim=1) # [num_candidates, dim]
# 3. 计算query与候选条目的相似度
query_feature = query_features[batch_idx] # [dim]
similarity_scores = F.cosine_similarity(
query_feature.unsqueeze(0), candidate_features, dim=1
) # [num_candidates]
# 4. 将相似度分数归一化为概率分布
# 应用相关性阈值过滤
relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype)
# 相关性阈值:选择概率大于某个阈值的候选项
# 动态阈值:如果所有候选项的相似度都很平均,降低阈值
mean_prob = relevance_probs.mean()
adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5)
relevant_mask = relevance_probs > adaptive_threshold
if relevant_mask.sum() == 0:
# 如果没有足够相关的选择相似度最高的top-k
# 如果没有相关候选,选择相似度最高的
top_k = min(5, len(indices))
_, top_indices = similarity_scores.topk(top_k)
relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool)
relevant_mask[top_indices] = True
# 第二层:在相关候选中应用平衡策略
# 在相关候选中应用负载均衡
if relevant_mask.sum() > 1:
# 计算平衡分数(使用频率低的分数高)
relevant_indices = indices[relevant_mask]
relevant_usage = self.usage_counts[relevant_indices]
# 平衡分数使用频率的倒数加1避免除零
# 计算平衡分数
balance_scores = 1.0 / (relevant_usage + 1.0)
balance_scores = balance_scores / (balance_scores.sum() + 1e-8)
# 相关性分数(基于真实的语义相似度)
# 相关性分数
relevant_rel_scores = relevance_probs[relevant_mask]
relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8)
# 综合分数:动态权重组合
# 综合分数
combined_scores = (self.current_relevance_weight * relevant_rel_scores +
self.current_balance_weight * balance_scores.to(dtype))
# 确保数据类型一致
# 应用调整
adjustment = self.base_balance_strength * combined_scores.to(dtype)
# 将综合分数应用到enhanced_scores
enhanced_scores[batch_idx, relevant_mask] = (
scores[relevant_mask] + adjustment
)
# 清理中间变量,释放显存
del candidate_tokens, flat_tokens, flat_embeddings, candidate_embeddings, candidate_features
enhanced_scores[batch_idx, relevant_mask] = scores[relevant_mask] + adjustment
return enhanced_scores.to(device)
def q_to_k(self,x):
# 1. 生成queries
self.batch_size, seq_len, dim = x.shape
self.batch_size, seq_len, dim = x.shape
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
# collapse sequence dimension by averaging
x_flat = x.mean(dim=1) # [batch_size, dim]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
queries = queries.reshape(self.batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 2. 计算queries与keys的相似度
sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 3. 在两个子空间分别做top-k
scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
# 4. 组合两个子空间的分数和索引
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
# 5. 应用智能分层选择策略
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
# 5. 应用智能分层选择策略
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
# 6. 基于增强后的分数进行最终top-k选择
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
flat_indices = indices.view(-1)
# 6. 基于增强后的分数进行最终top-k选择
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
indices = all_indices.gather(-1, pk_indices)
# 7. 更新使用统计
self.update_usage_statistics(flat_indices)
flat_indices = indices.view(-1)
return flat_indices
# 7. 更新使用统计
self.update_usage_statistics(flat_indices)
return flat_indices
def get_data(self, index):
# 直接从GPU获取embedding
@ -708,8 +701,13 @@ class MiniMindLM(PreTrainedModel):
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.dropout = nn.Dropout(params.dropout)
# 创建ExtractDB传入tok_embeddings引用
self.extract_db = ExtractDB(self.params, self.tok_embeddings)
# 根据配置选择ExtractDB版本
# use_direct_semantic = getattr(params, 'use_direct_semantic', False)
# if use_direct_semantic:
# self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
# else:
# self.extract_db = ExtractDB(self.params, self.tok_embeddings)
self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
# 将self.weight_down_embed传递给每个MiniMindBlock
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
@ -878,3 +876,205 @@ class MiniMindLM(PreTrainedModel):
yield input_ids[:, start:]
if input_ids_next.item() == eos_token_id:
break
class ExtractDB_DirectSemantic(nn.Module):
"""直接语义匹配的数据库检索模块完全移除Product Key"""
def __init__(self, params, tok_embeddings=None):
super().__init__()
self.batch_size = None
self.dim = params.dim
self.knowledge_num = params.knowledge_num
self.knowledge_length = params.knowledge_length
self.tok_embeddings = tok_embeddings
self.num_experts_per_head_topk = 1
# 训练步数管理
self.current_step = 0
self.realtime_threshold = getattr(params, 'realtime_steps', 800) # 前800步实时计算
# 渐进式缓存策略参数
self.knowledge_update_rate = 0.01 # 每步更新1%的知识
self.knowledge_per_step = max(1, int(self.knowledge_num * self.knowledge_update_rate))
self.update_cycle = 100 # 100步循环
# 知识库存储
self.register_buffer('weight_down_embed',
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
)
# 嵌入缓存
self.knowledge_embeddings_cache = None
self.cache_update_mask = torch.zeros(self.knowledge_num, dtype=torch.bool) # 跟踪哪些已更新
# 归一化缓存(用于优化相似度计算)
self.normalized_knowledge_cache = None
self.normalization_valid = False
# 负载均衡组件
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
self.register_buffer('total_queries', torch.tensor(0.0))
self.momentum = getattr(params, 'db_momentum', 0.9)
self.balance_strength = getattr(params, 'db_balance_strength', 0.1)
def should_use_realtime_computation(self):
"""判断是否应该使用实时计算"""
return self.current_step < self.realtime_threshold
def get_knowledge_indices_to_update(self):
"""获取本步需要更新的知识条目索引"""
if self.should_use_realtime_computation():
# 前800步全部实时计算
return torch.arange(self.knowledge_num)
# 后续步数:循环更新策略
cycle_position = self.current_step % self.update_cycle
start_idx = (cycle_position * self.knowledge_per_step) % self.knowledge_num
end_idx = min(start_idx + self.knowledge_per_step, self.knowledge_num)
return torch.arange(start_idx, end_idx)
def update_knowledge_embeddings(self, force_all=False):
"""智能更新知识嵌入缓存"""
if force_all or self.should_use_realtime_computation():
# 全量更新
indices_to_update = torch.arange(self.knowledge_num)
else:
# 渐进式更新
indices_to_update = self.get_knowledge_indices_to_update()
if len(indices_to_update) == 0:
return
# 初始化缓存
if self.knowledge_embeddings_cache is None:
# 获取tok_embeddings的dtype确保类型一致
dummy_input = torch.zeros(1, dtype=torch.long, device=self.weight_down_embed.device)
dummy_embedding = self.tok_embeddings(dummy_input)
embedding_dtype = dummy_embedding.dtype
self.knowledge_embeddings_cache = torch.zeros(
self.knowledge_num, self.dim,
device=self.weight_down_embed.device,
dtype=embedding_dtype # 使用与tok_embeddings相同的dtype
)
with torch.no_grad():
# 只更新指定的知识条目
tokens_to_update = self.weight_down_embed[indices_to_update] # [num_update, knowledge_length]
flat_tokens = tokens_to_update.view(-1) # [num_update * knowledge_length]
# 批量计算嵌入
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_update * knowledge_length, dim]
# 重塑并平均池化
updated_embeddings = flat_embeddings.view(
len(indices_to_update), self.knowledge_length, -1
).mean(dim=1) # [num_update, dim]
# 更新缓存 - 现在类型应该匹配了
self.knowledge_embeddings_cache[indices_to_update] = updated_embeddings
self.cache_update_mask[indices_to_update] = True
# 使归一化缓存失效
self.normalization_valid = False
def get_normalized_knowledge_embeddings(self):
"""获取归一化的知识嵌入(用于优化相似度计算)"""
if not self.normalization_valid or self.normalized_knowledge_cache is None:
if self.knowledge_embeddings_cache is None:
self.update_knowledge_embeddings(force_all=True)
self.normalized_knowledge_cache = F.normalize(
self.knowledge_embeddings_cache, dim=-1
)
self.normalization_valid = True
return self.normalized_knowledge_cache
def optimized_similarity_computation(self, query_features):
"""优化的相似度计算"""
# 归一化查询特征
normalized_query = F.normalize(query_features, dim=-1) # [batch_size, dim]
# 获取归一化的知识嵌入
normalized_knowledge = self.get_normalized_knowledge_embeddings() # [knowledge_num, dim]
# 使用矩阵乘法计算余弦相似度
similarities = torch.mm(normalized_query, normalized_knowledge.t()) # [batch_size, knowledge_num]
return similarities
def apply_load_balancing(self, similarities):
"""应用负载均衡策略"""
if not self.training or self.total_queries == 0:
return similarities
# 计算使用频率
usage_rates = self.usage_counts / (self.total_queries + 1e-8)
# 创建平衡偏置(低频率条目获得正偏置)
max_usage = usage_rates.max()
balance_bias = self.balance_strength * (max_usage - usage_rates + 1e-8).log()
# 应用偏置
balanced_similarities = similarities + balance_bias.unsqueeze(0)
return balanced_similarities
def update_usage_statistics(self, selected_indices):
"""更新使用统计"""
if not self.training:
return
with torch.no_grad():
# 统计当前batch中每个条目的使用次数
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
batch_usage[unique_indices] = counts.float()
# 更新统计
self.usage_counts.copy_(
self.momentum * self.usage_counts + (1 - self.momentum) * batch_usage
)
self.total_queries.copy_(self.total_queries + selected_indices.numel())
def q_to_k(self, x):
"""直接语义检索的主方法"""
self.current_step += 1
batch_size, seq_len, dim = x.shape
# 智能更新知识嵌入缓存
self.update_knowledge_embeddings()
# 计算查询特征(序列平均)
query_features = x.mean(dim=1) # [batch_size, dim]
# 优化的相似度计算
similarities = self.optimized_similarity_computation(query_features)
# 应用负载均衡
balanced_similarities = self.apply_load_balancing(similarities)
# 选择top-k
_, indices = balanced_similarities.topk(self.num_experts_per_head_topk, dim=-1)
flat_indices = indices.view(-1)
# 更新使用统计
self.update_usage_statistics(flat_indices)
return flat_indices
def get_data(self, index):
"""获取数据,与原版本兼容"""
return self.weight_down_embed[index]
@torch.no_grad()
def updata_value(self, k, v):
"""更新数据,与原版本兼容"""
v_reshaped = v.view(v.size(0), -1)
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
self.weight_down_embed[k] = v_reshaped
# 标记相关缓存需要更新
self.cache_update_mask[k] = False
self.normalization_valid = False