Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
7d726c5b20 | |||
0b53e1b951 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -7,5 +7,3 @@ models/sentence_transformers/
|
|||||||
models/sentence_transformers_cache/
|
models/sentence_transformers_cache/
|
||||||
**/*.pyc
|
**/*.pyc
|
||||||
qwen2-1.7B/
|
qwen2-1.7B/
|
||||||
images/
|
|
||||||
cache/
|
|
102
.vscode/launch.json
vendored
102
.vscode/launch.json
vendored
@ -1,102 +0,0 @@
|
|||||||
{
|
|
||||||
"version": "0.2.0",
|
|
||||||
"configurations": [
|
|
||||||
{
|
|
||||||
"name": "Debug Train Pretrain Accelerate",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"python": "/opt/conda/envs/mini/bin/python",
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}",
|
|
||||||
"CUDA_VISIBLE_DEVICES": "0"
|
|
||||||
},
|
|
||||||
"justMyCode": false,
|
|
||||||
"stopOnEntry": false,
|
|
||||||
"redirectOutput": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Debug Train Pretrain Accelerate (Multi-GPU)",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"python": "/opt/conda/envs/mini/bin/python",
|
|
||||||
"args": [
|
|
||||||
"--hidden_size", "512",
|
|
||||||
"--max_seq_len", "512",
|
|
||||||
"--n_layers", "8",
|
|
||||||
"--batch_size", "8",
|
|
||||||
"--epochs", "1"
|
|
||||||
],
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}",
|
|
||||||
"CUDA_VISIBLE_DEVICES": "0,1"
|
|
||||||
},
|
|
||||||
"justMyCode": false,
|
|
||||||
"stopOnEntry": false,
|
|
||||||
"redirectOutput": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Debug Train Pretrain Accelerate (Small Test)",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"python": "/opt/conda/envs/mini/bin/python",
|
|
||||||
"args": [
|
|
||||||
"--hidden_size", "512",
|
|
||||||
"--max_seq_len", "512",
|
|
||||||
"--n_layers", "8",
|
|
||||||
"--batch_size", "2",
|
|
||||||
"--epochs", "1",
|
|
||||||
"--log_interval", "10",
|
|
||||||
"--save_interval", "100",
|
|
||||||
"--accumulation_steps", "4"
|
|
||||||
],
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}",
|
|
||||||
"CUDA_VISIBLE_DEVICES": "0",
|
|
||||||
"WANDB_MODE": "offline"
|
|
||||||
},
|
|
||||||
"justMyCode": false,
|
|
||||||
"stopOnEntry": false,
|
|
||||||
"redirectOutput": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Debug ExtractDB Comparison",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "${workspaceFolder}/train_pretrain_accelerate.py",
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"python": "/opt/conda/envs/mini/bin/python",
|
|
||||||
"args": [
|
|
||||||
"--hidden_size", "512",
|
|
||||||
"--max_seq_len", "256",
|
|
||||||
"--n_layers", "4",
|
|
||||||
"--batch_size", "2",
|
|
||||||
"--epochs", "1",
|
|
||||||
"--log_interval", "10",
|
|
||||||
"--save_interval", "200",
|
|
||||||
"--accumulation_steps", "2",
|
|
||||||
"--comparison_mode",
|
|
||||||
"--knowledge_num", "256",
|
|
||||||
"--knowledge_length", "64",
|
|
||||||
"--comparison_mode"
|
|
||||||
],
|
|
||||||
"cwd": "${workspaceFolder}",
|
|
||||||
"env": {
|
|
||||||
"PYTHONPATH": "${workspaceFolder}",
|
|
||||||
"CUDA_VISIBLE_DEVICES": "0",
|
|
||||||
"WANDB_MODE": "offline"
|
|
||||||
},
|
|
||||||
"justMyCode": false,
|
|
||||||
"stopOnEntry": false,
|
|
||||||
"redirectOutput": true
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
18
.vscode/settings.json
vendored
18
.vscode/settings.json
vendored
@ -1,18 +0,0 @@
|
|||||||
{
|
|
||||||
"python.pythonPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
|
||||||
"python.defaultInterpreterPath": "/home/iomgaa/miniconda3/envs/accelerate/bin/python",
|
|
||||||
"python.terminal.activateEnvironment": true,
|
|
||||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
|
||||||
"python.linting.enabled": true,
|
|
||||||
"python.linting.pylintEnabled": false,
|
|
||||||
"python.linting.flake8Enabled": true,
|
|
||||||
"python.formatting.provider": "black",
|
|
||||||
"python.analysis.autoImportCompletions": true,
|
|
||||||
"python.analysis.typeCheckingMode": "off",
|
|
||||||
"files.exclude": {
|
|
||||||
"**/__pycache__": true,
|
|
||||||
"**/*.pyc": true,
|
|
||||||
"**/.git": false,
|
|
||||||
"**/wandb": false
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,5 +1,5 @@
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class LMConfig(PretrainedConfig):
|
class LMConfig(PretrainedConfig):
|
||||||
@ -12,18 +12,24 @@ class LMConfig(PretrainedConfig):
|
|||||||
n_heads: int = 32,
|
n_heads: int = 32,
|
||||||
n_kv_heads: int = 8,
|
n_kv_heads: int = 8,
|
||||||
vocab_size: int = 6400,
|
vocab_size: int = 6400,
|
||||||
hidden_dim: int = None,
|
hidden_dim: Optional[int] = None,
|
||||||
multiple_of: int = 64,
|
multiple_of: int = 64,
|
||||||
norm_eps: float = 1e-5,
|
norm_eps: float = 1e-5,
|
||||||
max_seq_len: int = 8192,
|
max_seq_len: int = 8192,
|
||||||
rope_theta: int = 1e6,
|
rope_theta: float = 1e6,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
flash_attn: bool = True,
|
flash_attn: bool = True,
|
||||||
embeddings_epoch: int = 2,
|
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
disable_db: bool = False, # 特殊模式:禁用数据库功能
|
||||||
|
use_direct_semantic: bool = False, # 是否使用直接语义匹配(替代Product Key)
|
||||||
|
realtime_steps: int = 2000, # 前多少步使用实时计算(后续使用渐进式缓存)
|
||||||
|
db_intelligent_balance: bool = True, # 是否启用智能负载均衡
|
||||||
|
db_relevance_threshold: float = 0.7, # 相关性阈值(第一层过滤)
|
||||||
|
db_balance_strength: float = 0.3, # 平衡权重的基础值
|
||||||
|
db_momentum: float = 0.9, # 使用频率统计的动量
|
||||||
|
db_adaptive_weights: bool = True, # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
@ -40,7 +46,6 @@ class LMConfig(PretrainedConfig):
|
|||||||
####################################################
|
####################################################
|
||||||
knowledge_num: int = 64*64,
|
knowledge_num: int = 64*64,
|
||||||
knowledge_length: int = 8,
|
knowledge_length: int = 8,
|
||||||
knowledge_dim: int = 128,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -55,11 +60,17 @@ class LMConfig(PretrainedConfig):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.flash_attn = flash_attn
|
self.flash_attn = flash_attn
|
||||||
self.embeddings_epoch = embeddings_epoch
|
|
||||||
####################################################
|
####################################################
|
||||||
# DB related configurations
|
# DB related configurations
|
||||||
####################################################
|
####################################################
|
||||||
self.disable_db = disable_db # 设置是否禁用数据库
|
self.disable_db = disable_db # 设置是否禁用数据库
|
||||||
|
self.use_direct_semantic = use_direct_semantic # 是否使用直接语义匹配(替代Product Key)
|
||||||
|
self.realtime_steps = realtime_steps # 前多少步使用实时计算(后续使用渐进式缓存)
|
||||||
|
self.db_intelligent_balance = db_intelligent_balance # 是否启用智能负载均衡
|
||||||
|
self.db_relevance_threshold = db_relevance_threshold # 相关性阈值(第一层过滤)
|
||||||
|
self.db_balance_strength = db_balance_strength # 平衡权重的基础值
|
||||||
|
self.db_momentum = db_momentum # 使用频率统计的动量
|
||||||
|
self.db_adaptive_weights = db_adaptive_weights # 是否启用动态权重调整
|
||||||
####################################################
|
####################################################
|
||||||
# Here are the specific configurations of MOE
|
# Here are the specific configurations of MOE
|
||||||
# When use_moe is false, the following is invalid
|
# When use_moe is false, the following is invalid
|
||||||
@ -75,5 +86,4 @@ class LMConfig(PretrainedConfig):
|
|||||||
####################################################
|
####################################################
|
||||||
self.knowledge_num = knowledge_num
|
self.knowledge_num = knowledge_num
|
||||||
self.knowledge_length = knowledge_length
|
self.knowledge_length = knowledge_length
|
||||||
self.knowledge_dim = knowledge_dim
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
997
model/model.py
997
model/model.py
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,8 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# 激活conda环境
|
# 激活conda环境
|
||||||
source $(conda info --base)/etc/profile.d/conda.sh
|
# source $(conda info --base)/etc/profile.d/conda.sh
|
||||||
conda activate mini
|
# conda activate ycz_accelerate
|
||||||
|
|
||||||
# 设置环境变量以帮助调试
|
# 设置环境变量以帮助调试
|
||||||
export NCCL_DEBUG=INFO
|
export NCCL_DEBUG=INFO
|
||||||
@ -26,9 +26,24 @@ export PYTHONFAULTHANDLER=1
|
|||||||
# --profile_interval 10
|
# --profile_interval 10
|
||||||
|
|
||||||
# 方法2: 使用命令行参数直接配置accelerate
|
# 方法2: 使用命令行参数直接配置accelerate
|
||||||
CUDA_VISIBLE_DEVICES=0 /opt/conda/envs/mini/bin/python -m accelerate.commands.launch \
|
CUDA_VISIBLE_DEVICES=0 accelerate launch \
|
||||||
--num_processes=1 \
|
--num_processes=1 \
|
||||||
--mixed_precision=bf16 \
|
--mixed_precision=bf16 \
|
||||||
--main_process_port=29500 \
|
--main_process_port=29500 \
|
||||||
train_pretrain_accelerate.py \
|
train_pretrain_accelerate.py \
|
||||||
|
--epochs 3 \
|
||||||
|
--batch_size 24 \
|
||||||
|
--learning_rate 2e-4 \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--accumulation_steps 32 \
|
||||||
|
--grad_clip 1.0 \
|
||||||
|
--log_interval 100 \
|
||||||
|
--save_interval 10000 \
|
||||||
|
--dim 512 \
|
||||||
|
--n_layers 12 \
|
||||||
|
--max_seq_len 512 \
|
||||||
|
--use_flash_attn \
|
||||||
|
--profile \
|
||||||
|
--profile_interval 10\
|
||||||
|
--knowledge_num 4096 \
|
||||||
|
--knowledge_length 8
|
||||||
|
@ -74,8 +74,8 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
nn.init.ones_(module.weight)
|
nn.init.ones_(module.weight)
|
||||||
|
|
||||||
# 初始化位置编码相关参数
|
# 初始化位置编码相关参数
|
||||||
if hasattr(model.knowledge_dataset, 'keys'):
|
if hasattr(model.extract_db, 'keys'):
|
||||||
nn.init.normal_(model.knowledge_dataset.keys, mean=0.0, std=0.02)
|
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02)
|
||||||
|
|
||||||
Logger("Default model initialization completed")
|
Logger("Default model initialization completed")
|
||||||
|
|
||||||
@ -88,52 +88,54 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
|
|
||||||
if database_init_path:
|
if database_init_path:
|
||||||
import json
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 数据库参数
|
# 聚类参数(需要提前定义用于缓存检查)
|
||||||
knowledge_num = args.knowledge_num
|
knowledge_num = args.knowledge_num
|
||||||
knowledge_length = args.knowledge_length
|
knowledge_length = args.knowledge_length
|
||||||
|
|
||||||
# 检查是否使用缓存
|
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
|
||||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||||
if cache_dir:
|
if cache_dir:
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
processed_tensor = None
|
clustered_tensor = None
|
||||||
|
|
||||||
# 尝试加载缓存的处理结果
|
# 尝试加载缓存的聚类结果
|
||||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||||
try:
|
try:
|
||||||
Logger(f"Loading cached processed results from {args.cluster_cache_path}")
|
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
|
||||||
processed_tensor = torch.load(args.cluster_cache_path)
|
clustered_tensor = torch.load(args.cluster_cache_path)
|
||||||
|
|
||||||
# 验证缓存文件的形状是否可用
|
# 验证缓存文件的形状是否可用
|
||||||
cached_knowledge_num, cached_knowledge_length = processed_tensor.shape
|
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
|
||||||
|
|
||||||
if cached_knowledge_length == knowledge_length:
|
if cached_knowledge_length == knowledge_length:
|
||||||
if cached_knowledge_num >= knowledge_num:
|
if cached_knowledge_num >= knowledge_num:
|
||||||
# 缓存足够大,可以截取使用
|
# 缓存足够大,可以截取使用
|
||||||
processed_tensor = processed_tensor[:knowledge_num, :]
|
clustered_tensor = clustered_tensor[:knowledge_num, :]
|
||||||
Logger(f"Successfully loaded cached data with shape {processed_tensor.shape}")
|
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
|
||||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||||
Logger("Skipping database initialization - using cached results")
|
Logger("Skipping database initialization and clustering - using cached results")
|
||||||
else:
|
else:
|
||||||
# 缓存太小,需要重新计算
|
# 缓存太小,需要重新计算
|
||||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||||
processed_tensor = None
|
clustered_tensor = None
|
||||||
else:
|
else:
|
||||||
# knowledge_length不匹配,需要重新计算
|
# knowledge_length不匹配,需要重新计算
|
||||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||||
processed_tensor = None
|
clustered_tensor = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Failed to load cached data: {e}, recomputing...")
|
Logger(f"Failed to load cached clusters: {e}, recomputing...")
|
||||||
processed_tensor = None
|
clustered_tensor = None
|
||||||
|
|
||||||
# 只有在没有有效缓存时才进行数据库初始化和处理
|
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
|
||||||
if processed_tensor is None:
|
if clustered_tensor is None:
|
||||||
Logger(f"Loading database initialization data from {database_init_path}")
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
# 1. 加载JSON文件
|
# 1. 加载JSON文件并转换为字典
|
||||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||||
database_data = json.load(f)
|
database_data = json.load(f)
|
||||||
|
|
||||||
@ -145,73 +147,300 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||||
|
|
||||||
# 3. 处理每条数据,不进行聚类
|
# 3. 下载并初始化本地嵌入模型
|
||||||
Logger("Processing individual sentences...")
|
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||||
processed_rows = []
|
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||||
|
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||||
|
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# 获取空token的id(用于填充)
|
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
try:
|
||||||
|
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||||
|
Logger("Embedding model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load embedding model: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
# 处理所需数量的句子
|
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||||
num_to_process = min(knowledge_num, len(sorted_sentences))
|
Logger("Processing sentences for embeddings and token lengths...")
|
||||||
|
|
||||||
for i in range(num_to_process):
|
# 提取所有句子
|
||||||
sentence_data = sorted_sentences[i]
|
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||||
sentence = sentence_data.get('corrected_sentence', '')
|
|
||||||
|
|
||||||
# 将句子转换为tokens
|
# 批量计算token长度
|
||||||
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
Logger("Computing token lengths...")
|
||||||
|
token_lengths = []
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
|
token_lengths.append(len(tokens))
|
||||||
|
|
||||||
# 截断或填充到knowledge_length
|
# 批量计算嵌入 - 大幅提升速度
|
||||||
if len(sentence_tokens) > knowledge_length:
|
Logger("Computing embeddings in batches...")
|
||||||
# 如果超过长度,截断
|
embeddings_list = []
|
||||||
sentence_tokens = sentence_tokens[:knowledge_length]
|
batch_size = 256 # 可以根据GPU内存调整
|
||||||
Logger(f"Sentence {i+1} truncated from {len(tokenizer.encode(sentence, add_special_tokens=False))} to {knowledge_length} tokens")
|
|
||||||
else:
|
|
||||||
# 如果不足长度,用空token填充
|
|
||||||
original_length = len(sentence_tokens)
|
|
||||||
sentence_tokens.extend([pad_token_id] * (knowledge_length - len(sentence_tokens)))
|
|
||||||
if original_length < knowledge_length:
|
|
||||||
Logger(f"Sentence {i+1} padded from {original_length} to {knowledge_length} tokens")
|
|
||||||
|
|
||||||
processed_rows.append(sentence_tokens)
|
if embedding_model is not None:
|
||||||
|
try:
|
||||||
|
for i in range(0, len(sentences), batch_size):
|
||||||
|
batch_sentences = sentences[i:i+batch_size]
|
||||||
|
batch_embeddings = embedding_model.encode(
|
||||||
|
batch_sentences,
|
||||||
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=True if i == 0 else False,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
embeddings_list.extend(batch_embeddings)
|
||||||
|
|
||||||
if (i + 1) % 1000 == 0:
|
if (i + batch_size) % (batch_size * 10) == 0:
|
||||||
Logger(f"Processed {i + 1}/{num_to_process} sentences")
|
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||||
|
|
||||||
# 如果句子数量不足,用空token填充剩余位置
|
Logger("Batch embedding computation completed")
|
||||||
while len(processed_rows) < knowledge_num:
|
except Exception as e:
|
||||||
empty_tokens = [pad_token_id] * knowledge_length
|
Logger(f"Error in batch encoding: {e}")
|
||||||
processed_rows.append(empty_tokens)
|
Logger("Falling back to random embeddings")
|
||||||
if len(processed_rows) % 1000 == 0:
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
Logger(f"Added empty entry {len(processed_rows)}/{knowledge_num}")
|
else:
|
||||||
|
# 使用随机嵌入
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
|
||||||
Logger(f"Finished adding empty entries. Total: {len(processed_rows)}/{knowledge_num}")
|
# 创建处理后的句子列表
|
||||||
|
processed_sentences = []
|
||||||
|
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||||
|
processed_sentences.append({
|
||||||
|
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||||
|
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||||
|
'token_length': token_length,
|
||||||
|
'embedding': embedding, # Convert numpy array to list
|
||||||
|
'original_index': i
|
||||||
|
})
|
||||||
|
|
||||||
|
# 转换为numpy数组以便后续处理
|
||||||
|
embeddings_array = np.array(embeddings_list)
|
||||||
|
token_lengths_array = np.array(token_lengths)
|
||||||
|
|
||||||
|
Logger(f"Embedding processing completed:")
|
||||||
|
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||||
|
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||||
|
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||||
|
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||||
|
|
||||||
|
# 聚类参数定义
|
||||||
|
min_tokens = int(0.85 * knowledge_length)
|
||||||
|
max_tokens = int(0.95 * knowledge_length)
|
||||||
|
|
||||||
|
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||||
|
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||||
|
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||||
|
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||||
|
else:
|
||||||
|
similarity_matrix = None
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
|
||||||
|
clustered_rows = []
|
||||||
|
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||||
|
|
||||||
|
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||||
|
|
||||||
|
# 选择聚类算法
|
||||||
|
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||||
|
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||||
|
|
||||||
|
# 超快速聚类:随机采样 + 批量处理
|
||||||
|
import random
|
||||||
|
random.seed(42) # 确保可重现性
|
||||||
|
|
||||||
|
# 按重要性分层采样
|
||||||
|
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||||
|
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||||
|
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||||
|
|
||||||
|
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||||
|
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
# 分层选择种子:优先选择高重要性句子
|
||||||
|
if high_importance:
|
||||||
|
seed_pool = high_importance
|
||||||
|
elif medium_importance:
|
||||||
|
seed_pool = medium_importance
|
||||||
|
else:
|
||||||
|
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||||
|
|
||||||
|
if not seed_pool:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 随机选择种子(在同一重要性层级内)
|
||||||
|
seed_global_idx = random.choice(seed_pool)
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从所有池中移除种子
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if seed_global_idx in pool:
|
||||||
|
pool.remove(seed_global_idx)
|
||||||
|
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens < max_tokens:
|
||||||
|
# 快速选择:只从附近的句子中随机选择
|
||||||
|
all_remaining = high_importance + medium_importance + low_importance
|
||||||
|
if all_remaining:
|
||||||
|
# 随机采样候选句子(而不是计算所有相似度)
|
||||||
|
sample_size = min(2000, len(all_remaining))
|
||||||
|
candidates = random.sample(all_remaining, sample_size)
|
||||||
|
|
||||||
|
# 简单按token长度和重要性选择
|
||||||
|
for candidate_idx in candidates:
|
||||||
|
candidate = processed_sentences[candidate_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||||
|
current_cluster_indices.append(candidate_idx)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
# 从池中移除
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if candidate_idx in pool:
|
||||||
|
pool.remove(candidate_idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 生成聚类文本
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n '.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 转换为tokens
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
if (cluster_idx + 1) % 1000 == 0:
|
||||||
|
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||||
|
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 原始优化算法(适用于中等规模数据集)
|
||||||
|
# 优化2: 批量处理和更高效的数据结构
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
if not remaining_indices:
|
||||||
|
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2.1 选择importance_score最高的句子作为种子
|
||||||
|
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||||
|
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||||
|
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||||
|
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从剩余索引中移除种子
|
||||||
|
remaining_indices.remove(seed_global_idx)
|
||||||
|
|
||||||
|
# 当前聚类
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens >= max_tokens:
|
||||||
|
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||||
|
cluster_text = seed_sentence['sentence']
|
||||||
|
else:
|
||||||
|
# 2.2 优化的相似度计算和选择
|
||||||
|
if remaining_indices:
|
||||||
|
if similarity_matrix is not None:
|
||||||
|
# 使用预计算的相似度矩阵
|
||||||
|
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||||
|
else:
|
||||||
|
# 动态计算相似度(批量)
|
||||||
|
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||||
|
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||||
|
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||||
|
|
||||||
|
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||||
|
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||||
|
for i in range(len(remaining_indices))]
|
||||||
|
|
||||||
|
# 按相似度排序(降序)
|
||||||
|
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||||
|
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||||
|
|
||||||
|
selected_indices_in_remaining = []
|
||||||
|
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||||
|
candidate = processed_sentences[global_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||||
|
current_cluster_indices.append(global_idx)
|
||||||
|
selected_indices_in_remaining.append(pos_in_remaining)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||||
|
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||||
|
remaining_indices.pop(pos)
|
||||||
|
|
||||||
|
# 拼接句子
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 将聚类文本转换为token
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# 截断或填充到knowledge_length
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
# 用pad_token_id填充
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
# 优化4: 减少日志频率
|
||||||
|
if (cluster_idx + 1) % 500 == 0:
|
||||||
|
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||||
|
|
||||||
|
# 如果聚类数量不足,用随机token填充
|
||||||
|
while len(clustered_rows) < knowledge_num:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
random_tokens = [pad_token_id] * knowledge_length
|
||||||
|
clustered_rows.append(random_tokens)
|
||||||
|
|
||||||
# 转换为tensor
|
# 转换为tensor
|
||||||
processed_tensor = torch.tensor(processed_rows, dtype=torch.long)
|
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||||
|
|
||||||
Logger(f"Data processing completed:")
|
Logger(f"Clustering completed:")
|
||||||
Logger(f" - Processed {num_to_process} sentences")
|
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||||
Logger(f" - Added {knowledge_num - num_to_process} empty entries")
|
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||||
Logger(f" - Final shape: {processed_tensor.shape}")
|
|
||||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
# 保存处理结果到缓存文件
|
# 保存聚类结果到缓存文件
|
||||||
try:
|
try:
|
||||||
torch.save(processed_tensor, args.cluster_cache_path)
|
torch.save(clustered_tensor, args.cluster_cache_path)
|
||||||
Logger(f"Processed results saved to {args.cluster_cache_path}")
|
Logger(f"Cluster results saved to {args.cluster_cache_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Logger(f"Failed to save processed results: {e}")
|
Logger(f"Failed to save cluster results: {e}")
|
||||||
|
|
||||||
# 4. 初始化模型的knowledge_dataset
|
# 3. 初始化模型的weight_down_embed
|
||||||
if hasattr(model, 'knowledge_dataset') and hasattr(model.knowledge_dataset, 'knowledge_dataset'):
|
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||||
model.knowledge_dataset.knowledge_dataset.data.copy_(processed_tensor)
|
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
|
||||||
Logger("Successfully initialized model.knowledge_dataset.knowledge_dataset with processed data")
|
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
|
||||||
else:
|
else:
|
||||||
Logger("Warning: Could not find model.knowledge_dataset.knowledge_dataset to initialize")
|
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize")
|
||||||
# 存储为全局变量作为备选
|
# 存储为全局变量作为备选
|
||||||
globals()['processed_database'] = processed_tensor
|
globals()['clustered_database'] = clustered_tensor
|
||||||
|
|
||||||
Logger(f"Database embeddings and sentences stored in model")
|
Logger(f"Database embeddings and sentences stored in model")
|
||||||
|
|
||||||
@ -224,7 +453,6 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
total_steps_in_epoch = len(train_loader)
|
total_steps_in_epoch = len(train_loader)
|
||||||
total_training_steps = args.epochs * total_steps_in_epoch
|
total_training_steps = args.epochs * total_steps_in_epoch
|
||||||
moe_path = '_moe' if args.use_moe else ''
|
moe_path = '_moe' if args.use_moe else ''
|
||||||
best_loss = float('10000')
|
|
||||||
|
|
||||||
# 添加CUDA事件来分析性能 (只在主进程进行)
|
# 添加CUDA事件来分析性能 (只在主进程进行)
|
||||||
if args.profile and accelerator.is_main_process:
|
if args.profile and accelerator.is_main_process:
|
||||||
@ -288,12 +516,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
with ctx:
|
with ctx:
|
||||||
if step == 0 and args.embedding_epoch == epoch:
|
res = model(X)
|
||||||
# 需要设置原始模型的freeze_embedding属性,而不是包装后的模型
|
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
|
||||||
unwrapped_model.freeze_embedding = True
|
|
||||||
Logger(f"Set freeze_embedding=True for epoch {epoch}, step {step}", accelerator)
|
|
||||||
res = model(X, step=step)
|
|
||||||
loss = loss_fct(
|
loss = loss_fct(
|
||||||
res.logits.view(-1, res.logits.size(-1)),
|
res.logits.view(-1, res.logits.size(-1)),
|
||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
@ -417,9 +640,7 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
wandb.log(log_dict)
|
wandb.log(log_dict)
|
||||||
|
|
||||||
# 保存模型 (只在主进程进行)
|
# 保存模型 (只在主进程进行)
|
||||||
loss_total = loss.item() * args.accumulation_steps
|
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||||
if best_loss > loss_total and accelerator.is_main_process:
|
|
||||||
best_loss = loss_total
|
|
||||||
# 使用函数开始处定义的moe_path变量
|
# 使用函数开始处定义的moe_path变量
|
||||||
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
ckp = f'{args.save_dir}/pretrain_{args.dim}{moe_path}.pth'
|
||||||
|
|
||||||
@ -438,22 +659,21 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
parser = argparse.ArgumentParser(description="MiniMind Pretraining with Accelerate")
|
||||||
parser.add_argument("--out_dir", type=str, default="out")
|
parser.add_argument("--out_dir", type=str, default="out")
|
||||||
parser.add_argument("--epochs", type=int, default=4)
|
parser.add_argument("--epochs", type=int, default=3)
|
||||||
parser.add_argument("--embedding_epoch", type=int, default=2, help="embedding训练的epoch数")
|
parser.add_argument("--batch_size", type=int, default=24)
|
||||||
parser.add_argument("--batch_size", type=int, default=64)
|
|
||||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||||
parser.add_argument("--num_workers", type=int, default=8)
|
parser.add_argument("--num_workers", type=int, default=48)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||||
parser.add_argument("--warmup_iters", type=int, default=0)
|
parser.add_argument("--warmup_iters", type=int, default=0)
|
||||||
parser.add_argument("--log_interval", type=int, default=100)
|
parser.add_argument("--log_interval", type=int, default=100)
|
||||||
parser.add_argument("--save_interval", type=int, default=10000)
|
parser.add_argument("--save_interval", type=int, default=10000)
|
||||||
parser.add_argument('--dim', default=512, type=int)
|
parser.add_argument('--dim', default=1024, type=int)
|
||||||
parser.add_argument('--n_layers', default=8, type=int)
|
parser.add_argument('--n_layers', default=32, type=int)
|
||||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
parser.add_argument('--max_seq_len', default=1024, type=int)
|
||||||
parser.add_argument('--use_moe', default=False, type=bool)
|
parser.add_argument('--use_moe', default=False, type=bool)
|
||||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||||
@ -461,11 +681,11 @@ def main():
|
|||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
parser.add_argument("--knowledge_num", type=int, default=8192,help="知识库的数据数目")
|
parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目")
|
||||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径")
|
||||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -504,8 +724,7 @@ def main():
|
|||||||
disable_db=args.disable_db,
|
disable_db=args.disable_db,
|
||||||
flash_attn=args.use_flash_attn,
|
flash_attn=args.use_flash_attn,
|
||||||
knowledge_num=args.knowledge_num,
|
knowledge_num=args.knowledge_num,
|
||||||
knowledge_length=args.knowledge_length,
|
knowledge_length=args.knowledge_length
|
||||||
embeddings_epoch=args.embedding_epoch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user