Minimind/preprocessing/preprocess_triple.py

441 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import json
import re
import asyncio
import aiofiles
from concurrent.futures import ThreadPoolExecutor
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
from typing import Dict, List, Tuple
import gc
import time
import psutil
from tqdm.asyncio import tqdm as async_tqdm
from tqdm import tqdm
json_path = "dataset/merged_pretrain_extra.jsonl"
output_path = "dataset/processed_triples.jsonl"
# 优化后的配置参数 - 降低资源消耗
BATCH_SIZE = 5000 # 减少批次大小每批1万条数据
MAX_CONCURRENT = 200 # 减少并发数最多50条并发处理
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小只创建5个agent实例
def get_memory_usage():
"""获取当前内存使用情况"""
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
memory_mb = memory_info.rss / 1024 / 1024
return memory_mb
def print_memory_info(stage=""):
"""打印内存使用信息"""
memory_mb = get_memory_usage()
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
# 创建extractor_agent池避免并发冲突
def create_extractor_pool(pool_size: int = 5):
"""创建extractor_agent池"""
print(f"正在创建 {pool_size} 个agent实例...")
agents = []
for i in range(pool_size):
try:
agent = DepartmentAgent(model_type="deepseek")
agents.append(agent)
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
except Exception as e:
print(f" ✗ Agent {i+1} 创建失败: {e}")
print(f"Agent池创建完成实际创建了 {len(agents)} 个实例")
return agents
# 延迟初始化agent池
AGENT_POOL = None
agent_pool_index = 0
def get_agent_pool():
"""获取agent池延迟初始化"""
global AGENT_POOL
if AGENT_POOL is None:
print_memory_info("创建Agent池前")
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
print_memory_info("创建Agent池后")
return AGENT_POOL
def get_next_agent():
"""轮询获取下一个可用的agent"""
global agent_pool_index
pool = get_agent_pool()
agent = pool[agent_pool_index % len(pool)]
agent_pool_index += 1
return agent
def clean_and_split_text(text):
"""
去除文本开头结尾的标记,并按句子分割
"""
# 去除开头的<|im_start|>和结尾的<|im_end|>
text = text.strip()
if text.startswith('<|im_start|>'):
text = text[len('<|im_start|>'):]
if text.endswith('<|im_end|>'):
text = text[:-len('<|im_end|>')]
# 清理文本,去除多余的空白字符
text = text.strip()
# 按句子分割(根据句号、问号、感叹号等标点符号)
# 使用正则表达式匹配句子结束标志
sentence_endings = r'[.!?。!?]'
sentences = re.split(sentence_endings, text)
# 清理每个句子,去除空白和空句子
cleaned_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
cleaned_sentences.append(sentence)
return cleaned_sentences
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
"""
异步使用extractor_agent从句子中提取三元组
"""
try:
# 获取一个agent实例
agent = get_next_agent()
result = await agent.async_run(sentence=sentence, context=context)
return {
"sentence": sentence,
"triple": {
"subject": result.triple.subject,
"predicate": result.triple.predicate,
"object": result.triple.object
},
"confidence": result.confidence
}
except Exception as e:
return {
"sentence": sentence,
"triple": {
"subject": "",
"predicate": "",
"object": ""
},
"confidence": 0.0,
"error": str(e)
}
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
"""
异步处理单个段落
"""
async with semaphore: # 控制并发数量
try:
# 清理并分割文本
sentences = clean_and_split_text(original_text)
if not sentences:
return None
# 构建当前段落的结果
paragraph_result = {
"source_line": line_num,
"original_paragraph": original_text,
"sentences": [],
"triples": []
}
# 异步处理所有句子
tasks = []
for sentence in sentences:
task = extract_triple_from_sentence_async(sentence, context=original_text)
tasks.append(task)
# 等待所有句子处理完成
triple_results = await asyncio.gather(*tasks)
# 整理结果
for i, sentence in enumerate(sentences):
paragraph_result["sentences"].append(sentence)
paragraph_result["triples"].append(triple_results[i])
return paragraph_result
except Exception as e:
print(f"处理第 {line_num} 行时出错: {e}")
return None
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
"""
异步处理一个批次的数据,带进度条和内存监控
"""
print(f"\n=== 异步处理批次 {batch_num} ===")
print(f"批次大小: {len(batch_data)} 条记录")
print_memory_info(f"批次 {batch_num} 开始前")
start_time = time.time()
# 创建信号量控制并发数量
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
# 分块处理任务,避免一次性创建太多任务
chunk_size = 1000 # 每次处理1000个任务
all_results = []
for chunk_start in range(0, len(batch_data), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_data))
chunk_data = batch_data[chunk_start:chunk_end]
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
# 创建当前块的异步任务
tasks = []
for line_num, original_text in chunk_data:
task = process_paragraph_async(line_num, original_text, semaphore)
tasks.append(task)
# 使用进度条处理当前块
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
chunk_results = []
completed_tasks = 0
# 使用as_completed来获取完成的任务并更新进度条
for coro in asyncio.as_completed(tasks):
try:
result = await coro
chunk_results.append(result)
completed_tasks += 1
# 更新进度条
progress_bar.update(1)
# 每完成50个任务更新一次描述
if completed_tasks % 50 == 0:
valid_results = [r for r in chunk_results if r is not None]
progress_bar.set_postfix({
'有效': len(valid_results),
'完成': completed_tasks,
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
})
except Exception as e:
print(f"任务执行失败: {e}")
completed_tasks += 1
progress_bar.update(1)
progress_bar.close()
all_results.extend(chunk_results)
# 每个块完成后清理内存
del tasks, chunk_results
gc.collect()
print_memory_info(f"批次 {batch_num}{chunk_start//chunk_size + 1} 完成后")
# 过滤None结果
valid_results = [result for result in all_results if result is not None]
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in valid_results
)
end_time = time.time()
processing_time = end_time - start_time
print(f"批次 {batch_num} 异步处理完成:")
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
print(f" - 总句子数: {batch_sentences}")
print(f" - 成功三元组: {batch_triples}")
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
print(f" - 处理时间: {processing_time:.2f}")
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
print_memory_info(f"批次 {batch_num} 完成后")
return valid_results
async def write_results_batch(results: List[Dict], output_path: str):
"""
异步批量写入结果,带进度提示
"""
try:
print(f"开始批量写入 {len(results)} 条结果...")
# 准备写入内容
content_lines = []
for result in results:
content_lines.append(json.dumps(result, ensure_ascii=False))
# 异步批量写入
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
await f.write("\n".join(content_lines) + "\n")
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
except Exception as e:
print(f"✗ 批量写入失败: {e}")
print("尝试逐条写入...")
# 如果批量写入失败,回退到逐条写入(带进度条)
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
for result in tqdm(results, desc="逐条写入", unit=""):
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
print(f"✓ 逐条写入完成")
# 主处理流程
async def main_async():
total_processed = 0
total_sentences = 0
total_triples = 0
batch_num = 0
print("=== 开始异步批次处理JSONL文件 ===")
print(f"优化后的配置信息:")
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
print(f" - 最大并发数: {MAX_CONCURRENT}")
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
print(f" - 输入文件: {json_path}")
print(f" - 输出文件: {output_path}")
print()
print_memory_info("程序开始")
# 清空输出文件
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
pass
# 读取并处理数据
with open(json_path, "r", encoding="utf-8") as f_in:
batch_data = []
for line_num, line in enumerate(f_in):
if line.strip(): # 跳过空行
try:
item = json.loads(line)
original_text = item.get("text", "")
if original_text:
batch_data.append((line_num + 1, original_text))
# 当批次达到指定大小时,异步处理这个批次
if len(batch_data) >= BATCH_SIZE:
batch_num += 1
# 异步处理批次
batch_results = await process_batch_async(batch_data, batch_num)
# 批量写入结果
if batch_results:
await write_results_batch(batch_results, output_path)
# 统计信息
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
print(f"\n📊 批次 {batch_num} 累计统计:")
print(f" - 累计处理段落: {total_processed:,}")
print(f" - 累计句子数: {total_sentences:,}")
print(f" - 累计三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
print("-" * 80)
# 清理批次数据,释放内存
batch_data.clear()
batch_results.clear()
gc.collect() # 强制垃圾回收
print_memory_info(f"批次 {batch_num} 清理后")
except json.JSONDecodeError as e:
print(f"{line_num + 1} 行JSON解析错误: {e}")
except Exception as e:
print(f"处理第 {line_num + 1} 行时出错: {e}")
# 处理最后一个不完整的批次
if batch_data:
batch_num += 1
batch_results = await process_batch_async(batch_data, batch_num)
if batch_results:
await write_results_batch(batch_results, output_path)
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
batch_triples = sum(
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
for result in batch_results
)
total_processed += len(batch_data)
total_sentences += batch_sentences
total_triples += batch_triples
# 最终统计
print(f"\n{'='*80}")
print(f"🎉 所有批次异步处理完成!")
print(f"{'='*80}")
print(f"最终统计:")
print(f" - 总批次数: {batch_num}")
print(f" - 总段落数: {total_processed:,}")
print(f" - 总句子数: {total_sentences:,}")
print(f" - 总三元组: {total_triples:,}")
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
print(f" - 输出文件: {output_path}")
print(f"{'='*80}")
print_memory_info("程序结束前")
# 显示示例结果
await show_sample_results()
async def show_sample_results():
"""显示前几个处理结果作为示例"""
print("\n📋 前3个处理结果示例:")
try:
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
i = 0
async for line in f:
if i >= 3:
break
item = json.loads(line)
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
print(f"原始段落: {item['original_paragraph'][:100]}...")
print(f"句子数量: {len(item['sentences'])}")
if item['triples']:
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
if triple['confidence'] > 0:
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
print(f" 置信度: {triple['confidence']:.2f}")
else:
print(f" 提取失败: {triple.get('error', '未知错误')}")
i += 1
except Exception as e:
print(f"读取示例结果时出错: {e}")
def main():
"""主入口函数"""
try:
# 运行异步主函数
asyncio.run(main_async())
except KeyboardInterrupt:
print("\n⚠️ 用户中断处理")
except Exception as e:
print(f"❌ 处理过程中出现错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()