数据初始化使用了缓存
This commit is contained in:
parent
6932e5fa8e
commit
64e92473c3
@ -92,6 +92,47 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
# 聚类参数(需要提前定义用于缓存检查)
|
||||||
|
knowledge_num = args.knowledge_num
|
||||||
|
knowledge_length = args.knowledge_length
|
||||||
|
|
||||||
|
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
|
||||||
|
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||||
|
if cache_dir:
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
clustered_tensor = None
|
||||||
|
|
||||||
|
# 尝试加载缓存的聚类结果
|
||||||
|
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||||
|
try:
|
||||||
|
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
|
||||||
|
clustered_tensor = torch.load(args.cluster_cache_path)
|
||||||
|
|
||||||
|
# 验证缓存文件的形状是否可用
|
||||||
|
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
|
||||||
|
|
||||||
|
if cached_knowledge_length == knowledge_length:
|
||||||
|
if cached_knowledge_num >= knowledge_num:
|
||||||
|
# 缓存足够大,可以截取使用
|
||||||
|
clustered_tensor = clustered_tensor[:knowledge_num, :]
|
||||||
|
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
|
||||||
|
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||||
|
Logger("Skipping database initialization and clustering - using cached results")
|
||||||
|
else:
|
||||||
|
# 缓存太小,需要重新计算
|
||||||
|
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
|
else:
|
||||||
|
# knowledge_length不匹配,需要重新计算
|
||||||
|
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load cached clusters: {e}, recomputing...")
|
||||||
|
clustered_tensor = None
|
||||||
|
|
||||||
|
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
|
||||||
|
if clustered_tensor is None:
|
||||||
Logger(f"Loading database initialization data from {database_init_path}")
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
# 1. 加载JSON文件并转换为字典
|
# 1. 加载JSON文件并转换为字典
|
||||||
@ -174,19 +215,6 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
'original_index': i
|
'original_index': i
|
||||||
})
|
})
|
||||||
|
|
||||||
# # Create a JSON-serializable version for saving
|
|
||||||
# json_serializable_sentences = []
|
|
||||||
# for sentence in processed_sentences:
|
|
||||||
# json_sentence = sentence.copy()
|
|
||||||
# # Convert embedding to list if it's a numpy array
|
|
||||||
# if hasattr(json_sentence['embedding'], 'tolist'):
|
|
||||||
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
|
|
||||||
# json_serializable_sentences.append(json_sentence)
|
|
||||||
|
|
||||||
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
|
|
||||||
|
|
||||||
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
|
|
||||||
|
|
||||||
# 转换为numpy数组以便后续处理
|
# 转换为numpy数组以便后续处理
|
||||||
embeddings_array = np.array(embeddings_list)
|
embeddings_array = np.array(embeddings_list)
|
||||||
token_lengths_array = np.array(token_lengths)
|
token_lengths_array = np.array(token_lengths)
|
||||||
@ -197,12 +225,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||||
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||||
|
|
||||||
# 2. 聚类处理 - 优化版本
|
# 聚类参数定义
|
||||||
Logger("Starting optimized clustering process...")
|
|
||||||
|
|
||||||
# 聚类参数
|
|
||||||
knowledge_num = args.knowledge_num
|
|
||||||
knowledge_length = args.knowledge_length
|
|
||||||
min_tokens = int(0.85 * knowledge_length)
|
min_tokens = int(0.85 * knowledge_length)
|
||||||
max_tokens = int(0.95 * knowledge_length)
|
max_tokens = int(0.95 * knowledge_length)
|
||||||
|
|
||||||
@ -265,7 +288,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
all_remaining = high_importance + medium_importance + low_importance
|
all_remaining = high_importance + medium_importance + low_importance
|
||||||
if all_remaining:
|
if all_remaining:
|
||||||
# 随机采样候选句子(而不是计算所有相似度)
|
# 随机采样候选句子(而不是计算所有相似度)
|
||||||
sample_size = min(100, len(all_remaining))
|
sample_size = min(2000, len(all_remaining))
|
||||||
candidates = random.sample(all_remaining, sample_size)
|
candidates = random.sample(all_remaining, sample_size)
|
||||||
|
|
||||||
# 简单按token长度和重要性选择
|
# 简单按token长度和重要性选择
|
||||||
@ -403,6 +426,13 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
|||||||
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
|
# 保存聚类结果到缓存文件
|
||||||
|
try:
|
||||||
|
torch.save(clustered_tensor, args.cluster_cache_path)
|
||||||
|
Logger(f"Cluster results saved to {args.cluster_cache_path}")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to save cluster results: {e}")
|
||||||
|
|
||||||
# 3. 初始化模型的weight_down_embed
|
# 3. 初始化模型的weight_down_embed
|
||||||
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||||
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
|
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
|
||||||
@ -651,10 +681,12 @@ def main():
|
|||||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目")
|
||||||
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
|
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径")
|
||||||
|
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
|
Loading…
x
Reference in New Issue
Block a user