This commit is contained in:
iomgaa 2025-05-27 11:46:18 +08:00
parent c96a9c35d5
commit 67c632d010
2 changed files with 4 additions and 4 deletions

View File

@ -703,7 +703,7 @@ class MiniMindLM(PreTrainedModel):
# Process query path as before # Process query path as before
z_q = self.downsample_q_specific(shared_features) z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q) z_k = self.extract_db.q_to_k(z_q)
self.extract_db.updata_value(z_k, token_indices) # self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :]) logits = self.output(self.norm(h)[:, slice_indices, :])

View File

@ -203,8 +203,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 聚类参数 # 聚类参数
knowledge_num = args.knowledge_num knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length knowledge_length = args.knowledge_length
min_tokens = int(0.9 * knowledge_length) min_tokens = int(0.85 * knowledge_length)
max_tokens = knowledge_length max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) # 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
@ -288,7 +288,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
# 生成聚类文本 # 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences) cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens # 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)