diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 54e1c05..ab1cc81 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -92,316 +92,346 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non from sentence_transformers import SentenceTransformer import os - Logger(f"Loading database initialization data from {database_init_path}") - - # 1. 加载JSON文件并转换为字典 - with open(database_init_path, 'r', encoding='utf-8') as f: - database_data = json.load(f) - - # 提取sentences列表 - sentences_data = database_data.get('sentences', []) - Logger(f"Loaded {len(sentences_data)} sentences from database") - - # 2. 按照importance_score进行排序(从高到低) - sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) - Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") - - # 3. 下载并初始化本地嵌入模型 - embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型 - embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2" - embedding_cache_dir = "./models/sentence_transformers/cache" - os.makedirs(embedding_cache_dir, exist_ok=True) - - Logger(f"Loading embedding model: {embedding_model_name}") - try: - embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir) - Logger("Embedding model loaded successfully") - except Exception as e: - Logger(f"Failed to load embedding model: {e}") - Logger("Falling back to random embeddings") - embedding_model = None - - # 4. 对每个corrected_sentence进行嵌入和token长度计算 - Logger("Processing sentences for embeddings and token lengths...") - - # 提取所有句子 - sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences] - - # 批量计算token长度 - Logger("Computing token lengths...") - token_lengths = [] - for sentence in sentences: - tokens = tokenizer.encode(sentence, add_special_tokens=False) - token_lengths.append(len(tokens)) - - # 批量计算嵌入 - 大幅提升速度 - Logger("Computing embeddings in batches...") - embeddings_list = [] - batch_size = 256 # 可以根据GPU内存调整 - - if embedding_model is not None: - try: - for i in range(0, len(sentences), batch_size): - batch_sentences = sentences[i:i+batch_size] - batch_embeddings = embedding_model.encode( - batch_sentences, - convert_to_tensor=False, - show_progress_bar=True if i == 0 else False, - batch_size=batch_size - ) - embeddings_list.extend(batch_embeddings) - - if (i + batch_size) % (batch_size * 10) == 0: - Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences") - - Logger("Batch embedding computation completed") - except Exception as e: - Logger(f"Error in batch encoding: {e}") - Logger("Falling back to random embeddings") - embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] - else: - # 使用随机嵌入 - embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] - - # 创建处理后的句子列表 - processed_sentences = [] - for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)): - processed_sentences.append({ - 'sentence': sentence_data.get('corrected_sentence', ''), - 'importance_score': sentence_data.get('importance_score', 0.0), - 'token_length': token_length, - 'embedding': embedding, # Convert numpy array to list - 'original_index': i - }) - - # # Create a JSON-serializable version for saving - # json_serializable_sentences = [] - # for sentence in processed_sentences: - # json_sentence = sentence.copy() - # # Convert embedding to list if it's a numpy array - # if hasattr(json_sentence['embedding'], 'tolist'): - # json_sentence['embedding'] = json_sentence['embedding'].tolist() - # json_serializable_sentences.append(json_sentence) - - # json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8')) - - # processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8')) - - # 转换为numpy数组以便后续处理 - embeddings_array = np.array(embeddings_list) - token_lengths_array = np.array(token_lengths) - - Logger(f"Embedding processing completed:") - Logger(f" - Total sentences: {len(processed_sentences)}") - Logger(f" - Embedding shape: {embeddings_array.shape}") - Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}") - Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}") - - # 2. 聚类处理 - 优化版本 - Logger("Starting optimized clustering process...") - - # 聚类参数 + # 聚类参数(需要提前定义用于缓存检查) knowledge_num = args.knowledge_num knowledge_length = args.knowledge_length - min_tokens = int(0.85 * knowledge_length) - max_tokens = int(0.95 * knowledge_length) - # 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) - if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 - Logger("Pre-computing similarity matrix for faster clustering...") - embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) - similarity_matrix = cosine_similarity(embeddings_matrix) - Logger(f"Similarity matrix computed: {similarity_matrix.shape}") - else: - similarity_matrix = None - embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) + # 检查是否使用缓存(提前检查,避免不必要的数据处理) + cache_dir = os.path.dirname(args.cluster_cache_path) + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) - clustered_rows = [] - remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象 + clustered_tensor = None - Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens") - - # 选择聚类算法 - if args.fast_clustering and len(processed_sentences) > 5000: - Logger("Using ultra-fast approximate clustering algorithm...") - - # 超快速聚类:随机采样 + 批量处理 - import random - random.seed(42) # 确保可重现性 - - # 按重要性分层采样 - high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7] - medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7] - low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3] - - Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}") - - for cluster_idx in tqdm(range(knowledge_num)): - # 分层选择种子:优先选择高重要性句子 - if high_importance: - seed_pool = high_importance - elif medium_importance: - seed_pool = medium_importance + # 尝试加载缓存的聚类结果 + if not args.recompute_clusters and os.path.exists(args.cluster_cache_path): + try: + Logger(f"Loading cached cluster results from {args.cluster_cache_path}") + clustered_tensor = torch.load(args.cluster_cache_path) + + # 验证缓存文件的形状是否可用 + cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape + + if cached_knowledge_length == knowledge_length: + if cached_knowledge_num >= knowledge_num: + # 缓存足够大,可以截取使用 + clustered_tensor = clustered_tensor[:knowledge_num, :] + Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}") + Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})") + Logger("Skipping database initialization and clustering - using cached results") + else: + # 缓存太小,需要重新计算 + Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...") + clustered_tensor = None else: - seed_pool = low_importance if low_importance else list(range(len(processed_sentences))) - - if not seed_pool: - break - - # 随机选择种子(在同一重要性层级内) - seed_global_idx = random.choice(seed_pool) - seed_sentence = processed_sentences[seed_global_idx] - - # 从所有池中移除种子 - for pool in [high_importance, medium_importance, low_importance]: - if seed_global_idx in pool: - pool.remove(seed_global_idx) - - current_cluster_indices = [seed_global_idx] - current_tokens = seed_sentence['token_length'] - - if current_tokens < max_tokens: - # 快速选择:只从附近的句子中随机选择 - all_remaining = high_importance + medium_importance + low_importance - if all_remaining: - # 随机采样候选句子(而不是计算所有相似度) - sample_size = min(100, len(all_remaining)) - candidates = random.sample(all_remaining, sample_size) + # knowledge_length不匹配,需要重新计算 + Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...") + clustered_tensor = None + except Exception as e: + Logger(f"Failed to load cached clusters: {e}, recomputing...") + clustered_tensor = None + + # 只有在没有有效缓存时才进行数据库初始化和聚类计算 + if clustered_tensor is None: + Logger(f"Loading database initialization data from {database_init_path}") + + # 1. 加载JSON文件并转换为字典 + with open(database_init_path, 'r', encoding='utf-8') as f: + database_data = json.load(f) + + # 提取sentences列表 + sentences_data = database_data.get('sentences', []) + Logger(f"Loaded {len(sentences_data)} sentences from database") + + # 2. 按照importance_score进行排序(从高到低) + sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) + Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") + + # 3. 下载并初始化本地嵌入模型 + embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型 + embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2" + embedding_cache_dir = "./models/sentence_transformers/cache" + os.makedirs(embedding_cache_dir, exist_ok=True) + + Logger(f"Loading embedding model: {embedding_model_name}") + try: + embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir) + Logger("Embedding model loaded successfully") + except Exception as e: + Logger(f"Failed to load embedding model: {e}") + Logger("Falling back to random embeddings") + embedding_model = None + + # 4. 对每个corrected_sentence进行嵌入和token长度计算 + Logger("Processing sentences for embeddings and token lengths...") + + # 提取所有句子 + sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences] + + # 批量计算token长度 + Logger("Computing token lengths...") + token_lengths = [] + for sentence in sentences: + tokens = tokenizer.encode(sentence, add_special_tokens=False) + token_lengths.append(len(tokens)) + + # 批量计算嵌入 - 大幅提升速度 + Logger("Computing embeddings in batches...") + embeddings_list = [] + batch_size = 256 # 可以根据GPU内存调整 + + if embedding_model is not None: + try: + for i in range(0, len(sentences), batch_size): + batch_sentences = sentences[i:i+batch_size] + batch_embeddings = embedding_model.encode( + batch_sentences, + convert_to_tensor=False, + show_progress_bar=True if i == 0 else False, + batch_size=batch_size + ) + embeddings_list.extend(batch_embeddings) - # 简单按token长度和重要性选择 - for candidate_idx in candidates: - candidate = processed_sentences[candidate_idx] - candidate_tokens = candidate['token_length'] + if (i + batch_size) % (batch_size * 10) == 0: + Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences") - if current_tokens + candidate_tokens + 1 <= max_tokens: - current_cluster_indices.append(candidate_idx) - current_tokens += candidate_tokens + 1 - - # 从池中移除 - for pool in [high_importance, medium_importance, low_importance]: - if candidate_idx in pool: - pool.remove(candidate_idx) - break - - if current_tokens >= min_tokens: - break + Logger("Batch embedding computation completed") + except Exception as e: + Logger(f"Error in batch encoding: {e}") + Logger("Falling back to random embeddings") + embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] + else: + # 使用随机嵌入 + embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] + + # 创建处理后的句子列表 + processed_sentences = [] + for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)): + processed_sentences.append({ + 'sentence': sentence_data.get('corrected_sentence', ''), + 'importance_score': sentence_data.get('importance_score', 0.0), + 'token_length': token_length, + 'embedding': embedding, # Convert numpy array to list + 'original_index': i + }) + + # 转换为numpy数组以便后续处理 + embeddings_array = np.array(embeddings_list) + token_lengths_array = np.array(token_lengths) + + Logger(f"Embedding processing completed:") + Logger(f" - Total sentences: {len(processed_sentences)}") + Logger(f" - Embedding shape: {embeddings_array.shape}") + Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}") + Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}") + + # 聚类参数定义 + min_tokens = int(0.85 * knowledge_length) + max_tokens = int(0.95 * knowledge_length) + + # 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) + if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 + Logger("Pre-computing similarity matrix for faster clustering...") + embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) + similarity_matrix = cosine_similarity(embeddings_matrix) + Logger(f"Similarity matrix computed: {similarity_matrix.shape}") + else: + similarity_matrix = None + embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) + + clustered_rows = [] + remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象 + + Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens") + + # 选择聚类算法 + if args.fast_clustering and len(processed_sentences) > 5000: + Logger("Using ultra-fast approximate clustering algorithm...") - # 生成聚类文本 - cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] - cluster_text = '\n '.join(cluster_sentences) + # 超快速聚类:随机采样 + 批量处理 + import random + random.seed(42) # 确保可重现性 - # 转换为tokens - cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) - if len(cluster_tokens) > knowledge_length: - cluster_tokens = cluster_tokens[:knowledge_length] - else: - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) + # 按重要性分层采样 + high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7] + medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7] + low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3] - clustered_rows.append(cluster_tokens) + Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}") - if (cluster_idx + 1) % 1000 == 0: - total_remaining = len(high_importance) + len(medium_importance) + len(low_importance) - Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining") - - else: - # 原始优化算法(适用于中等规模数据集) - # 优化2: 批量处理和更高效的数据结构 - for cluster_idx in tqdm(range(knowledge_num)): - if not remaining_indices: - Logger(f"No more sentences available. Created {cluster_idx} clusters.") - break - - # 2.1 选择importance_score最高的句子作为种子 - remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices] - seed_idx_in_subset = max(range(len(remaining_sentences_subset)), - key=lambda i: remaining_sentences_subset[i]['importance_score']) - seed_global_idx = remaining_indices[seed_idx_in_subset] - seed_sentence = processed_sentences[seed_global_idx] - - # 从剩余索引中移除种子 - remaining_indices.remove(seed_global_idx) - - # 当前聚类 - current_cluster_indices = [seed_global_idx] - current_tokens = seed_sentence['token_length'] - - if current_tokens >= max_tokens: - # 如果种子句子已经超过最大token数,直接作为一个聚类 - cluster_text = seed_sentence['sentence'] - else: - # 2.2 优化的相似度计算和选择 - if remaining_indices: - if similarity_matrix is not None: - # 使用预计算的相似度矩阵 - similarities = similarity_matrix[seed_global_idx][remaining_indices] - else: - # 动态计算相似度(批量) - seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1] - remaining_embeddings = embeddings_matrix[remaining_indices] - similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0] - - # 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表 - similarity_tuples = [(similarities[i], remaining_indices[i], i) - for i in range(len(remaining_indices))] - - # 按相似度排序(降序) - similarity_tuples.sort(key=lambda x: x[0], reverse=True) - - # 优化3: 贪心选择,但限制搜索范围以提高速度 - max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子 - - selected_indices_in_remaining = [] - for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]: - candidate = processed_sentences[global_idx] - candidate_tokens = candidate['token_length'] - - if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline - current_cluster_indices.append(global_idx) - selected_indices_in_remaining.append(pos_in_remaining) - current_tokens += candidate_tokens + 1 - - if current_tokens >= min_tokens: - break - - # 批量移除选中的句子(从后往前移除以避免索引问题) - for pos in sorted(selected_indices_in_remaining, reverse=True): - remaining_indices.pop(pos) + for cluster_idx in tqdm(range(knowledge_num)): + # 分层选择种子:优先选择高重要性句子 + if high_importance: + seed_pool = high_importance + elif medium_importance: + seed_pool = medium_importance + else: + seed_pool = low_importance if low_importance else list(range(len(processed_sentences))) - # 拼接句子 + if not seed_pool: + break + + # 随机选择种子(在同一重要性层级内) + seed_global_idx = random.choice(seed_pool) + seed_sentence = processed_sentences[seed_global_idx] + + # 从所有池中移除种子 + for pool in [high_importance, medium_importance, low_importance]: + if seed_global_idx in pool: + pool.remove(seed_global_idx) + + current_cluster_indices = [seed_global_idx] + current_tokens = seed_sentence['token_length'] + + if current_tokens < max_tokens: + # 快速选择:只从附近的句子中随机选择 + all_remaining = high_importance + medium_importance + low_importance + if all_remaining: + # 随机采样候选句子(而不是计算所有相似度) + sample_size = min(2000, len(all_remaining)) + candidates = random.sample(all_remaining, sample_size) + + # 简单按token长度和重要性选择 + for candidate_idx in candidates: + candidate = processed_sentences[candidate_idx] + candidate_tokens = candidate['token_length'] + + if current_tokens + candidate_tokens + 1 <= max_tokens: + current_cluster_indices.append(candidate_idx) + current_tokens += candidate_tokens + 1 + + # 从池中移除 + for pool in [high_importance, medium_importance, low_importance]: + if candidate_idx in pool: + pool.remove(candidate_idx) + break + + if current_tokens >= min_tokens: + break + + # 生成聚类文本 cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] - cluster_text = '\n'.join(cluster_sentences) - - # 将聚类文本转换为token - cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) - - # 截断或填充到knowledge_length - if len(cluster_tokens) > knowledge_length: - cluster_tokens = cluster_tokens[:knowledge_length] - else: - # 用pad_token_id填充 - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) - - clustered_rows.append(cluster_tokens) - - # 优化4: 减少日志频率 - if (cluster_idx + 1) % 500 == 0: - Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining") - - # 如果聚类数量不足,用随机token填充 - while len(clustered_rows) < knowledge_num: - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - random_tokens = [pad_token_id] * knowledge_length - clustered_rows.append(random_tokens) - - # 转换为tensor - clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long) - - Logger(f"Clustering completed:") - Logger(f" - Created {len(clustered_rows)} clusters") - Logger(f" - Cluster shape: {clustered_tensor.shape}") - Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})") + cluster_text = '\n '.join(cluster_sentences) + + # 转换为tokens + cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) + if len(cluster_tokens) > knowledge_length: + cluster_tokens = cluster_tokens[:knowledge_length] + else: + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) + + clustered_rows.append(cluster_tokens) + + if (cluster_idx + 1) % 1000 == 0: + total_remaining = len(high_importance) + len(medium_importance) + len(low_importance) + Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining") + + else: + # 原始优化算法(适用于中等规模数据集) + # 优化2: 批量处理和更高效的数据结构 + for cluster_idx in tqdm(range(knowledge_num)): + if not remaining_indices: + Logger(f"No more sentences available. Created {cluster_idx} clusters.") + break + + # 2.1 选择importance_score最高的句子作为种子 + remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices] + seed_idx_in_subset = max(range(len(remaining_sentences_subset)), + key=lambda i: remaining_sentences_subset[i]['importance_score']) + seed_global_idx = remaining_indices[seed_idx_in_subset] + seed_sentence = processed_sentences[seed_global_idx] + + # 从剩余索引中移除种子 + remaining_indices.remove(seed_global_idx) + + # 当前聚类 + current_cluster_indices = [seed_global_idx] + current_tokens = seed_sentence['token_length'] + + if current_tokens >= max_tokens: + # 如果种子句子已经超过最大token数,直接作为一个聚类 + cluster_text = seed_sentence['sentence'] + else: + # 2.2 优化的相似度计算和选择 + if remaining_indices: + if similarity_matrix is not None: + # 使用预计算的相似度矩阵 + similarities = similarity_matrix[seed_global_idx][remaining_indices] + else: + # 动态计算相似度(批量) + seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1] + remaining_embeddings = embeddings_matrix[remaining_indices] + similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0] + + # 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表 + similarity_tuples = [(similarities[i], remaining_indices[i], i) + for i in range(len(remaining_indices))] + + # 按相似度排序(降序) + similarity_tuples.sort(key=lambda x: x[0], reverse=True) + + # 优化3: 贪心选择,但限制搜索范围以提高速度 + max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子 + + selected_indices_in_remaining = [] + for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]: + candidate = processed_sentences[global_idx] + candidate_tokens = candidate['token_length'] + + if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline + current_cluster_indices.append(global_idx) + selected_indices_in_remaining.append(pos_in_remaining) + current_tokens += candidate_tokens + 1 + + if current_tokens >= min_tokens: + break + + # 批量移除选中的句子(从后往前移除以避免索引问题) + for pos in sorted(selected_indices_in_remaining, reverse=True): + remaining_indices.pop(pos) + + # 拼接句子 + cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] + cluster_text = '\n'.join(cluster_sentences) + + # 将聚类文本转换为token + cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) + + # 截断或填充到knowledge_length + if len(cluster_tokens) > knowledge_length: + cluster_tokens = cluster_tokens[:knowledge_length] + else: + # 用pad_token_id填充 + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) + + clustered_rows.append(cluster_tokens) + + # 优化4: 减少日志频率 + if (cluster_idx + 1) % 500 == 0: + Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining") + + # 如果聚类数量不足,用随机token填充 + while len(clustered_rows) < knowledge_num: + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + random_tokens = [pad_token_id] * knowledge_length + clustered_rows.append(random_tokens) + + # 转换为tensor + clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long) + + Logger(f"Clustering completed:") + Logger(f" - Created {len(clustered_rows)} clusters") + Logger(f" - Cluster shape: {clustered_tensor.shape}") + Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})") + + # 保存聚类结果到缓存文件 + try: + torch.save(clustered_tensor, args.cluster_cache_path) + Logger(f"Cluster results saved to {args.cluster_cache_path}") + except Exception as e: + Logger(f"Failed to save cluster results: {e}") # 3. 初始化模型的weight_down_embed if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'): @@ -651,10 +681,12 @@ def main(): parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") - parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目") + parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目") parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度") parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") + parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径") + parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") args = parser.parse_args() #########################################################