From 67c632d010198a12d1b70bde4ea3d5127dbeac5d Mon Sep 17 00:00:00 2001 From: iomgaa Date: Tue, 27 May 2025 11:46:18 +0800 Subject: [PATCH] update --- model/model.py | 2 +- train_pretrain_accelerate.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/model/model.py b/model/model.py index ff85866..814674d 100644 --- a/model/model.py +++ b/model/model.py @@ -703,7 +703,7 @@ class MiniMindLM(PreTrainedModel): # Process query path as before z_q = self.downsample_q_specific(shared_features) z_k = self.extract_db.q_to_k(z_q) - self.extract_db.updata_value(z_k, token_indices) + # self.extract_db.updata_value(z_k, token_indices) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output(self.norm(h)[:, slice_indices, :]) diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index dd35a19..54e1c05 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -203,8 +203,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non # 聚类参数 knowledge_num = args.knowledge_num knowledge_length = args.knowledge_length - min_tokens = int(0.9 * knowledge_length) - max_tokens = knowledge_length + min_tokens = int(0.85 * knowledge_length) + max_tokens = int(0.95 * knowledge_length) # 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 @@ -288,7 +288,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non # 生成聚类文本 cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] - cluster_text = '\n'.join(cluster_sentences) + cluster_text = '\n '.join(cluster_sentences) # 转换为tokens cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)