Compare commits
2 Commits
6932e5fa8e
...
0b53e1b951
Author | SHA1 | Date | |
---|---|---|---|
0b53e1b951 | |||
64e92473c3 |
@ -1,5 +1,5 @@
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class LMConfig(PretrainedConfig):
|
class LMConfig(PretrainedConfig):
|
||||||
@ -12,17 +12,22 @@ class LMConfig(PretrainedConfig):
|
|||||||
n_heads: int = 32,
|
n_heads: int = 32,
|
||||||
n_kv_heads: int = 8,
|
n_kv_heads: int = 8,
|
||||||
vocab_size: int = 6400,
|
vocab_size: int = 6400,
|
||||||
hidden_dim: int = None,
|
hidden_dim: Optional[int] = None,
|
||||||
multiple_of: int = 64,
|
multiple_of: int = 64,
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
max_seq_len: int = 8192,
|
max_seq_len: int = 8192,
|
||||||
rope_theta: int = 1e6,
|
rope_theta: float = 1e6,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
flash_attn: bool = True,
|
flash_attn: bool = True,
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
||||||
|
db_intelligent_balance: bool = True, # 是否启用智能负载均衡
|
||||||
|
db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤)
|
||||||
|
db_balance_strength: float = 0.3, # 平衡权重的基础值
|
||||||
|
db_momentum: float = 0.9, # 使用频率统计的动量
|
||||||
|
db_adaptive_weights: bool = True, # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
@ -57,6 +62,11 @@ class LMConfig(PretrainedConfig):
|
|||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
self.disable_db = disable_db # 设置是否禁用数据库
|
self.disable_db = disable_db # 设置是否禁用数据库
|
||||||
|
self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡
|
||||||
|
self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤)
|
||||||
|
self.db_balance_strength = db_balance_strength # 平衡权重的基础值
|
||||||
|
self.db_momentum = db_momentum # 使用频率统计的动量
|
||||||
|
self.db_adaptive_weights = db_adaptive_weights # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
|
268
model/model.py
268
model/model.py
@ -188,11 +188,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# 应用旋转位置编码(使用实数版本)
|
# 应用旋转位置编码(使用实数版本)
|
||||||
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
xq, xk = apply_rotary_emb_real(xq, xk, pos_cis)
|
||||||
# kv_cache实现 REMOVED
|
|
||||||
# if past_key_value is not None:
|
|
||||||
# xk = torch.cat([past_key_value[0], xk], dim=1)
|
|
||||||
# xv = torch.cat([past_key_value[1], xv], dim=1)
|
|
||||||
# past_kv = (xk, xv) if use_cache else None
|
|
||||||
|
|
||||||
# 重复键值对
|
# 重复键值对
|
||||||
xq, xk, xv = (
|
xq, xk, xv = (
|
||||||
@ -440,66 +435,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||||
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
self.feed_forward = FeedForward(config) if not config.use_moe else MOEFeedForward(config)
|
||||||
|
|
||||||
# 假设num_experts是已定义的总专家数量的平方根
|
|
||||||
|
|
||||||
|
|
||||||
# 查询生成的参数
|
|
||||||
|
|
||||||
|
|
||||||
# 创建查询生成模块
|
|
||||||
# if weight_down_embed is not None:
|
|
||||||
# self.to_queries = nn.Sequential(
|
|
||||||
# nn.Linear(config.dim, self.dim_key * 2, bias=False),
|
|
||||||
# # nn.Unflatten(2, (2, self.n_heads, self.dim_key)) # 替代Rearrange
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 超参数
|
|
||||||
# self.product_key_topk = min(16, self.num_keys) # 确保不超过num_keys
|
|
||||||
# self.num_experts_per_head_topk = 1 # 最终每个头选取的专家数
|
|
||||||
|
|
||||||
def forward(self, x, db_value, pos_cis):
|
def forward(self, x, db_value, pos_cis):
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
# db_value = None
|
|
||||||
|
|
||||||
# # 如果有weight_down_embed,使用Product Key机制
|
|
||||||
# if self.weight_down_embed is not None:
|
|
||||||
# # 1. 生成queries
|
|
||||||
# batch_size, seq_len, dim = x.shape
|
|
||||||
|
|
||||||
# # collapse sequence dimension by averaging
|
|
||||||
# x_flat = x.mean(dim=1) # [batch_size, dim]
|
|
||||||
# queries = self.to_queries(x_flat) # [batch_size, 2*dim_key]
|
|
||||||
# queries = queries.reshape(batch_size, 2, self.dim_key) # [batch_size, 2, dim_key]
|
|
||||||
# queries = queries.permute(1, 0, 2) # [2, batch_size, dim_key]
|
|
||||||
|
|
||||||
# # 2. 计算queries与keys的相似度
|
|
||||||
# sim = torch.einsum('p b d, k p d -> p b k', queries, self.keys)
|
|
||||||
|
|
||||||
# # 3. 在两个子空间分别做top-k
|
|
||||||
# scores_and_indices = [sim[p].topk(self.product_key_topk, dim=-1) for p in range(2)]
|
|
||||||
# scores_x, scores_y = scores_and_indices[0][0], scores_and_indices[1][0]
|
|
||||||
# indices_x, indices_y = scores_and_indices[0][1], scores_and_indices[1][1]
|
|
||||||
|
|
||||||
# # 4. 组合两个子空间的分数和索引
|
|
||||||
# all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
|
||||||
# all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
|
||||||
|
|
||||||
# all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
|
||||||
# all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
|
||||||
|
|
||||||
# # 5. 最终top-k选择
|
|
||||||
# scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
|
||||||
# indices = all_indices.gather(-1, pk_indices)
|
|
||||||
|
|
||||||
# # 6. 从embedding中获取专家值
|
|
||||||
|
|
||||||
# # 从embedding中获取值
|
|
||||||
# flat_indices = indices.view(-1) # 将索引展平为一维张量
|
|
||||||
# db_values = self.weight_down_embed(flat_indices)
|
|
||||||
|
|
||||||
# # 重塑回原始形状
|
|
||||||
# db_value = db_values.view(batch_size, -1, dim)
|
|
||||||
|
|
||||||
|
|
||||||
# 注意力计算
|
# 注意力计算
|
||||||
h_attn = self.attention(
|
h_attn = self.attention(
|
||||||
@ -518,7 +454,7 @@ class MiniMindBlock(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class ExtractDB(nn.Module):
|
class ExtractDB(nn.Module):
|
||||||
def __init__(self,params):
|
def __init__(self, params, tok_embeddings=None):
|
||||||
# 修改专家数量和知识维度,确保能开方
|
# 修改专家数量和知识维度,确保能开方
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = None
|
self.batch_size = None
|
||||||
@ -529,12 +465,27 @@ class ExtractDB(nn.Module):
|
|||||||
self.head_dim = params.dim // params.n_heads
|
self.head_dim = params.dim // params.n_heads
|
||||||
self.knowledge_length = params.knowledge_length
|
self.knowledge_length = params.knowledge_length
|
||||||
|
|
||||||
# 使用register_buffer代替nn.Parameter,避免梯度问题
|
# 智能负载均衡相关参数
|
||||||
# self.register_buffer('weight_down_embed', torch.randn(self.knowledge_num, self.knowledge_length) * 0.02)
|
self.enable_intelligent_balance = getattr(params, 'db_intelligent_balance', True)
|
||||||
self.register_buffer('weight_down_embed',torch.randint(low=0,high=6400, size=(self.knowledge_num, self.knowledge_length),dtype=torch.long))
|
self.relevance_threshold = getattr(params, 'db_relevance_threshold', 0.7)
|
||||||
|
self.base_balance_strength = getattr(params, 'db_balance_strength', 0.3)
|
||||||
|
self.momentum = getattr(params, 'db_momentum', 0.9)
|
||||||
|
self.adaptive_weights = getattr(params, 'db_adaptive_weights', True)
|
||||||
|
|
||||||
|
# 动态权重调整参数
|
||||||
|
self.current_relevance_weight = 0.8 # 开始时更重视相关性
|
||||||
|
self.current_balance_weight = 0.2
|
||||||
|
self.weight_update_frequency = 100 # 每100步调整一次权重
|
||||||
|
self.step_counter = 0
|
||||||
|
|
||||||
|
# 使用频率统计 - 使用register_buffer以便在GPU/CPU间正确移动
|
||||||
|
self.register_buffer('usage_counts', torch.zeros(self.knowledge_num))
|
||||||
|
self.register_buffer('total_queries', torch.tensor(0.0))
|
||||||
|
|
||||||
|
# 知识库存储 - 使用register_buffer因为这是整数索引,不需要梯度
|
||||||
|
self.register_buffer('weight_down_embed',
|
||||||
|
torch.randint(low=0, high=6400, size=(self.knowledge_num, self.knowledge_length), dtype=torch.long)
|
||||||
|
)
|
||||||
|
|
||||||
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
self.num_keys = int(math.sqrt(self.knowledge_num)) if self.knowledge_num > 0 else 0
|
||||||
self.product_key_topk = min(16, self.num_keys)
|
self.product_key_topk = min(16, self.num_keys)
|
||||||
@ -544,6 +495,152 @@ class ExtractDB(nn.Module):
|
|||||||
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
nn.Linear(params.dim, self.dim_key * 2, bias=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 存储token embeddings的引用,用于计算真实的语义相关性
|
||||||
|
self.tok_embeddings = tok_embeddings
|
||||||
|
|
||||||
|
def update_usage_statistics(self, selected_indices):
|
||||||
|
"""更新数据库条目的使用统计"""
|
||||||
|
if not self.training or not self.enable_intelligent_balance:
|
||||||
|
return
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 统计当前batch中每个条目的使用次数
|
||||||
|
batch_usage = torch.zeros(self.knowledge_num, device=selected_indices.device)
|
||||||
|
unique_indices, counts = torch.unique(selected_indices, return_counts=True)
|
||||||
|
batch_usage[unique_indices] = counts.float()
|
||||||
|
|
||||||
|
# 使用简单的tensor操作来更新统计
|
||||||
|
current_usage = self.usage_counts.clone()
|
||||||
|
current_total = self.total_queries.clone()
|
||||||
|
|
||||||
|
new_usage = self.momentum * current_usage + (1 - self.momentum) * batch_usage
|
||||||
|
new_total = current_total + selected_indices.numel()
|
||||||
|
|
||||||
|
# 直接替换buffer内容
|
||||||
|
self.usage_counts.copy_(new_usage)
|
||||||
|
self.total_queries.copy_(new_total)
|
||||||
|
|
||||||
|
def update_dynamic_weights(self):
|
||||||
|
"""动态调整相关性和平衡权重"""
|
||||||
|
if not self.adaptive_weights or not self.training:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.step_counter += 1
|
||||||
|
|
||||||
|
# 每隔一定步数调整权重
|
||||||
|
if self.step_counter % self.weight_update_frequency == 0:
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.total_queries > 0:
|
||||||
|
# 计算使用分布的方差(不平衡程度)
|
||||||
|
usage_rates = self.usage_counts / self.total_queries
|
||||||
|
usage_variance = usage_rates.var().item()
|
||||||
|
|
||||||
|
# 根据不平衡程度调整权重
|
||||||
|
if usage_variance > 0.01: # 高度不平衡
|
||||||
|
self.current_relevance_weight = max(0.5, self.current_relevance_weight - 0.1)
|
||||||
|
self.current_balance_weight = min(0.5, self.current_balance_weight + 0.1)
|
||||||
|
elif usage_variance < 0.001: # 已经很平衡
|
||||||
|
self.current_relevance_weight = min(0.9, self.current_relevance_weight + 0.1)
|
||||||
|
self.current_balance_weight = max(0.1, self.current_balance_weight - 0.1)
|
||||||
|
|
||||||
|
# 确保权重和为1
|
||||||
|
total_weight = self.current_relevance_weight + self.current_balance_weight
|
||||||
|
self.current_relevance_weight /= total_weight
|
||||||
|
self.current_balance_weight /= total_weight
|
||||||
|
|
||||||
|
def intelligent_selection(self, query, all_scores, all_indices):
|
||||||
|
"""智能分层选择策略"""
|
||||||
|
if not self.enable_intelligent_balance or not self.training:
|
||||||
|
# 如果禁用智能平衡或在推理模式,使用原始分数
|
||||||
|
return all_scores
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_size = all_scores.size(0)
|
||||||
|
device = all_scores.device
|
||||||
|
dtype = all_scores.dtype
|
||||||
|
|
||||||
|
# 更新动态权重
|
||||||
|
self.update_dynamic_weights()
|
||||||
|
|
||||||
|
# 对每个batch进行分层选择
|
||||||
|
enhanced_scores = all_scores.clone()
|
||||||
|
|
||||||
|
# 预先计算query的特征表示(取平均)
|
||||||
|
query_features = query.mean(dim=1) # [batch_size, dim]
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
indices = all_indices[batch_idx] # 当前batch的候选条目
|
||||||
|
scores = all_scores[batch_idx] # 当前batch的原始分数
|
||||||
|
|
||||||
|
# 第一层:基于value内容计算真正的相关性
|
||||||
|
# 1. 获取候选条目的value tokens(只获取当前需要的)
|
||||||
|
candidate_tokens = self.weight_down_embed[indices] # [num_candidates, knowledge_length]
|
||||||
|
|
||||||
|
# 2. 高效计算:直接使用embedding层,避免中间变量
|
||||||
|
# 将tokens reshape为一维,批量计算embeddings,然后reshape回来
|
||||||
|
num_candidates, knowledge_length = candidate_tokens.shape
|
||||||
|
flat_tokens = candidate_tokens.view(-1) # [num_candidates * knowledge_length]
|
||||||
|
|
||||||
|
# 批量计算所有token的embeddings
|
||||||
|
flat_embeddings = self.tok_embeddings(flat_tokens) # [num_candidates * knowledge_length, dim]
|
||||||
|
|
||||||
|
# Reshape回原始形状并进行mean pooling
|
||||||
|
candidate_embeddings = flat_embeddings.view(num_candidates, knowledge_length, -1)
|
||||||
|
candidate_features = candidate_embeddings.mean(dim=1) # [num_candidates, dim]
|
||||||
|
|
||||||
|
# 3. 计算query与候选条目的相似度
|
||||||
|
query_feature = query_features[batch_idx] # [dim]
|
||||||
|
similarity_scores = F.cosine_similarity(
|
||||||
|
query_feature.unsqueeze(0), candidate_features, dim=1
|
||||||
|
) # [num_candidates]
|
||||||
|
|
||||||
|
# 4. 将相似度分数归一化为概率分布
|
||||||
|
relevance_probs = F.softmax(similarity_scores.float(), dim=-1).to(dtype)
|
||||||
|
|
||||||
|
# 相关性阈值:选择概率大于某个阈值的候选项
|
||||||
|
# 动态阈值:如果所有候选项的相似度都很平均,降低阈值
|
||||||
|
mean_prob = relevance_probs.mean()
|
||||||
|
adaptive_threshold = max(self.relevance_threshold * mean_prob, mean_prob * 0.5)
|
||||||
|
relevant_mask = relevance_probs > adaptive_threshold
|
||||||
|
|
||||||
|
if relevant_mask.sum() == 0:
|
||||||
|
# 如果没有足够相关的,选择相似度最高的top-k
|
||||||
|
top_k = min(5, len(indices))
|
||||||
|
_, top_indices = similarity_scores.topk(top_k)
|
||||||
|
relevant_mask = torch.zeros_like(relevant_mask, dtype=torch.bool)
|
||||||
|
relevant_mask[top_indices] = True
|
||||||
|
|
||||||
|
# 第二层:在相关候选中应用平衡策略
|
||||||
|
if relevant_mask.sum() > 1:
|
||||||
|
# 计算平衡分数(使用频率低的分数高)
|
||||||
|
relevant_indices = indices[relevant_mask]
|
||||||
|
relevant_usage = self.usage_counts[relevant_indices]
|
||||||
|
|
||||||
|
# 平衡分数:使用频率的倒数(加1避免除零)
|
||||||
|
balance_scores = 1.0 / (relevant_usage + 1.0)
|
||||||
|
balance_scores = balance_scores / (balance_scores.sum() + 1e-8)
|
||||||
|
|
||||||
|
# 相关性分数(基于真实的语义相似度)
|
||||||
|
relevant_rel_scores = relevance_probs[relevant_mask]
|
||||||
|
relevant_rel_scores = relevant_rel_scores / (relevant_rel_scores.sum() + 1e-8)
|
||||||
|
|
||||||
|
# 综合分数:动态权重组合
|
||||||
|
combined_scores = (self.current_relevance_weight * relevant_rel_scores +
|
||||||
|
self.current_balance_weight * balance_scores.to(dtype))
|
||||||
|
|
||||||
|
# 确保数据类型一致
|
||||||
|
adjustment = self.base_balance_strength * combined_scores.to(dtype)
|
||||||
|
|
||||||
|
# 将综合分数应用到enhanced_scores
|
||||||
|
enhanced_scores[batch_idx, relevant_mask] = (
|
||||||
|
scores[relevant_mask] + adjustment
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清理中间变量,释放显存
|
||||||
|
del candidate_tokens, flat_tokens, flat_embeddings, candidate_embeddings, candidate_features
|
||||||
|
|
||||||
|
return enhanced_scores.to(device)
|
||||||
|
|
||||||
def q_to_k(self,x):
|
def q_to_k(self,x):
|
||||||
# 1. 生成queries
|
# 1. 生成queries
|
||||||
self.batch_size, seq_len, dim = x.shape
|
self.batch_size, seq_len, dim = x.shape
|
||||||
@ -570,10 +667,17 @@ class ExtractDB(nn.Module):
|
|||||||
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
|
||||||
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
||||||
|
|
||||||
# 5. 最终top-k选择
|
# 5. 应用智能分层选择策略
|
||||||
scores, pk_indices = all_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
enhanced_scores = self.intelligent_selection(x, all_scores, all_indices)
|
||||||
|
|
||||||
|
# 6. 基于增强后的分数进行最终top-k选择
|
||||||
|
scores, pk_indices = enhanced_scores.topk(self.num_experts_per_head_topk, dim=-1)
|
||||||
indices = all_indices.gather(-1, pk_indices)
|
indices = all_indices.gather(-1, pk_indices)
|
||||||
flat_indices = indices.view(-1)
|
flat_indices = indices.view(-1)
|
||||||
|
|
||||||
|
# 7. 更新使用统计
|
||||||
|
self.update_usage_statistics(flat_indices)
|
||||||
|
|
||||||
return flat_indices
|
return flat_indices
|
||||||
|
|
||||||
def get_data(self, index):
|
def get_data(self, index):
|
||||||
@ -599,10 +703,13 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
self.params = params or LMConfig()
|
self.params = params or LMConfig()
|
||||||
super().__init__(self.params)
|
super().__init__(self.params)
|
||||||
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
|
||||||
|
|
||||||
|
# 先创建token embeddings
|
||||||
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
||||||
self.dropout = nn.Dropout(params.dropout)
|
self.dropout = nn.Dropout(params.dropout)
|
||||||
# 移除旧的weight_down_embed声明
|
|
||||||
self.extract_db = ExtractDB(self.params)
|
# 创建ExtractDB,传入tok_embeddings引用
|
||||||
|
self.extract_db = ExtractDB(self.params, self.tok_embeddings)
|
||||||
|
|
||||||
# 将self.weight_down_embed传递给每个MiniMindBlock
|
# 将self.weight_down_embed传递给每个MiniMindBlock
|
||||||
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])
|
||||||
@ -652,20 +759,13 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
h_list = []
|
h_list = []
|
||||||
|
|
||||||
for l, layer in enumerate(self.layers):
|
for l, layer in enumerate(self.layers):
|
||||||
# 禁用数据库模式,使用固定值替代数据库查询
|
# 正常模式,使用数据库查询
|
||||||
if self.params.disable_db:
|
# import pdb;pdb.set_trace()
|
||||||
# 创建一个形状为[batch_size, n_layers, dim]的tensor,所有元素值为1e-4
|
index = self.extract_db.q_to_k(h)
|
||||||
batch_size = h.size(0)
|
|
||||||
db_value = torch.full((batch_size, self.n_layers, self.params.dim), 1e-4,
|
|
||||||
dtype=h.dtype, device=h.device)
|
|
||||||
else:
|
|
||||||
# 正常模式,使用数据库查询
|
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
index = self.extract_db.q_to_k(h)
|
|
||||||
|
|
||||||
token_idx = self.extract_db.get_data(index) #这里是index
|
token_idx = self.extract_db.get_data(index) #这里是index
|
||||||
|
|
||||||
db_value =self.tok_embeddings(token_idx)
|
db_value =self.tok_embeddings(token_idx)
|
||||||
|
|
||||||
h = layer(
|
h = layer(
|
||||||
h, db_value, pos_cis_real
|
h, db_value, pos_cis_real
|
||||||
|
@ -92,316 +92,346 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import os
|
import os
|
||||||
|
|
||||||
Logger(f"Loading database initialization data from {database_init_path}")
|
# 聚类参数(需要提前定义用于缓存检查)
|
||||||
|
|
||||||
# 1. 加载JSON文件并转换为字典
|
|
||||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
|
||||||
database_data = json.load(f)
|
|
||||||
|
|
||||||
# 提取sentences列表
|
|
||||||
sentences_data = database_data.get('sentences', [])
|
|
||||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
|
||||||
|
|
||||||
# 2. 按照importance_score进行排序(从高到低)
|
|
||||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
|
||||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
|
||||||
|
|
||||||
# 3. 下载并初始化本地嵌入模型
|
|
||||||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
|
||||||
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
|
||||||
embedding_cache_dir = "./models/sentence_transformers/cache"
|
|
||||||
os.makedirs(embedding_cache_dir, exist_ok=True)
|
|
||||||
|
|
||||||
Logger(f"Loading embedding model: {embedding_model_name}")
|
|
||||||
try:
|
|
||||||
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
|
||||||
Logger("Embedding model loaded successfully")
|
|
||||||
except Exception as e:
|
|
||||||
Logger(f"Failed to load embedding model: {e}")
|
|
||||||
Logger("Falling back to random embeddings")
|
|
||||||
embedding_model = None
|
|
||||||
|
|
||||||
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
|
||||||
Logger("Processing sentences for embeddings and token lengths...")
|
|
||||||
|
|
||||||
# 提取所有句子
|
|
||||||
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
|
||||||
|
|
||||||
# 批量计算token长度
|
|
||||||
Logger("Computing token lengths...")
|
|
||||||
token_lengths = []
|
|
||||||
for sentence in sentences:
|
|
||||||
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
|
||||||
token_lengths.append(len(tokens))
|
|
||||||
|
|
||||||
# 批量计算嵌入 - 大幅提升速度
|
|
||||||
Logger("Computing embeddings in batches...")
|
|
||||||
embeddings_list = []
|
|
||||||
batch_size = 256 # 可以根据GPU内存调整
|
|
||||||
|
|
||||||
if embedding_model is not None:
|
|
||||||
try:
|
|
||||||
for i in range(0, len(sentences), batch_size):
|
|
||||||
batch_sentences = sentences[i:i+batch_size]
|
|
||||||
batch_embeddings = embedding_model.encode(
|
|
||||||
batch_sentences,
|
|
||||||
convert_to_tensor=False,
|
|
||||||
show_progress_bar=True if i == 0 else False,
|
|
||||||
batch_size=batch_size
|
|
||||||
)
|
|
||||||
embeddings_list.extend(batch_embeddings)
|
|
||||||
|
|
||||||
if (i + batch_size) % (batch_size * 10) == 0:
|
|
||||||
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
|
||||||
|
|
||||||
Logger("Batch embedding computation completed")
|
|
||||||
except Exception as e:
|
|
||||||
Logger(f"Error in batch encoding: {e}")
|
|
||||||
Logger("Falling back to random embeddings")
|
|
||||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
|
||||||
else:
|
|
||||||
# 使用随机嵌入
|
|
||||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
|
||||||
|
|
||||||
# 创建处理后的句子列表
|
|
||||||
processed_sentences = []
|
|
||||||
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
|
||||||
processed_sentences.append({
|
|
||||||
'sentence': sentence_data.get('corrected_sentence', ''),
|
|
||||||
'importance_score': sentence_data.get('importance_score', 0.0),
|
|
||||||
'token_length': token_length,
|
|
||||||
'embedding': embedding, # Convert numpy array to list
|
|
||||||
'original_index': i
|
|
||||||
})
|
|
||||||
|
|
||||||
# # Create a JSON-serializable version for saving
|
|
||||||
# json_serializable_sentences = []
|
|
||||||
# for sentence in processed_sentences:
|
|
||||||
# json_sentence = sentence.copy()
|
|
||||||
# # Convert embedding to list if it's a numpy array
|
|
||||||
# if hasattr(json_sentence['embedding'], 'tolist'):
|
|
||||||
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
|
|
||||||
# json_serializable_sentences.append(json_sentence)
|
|
||||||
|
|
||||||
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
|
|
||||||
|
|
||||||
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
|
|
||||||
|
|
||||||
# 转换为numpy数组以便后续处理
|
|
||||||
embeddings_array = np.array(embeddings_list)
|
|
||||||
token_lengths_array = np.array(token_lengths)
|
|
||||||
|
|
||||||
Logger(f"Embedding processing completed:")
|
|
||||||
Logger(f" - Total sentences: {len(processed_sentences)}")
|
|
||||||
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
|
||||||
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
|
||||||
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
|
||||||
|
|
||||||
# 2. 聚类处理 - 优化版本
|
|
||||||
Logger("Starting optimized clustering process...")
|
|
||||||
|
|
||||||
# 聚类参数
|
|
||||||
knowledge_num = args.knowledge_num
|
knowledge_num = args.knowledge_num
|
||||||
knowledge_length = args.knowledge_length
|
knowledge_length = args.knowledge_length
|
||||||
min_tokens = int(0.85 * knowledge_length)
|
|
||||||
max_tokens = int(0.95 * knowledge_length)
|
|
||||||
|
|
||||||
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
|
||||||
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||||
Logger("Pre-computing similarity matrix for faster clustering...")
|
if cache_dir:
|
||||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
similarity_matrix = cosine_similarity(embeddings_matrix)
|
|
||||||
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
|
||||||
else:
|
|
||||||
similarity_matrix = None
|
|
||||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
|
||||||
|
|
||||||
clustered_rows = []
|
clustered_tensor = None
|
||||||
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
|
||||||
|
|
||||||
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
# 尝试加载缓存的聚类结果
|
||||||
|
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||||
|
try:
|
||||||
|
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
|
||||||
|
clustered_tensor = torch.load(args.cluster_cache_path)
|
||||||
|
|
||||||
# 选择聚类算法
|
# 验证缓存文件的形状是否可用
|
||||||
if args.fast_clustering and len(processed_sentences) > 5000:
|
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
|
||||||
Logger("Using ultra-fast approximate clustering algorithm...")
|
|
||||||
|
|
||||||
# 超快速聚类:随机采样 + 批量处理
|
if cached_knowledge_length == knowledge_length:
|
||||||
import random
|
if cached_knowledge_num >= knowledge_num:
|
||||||
random.seed(42) # 确保可重现性
|
# 缓存足够大,可以截取使用
|
||||||
|
clustered_tensor = clustered_tensor[:knowledge_num, :]
|
||||||
# 按重要性分层采样
|
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
|
||||||
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||||
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
Logger("Skipping database initialization and clustering - using cached results")
|
||||||
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
else:
|
||||||
|
# 缓存太小,需要重新计算
|
||||||
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
for cluster_idx in tqdm(range(knowledge_num)):
|
|
||||||
# 分层选择种子:优先选择高重要性句子
|
|
||||||
if high_importance:
|
|
||||||
seed_pool = high_importance
|
|
||||||
elif medium_importance:
|
|
||||||
seed_pool = medium_importance
|
|
||||||
else:
|
else:
|
||||||
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
# knowledge_length不匹配,需要重新计算
|
||||||
|
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load cached clusters: {e}, recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
|
|
||||||
if not seed_pool:
|
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
|
||||||
break
|
if clustered_tensor is None:
|
||||||
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
# 随机选择种子(在同一重要性层级内)
|
# 1. 加载JSON文件并转换为字典
|
||||||
seed_global_idx = random.choice(seed_pool)
|
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||||
seed_sentence = processed_sentences[seed_global_idx]
|
database_data = json.load(f)
|
||||||
|
|
||||||
# 从所有池中移除种子
|
# 提取sentences列表
|
||||||
for pool in [high_importance, medium_importance, low_importance]:
|
sentences_data = database_data.get('sentences', [])
|
||||||
if seed_global_idx in pool:
|
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||||
pool.remove(seed_global_idx)
|
|
||||||
|
|
||||||
current_cluster_indices = [seed_global_idx]
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
current_tokens = seed_sentence['token_length']
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||||
|
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||||
|
|
||||||
if current_tokens < max_tokens:
|
# 3. 下载并初始化本地嵌入模型
|
||||||
# 快速选择:只从附近的句子中随机选择
|
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||||
all_remaining = high_importance + medium_importance + low_importance
|
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||||
if all_remaining:
|
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||||
# 随机采样候选句子(而不是计算所有相似度)
|
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||||
sample_size = min(100, len(all_remaining))
|
|
||||||
candidates = random.sample(all_remaining, sample_size)
|
|
||||||
|
|
||||||
# 简单按token长度和重要性选择
|
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||||
for candidate_idx in candidates:
|
try:
|
||||||
candidate = processed_sentences[candidate_idx]
|
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||||
candidate_tokens = candidate['token_length']
|
Logger("Embedding model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load embedding model: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||||
current_cluster_indices.append(candidate_idx)
|
Logger("Processing sentences for embeddings and token lengths...")
|
||||||
current_tokens += candidate_tokens + 1
|
|
||||||
|
|
||||||
# 从池中移除
|
# 提取所有句子
|
||||||
for pool in [high_importance, medium_importance, low_importance]:
|
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||||
if candidate_idx in pool:
|
|
||||||
pool.remove(candidate_idx)
|
|
||||||
break
|
|
||||||
|
|
||||||
if current_tokens >= min_tokens:
|
# 批量计算token长度
|
||||||
break
|
Logger("Computing token lengths...")
|
||||||
|
token_lengths = []
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
|
token_lengths.append(len(tokens))
|
||||||
|
|
||||||
# 生成聚类文本
|
# 批量计算嵌入 - 大幅提升速度
|
||||||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
Logger("Computing embeddings in batches...")
|
||||||
cluster_text = '\n '.join(cluster_sentences)
|
embeddings_list = []
|
||||||
|
batch_size = 256 # 可以根据GPU内存调整
|
||||||
|
|
||||||
# 转换为tokens
|
if embedding_model is not None:
|
||||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
try:
|
||||||
if len(cluster_tokens) > knowledge_length:
|
for i in range(0, len(sentences), batch_size):
|
||||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
batch_sentences = sentences[i:i+batch_size]
|
||||||
else:
|
batch_embeddings = embedding_model.encode(
|
||||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
batch_sentences,
|
||||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=True if i == 0 else False,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
embeddings_list.extend(batch_embeddings)
|
||||||
|
|
||||||
clustered_rows.append(cluster_tokens)
|
if (i + batch_size) % (batch_size * 10) == 0:
|
||||||
|
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||||
|
|
||||||
if (cluster_idx + 1) % 1000 == 0:
|
Logger("Batch embedding computation completed")
|
||||||
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
except Exception as e:
|
||||||
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
Logger(f"Error in batch encoding: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
else:
|
||||||
|
# 使用随机嵌入
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
|
||||||
else:
|
# 创建处理后的句子列表
|
||||||
# 原始优化算法(适用于中等规模数据集)
|
processed_sentences = []
|
||||||
# 优化2: 批量处理和更高效的数据结构
|
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||||
for cluster_idx in tqdm(range(knowledge_num)):
|
processed_sentences.append({
|
||||||
if not remaining_indices:
|
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||||
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||||
break
|
'token_length': token_length,
|
||||||
|
'embedding': embedding, # Convert numpy array to list
|
||||||
|
'original_index': i
|
||||||
|
})
|
||||||
|
|
||||||
# 2.1 选择importance_score最高的句子作为种子
|
# 转换为numpy数组以便后续处理
|
||||||
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
embeddings_array = np.array(embeddings_list)
|
||||||
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
token_lengths_array = np.array(token_lengths)
|
||||||
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
|
||||||
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
|
||||||
seed_sentence = processed_sentences[seed_global_idx]
|
|
||||||
|
|
||||||
# 从剩余索引中移除种子
|
Logger(f"Embedding processing completed:")
|
||||||
remaining_indices.remove(seed_global_idx)
|
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||||
|
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||||
|
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||||
|
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||||
|
|
||||||
# 当前聚类
|
# 聚类参数定义
|
||||||
current_cluster_indices = [seed_global_idx]
|
min_tokens = int(0.85 * knowledge_length)
|
||||||
current_tokens = seed_sentence['token_length']
|
max_tokens = int(0.95 * knowledge_length)
|
||||||
|
|
||||||
if current_tokens >= max_tokens:
|
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||||
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||||
cluster_text = seed_sentence['sentence']
|
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||||
else:
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
# 2.2 优化的相似度计算和选择
|
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||||
if remaining_indices:
|
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||||
if similarity_matrix is not None:
|
else:
|
||||||
# 使用预计算的相似度矩阵
|
similarity_matrix = None
|
||||||
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
else:
|
|
||||||
# 动态计算相似度(批量)
|
|
||||||
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
|
||||||
remaining_embeddings = embeddings_matrix[remaining_indices]
|
|
||||||
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
|
||||||
|
|
||||||
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
clustered_rows = []
|
||||||
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||||
for i in range(len(remaining_indices))]
|
|
||||||
|
|
||||||
# 按相似度排序(降序)
|
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||||
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
# 选择聚类算法
|
||||||
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||||
|
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||||
|
|
||||||
selected_indices_in_remaining = []
|
# 超快速聚类:随机采样 + 批量处理
|
||||||
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
import random
|
||||||
candidate = processed_sentences[global_idx]
|
random.seed(42) # 确保可重现性
|
||||||
candidate_tokens = candidate['token_length']
|
|
||||||
|
|
||||||
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
# 按重要性分层采样
|
||||||
current_cluster_indices.append(global_idx)
|
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||||
selected_indices_in_remaining.append(pos_in_remaining)
|
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||||
current_tokens += candidate_tokens + 1
|
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||||
|
|
||||||
if current_tokens >= min_tokens:
|
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||||
break
|
|
||||||
|
|
||||||
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
# 分层选择种子:优先选择高重要性句子
|
||||||
remaining_indices.pop(pos)
|
if high_importance:
|
||||||
|
seed_pool = high_importance
|
||||||
|
elif medium_importance:
|
||||||
|
seed_pool = medium_importance
|
||||||
|
else:
|
||||||
|
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||||
|
|
||||||
# 拼接句子
|
if not seed_pool:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 随机选择种子(在同一重要性层级内)
|
||||||
|
seed_global_idx = random.choice(seed_pool)
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从所有池中移除种子
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if seed_global_idx in pool:
|
||||||
|
pool.remove(seed_global_idx)
|
||||||
|
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens < max_tokens:
|
||||||
|
# 快速选择:只从附近的句子中随机选择
|
||||||
|
all_remaining = high_importance + medium_importance + low_importance
|
||||||
|
if all_remaining:
|
||||||
|
# 随机采样候选句子(而不是计算所有相似度)
|
||||||
|
sample_size = min(2000, len(all_remaining))
|
||||||
|
candidates = random.sample(all_remaining, sample_size)
|
||||||
|
|
||||||
|
# 简单按token长度和重要性选择
|
||||||
|
for candidate_idx in candidates:
|
||||||
|
candidate = processed_sentences[candidate_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||||
|
current_cluster_indices.append(candidate_idx)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
# 从池中移除
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if candidate_idx in pool:
|
||||||
|
pool.remove(candidate_idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 生成聚类文本
|
||||||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
cluster_text = '\n'.join(cluster_sentences)
|
cluster_text = '\n '.join(cluster_sentences)
|
||||||
|
|
||||||
# 将聚类文本转换为token
|
# 转换为tokens
|
||||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
# 截断或填充到knowledge_length
|
clustered_rows.append(cluster_tokens)
|
||||||
if len(cluster_tokens) > knowledge_length:
|
|
||||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
|
||||||
else:
|
|
||||||
# 用pad_token_id填充
|
|
||||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
|
||||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
|
||||||
|
|
||||||
clustered_rows.append(cluster_tokens)
|
if (cluster_idx + 1) % 1000 == 0:
|
||||||
|
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||||
|
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||||
|
|
||||||
# 优化4: 减少日志频率
|
else:
|
||||||
if (cluster_idx + 1) % 500 == 0:
|
# 原始优化算法(适用于中等规模数据集)
|
||||||
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
# 优化2: 批量处理和更高效的数据结构
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
if not remaining_indices:
|
||||||
|
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||||
|
break
|
||||||
|
|
||||||
# 如果聚类数量不足,用随机token填充
|
# 2.1 选择importance_score最高的句子作为种子
|
||||||
while len(clustered_rows) < knowledge_num:
|
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||||
random_tokens = [pad_token_id] * knowledge_length
|
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||||
clustered_rows.append(random_tokens)
|
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
# 转换为tensor
|
# 从剩余索引中移除种子
|
||||||
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
remaining_indices.remove(seed_global_idx)
|
||||||
|
|
||||||
Logger(f"Clustering completed:")
|
# 当前聚类
|
||||||
Logger(f" - Created {len(clustered_rows)} clusters")
|
current_cluster_indices = [seed_global_idx]
|
||||||
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
current_tokens = seed_sentence['token_length']
|
||||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
|
||||||
|
if current_tokens >= max_tokens:
|
||||||
|
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||||
|
cluster_text = seed_sentence['sentence']
|
||||||
|
else:
|
||||||
|
# 2.2 优化的相似度计算和选择
|
||||||
|
if remaining_indices:
|
||||||
|
if similarity_matrix is not None:
|
||||||
|
# 使用预计算的相似度矩阵
|
||||||
|
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||||
|
else:
|
||||||
|
# 动态计算相似度(批量)
|
||||||
|
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||||
|
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||||
|
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||||
|
|
||||||
|
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||||
|
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||||
|
for i in range(len(remaining_indices))]
|
||||||
|
|
||||||
|
# 按相似度排序(降序)
|
||||||
|
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||||
|
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||||
|
|
||||||
|
selected_indices_in_remaining = []
|
||||||
|
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||||
|
candidate = processed_sentences[global_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||||
|
current_cluster_indices.append(global_idx)
|
||||||
|
selected_indices_in_remaining.append(pos_in_remaining)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||||
|
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||||
|
remaining_indices.pop(pos)
|
||||||
|
|
||||||
|
# 拼接句子
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 将聚类文本转换为token
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# 截断或填充到knowledge_length
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
# 用pad_token_id填充
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
# 优化4: 减少日志频率
|
||||||
|
if (cluster_idx + 1) % 500 == 0:
|
||||||
|
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||||
|
|
||||||
|
# 如果聚类数量不足,用随机token填充
|
||||||
|
while len(clustered_rows) < knowledge_num:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
random_tokens = [pad_token_id] * knowledge_length
|
||||||
|
clustered_rows.append(random_tokens)
|
||||||
|
|
||||||
|
# 转换为tensor
|
||||||
|
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||||
|
|
||||||
|
Logger(f"Clustering completed:")
|
||||||
|
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||||
|
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||||
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
|
# 保存聚类结果到缓存文件
|
||||||
|
try:
|
||||||
|
torch.save(clustered_tensor, args.cluster_cache_path)
|
||||||
|
Logger(f"Cluster results saved to {args.cluster_cache_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save cluster results: {e}")
|
||||||
|
|
||||||
# 3. 初始化模型的weight_down_embed
|
# 3. 初始化模型的weight_down_embed
|
||||||
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||||
@ -651,10 +681,12 @@ def main():
|
|||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目")
|
||||||
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
|
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径")
|
||||||
|
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user