Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
7d726c5b20 | |||
0b53e1b951 |
@ -1,5 +1,5 @@
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class LMConfig(PretrainedConfig):
|
class LMConfig(PretrainedConfig):
|
||||||
@ -12,17 +12,24 @@ class LMConfig(PretrainedConfig):
|
|||||||
n_heads: int = 32,
|
n_heads: int = 32,
|
||||||
n_kv_heads: int = 8,
|
n_kv_heads: int = 8,
|
||||||
vocab_size: int = 6400,
|
vocab_size: int = 6400,
|
||||||
hidden_dim: int = None,
|
hidden_dim: Optional[int] = None,
|
||||||
multiple_of: int = 64,
|
multiple_of: int = 64,
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
max_seq_len: int = 8192,
|
max_seq_len: int = 8192,
|
||||||
rope_theta: int = 1e6,
|
rope_theta: float = 1e6,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
flash_attn: bool = True,
|
flash_attn: bool = True,
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
||||||
|
use_direct_semantic: bool = False, # 是否使用直接语义匹配(替代Product Key)
|
||||||
|
realtime_steps: int = 2000, # 前多少步使用实时计算(后续使用渐进式缓存)
|
||||||
|
db_intelligent_balance: bool = True, # 是否启用智能负载均衡
|
||||||
|
db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤)
|
||||||
|
db_balance_strength: float = 0.3, # 平衡权重的基础值
|
||||||
|
db_momentum: float = 0.9, # 使用频率统计的动量
|
||||||
|
db_adaptive_weights: bool = True, # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
@ -57,6 +64,13 @@ class LMConfig(PretrainedConfig):
|
|||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
self.disable_db = disable_db # 设置是否禁用数据库
|
self.disable_db = disable_db # 设置是否禁用数据库
|
||||||
|
self.use_direct_semantic = use_direct_semantic # 是否使用直接语义匹配(替代Product Key)
|
||||||
|
self.realtime_steps = realtime_steps # 前多少步使用实时计算(后续使用渐进式缓存)
|
||||||
|
self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡
|
||||||
|
self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤)
|
||||||
|
self.db_balance_strength = db_balance_strength # 平衡权重的基础值
|
||||||
|
self.db_momentum = db_momentum # 使用频率统计的动量
|
||||||
|
self.db_adaptive_weights = db_adaptive_weights # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
|
458
model/model.py
458
model/model.py
@ -188,11 +188,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# 应用旋转位置编码(使用实数版本)
|
# 应用旋转位置编码(使用实数版本)
|
||||||
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||||
# kv_cache实现 REMOVED
|
|
||||||
# if past_key_value is not None:
|
|
||||||
# xk = torch.cat([past_key_value[0], xk], dim=1)
|
|
||||||
# xv = torch.cat([past_key_value[1], xv], dim=1)
|
|
||||||
# past_kv = (xk, xv) if use_cache else None
|
|
||||||
|
|
||||||
# 重复键值对
|
# 重复键值对
|
||||||
xq, xk, xv = (
|
xq, xk, xv = (
|
||||||
@ -440,66 +435,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||||
|
|
||||||
# 假设num_experts是已定义的总专家数量的平方根
|
|
||||||
|
|
||||||
|
|
||||||
# 查询生成的参数
|
|
||||||
|
|
||||||
|
|
||||||
# 创建查询生成模块
|
|
||||||
# if weight_down_embed is not None:
|
|
||||||
# self.to_queries = nn.Sequential(
|
|
||||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
|
||||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 超参数
|
|
||||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
|
||||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
|
||||||
|
|
||||||
def forward(self, x, db_value, pos_cis):
|
def forward(self, x, db_value, pos_cis):
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
# db_value = None
|
|
||||||
|
|
||||||
# # 如果有weight_down_embed,使用Product Key机制
|
|
||||||
# if self.weight_down_embed is not None:
|
|
||||||
# # 1. 生成queries
|
|
||||||
# batch_size, seq_len, dim = x.shape
|
|
||||||
|
|
||||||
# # collapse sequence dimension by averaging
|
|
||||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
|
||||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
|
||||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
|
||||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
|
||||||
|
|
||||||
# # 2. 计算queries与keys的相似度
|
|
||||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
|
||||||
|
|
||||||
# # 3. 在两个子空间分别做top-k
|
|
||||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
|
||||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
|
||||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
|
||||||
|
|
||||||
# # 4. 组合两个子空间的分数和索引
|
|
||||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
|
||||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
|
||||||
|
|
||||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
|
||||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
|
||||||
|
|
||||||
# # 5. 最终top-k选择
|
|
||||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
|
||||||
# indices = all_indices.gather(-1, pk_indices)
|
|
||||||
|
|
||||||
# # 6. 从embedding中获取专家值
|
|
||||||
|
|
||||||
# # 从embedding中获取值
|
|
||||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
|
||||||
# db_values = self.weight_down_embed(flat_indices)
|
|
||||||
|
|
||||||
# # 重塑回原始形状
|
|
||||||
# db_value = db_values.view(batch_size, -1, dim)
|
|
||||||
|
|
||||||
|
|
||||||
# 注意力计算
|
# 注意力计算
|
||||||
h_attn = self.attention(
|
h_attn = self.attention(
|
||||||
@ -518,7 +454,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class ExtractDB(nn.Module):
|
class ExtractDB(nn.Module):
|
||||||
def __init__(self,params):
|
def __init__(self, params, tok_embeddings=None):
|
||||||
# 修改专家数量和知识维度,确保能开方
|
# 修改专家数量和知识维度,确保能开方
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = None
|
self.batch_size = None
|
||||||
@ -529,12 +465,27 @@ class ExtractDB(nn.Module):
|
|||||||
self.head_dim = params.dim // params.n_heads
|
self.head_dim = params.dim // params.n_heads
|
||||||
self.knowledge_length = params.knowledge_length
|
self.knowledge_length = params.knowledge_length
|
||||||
|
|
||||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
# 智能负载均衡相关参数
|
||||||
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
|
self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True)
|
||||||
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
|
self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7)
|
||||||
|
self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3)
|
||||||
|
self.momentum = getattr(params, 'db_momentum', 0.9)
|
||||||
|
self.adaptive_weights = getattr(params, 'db_adaptive_weights', True)
|
||||||
|
|
||||||
|
# 动态权重调整参数
|
||||||
|
self.current_relevance_weight = 0.8 # 开始时更重视相关性
|
||||||
|
self.current_balance_weight = 0.2
|
||||||
|
self.weight_update_frequency = 100 # 每100步调整一次权重
|
||||||
|
self.step_counter = 0
|
||||||
|
|
||||||
|
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||||
|
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
|
||||||
|
self.register_buffer('total_queries', torch.tensor(0.0))
|
||||||
|
|
||||||
|
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||||
|
self.register_buffer('weight_down_embed',
|
||||||
|
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
||||||
self.product_key_topk = min(16, self.num_keys)
|
self.product_key_topk = min(16, self.num_keys)
|
||||||
@ -544,6 +495,144 @@ class ExtractDB(nn.Module):
|
|||||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 存储token embeddings的引用,用于计算真实的语义相关性
|
||||||
|
self.tok_embeddings = tok_embeddings
|
||||||
|
|
||||||
|
def update_usage_statistics(self, selected_indices):
|
||||||
|
"""更新数据库条目的使用统计"""
|
||||||
|
if not self.training or not self.enable_intelligent_balance:
|
||||||
|
return
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 统计当前batch中每个条目的使用次数
|
||||||
|
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
|
||||||
|
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
|
||||||
|
batch_usage[unique_indices] = counts.float()
|
||||||
|
|
||||||
|
# 使用简单的tensor操作来更新统计
|
||||||
|
current_usage = self.usage_counts.clone()
|
||||||
|
current_total = self.total_queries.clone()
|
||||||
|
|
||||||
|
new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage
|
||||||
|
new_total = current_total + selected_indices.numel()
|
||||||
|
|
||||||
|
# 直接替换buffer内容
|
||||||
|
self.usage_counts.copy_(new_usage)
|
||||||
|
self.total_queries.copy_(new_total)
|
||||||
|
|
||||||
|
def update_dynamic_weights(self):
|
||||||
|
"""动态调整相关性和平衡权重"""
|
||||||
|
if not self.adaptive_weights or not self.training:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.step_counter += 1
|
||||||
|
|
||||||
|
# 每隔一定步数调整权重
|
||||||
|
if self.step_counter % self.weight_update_frequency == 0:
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.total_queries > 0:
|
||||||
|
# 计算使用分布的方差(不平衡程度)
|
||||||
|
usage_rates = self.usage_counts / self.total_queries
|
||||||
|
usage_variance = usage_rates.var().item()
|
||||||
|
|
||||||
|
# 根据不平衡程度调整权重
|
||||||
|
if usage_variance > 0.01: # 高度不平衡
|
||||||
|
self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1)
|
||||||
|
self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1)
|
||||||
|
elif usage_variance < 0.001: # 已经很平衡
|
||||||
|
self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1)
|
||||||
|
self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1)
|
||||||
|
|
||||||
|
# 确保权重和为1
|
||||||
|
total_weight = self.current_relevance_weight + self.current_balance_weight
|
||||||
|
self.current_relevance_weight /= total_weight
|
||||||
|
self.current_balance_weight /= total_weight
|
||||||
|
|
||||||
|
def intelligent_selection(self, query, all_scores, all_indices):
|
||||||
|
"""智能分层选择策略"""
|
||||||
|
if not self.enable_intelligent_balance or not self.training:
|
||||||
|
return all_scores
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_size = all_scores.size(0)
|
||||||
|
device = all_scores.device
|
||||||
|
dtype = all_scores.dtype
|
||||||
|
|
||||||
|
# 更新动态权重
|
||||||
|
self.update_dynamic_weights()
|
||||||
|
|
||||||
|
# 对每个batch进行分层选择
|
||||||
|
enhanced_scores = all_scores.clone()
|
||||||
|
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
|
# 预先计算所有候选条目的嵌入(批量优化)
|
||||||
|
all_candidate_indices = torch.cat([all_indices[i] for i in range(batch_size)], dim=0)
|
||||||
|
unique_indices, inverse_indices = torch.unique(all_candidate_indices, return_inverse=True)
|
||||||
|
|
||||||
|
# 批量计算唯一候选条目的嵌入
|
||||||
|
candidate_tokens = self.weight_down_embed[unique_indices]
|
||||||
|
flat_tokens = candidate_tokens.view(-1)
|
||||||
|
flat_embeddings = self.tok_embeddings(flat_tokens)
|
||||||
|
unique_candidate_features = flat_embeddings.view(
|
||||||
|
len(unique_indices), self.knowledge_length, -1
|
||||||
|
).mean(dim=1) # [num_unique_candidates, dim]
|
||||||
|
|
||||||
|
# 归一化候选特征(优化相似度计算)
|
||||||
|
normalized_candidates = F.normalize(unique_candidate_features, dim=-1)
|
||||||
|
normalized_queries = F.normalize(query_features, dim=-1)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
indices = all_indices[batch_idx]
|
||||||
|
scores = all_scores[batch_idx]
|
||||||
|
|
||||||
|
# 获取当前batch候选条目对应的特征索引
|
||||||
|
start_idx = batch_idx * len(indices)
|
||||||
|
end_idx = start_idx + len(indices)
|
||||||
|
batch_inverse_indices = inverse_indices[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 使用预计算的归一化特征进行优化相似度计算
|
||||||
|
batch_candidate_features = normalized_candidates[batch_inverse_indices]
|
||||||
|
query_feature = normalized_queries[batch_idx]
|
||||||
|
|
||||||
|
# 使用矩阵乘法计算余弦相似度
|
||||||
|
similarity_scores = torch.mv(batch_candidate_features, query_feature)
|
||||||
|
|
||||||
|
# 应用相关性阈值过滤
|
||||||
|
relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype)
|
||||||
|
mean_prob = relevance_probs.mean()
|
||||||
|
adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5)
|
||||||
|
relevant_mask = relevance_probs > adaptive_threshold
|
||||||
|
|
||||||
|
if relevant_mask.sum() == 0:
|
||||||
|
# 如果没有相关候选,选择相似度最高的
|
||||||
|
top_k = min(5, len(indices))
|
||||||
|
_, top_indices = similarity_scores.topk(top_k)
|
||||||
|
relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool)
|
||||||
|
relevant_mask[top_indices] = True
|
||||||
|
|
||||||
|
# 在相关候选中应用负载均衡
|
||||||
|
if relevant_mask.sum() > 1:
|
||||||
|
relevant_indices = indices[relevant_mask]
|
||||||
|
relevant_usage = self.usage_counts[relevant_indices]
|
||||||
|
|
||||||
|
# 计算平衡分数
|
||||||
|
balance_scores = 1.0 / (relevant_usage + 1.0)
|
||||||
|
balance_scores = balance_scores / (balance_scores.sum() + 1e-8)
|
||||||
|
|
||||||
|
# 相关性分数
|
||||||
|
relevant_rel_scores = relevance_probs[relevant_mask]
|
||||||
|
relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8)
|
||||||
|
|
||||||
|
# 综合分数
|
||||||
|
combined_scores = (self.current_relevance_weight * relevant_rel_scores +
|
||||||
|
self.current_balance_weight * balance_scores.to(dtype))
|
||||||
|
|
||||||
|
# 应用调整
|
||||||
|
adjustment = self.base_balance_strength * combined_scores.to(dtype)
|
||||||
|
enhanced_scores[batch_idx, relevant_mask] = scores[relevant_mask] + adjustment
|
||||||
|
|
||||||
|
return enhanced_scores.to(device)
|
||||||
|
|
||||||
def q_to_k(self,x):
|
def q_to_k(self,x):
|
||||||
# 1. 生成queries
|
# 1. 生成queries
|
||||||
self.batch_size, seq_len, dim = x.shape
|
self.batch_size, seq_len, dim = x.shape
|
||||||
@ -570,10 +659,18 @@ class ExtractDB(nn.Module):
|
|||||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||||
|
|
||||||
# 5. 最终top-k选择
|
# 5. 应用智能分层选择策略
|
||||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
|
||||||
|
|
||||||
|
# 6. 基于增强后的分数进行最终top-k选择
|
||||||
|
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||||
indices = all_indices.gather(-1, pk_indices)
|
indices = all_indices.gather(-1, pk_indices)
|
||||||
|
|
||||||
flat_indices = indices.view(-1)
|
flat_indices = indices.view(-1)
|
||||||
|
|
||||||
|
# 7. 更新使用统计
|
||||||
|
self.update_usage_statistics(flat_indices)
|
||||||
|
|
||||||
return flat_indices
|
return flat_indices
|
||||||
|
|
||||||
def get_data(self, index):
|
def get_data(self, index):
|
||||||
@ -599,10 +696,18 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.params = params or LMConfig()
|
self.params = params or LMConfig()
|
||||||
super().__init__(self.params)
|
super().__init__(self.params)
|
||||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||||
|
|
||||||
|
# 先创建token embeddings
|
||||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||||
self.dropout = nn.Dropout(params.dropout)
|
self.dropout = nn.Dropout(params.dropout)
|
||||||
# 移除旧的weight_down_embed声明
|
|
||||||
self.extract_db = ExtractDB(self.params)
|
# 根据配置选择ExtractDB版本
|
||||||
|
# use_direct_semantic = getattr(params, 'use_direct_semantic', False)
|
||||||
|
# if use_direct_semantic:
|
||||||
|
# self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
|
||||||
|
# else:
|
||||||
|
# self.extract_db = ExtractDB(self.params, self.tok_embeddings)
|
||||||
|
self.extract_db = ExtractDB_DirectSemantic(self.params, self.tok_embeddings)
|
||||||
|
|
||||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||||
@ -652,13 +757,6 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
h_list = []
|
h_list = []
|
||||||
|
|
||||||
for l, layer in enumerate(self.layers):
|
for l, layer in enumerate(self.layers):
|
||||||
# 禁用数据库模式,使用固定值替代数据库查询
|
|
||||||
if self.params.disable_db:
|
|
||||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
|
||||||
batch_size = h.size(0)
|
|
||||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
|
||||||
dtype=h.dtype, device=h.device)
|
|
||||||
else:
|
|
||||||
# 正常模式,使用数据库查询
|
# 正常模式,使用数据库查询
|
||||||
# import pdb;pdb.set_trace()
|
# import pdb;pdb.set_trace()
|
||||||
index = self.extract_db.q_to_k(h)
|
index = self.extract_db.q_to_k(h)
|
||||||
@ -778,3 +876,205 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
yield input_ids[:, start:]
|
yield input_ids[:, start:]
|
||||||
if input_ids_next.item() == eos_token_id:
|
if input_ids_next.item() == eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
class ExtractDB_DirectSemantic(nn.Module):
|
||||||
|
"""直接语义匹配的数据库检索模块,完全移除Product Key"""
|
||||||
|
def __init__(self, params, tok_embeddings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.batch_size = None
|
||||||
|
self.dim = params.dim
|
||||||
|
self.knowledge_num = params.knowledge_num
|
||||||
|
self.knowledge_length = params.knowledge_length
|
||||||
|
self.tok_embeddings = tok_embeddings
|
||||||
|
self.num_experts_per_head_topk = 1
|
||||||
|
|
||||||
|
# 训练步数管理
|
||||||
|
self.current_step = 0
|
||||||
|
self.realtime_threshold = getattr(params, 'realtime_steps', 800) # 前800步实时计算
|
||||||
|
|
||||||
|
# 渐进式缓存策略参数
|
||||||
|
self.knowledge_update_rate = 0.01 # 每步更新1%的知识
|
||||||
|
self.knowledge_per_step = max(1, int(self.knowledge_num * self.knowledge_update_rate))
|
||||||
|
self.update_cycle = 100 # 100步循环
|
||||||
|
|
||||||
|
# 知识库存储
|
||||||
|
self.register_buffer('weight_down_embed',
|
||||||
|
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 嵌入缓存
|
||||||
|
self.knowledge_embeddings_cache = None
|
||||||
|
self.cache_update_mask = torch.zeros(self.knowledge_num, dtype=torch.bool) # 跟踪哪些已更新
|
||||||
|
|
||||||
|
# 归一化缓存(用于优化相似度计算)
|
||||||
|
self.normalized_knowledge_cache = None
|
||||||
|
self.normalization_valid = False
|
||||||
|
|
||||||
|
# 负载均衡组件
|
||||||
|
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
|
||||||
|
self.register_buffer('total_queries', torch.tensor(0.0))
|
||||||
|
self.momentum = getattr(params, 'db_momentum', 0.9)
|
||||||
|
self.balance_strength = getattr(params, 'db_balance_strength', 0.1)
|
||||||
|
|
||||||
|
def should_use_realtime_computation(self):
|
||||||
|
"""判断是否应该使用实时计算"""
|
||||||
|
return self.current_step < self.realtime_threshold
|
||||||
|
|
||||||
|
def get_knowledge_indices_to_update(self):
|
||||||
|
"""获取本步需要更新的知识条目索引"""
|
||||||
|
if self.should_use_realtime_computation():
|
||||||
|
# 前800步:全部实时计算
|
||||||
|
return torch.arange(self.knowledge_num)
|
||||||
|
|
||||||
|
# 后续步数:循环更新策略
|
||||||
|
cycle_position = self.current_step % self.update_cycle
|
||||||
|
start_idx = (cycle_position * self.knowledge_per_step) % self.knowledge_num
|
||||||
|
end_idx = min(start_idx + self.knowledge_per_step, self.knowledge_num)
|
||||||
|
|
||||||
|
return torch.arange(start_idx, end_idx)
|
||||||
|
|
||||||
|
def update_knowledge_embeddings(self, force_all=False):
|
||||||
|
"""智能更新知识嵌入缓存"""
|
||||||
|
if force_all or self.should_use_realtime_computation():
|
||||||
|
# 全量更新
|
||||||
|
indices_to_update = torch.arange(self.knowledge_num)
|
||||||
|
else:
|
||||||
|
# 渐进式更新
|
||||||
|
indices_to_update = self.get_knowledge_indices_to_update()
|
||||||
|
|
||||||
|
if len(indices_to_update) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化缓存
|
||||||
|
if self.knowledge_embeddings_cache is None:
|
||||||
|
# 获取tok_embeddings的dtype,确保类型一致
|
||||||
|
dummy_input = torch.zeros(1, dtype=torch.long, device=self.weight_down_embed.device)
|
||||||
|
dummy_embedding = self.tok_embeddings(dummy_input)
|
||||||
|
embedding_dtype = dummy_embedding.dtype
|
||||||
|
|
||||||
|
self.knowledge_embeddings_cache = torch.zeros(
|
||||||
|
self.knowledge_num, self.dim,
|
||||||
|
device=self.weight_down_embed.device,
|
||||||
|
dtype=embedding_dtype # 使用与tok_embeddings相同的dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 只更新指定的知识条目
|
||||||
|
tokens_to_update = self.weight_down_embed[indices_to_update] # [num_update, knowledge_length]
|
||||||
|
flat_tokens = tokens_to_update.view(-1) # [num_update * knowledge_length]
|
||||||
|
|
||||||
|
# 批量计算嵌入
|
||||||
|
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_update * knowledge_length, dim]
|
||||||
|
|
||||||
|
# 重塑并平均池化
|
||||||
|
updated_embeddings = flat_embeddings.view(
|
||||||
|
len(indices_to_update), self.knowledge_length, -1
|
||||||
|
).mean(dim=1) # [num_update, dim]
|
||||||
|
|
||||||
|
# 更新缓存 - 现在类型应该匹配了
|
||||||
|
self.knowledge_embeddings_cache[indices_to_update] = updated_embeddings
|
||||||
|
self.cache_update_mask[indices_to_update] = True
|
||||||
|
|
||||||
|
# 使归一化缓存失效
|
||||||
|
self.normalization_valid = False
|
||||||
|
|
||||||
|
def get_normalized_knowledge_embeddings(self):
|
||||||
|
"""获取归一化的知识嵌入(用于优化相似度计算)"""
|
||||||
|
if not self.normalization_valid or self.normalized_knowledge_cache is None:
|
||||||
|
if self.knowledge_embeddings_cache is None:
|
||||||
|
self.update_knowledge_embeddings(force_all=True)
|
||||||
|
|
||||||
|
self.normalized_knowledge_cache = F.normalize(
|
||||||
|
self.knowledge_embeddings_cache, dim=-1
|
||||||
|
)
|
||||||
|
self.normalization_valid = True
|
||||||
|
|
||||||
|
return self.normalized_knowledge_cache
|
||||||
|
|
||||||
|
def optimized_similarity_computation(self, query_features):
|
||||||
|
"""优化的相似度计算"""
|
||||||
|
# 归一化查询特征
|
||||||
|
normalized_query = F.normalize(query_features, dim=-1) # [batch_size, dim]
|
||||||
|
|
||||||
|
# 获取归一化的知识嵌入
|
||||||
|
normalized_knowledge = self.get_normalized_knowledge_embeddings() # [knowledge_num, dim]
|
||||||
|
|
||||||
|
# 使用矩阵乘法计算余弦相似度
|
||||||
|
similarities = torch.mm(normalized_query, normalized_knowledge.t()) # [batch_size, knowledge_num]
|
||||||
|
|
||||||
|
return similarities
|
||||||
|
|
||||||
|
def apply_load_balancing(self, similarities):
|
||||||
|
"""应用负载均衡策略"""
|
||||||
|
if not self.training or self.total_queries == 0:
|
||||||
|
return similarities
|
||||||
|
|
||||||
|
# 计算使用频率
|
||||||
|
usage_rates = self.usage_counts / (self.total_queries + 1e-8)
|
||||||
|
|
||||||
|
# 创建平衡偏置(低频率条目获得正偏置)
|
||||||
|
max_usage = usage_rates.max()
|
||||||
|
balance_bias = self.balance_strength * (max_usage - usage_rates + 1e-8).log()
|
||||||
|
|
||||||
|
# 应用偏置
|
||||||
|
balanced_similarities = similarities + balance_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
return balanced_similarities
|
||||||
|
|
||||||
|
def update_usage_statistics(self, selected_indices):
|
||||||
|
"""更新使用统计"""
|
||||||
|
if not self.training:
|
||||||
|
return
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 统计当前batch中每个条目的使用次数
|
||||||
|
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
|
||||||
|
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
|
||||||
|
batch_usage[unique_indices] = counts.float()
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.usage_counts.copy_(
|
||||||
|
self.momentum * self.usage_counts + (1 - self.momentum) * batch_usage
|
||||||
|
)
|
||||||
|
self.total_queries.copy_(self.total_queries + selected_indices.numel())
|
||||||
|
|
||||||
|
def q_to_k(self, x):
|
||||||
|
"""直接语义检索的主方法"""
|
||||||
|
self.current_step += 1
|
||||||
|
batch_size, seq_len, dim = x.shape
|
||||||
|
|
||||||
|
# 智能更新知识嵌入缓存
|
||||||
|
self.update_knowledge_embeddings()
|
||||||
|
|
||||||
|
# 计算查询特征(序列平均)
|
||||||
|
query_features = x.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
|
# 优化的相似度计算
|
||||||
|
similarities = self.optimized_similarity_computation(query_features)
|
||||||
|
|
||||||
|
# 应用负载均衡
|
||||||
|
balanced_similarities = self.apply_load_balancing(similarities)
|
||||||
|
|
||||||
|
# 选择top-k
|
||||||
|
_, indices = balanced_similarities.topk(self.num_experts_per_head_topk, dim=-1)
|
||||||
|
flat_indices = indices.view(-1)
|
||||||
|
|
||||||
|
# 更新使用统计
|
||||||
|
self.update_usage_statistics(flat_indices)
|
||||||
|
|
||||||
|
return flat_indices
|
||||||
|
|
||||||
|
def get_data(self, index):
|
||||||
|
"""获取数据,与原版本兼容"""
|
||||||
|
return self.weight_down_embed[index]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def updata_value(self, k, v):
|
||||||
|
"""更新数据,与原版本兼容"""
|
||||||
|
v_reshaped = v.view(v.size(0), -1)
|
||||||
|
v_reshaped = v_reshaped.to(dtype=self.weight_down_embed.dtype)
|
||||||
|
self.weight_down_embed[k] = v_reshaped
|
||||||
|
|
||||||
|
# 标记相关缓存需要更新
|
||||||
|
self.cache_update_mask[k] = False
|
||||||
|
self.normalization_valid = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user