441 lines
16 KiB
Python
441 lines
16 KiB
Python
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
import json
|
||
import re
|
||
import asyncio
|
||
import aiofiles
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
|
||
from typing import Dict, List, Tuple
|
||
import gc
|
||
import time
|
||
import psutil
|
||
from tqdm.asyncio import tqdm as async_tqdm
|
||
from tqdm import tqdm
|
||
|
||
json_path = "dataset/merged_pretrain_extra.jsonl"
|
||
output_path = "dataset/processed_triples.jsonl"
|
||
|
||
# 优化后的配置参数 - 降低资源消耗
|
||
BATCH_SIZE = 5000 # 减少批次大小:每批1万条数据
|
||
MAX_CONCURRENT = 200 # 减少并发数:最多50条并发处理
|
||
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小:只创建5个agent实例
|
||
|
||
def get_memory_usage():
|
||
"""获取当前内存使用情况"""
|
||
process = psutil.Process(os.getpid())
|
||
memory_info = process.memory_info()
|
||
memory_mb = memory_info.rss / 1024 / 1024
|
||
return memory_mb
|
||
|
||
def print_memory_info(stage=""):
|
||
"""打印内存使用信息"""
|
||
memory_mb = get_memory_usage()
|
||
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
|
||
|
||
# 创建extractor_agent池,避免并发冲突
|
||
def create_extractor_pool(pool_size: int = 5):
|
||
"""创建extractor_agent池"""
|
||
print(f"正在创建 {pool_size} 个agent实例...")
|
||
agents = []
|
||
for i in range(pool_size):
|
||
try:
|
||
agent = DepartmentAgent(model_type="deepseek")
|
||
agents.append(agent)
|
||
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
|
||
except Exception as e:
|
||
print(f" ✗ Agent {i+1} 创建失败: {e}")
|
||
print(f"Agent池创建完成,实际创建了 {len(agents)} 个实例")
|
||
return agents
|
||
|
||
# 延迟初始化agent池
|
||
AGENT_POOL = None
|
||
agent_pool_index = 0
|
||
|
||
def get_agent_pool():
|
||
"""获取agent池,延迟初始化"""
|
||
global AGENT_POOL
|
||
if AGENT_POOL is None:
|
||
print_memory_info("创建Agent池前")
|
||
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
|
||
print_memory_info("创建Agent池后")
|
||
return AGENT_POOL
|
||
|
||
def get_next_agent():
|
||
"""轮询获取下一个可用的agent"""
|
||
global agent_pool_index
|
||
pool = get_agent_pool()
|
||
agent = pool[agent_pool_index % len(pool)]
|
||
agent_pool_index += 1
|
||
return agent
|
||
|
||
def clean_and_split_text(text):
|
||
"""
|
||
去除文本开头结尾的标记,并按句子分割
|
||
"""
|
||
# 去除开头的<|im_start|>和结尾的<|im_end|>
|
||
text = text.strip()
|
||
if text.startswith('<|im_start|>'):
|
||
text = text[len('<|im_start|>'):]
|
||
if text.endswith('<|im_end|>'):
|
||
text = text[:-len('<|im_end|>')]
|
||
|
||
# 清理文本,去除多余的空白字符
|
||
text = text.strip()
|
||
|
||
# 按句子分割(根据句号、问号、感叹号等标点符号)
|
||
# 使用正则表达式匹配句子结束标志
|
||
sentence_endings = r'[.!?。!?]'
|
||
sentences = re.split(sentence_endings, text)
|
||
|
||
# 清理每个句子,去除空白和空句子
|
||
cleaned_sentences = []
|
||
for sentence in sentences:
|
||
sentence = sentence.strip()
|
||
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
|
||
cleaned_sentences.append(sentence)
|
||
|
||
return cleaned_sentences
|
||
|
||
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
|
||
"""
|
||
异步使用extractor_agent从句子中提取三元组
|
||
"""
|
||
try:
|
||
# 获取一个agent实例
|
||
agent = get_next_agent()
|
||
result = await agent.async_run(sentence=sentence, context=context)
|
||
return {
|
||
"sentence": sentence,
|
||
"triple": {
|
||
"subject": result.triple.subject,
|
||
"predicate": result.triple.predicate,
|
||
"object": result.triple.object
|
||
},
|
||
"confidence": result.confidence
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"sentence": sentence,
|
||
"triple": {
|
||
"subject": "",
|
||
"predicate": "",
|
||
"object": ""
|
||
},
|
||
"confidence": 0.0,
|
||
"error": str(e)
|
||
}
|
||
|
||
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
|
||
"""
|
||
异步处理单个段落
|
||
"""
|
||
async with semaphore: # 控制并发数量
|
||
try:
|
||
# 清理并分割文本
|
||
sentences = clean_and_split_text(original_text)
|
||
|
||
if not sentences:
|
||
return None
|
||
|
||
# 构建当前段落的结果
|
||
paragraph_result = {
|
||
"source_line": line_num,
|
||
"original_paragraph": original_text,
|
||
"sentences": [],
|
||
"triples": []
|
||
}
|
||
|
||
# 异步处理所有句子
|
||
tasks = []
|
||
for sentence in sentences:
|
||
task = extract_triple_from_sentence_async(sentence, context=original_text)
|
||
tasks.append(task)
|
||
|
||
# 等待所有句子处理完成
|
||
triple_results = await asyncio.gather(*tasks)
|
||
|
||
# 整理结果
|
||
for i, sentence in enumerate(sentences):
|
||
paragraph_result["sentences"].append(sentence)
|
||
paragraph_result["triples"].append(triple_results[i])
|
||
|
||
return paragraph_result
|
||
|
||
except Exception as e:
|
||
print(f"处理第 {line_num} 行时出错: {e}")
|
||
return None
|
||
|
||
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
|
||
"""
|
||
异步处理一个批次的数据,带进度条和内存监控
|
||
"""
|
||
print(f"\n=== 异步处理批次 {batch_num} ===")
|
||
print(f"批次大小: {len(batch_data)} 条记录")
|
||
print_memory_info(f"批次 {batch_num} 开始前")
|
||
|
||
start_time = time.time()
|
||
|
||
# 创建信号量控制并发数量
|
||
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||
|
||
# 分块处理任务,避免一次性创建太多任务
|
||
chunk_size = 1000 # 每次处理1000个任务
|
||
all_results = []
|
||
|
||
for chunk_start in range(0, len(batch_data), chunk_size):
|
||
chunk_end = min(chunk_start + chunk_size, len(batch_data))
|
||
chunk_data = batch_data[chunk_start:chunk_end]
|
||
|
||
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
|
||
|
||
# 创建当前块的异步任务
|
||
tasks = []
|
||
for line_num, original_text in chunk_data:
|
||
task = process_paragraph_async(line_num, original_text, semaphore)
|
||
tasks.append(task)
|
||
|
||
# 使用进度条处理当前块
|
||
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
|
||
|
||
chunk_results = []
|
||
completed_tasks = 0
|
||
|
||
# 使用as_completed来获取完成的任务,并更新进度条
|
||
for coro in asyncio.as_completed(tasks):
|
||
try:
|
||
result = await coro
|
||
chunk_results.append(result)
|
||
completed_tasks += 1
|
||
|
||
# 更新进度条
|
||
progress_bar.update(1)
|
||
|
||
# 每完成50个任务更新一次描述
|
||
if completed_tasks % 50 == 0:
|
||
valid_results = [r for r in chunk_results if r is not None]
|
||
progress_bar.set_postfix({
|
||
'有效': len(valid_results),
|
||
'完成': completed_tasks,
|
||
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
|
||
})
|
||
except Exception as e:
|
||
print(f"任务执行失败: {e}")
|
||
completed_tasks += 1
|
||
progress_bar.update(1)
|
||
|
||
progress_bar.close()
|
||
all_results.extend(chunk_results)
|
||
|
||
# 每个块完成后清理内存
|
||
del tasks, chunk_results
|
||
gc.collect()
|
||
|
||
print_memory_info(f"批次 {batch_num} 块 {chunk_start//chunk_size + 1} 完成后")
|
||
|
||
# 过滤None结果
|
||
valid_results = [result for result in all_results if result is not None]
|
||
|
||
# 统计信息
|
||
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
|
||
batch_triples = sum(
|
||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||
for result in valid_results
|
||
)
|
||
|
||
end_time = time.time()
|
||
processing_time = end_time - start_time
|
||
|
||
print(f"批次 {batch_num} 异步处理完成:")
|
||
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
|
||
print(f" - 总句子数: {batch_sentences}")
|
||
print(f" - 成功三元组: {batch_triples}")
|
||
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
|
||
print(f" - 处理时间: {processing_time:.2f}秒")
|
||
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
|
||
|
||
print_memory_info(f"批次 {batch_num} 完成后")
|
||
|
||
return valid_results
|
||
|
||
async def write_results_batch(results: List[Dict], output_path: str):
|
||
"""
|
||
异步批量写入结果,带进度提示
|
||
"""
|
||
try:
|
||
print(f"开始批量写入 {len(results)} 条结果...")
|
||
|
||
# 准备写入内容
|
||
content_lines = []
|
||
for result in results:
|
||
content_lines.append(json.dumps(result, ensure_ascii=False))
|
||
|
||
# 异步批量写入
|
||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||
await f.write("\n".join(content_lines) + "\n")
|
||
|
||
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
|
||
|
||
except Exception as e:
|
||
print(f"✗ 批量写入失败: {e}")
|
||
print("尝试逐条写入...")
|
||
|
||
# 如果批量写入失败,回退到逐条写入(带进度条)
|
||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||
for result in tqdm(results, desc="逐条写入", unit="条"):
|
||
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
print(f"✓ 逐条写入完成")
|
||
|
||
# 主处理流程
|
||
async def main_async():
|
||
total_processed = 0
|
||
total_sentences = 0
|
||
total_triples = 0
|
||
batch_num = 0
|
||
|
||
print("=== 开始异步批次处理JSONL文件 ===")
|
||
print(f"优化后的配置信息:")
|
||
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
|
||
print(f" - 最大并发数: {MAX_CONCURRENT}")
|
||
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
|
||
print(f" - 输入文件: {json_path}")
|
||
print(f" - 输出文件: {output_path}")
|
||
print()
|
||
|
||
print_memory_info("程序开始")
|
||
|
||
# 清空输出文件
|
||
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
||
pass
|
||
|
||
# 读取并处理数据
|
||
with open(json_path, "r", encoding="utf-8") as f_in:
|
||
batch_data = []
|
||
|
||
for line_num, line in enumerate(f_in):
|
||
if line.strip(): # 跳过空行
|
||
try:
|
||
item = json.loads(line)
|
||
original_text = item.get("text", "")
|
||
|
||
if original_text:
|
||
batch_data.append((line_num + 1, original_text))
|
||
|
||
# 当批次达到指定大小时,异步处理这个批次
|
||
if len(batch_data) >= BATCH_SIZE:
|
||
batch_num += 1
|
||
|
||
# 异步处理批次
|
||
batch_results = await process_batch_async(batch_data, batch_num)
|
||
|
||
# 批量写入结果
|
||
if batch_results:
|
||
await write_results_batch(batch_results, output_path)
|
||
|
||
# 统计信息
|
||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||
batch_triples = sum(
|
||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||
for result in batch_results
|
||
)
|
||
|
||
total_processed += len(batch_data)
|
||
total_sentences += batch_sentences
|
||
total_triples += batch_triples
|
||
|
||
print(f"\n📊 批次 {batch_num} 累计统计:")
|
||
print(f" - 累计处理段落: {total_processed:,}")
|
||
print(f" - 累计句子数: {total_sentences:,}")
|
||
print(f" - 累计三元组: {total_triples:,}")
|
||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
|
||
print("-" * 80)
|
||
|
||
# 清理批次数据,释放内存
|
||
batch_data.clear()
|
||
batch_results.clear()
|
||
gc.collect() # 强制垃圾回收
|
||
|
||
print_memory_info(f"批次 {batch_num} 清理后")
|
||
|
||
except json.JSONDecodeError as e:
|
||
print(f"第 {line_num + 1} 行JSON解析错误: {e}")
|
||
except Exception as e:
|
||
print(f"处理第 {line_num + 1} 行时出错: {e}")
|
||
|
||
# 处理最后一个不完整的批次
|
||
if batch_data:
|
||
batch_num += 1
|
||
batch_results = await process_batch_async(batch_data, batch_num)
|
||
|
||
if batch_results:
|
||
await write_results_batch(batch_results, output_path)
|
||
|
||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||
batch_triples = sum(
|
||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||
for result in batch_results
|
||
)
|
||
|
||
total_processed += len(batch_data)
|
||
total_sentences += batch_sentences
|
||
total_triples += batch_triples
|
||
|
||
# 最终统计
|
||
print(f"\n{'='*80}")
|
||
print(f"🎉 所有批次异步处理完成!")
|
||
print(f"{'='*80}")
|
||
print(f"最终统计:")
|
||
print(f" - 总批次数: {batch_num}")
|
||
print(f" - 总段落数: {total_processed:,}")
|
||
print(f" - 总句子数: {total_sentences:,}")
|
||
print(f" - 总三元组: {total_triples:,}")
|
||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
|
||
print(f" - 输出文件: {output_path}")
|
||
print(f"{'='*80}")
|
||
|
||
print_memory_info("程序结束前")
|
||
|
||
# 显示示例结果
|
||
await show_sample_results()
|
||
|
||
async def show_sample_results():
|
||
"""显示前几个处理结果作为示例"""
|
||
print("\n📋 前3个处理结果示例:")
|
||
try:
|
||
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
|
||
i = 0
|
||
async for line in f:
|
||
if i >= 3:
|
||
break
|
||
item = json.loads(line)
|
||
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
|
||
print(f"原始段落: {item['original_paragraph'][:100]}...")
|
||
print(f"句子数量: {len(item['sentences'])}")
|
||
if item['triples']:
|
||
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
|
||
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
|
||
if triple['confidence'] > 0:
|
||
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
|
||
print(f" 置信度: {triple['confidence']:.2f}")
|
||
else:
|
||
print(f" 提取失败: {triple.get('error', '未知错误')}")
|
||
i += 1
|
||
except Exception as e:
|
||
print(f"读取示例结果时出错: {e}")
|
||
|
||
def main():
|
||
"""主入口函数"""
|
||
try:
|
||
# 运行异步主函数
|
||
asyncio.run(main_async())
|
||
except KeyboardInterrupt:
|
||
print("\n⚠️ 用户中断处理")
|
||
except Exception as e:
|
||
print(f"❌ 处理过程中出现错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == "__main__":
|
||
main() |