update
This commit is contained in:
parent
c96a9c35d5
commit
67c632d010
@ -703,7 +703,7 @@ class MiniMindLM(PreTrainedModel):
|
|||||||
# Process query path as before
|
# Process query path as before
|
||||||
z_q = self.downsample_q_specific(shared_features)
|
z_q = self.downsample_q_specific(shared_features)
|
||||||
z_k = self.extract_db.q_to_k(z_q)
|
z_k = self.extract_db.q_to_k(z_q)
|
||||||
self.extract_db.updata_value(z_k, token_indices)
|
# self.extract_db.updata_value(z_k, token_indices)
|
||||||
|
|
||||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||||
logits = self.output(self.norm(h)[:, slice_indices, :])
|
logits = self.output(self.norm(h)[:, slice_indices, :])
|
||||||
|
@ -203,8 +203,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
# 聚类参数
|
# 聚类参数
|
||||||
knowledge_num = args.knowledge_num
|
knowledge_num = args.knowledge_num
|
||||||
knowledge_length = args.knowledge_length
|
knowledge_length = args.knowledge_length
|
||||||
min_tokens = int(0.9 * knowledge_length)
|
min_tokens = int(0.85 * knowledge_length)
|
||||||
max_tokens = knowledge_length
|
max_tokens = int(0.95 * knowledge_length)
|
||||||
|
|
||||||
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||||
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||||
|
Loading…
x
Reference in New Issue
Block a user