triage/main.py
iomgaa 4d7a7b1ba4 实现main.py批处理系统:支持多线程并行处理医疗问诊工作流
主要功能:
- 完整的argparse参数系统,支持线程数、模型类型、数据范围等配置
- 集成AgentOps追踪功能,支持会话管理和性能监控
- 线程安全的BatchProcessor管理器,支持并发执行和进度统计
- 数据集加载和验证功能,支持范围选择和格式检查
- 多线程执行框架,使用ThreadPoolExecutor管理线程池
- 单样本处理函数,调用MedicalWorkflow并集成追踪
- 实时进度监控,后台线程定期报告处理状态
- 完整的错误处理和异常恢复机制
- 结果汇总和报告生成,支持JSON和文本格式
- 统一配置管理,AgentOps API密钥集成到config.py

技术特性:
- 支持20个并发线程处理1677个医疗病例样本
- 线程安全的进度追踪和状态管理
- 详细的日志记录和调试信息输出
- 试运行模式支持配置验证
- 优雅的中断处理和资源清理

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 21:04:11 +08:00

617 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIM医疗问诊工作流批处理系统
使用多线程并行处理数据集中的所有病例样本集成AgentOps追踪
"""
import argparse
import json
import logging
import os
import sys
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, Any, List, Optional
# AgentOps集成
try:
import agno
except ImportError:
print("警告无法导入agnoAgentOps功能将被禁用")
agno = None
# 导入本地模块
from workflow import MedicalWorkflow
from config import AGENTOPS_API_KEY
class BatchProcessor:
"""批处理管理器,负责协调多线程执行和状态管理"""
def __init__(self, num_threads: int = 20):
self.num_threads = num_threads
self.lock = threading.Lock() # 线程安全锁
self.processed_count = 0 # 已处理样本数
self.success_count = 0 # 成功处理数
self.failed_count = 0 # 失败处理数
self.results = [] # 结果列表
self.failed_samples = [] # 失败样本列表
self.start_time = None # 开始时间
def update_progress(self, success: bool, result: Dict[str, Any] = None,
error: Exception = None, sample_index: int = None):
"""线程安全地更新处理进度"""
with self.lock:
self.processed_count += 1
if success:
self.success_count += 1
if result:
self.results.append(result)
else:
self.failed_count += 1
if error and sample_index is not None:
self.failed_samples.append({
'sample_index': sample_index,
'error': str(error),
'timestamp': datetime.now().isoformat()
})
def get_progress_stats(self) -> Dict[str, Any]:
"""获取当前进度统计"""
with self.lock:
elapsed_time = time.time() - self.start_time if self.start_time else 0
return {
'processed': self.processed_count,
'success': self.success_count,
'failed': self.failed_count,
'success_rate': self.success_count / max(self.processed_count, 1),
'elapsed_time': elapsed_time,
'samples_per_minute': self.processed_count / max(elapsed_time / 60, 0.01)
}
def setup_logging(log_level: str = "INFO") -> None:
"""设置日志记录配置"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f'batch_processing_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
]
)
def parse_arguments() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="AIM医疗问诊工作流批处理系统",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 数据和输出配置
parser.add_argument(
'--dataset-path',
type=str,
default='dataset/update.json',
help='数据集JSON文件路径'
)
parser.add_argument(
'--log-dir',
type=str,
default='logs',
help='日志文件保存目录'
)
parser.add_argument(
'--output-dir',
type=str,
default='batch_results',
help='批处理结果保存目录'
)
# 执行参数
parser.add_argument(
'--num-threads',
type=int,
default=20,
help='并行处理线程数'
)
parser.add_argument(
'--max-steps',
type=int,
default=30,
help='每个工作流的最大执行步数'
)
parser.add_argument(
'--start-index',
type=int,
default=0,
help='开始处理的样本索引'
)
parser.add_argument(
'--end-index',
type=int,
default=None,
help='结束处理的样本索引(不包含)'
)
parser.add_argument(
'--sample-limit',
type=int,
default=None,
help='限制处理的样本数量(用于测试)'
)
# 模型配置
parser.add_argument(
'--model-type',
type=str,
default='gpt-oss:latest',
help='使用的语言模型类型'
)
parser.add_argument(
'--model-config',
type=str,
default=None,
help='模型配置JSON字符串'
)
# AgentOps配置
parser.add_argument(
'--agentops-api-key',
type=str,
default=AGENTOPS_API_KEY,
help='AgentOps API密钥默认从config.py读取'
)
parser.add_argument(
'--disable-agentops',
action='store_true',
help='禁用AgentOps追踪'
)
# 调试和日志
parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志记录级别'
)
parser.add_argument(
'--progress-interval',
type=int,
default=10,
help='进度报告间隔(秒)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='试运行模式,只验证配置不执行处理'
)
return parser.parse_args()
def load_dataset(dataset_path: str, start_index: int = 0,
end_index: Optional[int] = None,
sample_limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""加载和验证数据集"""
logging.info(f"正在加载数据集: {dataset_path}")
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"数据集文件不存在: {dataset_path}")
try:
with open(dataset_path, 'r', encoding='utf-8') as f:
full_dataset = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"数据集JSON格式错误: {e}")
except Exception as e:
raise Exception(f"加载数据集失败: {e}")
if not isinstance(full_dataset, list):
raise ValueError("数据集应该是包含病例的JSON数组")
total_samples = len(full_dataset)
logging.info(f"数据集总样本数: {total_samples}")
# 确定处理范围
if end_index is None:
end_index = total_samples
end_index = min(end_index, total_samples)
start_index = max(0, start_index)
if sample_limit:
end_index = min(start_index + sample_limit, end_index)
if start_index >= end_index:
raise ValueError(f"无效的索引范围: start_index={start_index}, end_index={end_index}")
# 提取指定范围的数据
dataset = full_dataset[start_index:end_index]
logging.info(f"将处理样本范围: [{start_index}, {end_index}), 共 {len(dataset)} 个样本")
# 验证数据格式
for i, sample in enumerate(dataset[:5]): # 只验证前5个样本
if not isinstance(sample, dict):
raise ValueError(f"样本 {start_index + i} 格式错误,应为字典类型")
required_keys = ['病案介绍']
for key in required_keys:
if key not in sample:
logging.warning(f"样本 {start_index + i} 缺少必需字段: {key}")
return dataset
def initialize_agentops(api_key: str, disable: bool = False) -> Optional[Any]:
"""初始化AgentOps会话"""
if disable or not agno:
logging.info("AgentOps已禁用")
return None
try:
if not api_key:
logging.warning("未提供AgentOps API密钥AgentOps功能被禁用")
return None
# 初始化AgentOps会话
agno.init(api_key)
session = agno.start_session(tags=['medical_workflow', 'batch_processing'])
logging.info(f"AgentOps会话已启动: {session}")
return session
except Exception as e:
logging.error(f"初始化AgentOps失败: {e}")
return None
def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
args: argparse.Namespace,
processor: BatchProcessor) -> Dict[str, Any]:
"""处理单个样本的工作函数"""
thread_id = threading.current_thread().ident
start_time = time.time()
# 为当前样本创建AgentOps span
span = None
if agno:
try:
span = agno.start_span(
name=f"process_sample_{sample_index}",
span_type="workflow",
tags={
'sample_index': sample_index,
'thread_id': str(thread_id),
'model_type': args.model_type
}
)
except Exception as e:
logging.warning(f"创建AgentOps span失败: {e}")
try:
# 解析模型配置
llm_config = {}
if args.model_config:
try:
llm_config = json.loads(args.model_config)
except json.JSONDecodeError:
logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误使用默认配置")
# 创建工作流实例
workflow = MedicalWorkflow(
case_data=sample_data,
model_type=args.model_type,
llm_config=llm_config,
max_steps=args.max_steps,
log_dir=args.log_dir,
case_index=sample_index
)
# 执行工作流
logging.debug(f"线程 {thread_id}: 开始处理样本 {sample_index}")
log_file_path = workflow.run()
execution_time = time.time() - start_time
# 获取执行结果
workflow_status = workflow.get_current_status()
medical_summary = workflow.get_medical_summary()
# 构建结果
result = {
'sample_index': sample_index,
'thread_id': thread_id,
'execution_time': execution_time,
'log_file_path': log_file_path,
'workflow_status': workflow_status,
'medical_summary': medical_summary,
'processed_at': datetime.now().isoformat()
}
# 记录AgentOps事件
if span:
try:
agno.record(span, {
'action': 'workflow_completed',
'execution_time': execution_time,
'steps_completed': workflow_status['current_step'],
'workflow_success': workflow_status['workflow_success'],
'completion_rate': workflow_status['completion_summary']['overall_completion_rate']
})
agno.end_span(span, 'Success')
except Exception as e:
logging.warning(f"记录AgentOps事件失败: {e}")
# 更新进度
processor.update_progress(success=True, result=result)
logging.info(f"样本 {sample_index} 处理完成 (耗时: {execution_time:.2f}s, "
f"步数: {workflow_status['current_step']}, "
f"成功: {workflow_status['workflow_success']})")
return result
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"样本 {sample_index} 处理失败: {str(e)}"
# 记录AgentOps错误
if span:
try:
agno.record(span, {
'action': 'workflow_failed',
'error': str(e),
'execution_time': execution_time
})
agno.end_span(span, 'Error')
except Exception:
pass
logging.error(error_msg)
processor.update_progress(success=False, error=e, sample_index=sample_index)
# 返回错误结果
return {
'sample_index': sample_index,
'thread_id': thread_id,
'execution_time': execution_time,
'error': str(e),
'processed_at': datetime.now().isoformat(),
'success': False
}
def print_progress_report(processor: BatchProcessor, total_samples: int):
"""打印进度报告"""
stats = processor.get_progress_stats()
print(f"\n=== 处理进度报告 ===")
print(f"已处理: {stats['processed']}/{total_samples} ({stats['processed']/total_samples:.1%})")
print(f"成功: {stats['success']} | 失败: {stats['failed']} | 成功率: {stats['success_rate']:.1%}")
print(f"耗时: {stats['elapsed_time']:.1f}s | 处理速度: {stats['samples_per_minute']:.1f} 样本/分钟")
print(f"预计剩余时间: {(total_samples - stats['processed']) / max(stats['samples_per_minute'] / 60, 0.01):.1f}s")
print("=" * 50)
def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace,
agentops_session: Optional[Any] = None) -> Dict[str, Any]:
"""执行批量工作流处理"""
total_samples = len(dataset)
logging.info(f"开始批量处理 {total_samples} 个样本,使用 {args.num_threads} 个线程")
# 创建批处理管理器
processor = BatchProcessor(num_threads=args.num_threads)
processor.start_time = time.time()
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
# 启动进度监控线程
def progress_monitor():
while processor.processed_count < total_samples:
time.sleep(args.progress_interval)
if processor.processed_count < total_samples:
print_progress_report(processor, total_samples)
progress_thread = threading.Thread(target=progress_monitor, daemon=True)
progress_thread.start()
try:
# 使用线程池执行批处理
with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
# 提交所有任务
future_to_index = {}
for i, sample_data in enumerate(dataset):
sample_index = args.start_index + i
future = executor.submit(
process_single_sample,
sample_data,
sample_index,
args,
processor
)
future_to_index[future] = sample_index
# 等待所有任务完成
for future in as_completed(future_to_index):
sample_index = future_to_index[future]
try:
_ = future.result() # 结果已经在process_single_sample中处理
except Exception as e:
logging.error(f"线程执行异常 (样本 {sample_index}): {e}")
except KeyboardInterrupt:
logging.warning("收到中断信号,正在停止处理...")
executor.shutdown(wait=False)
raise
# 最终进度报告
total_time = time.time() - processor.start_time
stats = processor.get_progress_stats()
print_progress_report(processor, total_samples)
# 构建最终结果摘要
summary = {
'total_samples': total_samples,
'processed_samples': processor.processed_count,
'successful_samples': processor.success_count,
'failed_samples': processor.failed_count,
'success_rate': stats['success_rate'],
'total_execution_time': total_time,
'average_time_per_sample': total_time / max(processor.processed_count, 1),
'samples_per_minute': stats['samples_per_minute'],
'failed_sample_details': processor.failed_samples,
'processing_config': {
'num_threads': args.num_threads,
'model_type': args.model_type,
'max_steps': args.max_steps,
'dataset_range': f"[{args.start_index}, {args.start_index + len(dataset)})"
}
}
return {
'summary': summary,
'results': processor.results,
'agentops_session': agentops_session
}
def generate_summary_report(batch_results: Dict[str, Any],
output_path: str) -> None:
"""生成详细的执行摘要报告"""
summary = batch_results['summary']
results = batch_results['results']
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 生成JSON格式的详细报告
detailed_report = {
'batch_execution_summary': summary,
'sample_results': results,
'generated_at': datetime.now().isoformat(),
'report_version': '1.0'
}
report_file = os.path.join(output_path, f'batch_report_{timestamp}.json')
try:
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(detailed_report, f, ensure_ascii=False, indent=2)
logging.info(f"详细报告已保存: {report_file}")
# 生成人类可读的摘要
summary_file = os.path.join(output_path, f'batch_summary_{timestamp}.txt')
with open(summary_file, 'w', encoding='utf-8') as f:
f.write("AIM医疗问诊工作流批处理执行摘要\n")
f.write("=" * 50 + "\n\n")
f.write(f"执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总样本数: {summary['total_samples']}\n")
f.write(f"处理样本数: {summary['processed_samples']}\n")
f.write(f"成功样本数: {summary['successful_samples']}\n")
f.write(f"失败样本数: {summary['failed_samples']}\n")
f.write(f"成功率: {summary['success_rate']:.2%}\n")
f.write(f"总执行时间: {summary['total_execution_time']:.2f}\n")
f.write(f"平均处理时间: {summary['average_time_per_sample']:.2f} 秒/样本\n")
f.write(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟\n\n")
f.write("处理配置:\n")
for key, value in summary['processing_config'].items():
f.write(f" {key}: {value}\n")
if summary['failed_samples'] > 0:
f.write(f"\n失败样本详情:\n")
for failed in summary['failed_sample_details']:
f.write(f" 样本 {failed['sample_index']}: {failed['error']}\n")
logging.info(f"摘要报告已保存: {summary_file}")
except Exception as e:
logging.error(f"生成报告失败: {e}")
def main():
"""主入口函数"""
# 解析参数
args = parse_arguments()
# 设置日志
setup_logging(args.log_level)
logging.info("=" * 60)
logging.info("AIM医疗问诊工作流批处理系统启动")
logging.info("=" * 60)
try:
# 验证参数
if args.num_threads <= 0:
raise ValueError("线程数必须大于0")
if args.max_steps <= 0:
raise ValueError("最大步数必须大于0")
# 试运行模式
if args.dry_run:
logging.info("试运行模式:验证配置...")
dataset = load_dataset(
args.dataset_path,
args.start_index,
args.end_index,
min(args.sample_limit or 5, 5) # 试运行只验证前5个样本
)
logging.info(f"配置验证成功,将处理 {len(dataset)} 个样本")
return 0
# 加载数据集
dataset = load_dataset(
args.dataset_path,
args.start_index,
args.end_index,
args.sample_limit
)
if len(dataset) == 0:
logging.warning("没有样本需要处理")
return 0
# 初始化AgentOps
agentops_session = initialize_agentops(
args.agentops_api_key,
args.disable_agentops
)
# 执行批处理
logging.info("开始批量处理...")
batch_results = run_workflow_batch(dataset, args, agentops_session)
# 生成报告
generate_summary_report(batch_results, args.output_dir)
# 关闭AgentOps会话
if agentops_session and agno:
try:
agno.end_session('Success')
logging.info("AgentOps会话已结束")
except Exception as e:
logging.error(f"关闭AgentOps会话失败: {e}")
# 输出最终统计
summary = batch_results['summary']
logging.info("=" * 60)
logging.info("批处理执行完成!")
logging.info(f"成功率: {summary['success_rate']:.2%} ({summary['successful_samples']}/{summary['total_samples']})")
logging.info(f"总耗时: {summary['total_execution_time']:.2f}")
logging.info(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟")
logging.info("=" * 60)
return 0 if summary['success_rate'] > 0.8 else 1
except KeyboardInterrupt:
logging.warning("程序被用户中断")
return 1
except Exception as e:
logging.error(f"程序执行失败: {e}")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)