This commit is contained in:
iomgaa 2025-05-27 11:46:18 +08:00
parent c96a9c35d5
commit 67c632d010
2 changed files with 4 additions and 4 deletions

View File

@ -703,7 +703,7 @@ class MiniMindLM(PreTrainedModel):
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, token_indices)
# self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])

View File

@ -203,8 +203,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 聚类参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
min_tokens = int(0.9 * knowledge_length)
max_tokens = knowledge_length
min_tokens = int(0.85 * knowledge_length)
max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
@ -288,7 +288,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences)
cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)