Compare commits
3 Commits
67c632d010
...
64e92473c3
Author | SHA1 | Date | |
---|---|---|---|
64e92473c3 | |||
6932e5fa8e | |||
c5d0a3aba3 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -5,3 +5,5 @@ wandb/
|
||||
**/*.log
|
||||
models/sentence_transformers/
|
||||
models/sentence_transformers_cache/
|
||||
**/*.pyc
|
||||
qwen2-1.7B/
|
@ -4,10 +4,35 @@
|
||||
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
|
||||
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
|
||||
|
||||
## 🆕 防卡死机制
|
||||
|
||||
为了解决LLM处理时可能出现的卡死问题,新增了以下功能:
|
||||
|
||||
### 超时和重试机制
|
||||
- **超时时间**:每个LLM请求60秒超时
|
||||
- **重试机制**:失败后最多重试2次,采用指数退避策略
|
||||
- **并发控制**:降低并发数至4个,减少服务器压力
|
||||
|
||||
### 心跳监控系统
|
||||
- **实时监控**:每30秒检查一次LLM响应状态
|
||||
- **异常警告**:超过30秒无成功响应时发出警告
|
||||
- **服务检测**:自动检查ollama服务状态
|
||||
- **详细统计**:实时显示成功率、超时率等统计信息
|
||||
|
||||
### 日志系统
|
||||
- **详细日志**:所有操作都记录在 `logs/` 目录下
|
||||
- **双重输出**:同时输出到日志文件和控制台
|
||||
- **时间戳标记**:日志文件包含启动时间戳
|
||||
|
||||
### 改进的错误处理
|
||||
- **异常恢复**:LLM处理失败时使用原句子和默认评分
|
||||
- **状态监控**:处理前检查ollama服务状态
|
||||
- **批次间休息**:批次之间休息5秒,避免过度压力
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install agno asyncio pydantic
|
||||
pip install agno asyncio pydantic requests
|
||||
```
|
||||
|
||||
确保已安装并启动 ollama,并下载 qwen3:4b 模型:
|
||||
@ -50,24 +75,52 @@ python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json
|
||||
|
||||
## 输出文件
|
||||
|
||||
**注意:所有输出文件都会自动保存在 `./output/` 目录中**
|
||||
**注意:所有输出文件都会自动保存在相应目录中**
|
||||
|
||||
### 步骤1输出
|
||||
### 句子提取输出
|
||||
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
|
||||
|
||||
### 步骤2输出
|
||||
### LLM处理输出
|
||||
- `output/{output_file}.txt`: 修正后的句子文本文件
|
||||
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
|
||||
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
|
||||
|
||||
### 检查点文件
|
||||
- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点
|
||||
- `output/{output_file}_checkpoint_{数量}.json`: 每1000条句子自动保存的检查点
|
||||
|
||||
### 日志文件
|
||||
- `logs/trex_processor_{时间戳}.log`: 详细的处理日志
|
||||
|
||||
## 🆕 故障诊断
|
||||
|
||||
### 如果遇到卡死问题:
|
||||
|
||||
1. **检查日志文件**:查看 `logs/` 目录下的最新日志
|
||||
2. **观察心跳监控**:注意控制台的心跳警告信息
|
||||
3. **检查ollama服务**:
|
||||
```bash
|
||||
ps aux | grep ollama
|
||||
curl http://localhost:11434/api/tags
|
||||
```
|
||||
4. **重启ollama服务**(如果需要):
|
||||
```bash
|
||||
pkill ollama
|
||||
ollama serve &
|
||||
```
|
||||
|
||||
### 常见警告信息:
|
||||
|
||||
- `⚠️ 心跳检测`: 30秒无成功响应(正常情况下会自动恢复)
|
||||
- `❌ 严重警告`: 90秒无成功响应(可能需要检查服务)
|
||||
- `💀 Ollama服务异常`: ollama服务可能已停止
|
||||
- `💀 致命错误`: 连续多次警告(建议重启程序)
|
||||
|
||||
## 检查点恢复机制
|
||||
|
||||
- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中)
|
||||
- 只处理尚未处理的句子,避免重复工作
|
||||
- 如果所有句子都已处理,会直接生成最终输出文件
|
||||
- 中断后重新运行会自动从最新检查点继续
|
||||
|
||||
## 示例工作流
|
||||
|
||||
@ -84,10 +137,11 @@ python trex_to_sentences_simple.py --step llm
|
||||
|
||||
## 性能特点
|
||||
|
||||
- **并发处理**: 最大54个并发LLM请求
|
||||
- **检查点保存**: 每2000条句子自动保存,支持断点续传
|
||||
- **进度显示**: 详细的处理进度和时间预估
|
||||
- **错误处理**: LLM请求失败时使用原句子和默认评分
|
||||
- **保守的并发**: 最大4个并发LLM请求(降低卡死风险)
|
||||
- **检查点保存**: 每1000条句子自动保存,支持断点续传
|
||||
- **智能监控**: 详细的处理进度和时间预估
|
||||
- **健壮的错误处理**: LLM请求失败时使用原句子和默认评分
|
||||
- **服务监控**: 自动检测ollama服务状态
|
||||
|
||||
## 注意事项
|
||||
|
||||
@ -95,3 +149,6 @@ python trex_to_sentences_simple.py --step llm
|
||||
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
|
||||
3. LLM处理速度取决于模型性能和网络状况
|
||||
4. 建议先用`--max_files`参数测试小批量数据
|
||||
5. **新增**:如果遇到卡死,查看日志文件和心跳监控信息
|
||||
6. **新增**:程序会自动检测并报告ollama服务状态
|
||||
7. **新增**:所有处理过程都有详细日志记录,便于问题诊断
|
225
preprocessing/merge_output_json.py
Normal file
225
preprocessing/merge_output_json.py
Normal file
@ -0,0 +1,225 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
JSON文件合并脚本
|
||||
读取多个JSON文件并合并为一个JSON文件
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Any, Union
|
||||
|
||||
# 需要合并的JSON文件列表
|
||||
JSON_FILES_TO_MERGE = [
|
||||
"output/trex_sentences_enhanced_checkpoint_360000.json"
|
||||
]
|
||||
for i in range(1, 1010):
|
||||
JSON_FILES_TO_MERGE.append(f"output/trex_sentences_enhanced_batch_{i}.json")
|
||||
|
||||
def load_json_file(file_path: str) -> Union[Dict, List, None]:
|
||||
"""加载JSON文件"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"警告: 文件 {file_path} 不存在")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
print(f"成功加载: {file_path}")
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"错误: 无法解析JSON文件 {file_path} - {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"错误: 读取文件 {file_path} 失败 - {e}")
|
||||
return None
|
||||
|
||||
def merge_json_data(data1: Union[Dict, List], data2: Union[Dict, List]) -> Union[Dict, List]:
|
||||
"""合并两个JSON数据结构"""
|
||||
|
||||
# 如果两个都是列表,直接合并
|
||||
if isinstance(data1, list) and isinstance(data2, list):
|
||||
print(f"合并两个列表: {len(data1)} + {len(data2)} = {len(data1) + len(data2)} 项")
|
||||
return data1 + data2
|
||||
|
||||
# 如果两个都是字典
|
||||
elif isinstance(data1, dict) and isinstance(data2, dict):
|
||||
print("合并两个字典结构")
|
||||
merged = data1.copy()
|
||||
|
||||
# 特殊处理:如果都有'sentences'字段且为列表,合并sentences
|
||||
if 'sentences' in data1 and 'sentences' in data2:
|
||||
if isinstance(data1['sentences'], list) and isinstance(data2['sentences'], list):
|
||||
print(f"合并sentences字段: {len(data1['sentences'])} + {len(data2['sentences'])} = {len(data1['sentences']) + len(data2['sentences'])} 项")
|
||||
merged['sentences'] = data1['sentences'] + data2['sentences']
|
||||
|
||||
# 更新metadata if exists
|
||||
if 'metadata' in merged:
|
||||
if isinstance(merged['metadata'], dict):
|
||||
merged['metadata']['total_sentences'] = len(merged['sentences'])
|
||||
merged['metadata']['merged_from'] = [os.path.basename(f) for f in JSON_FILES_TO_MERGE if os.path.exists(f)]
|
||||
|
||||
# 合并其他字段
|
||||
for key, value in data2.items():
|
||||
if key != 'sentences' and key not in merged:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
# 普通字典合并
|
||||
for key, value in data2.items():
|
||||
if key in merged:
|
||||
# 如果key重复且都是列表,合并列表
|
||||
if isinstance(merged[key], list) and isinstance(value, list):
|
||||
merged[key] = merged[key] + value
|
||||
# 如果key重复且都是字典,递归合并
|
||||
elif isinstance(merged[key], dict) and isinstance(value, dict):
|
||||
merged[key] = merge_json_data(merged[key], value)
|
||||
else:
|
||||
# 其他情况保留第二个文件的值
|
||||
merged[key] = value
|
||||
print(f"字段 '{key}' 被覆盖")
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
|
||||
# 类型不匹配的情况,创建一个包含两者的新结构
|
||||
else:
|
||||
print("数据类型不匹配,创建包含两者的新结构")
|
||||
return {
|
||||
"data_from_save.json": data1,
|
||||
"data_from_save2.json": data2,
|
||||
"merged_at": "test.py"
|
||||
}
|
||||
|
||||
def save_merged_json(data: Union[Dict, List], output_path: str):
|
||||
"""保存合并后的JSON数据"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"合并结果已保存到: {output_path}")
|
||||
|
||||
# 显示统计信息
|
||||
if isinstance(data, dict):
|
||||
if 'sentences' in data and isinstance(data['sentences'], list):
|
||||
print(f"总计句子数: {len(data['sentences'])}")
|
||||
print(f"总计字段数: {len(data)}")
|
||||
elif isinstance(data, list):
|
||||
print(f"总计列表项数: {len(data)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: 保存文件失败 - {e}")
|
||||
|
||||
def remove_duplicates_from_sentences(data: Union[Dict, List]) -> Union[Dict, List]:
|
||||
"""从合并结果中移除重复的句子(基于句子内容)"""
|
||||
if isinstance(data, dict) and 'sentences' in data:
|
||||
if isinstance(data['sentences'], list):
|
||||
original_count = len(data['sentences'])
|
||||
seen_sentences = set()
|
||||
unique_sentences = []
|
||||
|
||||
for item in data['sentences']:
|
||||
if isinstance(item, dict):
|
||||
# 如果是字典,使用sentence字段或corrected_sentence字段作为唯一标识
|
||||
sentence_key = item.get('sentence') or item.get('corrected_sentence') or item.get('original_sentence')
|
||||
elif isinstance(item, str):
|
||||
sentence_key = item
|
||||
else:
|
||||
sentence_key = str(item)
|
||||
|
||||
if sentence_key and sentence_key not in seen_sentences:
|
||||
seen_sentences.add(sentence_key)
|
||||
unique_sentences.append(item)
|
||||
|
||||
data['sentences'] = unique_sentences
|
||||
|
||||
# 更新metadata
|
||||
if 'metadata' in data and isinstance(data['metadata'], dict):
|
||||
data['metadata']['total_sentences'] = len(unique_sentences)
|
||||
data['metadata']['duplicates_removed'] = original_count - len(unique_sentences)
|
||||
|
||||
print(f"去重完成: {original_count} -> {len(unique_sentences)} (移除了 {original_count - len(unique_sentences)} 个重复项)")
|
||||
|
||||
return data
|
||||
|
||||
def merge_multiple_json_data(data_list: List[Union[Dict, List]]) -> Union[Dict, List]:
|
||||
"""合并多个JSON数据结构"""
|
||||
if not data_list:
|
||||
return {}
|
||||
|
||||
if len(data_list) == 1:
|
||||
return data_list[0]
|
||||
|
||||
print(f"准备合并 {len(data_list)} 个JSON数据结构")
|
||||
|
||||
# 从第一个数据开始,逐步合并其他数据
|
||||
merged_data = data_list[0]
|
||||
|
||||
for i, data in enumerate(data_list[1:], 1):
|
||||
print(f"正在合并第 {i+1} 个数据结构...")
|
||||
merged_data = merge_json_data(merged_data, data)
|
||||
|
||||
return merged_data
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=== JSON文件合并脚本 ===")
|
||||
|
||||
# 输出路径
|
||||
output_path = "output/merged.json"
|
||||
|
||||
print(f"准备合并以下文件:")
|
||||
for i, file_path in enumerate(JSON_FILES_TO_MERGE, 1):
|
||||
print(f" {i}. {file_path}")
|
||||
print(f"输出文件: {output_path}")
|
||||
print()
|
||||
|
||||
# 加载所有文件
|
||||
loaded_data = []
|
||||
successfully_loaded = []
|
||||
|
||||
for file_path in JSON_FILES_TO_MERGE:
|
||||
data = load_json_file(file_path)
|
||||
if data is not None:
|
||||
loaded_data.append(data)
|
||||
successfully_loaded.append(file_path)
|
||||
|
||||
# 检查是否至少有一个文件加载成功
|
||||
if not loaded_data:
|
||||
print("错误: 没有文件能够成功加载,退出")
|
||||
return
|
||||
|
||||
print(f"成功加载了 {len(loaded_data)} 个文件:")
|
||||
for file_path in successfully_loaded:
|
||||
print(f" ✓ {file_path}")
|
||||
|
||||
if len(loaded_data) < len(JSON_FILES_TO_MERGE):
|
||||
failed_count = len(JSON_FILES_TO_MERGE) - len(loaded_data)
|
||||
print(f"警告: {failed_count} 个文件加载失败")
|
||||
print()
|
||||
|
||||
# 合并所有数据
|
||||
if len(loaded_data) == 1:
|
||||
print("只有一个文件可用,直接使用...")
|
||||
merged_data = loaded_data[0]
|
||||
else:
|
||||
print("开始合并所有文件...")
|
||||
merged_data = merge_multiple_json_data(loaded_data)
|
||||
|
||||
# 去重处理
|
||||
print("\n检查并去除重复项...")
|
||||
merged_data = remove_duplicates_from_sentences(merged_data)
|
||||
|
||||
# 保存合并结果
|
||||
print("\n保存合并结果...")
|
||||
save_merged_json(merged_data, output_path)
|
||||
|
||||
print("\n=== 合并完成 ===")
|
||||
print(f"合并了 {len(successfully_loaded)} 个文件的数据")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
@ -92,316 +92,346 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import os
|
||||
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 1. 加载JSON文件并转换为字典
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
# 提取sentences列表
|
||||
sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
|
||||
# 3. 下载并初始化本地嵌入模型
|
||||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||
|
||||
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||
try:
|
||||
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||
Logger("Embedding model loaded successfully")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load embedding model: {e}")
|
||||
Logger("Falling back to random embeddings")
|
||||
embedding_model = None
|
||||
|
||||
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||
Logger("Processing sentences for embeddings and token lengths...")
|
||||
|
||||
# 提取所有句子
|
||||
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||
|
||||
# 批量计算token长度
|
||||
Logger("Computing token lengths...")
|
||||
token_lengths = []
|
||||
for sentence in sentences:
|
||||
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
token_lengths.append(len(tokens))
|
||||
|
||||
# 批量计算嵌入 - 大幅提升速度
|
||||
Logger("Computing embeddings in batches...")
|
||||
embeddings_list = []
|
||||
batch_size = 256 # 可以根据GPU内存调整
|
||||
|
||||
if embedding_model is not None:
|
||||
try:
|
||||
for i in range(0, len(sentences), batch_size):
|
||||
batch_sentences = sentences[i:i+batch_size]
|
||||
batch_embeddings = embedding_model.encode(
|
||||
batch_sentences,
|
||||
convert_to_tensor=False,
|
||||
show_progress_bar=True if i == 0 else False,
|
||||
batch_size=batch_size
|
||||
)
|
||||
embeddings_list.extend(batch_embeddings)
|
||||
|
||||
if (i + batch_size) % (batch_size * 10) == 0:
|
||||
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||
|
||||
Logger("Batch embedding computation completed")
|
||||
except Exception as e:
|
||||
Logger(f"Error in batch encoding: {e}")
|
||||
Logger("Falling back to random embeddings")
|
||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||
else:
|
||||
# 使用随机嵌入
|
||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||
|
||||
# 创建处理后的句子列表
|
||||
processed_sentences = []
|
||||
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||
processed_sentences.append({
|
||||
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||
'token_length': token_length,
|
||||
'embedding': embedding, # Convert numpy array to list
|
||||
'original_index': i
|
||||
})
|
||||
|
||||
# # Create a JSON-serializable version for saving
|
||||
# json_serializable_sentences = []
|
||||
# for sentence in processed_sentences:
|
||||
# json_sentence = sentence.copy()
|
||||
# # Convert embedding to list if it's a numpy array
|
||||
# if hasattr(json_sentence['embedding'], 'tolist'):
|
||||
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
|
||||
# json_serializable_sentences.append(json_sentence)
|
||||
|
||||
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
|
||||
|
||||
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
|
||||
|
||||
# 转换为numpy数组以便后续处理
|
||||
embeddings_array = np.array(embeddings_list)
|
||||
token_lengths_array = np.array(token_lengths)
|
||||
|
||||
Logger(f"Embedding processing completed:")
|
||||
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||
|
||||
# 2. 聚类处理 - 优化版本
|
||||
Logger("Starting optimized clustering process...")
|
||||
|
||||
# 聚类参数
|
||||
# 聚类参数(需要提前定义用于缓存检查)
|
||||
knowledge_num = args.knowledge_num
|
||||
knowledge_length = args.knowledge_length
|
||||
min_tokens = int(0.85 * knowledge_length)
|
||||
max_tokens = int(0.95 * knowledge_length)
|
||||
|
||||
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||
else:
|
||||
similarity_matrix = None
|
||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||
# 检查是否使用缓存(提前检查,避免不必要的数据处理)
|
||||
cache_dir = os.path.dirname(args.cluster_cache_path)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
clustered_rows = []
|
||||
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||
clustered_tensor = None
|
||||
|
||||
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||
# 尝试加载缓存的聚类结果
|
||||
if not args.recompute_clusters and os.path.exists(args.cluster_cache_path):
|
||||
try:
|
||||
Logger(f"Loading cached cluster results from {args.cluster_cache_path}")
|
||||
clustered_tensor = torch.load(args.cluster_cache_path)
|
||||
|
||||
# 选择聚类算法
|
||||
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||
# 验证缓存文件的形状是否可用
|
||||
cached_knowledge_num, cached_knowledge_length = clustered_tensor.shape
|
||||
|
||||
# 超快速聚类:随机采样 + 批量处理
|
||||
import random
|
||||
random.seed(42) # 确保可重现性
|
||||
|
||||
# 按重要性分层采样
|
||||
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||
|
||||
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||
|
||||
for cluster_idx in tqdm(range(knowledge_num)):
|
||||
# 分层选择种子:优先选择高重要性句子
|
||||
if high_importance:
|
||||
seed_pool = high_importance
|
||||
elif medium_importance:
|
||||
seed_pool = medium_importance
|
||||
if cached_knowledge_length == knowledge_length:
|
||||
if cached_knowledge_num >= knowledge_num:
|
||||
# 缓存足够大,可以截取使用
|
||||
clustered_tensor = clustered_tensor[:knowledge_num, :]
|
||||
Logger(f"Successfully loaded cached clusters with shape {clustered_tensor.shape}")
|
||||
Logger(f"Truncated from cached shape ({cached_knowledge_num}, {cached_knowledge_length}) to required shape ({knowledge_num}, {knowledge_length})")
|
||||
Logger("Skipping database initialization and clustering - using cached results")
|
||||
else:
|
||||
# 缓存太小,需要重新计算
|
||||
Logger(f"Cached knowledge_num ({cached_knowledge_num}) < required knowledge_num ({knowledge_num}), recomputing...")
|
||||
clustered_tensor = None
|
||||
else:
|
||||
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||
# knowledge_length不匹配,需要重新计算
|
||||
Logger(f"Cached knowledge_length ({cached_knowledge_length}) != required knowledge_length ({knowledge_length}), recomputing...")
|
||||
clustered_tensor = None
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load cached clusters: {e}, recomputing...")
|
||||
clustered_tensor = None
|
||||
|
||||
if not seed_pool:
|
||||
break
|
||||
# 只有在没有有效缓存时才进行数据库初始化和聚类计算
|
||||
if clustered_tensor is None:
|
||||
Logger(f"Loading database initialization data from {database_init_path}")
|
||||
|
||||
# 随机选择种子(在同一重要性层级内)
|
||||
seed_global_idx = random.choice(seed_pool)
|
||||
seed_sentence = processed_sentences[seed_global_idx]
|
||||
# 1. 加载JSON文件并转换为字典
|
||||
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||
database_data = json.load(f)
|
||||
|
||||
# 从所有池中移除种子
|
||||
for pool in [high_importance, medium_importance, low_importance]:
|
||||
if seed_global_idx in pool:
|
||||
pool.remove(seed_global_idx)
|
||||
# 提取sentences列表
|
||||
sentences_data = database_data.get('sentences', [])
|
||||
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||
|
||||
current_cluster_indices = [seed_global_idx]
|
||||
current_tokens = seed_sentence['token_length']
|
||||
# 2. 按照importance_score进行排序(从高到低)
|
||||
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||
|
||||
if current_tokens < max_tokens:
|
||||
# 快速选择:只从附近的句子中随机选择
|
||||
all_remaining = high_importance + medium_importance + low_importance
|
||||
if all_remaining:
|
||||
# 随机采样候选句子(而不是计算所有相似度)
|
||||
sample_size = min(100, len(all_remaining))
|
||||
candidates = random.sample(all_remaining, sample_size)
|
||||
# 3. 下载并初始化本地嵌入模型
|
||||
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||
|
||||
# 简单按token长度和重要性选择
|
||||
for candidate_idx in candidates:
|
||||
candidate = processed_sentences[candidate_idx]
|
||||
candidate_tokens = candidate['token_length']
|
||||
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||
try:
|
||||
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||
Logger("Embedding model loaded successfully")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to load embedding model: {e}")
|
||||
Logger("Falling back to random embeddings")
|
||||
embedding_model = None
|
||||
|
||||
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||
current_cluster_indices.append(candidate_idx)
|
||||
current_tokens += candidate_tokens + 1
|
||||
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||
Logger("Processing sentences for embeddings and token lengths...")
|
||||
|
||||
# 从池中移除
|
||||
for pool in [high_importance, medium_importance, low_importance]:
|
||||
if candidate_idx in pool:
|
||||
pool.remove(candidate_idx)
|
||||
break
|
||||
# 提取所有句子
|
||||
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||
|
||||
if current_tokens >= min_tokens:
|
||||
break
|
||||
# 批量计算token长度
|
||||
Logger("Computing token lengths...")
|
||||
token_lengths = []
|
||||
for sentence in sentences:
|
||||
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||
token_lengths.append(len(tokens))
|
||||
|
||||
# 生成聚类文本
|
||||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||
cluster_text = '\n '.join(cluster_sentences)
|
||||
# 批量计算嵌入 - 大幅提升速度
|
||||
Logger("Computing embeddings in batches...")
|
||||
embeddings_list = []
|
||||
batch_size = 256 # 可以根据GPU内存调整
|
||||
|
||||
# 转换为tokens
|
||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||
if len(cluster_tokens) > knowledge_length:
|
||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||
if embedding_model is not None:
|
||||
try:
|
||||
for i in range(0, len(sentences), batch_size):
|
||||
batch_sentences = sentences[i:i+batch_size]
|
||||
batch_embeddings = embedding_model.encode(
|
||||
batch_sentences,
|
||||
convert_to_tensor=False,
|
||||
show_progress_bar=True if i == 0 else False,
|
||||
batch_size=batch_size
|
||||
)
|
||||
embeddings_list.extend(batch_embeddings)
|
||||
|
||||
clustered_rows.append(cluster_tokens)
|
||||
if (i + batch_size) % (batch_size * 10) == 0:
|
||||
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||
|
||||
if (cluster_idx + 1) % 1000 == 0:
|
||||
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||
Logger("Batch embedding computation completed")
|
||||
except Exception as e:
|
||||
Logger(f"Error in batch encoding: {e}")
|
||||
Logger("Falling back to random embeddings")
|
||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||
else:
|
||||
# 使用随机嵌入
|
||||
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||
|
||||
else:
|
||||
# 原始优化算法(适用于中等规模数据集)
|
||||
# 优化2: 批量处理和更高效的数据结构
|
||||
for cluster_idx in tqdm(range(knowledge_num)):
|
||||
if not remaining_indices:
|
||||
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||
break
|
||||
# 创建处理后的句子列表
|
||||
processed_sentences = []
|
||||
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||
processed_sentences.append({
|
||||
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||
'token_length': token_length,
|
||||
'embedding': embedding, # Convert numpy array to list
|
||||
'original_index': i
|
||||
})
|
||||
|
||||
# 2.1 选择importance_score最高的句子作为种子
|
||||
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||
seed_sentence = processed_sentences[seed_global_idx]
|
||||
# 转换为numpy数组以便后续处理
|
||||
embeddings_array = np.array(embeddings_list)
|
||||
token_lengths_array = np.array(token_lengths)
|
||||
|
||||
# 从剩余索引中移除种子
|
||||
remaining_indices.remove(seed_global_idx)
|
||||
Logger(f"Embedding processing completed:")
|
||||
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||
|
||||
# 当前聚类
|
||||
current_cluster_indices = [seed_global_idx]
|
||||
current_tokens = seed_sentence['token_length']
|
||||
# 聚类参数定义
|
||||
min_tokens = int(0.85 * knowledge_length)
|
||||
max_tokens = int(0.95 * knowledge_length)
|
||||
|
||||
if current_tokens >= max_tokens:
|
||||
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||
cluster_text = seed_sentence['sentence']
|
||||
else:
|
||||
# 2.2 优化的相似度计算和选择
|
||||
if remaining_indices:
|
||||
if similarity_matrix is not None:
|
||||
# 使用预计算的相似度矩阵
|
||||
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||
else:
|
||||
# 动态计算相似度(批量)
|
||||
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||
else:
|
||||
similarity_matrix = None
|
||||
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||
|
||||
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||
for i in range(len(remaining_indices))]
|
||||
clustered_rows = []
|
||||
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||
|
||||
# 按相似度排序(降序)
|
||||
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||
|
||||
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||
# 选择聚类算法
|
||||
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||
|
||||
selected_indices_in_remaining = []
|
||||
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||
candidate = processed_sentences[global_idx]
|
||||
candidate_tokens = candidate['token_length']
|
||||
# 超快速聚类:随机采样 + 批量处理
|
||||
import random
|
||||
random.seed(42) # 确保可重现性
|
||||
|
||||
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||
current_cluster_indices.append(global_idx)
|
||||
selected_indices_in_remaining.append(pos_in_remaining)
|
||||
current_tokens += candidate_tokens + 1
|
||||
# 按重要性分层采样
|
||||
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||
|
||||
if current_tokens >= min_tokens:
|
||||
break
|
||||
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||
|
||||
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||
remaining_indices.pop(pos)
|
||||
for cluster_idx in tqdm(range(knowledge_num)):
|
||||
# 分层选择种子:优先选择高重要性句子
|
||||
if high_importance:
|
||||
seed_pool = high_importance
|
||||
elif medium_importance:
|
||||
seed_pool = medium_importance
|
||||
else:
|
||||
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||
|
||||
# 拼接句子
|
||||
if not seed_pool:
|
||||
break
|
||||
|
||||
# 随机选择种子(在同一重要性层级内)
|
||||
seed_global_idx = random.choice(seed_pool)
|
||||
seed_sentence = processed_sentences[seed_global_idx]
|
||||
|
||||
# 从所有池中移除种子
|
||||
for pool in [high_importance, medium_importance, low_importance]:
|
||||
if seed_global_idx in pool:
|
||||
pool.remove(seed_global_idx)
|
||||
|
||||
current_cluster_indices = [seed_global_idx]
|
||||
current_tokens = seed_sentence['token_length']
|
||||
|
||||
if current_tokens < max_tokens:
|
||||
# 快速选择:只从附近的句子中随机选择
|
||||
all_remaining = high_importance + medium_importance + low_importance
|
||||
if all_remaining:
|
||||
# 随机采样候选句子(而不是计算所有相似度)
|
||||
sample_size = min(2000, len(all_remaining))
|
||||
candidates = random.sample(all_remaining, sample_size)
|
||||
|
||||
# 简单按token长度和重要性选择
|
||||
for candidate_idx in candidates:
|
||||
candidate = processed_sentences[candidate_idx]
|
||||
candidate_tokens = candidate['token_length']
|
||||
|
||||
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||
current_cluster_indices.append(candidate_idx)
|
||||
current_tokens += candidate_tokens + 1
|
||||
|
||||
# 从池中移除
|
||||
for pool in [high_importance, medium_importance, low_importance]:
|
||||
if candidate_idx in pool:
|
||||
pool.remove(candidate_idx)
|
||||
break
|
||||
|
||||
if current_tokens >= min_tokens:
|
||||
break
|
||||
|
||||
# 生成聚类文本
|
||||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||
cluster_text = '\n'.join(cluster_sentences)
|
||||
cluster_text = '\n '.join(cluster_sentences)
|
||||
|
||||
# 将聚类文本转换为token
|
||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||
# 转换为tokens
|
||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||
if len(cluster_tokens) > knowledge_length:
|
||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
if len(cluster_tokens) > knowledge_length:
|
||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||
else:
|
||||
# 用pad_token_id填充
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||
clustered_rows.append(cluster_tokens)
|
||||
|
||||
clustered_rows.append(cluster_tokens)
|
||||
if (cluster_idx + 1) % 1000 == 0:
|
||||
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||
|
||||
# 优化4: 减少日志频率
|
||||
if (cluster_idx + 1) % 500 == 0:
|
||||
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||
else:
|
||||
# 原始优化算法(适用于中等规模数据集)
|
||||
# 优化2: 批量处理和更高效的数据结构
|
||||
for cluster_idx in tqdm(range(knowledge_num)):
|
||||
if not remaining_indices:
|
||||
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||
break
|
||||
|
||||
# 如果聚类数量不足,用随机token填充
|
||||
while len(clustered_rows) < knowledge_num:
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
random_tokens = [pad_token_id] * knowledge_length
|
||||
clustered_rows.append(random_tokens)
|
||||
# 2.1 选择importance_score最高的句子作为种子
|
||||
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||
seed_sentence = processed_sentences[seed_global_idx]
|
||||
|
||||
# 转换为tensor
|
||||
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||
# 从剩余索引中移除种子
|
||||
remaining_indices.remove(seed_global_idx)
|
||||
|
||||
Logger(f"Clustering completed:")
|
||||
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
# 当前聚类
|
||||
current_cluster_indices = [seed_global_idx]
|
||||
current_tokens = seed_sentence['token_length']
|
||||
|
||||
if current_tokens >= max_tokens:
|
||||
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||
cluster_text = seed_sentence['sentence']
|
||||
else:
|
||||
# 2.2 优化的相似度计算和选择
|
||||
if remaining_indices:
|
||||
if similarity_matrix is not None:
|
||||
# 使用预计算的相似度矩阵
|
||||
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||
else:
|
||||
# 动态计算相似度(批量)
|
||||
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||
|
||||
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||
for i in range(len(remaining_indices))]
|
||||
|
||||
# 按相似度排序(降序)
|
||||
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||
|
||||
selected_indices_in_remaining = []
|
||||
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||
candidate = processed_sentences[global_idx]
|
||||
candidate_tokens = candidate['token_length']
|
||||
|
||||
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||
current_cluster_indices.append(global_idx)
|
||||
selected_indices_in_remaining.append(pos_in_remaining)
|
||||
current_tokens += candidate_tokens + 1
|
||||
|
||||
if current_tokens >= min_tokens:
|
||||
break
|
||||
|
||||
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||
remaining_indices.pop(pos)
|
||||
|
||||
# 拼接句子
|
||||
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||
cluster_text = '\n'.join(cluster_sentences)
|
||||
|
||||
# 将聚类文本转换为token
|
||||
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||
|
||||
# 截断或填充到knowledge_length
|
||||
if len(cluster_tokens) > knowledge_length:
|
||||
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||
else:
|
||||
# 用pad_token_id填充
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||
|
||||
clustered_rows.append(cluster_tokens)
|
||||
|
||||
# 优化4: 减少日志频率
|
||||
if (cluster_idx + 1) % 500 == 0:
|
||||
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||
|
||||
# 如果聚类数量不足,用随机token填充
|
||||
while len(clustered_rows) < knowledge_num:
|
||||
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
random_tokens = [pad_token_id] * knowledge_length
|
||||
clustered_rows.append(random_tokens)
|
||||
|
||||
# 转换为tensor
|
||||
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||
|
||||
Logger(f"Clustering completed:")
|
||||
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||
|
||||
# 保存聚类结果到缓存文件
|
||||
try:
|
||||
torch.save(clustered_tensor, args.cluster_cache_path)
|
||||
Logger(f"Cluster results saved to {args.cluster_cache_path}")
|
||||
except Exception as e:
|
||||
Logger(f"Failed to save cluster results: {e}")
|
||||
|
||||
# 3. 初始化模型的weight_down_embed
|
||||
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||
@ -651,10 +681,12 @@ def main():
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_num", type=int, default=65536,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens.pt", help="聚类结果缓存文件路径")
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
args = parser.parse_args()
|
||||
|
||||
#########################################################
|
||||
|
Loading…
x
Reference in New Issue
Block a user