diff --git a/config.py b/config.py index d1931df..513796d 100644 --- a/config.py +++ b/config.py @@ -2,6 +2,9 @@ import os API_KEY = "sk-263038d4bf4e46a0bed16532587cff40" +# AgentOps API密钥 +AGENTOPS_API_KEY = "8c30718a-0485-4adb-a852-05d02e50e3cb" + # {project_root}/medsynthai BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/main.py b/main.py index a44d923..aa6fbe6 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,617 @@ -def main(): - print("Hello from aim!") +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +AIM医疗问诊工作流批处理系统 +使用多线程并行处理数据集中的所有病例样本,集成AgentOps追踪 +""" +import argparse +import json +import logging +import os +import sys +import time +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from typing import Dict, Any, List, Optional + +# AgentOps集成 +try: + import agno +except ImportError: + print("警告:无法导入agno,AgentOps功能将被禁用") + agno = None + +# 导入本地模块 +from workflow import MedicalWorkflow +from config import AGENTOPS_API_KEY + +class BatchProcessor: + """批处理管理器,负责协调多线程执行和状态管理""" + + def __init__(self, num_threads: int = 20): + self.num_threads = num_threads + self.lock = threading.Lock() # 线程安全锁 + self.processed_count = 0 # 已处理样本数 + self.success_count = 0 # 成功处理数 + self.failed_count = 0 # 失败处理数 + self.results = [] # 结果列表 + self.failed_samples = [] # 失败样本列表 + self.start_time = None # 开始时间 + + def update_progress(self, success: bool, result: Dict[str, Any] = None, + error: Exception = None, sample_index: int = None): + """线程安全地更新处理进度""" + with self.lock: + self.processed_count += 1 + if success: + self.success_count += 1 + if result: + self.results.append(result) + else: + self.failed_count += 1 + if error and sample_index is not None: + self.failed_samples.append({ + 'sample_index': sample_index, + 'error': str(error), + 'timestamp': datetime.now().isoformat() + }) + + def get_progress_stats(self) -> Dict[str, Any]: + """获取当前进度统计""" + with self.lock: + elapsed_time = time.time() - self.start_time if self.start_time else 0 + return { + 'processed': self.processed_count, + 'success': self.success_count, + 'failed': self.failed_count, + 'success_rate': self.success_count / max(self.processed_count, 1), + 'elapsed_time': elapsed_time, + 'samples_per_minute': self.processed_count / max(elapsed_time / 60, 0.01) + } + +def setup_logging(log_level: str = "INFO") -> None: + """设置日志记录配置""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler(f'batch_processing_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') + ] + ) + +def parse_arguments() -> argparse.Namespace: + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="AIM医疗问诊工作流批处理系统", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # 数据和输出配置 + parser.add_argument( + '--dataset-path', + type=str, + default='dataset/update.json', + help='数据集JSON文件路径' + ) + parser.add_argument( + '--log-dir', + type=str, + default='logs', + help='日志文件保存目录' + ) + parser.add_argument( + '--output-dir', + type=str, + default='batch_results', + help='批处理结果保存目录' + ) + + # 执行参数 + parser.add_argument( + '--num-threads', + type=int, + default=20, + help='并行处理线程数' + ) + parser.add_argument( + '--max-steps', + type=int, + default=30, + help='每个工作流的最大执行步数' + ) + parser.add_argument( + '--start-index', + type=int, + default=0, + help='开始处理的样本索引' + ) + parser.add_argument( + '--end-index', + type=int, + default=None, + help='结束处理的样本索引(不包含)' + ) + parser.add_argument( + '--sample-limit', + type=int, + default=None, + help='限制处理的样本数量(用于测试)' + ) + + # 模型配置 + parser.add_argument( + '--model-type', + type=str, + default='gpt-oss:latest', + help='使用的语言模型类型' + ) + parser.add_argument( + '--model-config', + type=str, + default=None, + help='模型配置JSON字符串' + ) + + # AgentOps配置 + parser.add_argument( + '--agentops-api-key', + type=str, + default=AGENTOPS_API_KEY, + help='AgentOps API密钥(默认从config.py读取)' + ) + parser.add_argument( + '--disable-agentops', + action='store_true', + help='禁用AgentOps追踪' + ) + + # 调试和日志 + parser.add_argument( + '--log-level', + type=str, + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + default='INFO', + help='日志记录级别' + ) + parser.add_argument( + '--progress-interval', + type=int, + default=10, + help='进度报告间隔(秒)' + ) + parser.add_argument( + '--dry-run', + action='store_true', + help='试运行模式,只验证配置不执行处理' + ) + + return parser.parse_args() + +def load_dataset(dataset_path: str, start_index: int = 0, + end_index: Optional[int] = None, + sample_limit: Optional[int] = None) -> List[Dict[str, Any]]: + """加载和验证数据集""" + logging.info(f"正在加载数据集: {dataset_path}") + + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"数据集文件不存在: {dataset_path}") + + try: + with open(dataset_path, 'r', encoding='utf-8') as f: + full_dataset = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f"数据集JSON格式错误: {e}") + except Exception as e: + raise Exception(f"加载数据集失败: {e}") + + if not isinstance(full_dataset, list): + raise ValueError("数据集应该是包含病例的JSON数组") + + total_samples = len(full_dataset) + logging.info(f"数据集总样本数: {total_samples}") + + # 确定处理范围 + if end_index is None: + end_index = total_samples + + end_index = min(end_index, total_samples) + start_index = max(0, start_index) + + if sample_limit: + end_index = min(start_index + sample_limit, end_index) + + if start_index >= end_index: + raise ValueError(f"无效的索引范围: start_index={start_index}, end_index={end_index}") + + # 提取指定范围的数据 + dataset = full_dataset[start_index:end_index] + + logging.info(f"将处理样本范围: [{start_index}, {end_index}), 共 {len(dataset)} 个样本") + + # 验证数据格式 + for i, sample in enumerate(dataset[:5]): # 只验证前5个样本 + if not isinstance(sample, dict): + raise ValueError(f"样本 {start_index + i} 格式错误,应为字典类型") + + required_keys = ['病案介绍'] + for key in required_keys: + if key not in sample: + logging.warning(f"样本 {start_index + i} 缺少必需字段: {key}") + + return dataset + +def initialize_agentops(api_key: str, disable: bool = False) -> Optional[Any]: + """初始化AgentOps会话""" + if disable or not agno: + logging.info("AgentOps已禁用") + return None + + try: + if not api_key: + logging.warning("未提供AgentOps API密钥,AgentOps功能被禁用") + return None + + # 初始化AgentOps会话 + agno.init(api_key) + session = agno.start_session(tags=['medical_workflow', 'batch_processing']) + + logging.info(f"AgentOps会话已启动: {session}") + return session + + except Exception as e: + logging.error(f"初始化AgentOps失败: {e}") + return None + +def process_single_sample(sample_data: Dict[str, Any], sample_index: int, + args: argparse.Namespace, + processor: BatchProcessor) -> Dict[str, Any]: + """处理单个样本的工作函数""" + thread_id = threading.current_thread().ident + start_time = time.time() + + # 为当前样本创建AgentOps span + span = None + if agno: + try: + span = agno.start_span( + name=f"process_sample_{sample_index}", + span_type="workflow", + tags={ + 'sample_index': sample_index, + 'thread_id': str(thread_id), + 'model_type': args.model_type + } + ) + except Exception as e: + logging.warning(f"创建AgentOps span失败: {e}") + + try: + # 解析模型配置 + llm_config = {} + if args.model_config: + try: + llm_config = json.loads(args.model_config) + except json.JSONDecodeError: + logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误,使用默认配置") + + # 创建工作流实例 + workflow = MedicalWorkflow( + case_data=sample_data, + model_type=args.model_type, + llm_config=llm_config, + max_steps=args.max_steps, + log_dir=args.log_dir, + case_index=sample_index + ) + + # 执行工作流 + logging.debug(f"线程 {thread_id}: 开始处理样本 {sample_index}") + log_file_path = workflow.run() + + execution_time = time.time() - start_time + + # 获取执行结果 + workflow_status = workflow.get_current_status() + medical_summary = workflow.get_medical_summary() + + # 构建结果 + result = { + 'sample_index': sample_index, + 'thread_id': thread_id, + 'execution_time': execution_time, + 'log_file_path': log_file_path, + 'workflow_status': workflow_status, + 'medical_summary': medical_summary, + 'processed_at': datetime.now().isoformat() + } + + # 记录AgentOps事件 + if span: + try: + agno.record(span, { + 'action': 'workflow_completed', + 'execution_time': execution_time, + 'steps_completed': workflow_status['current_step'], + 'workflow_success': workflow_status['workflow_success'], + 'completion_rate': workflow_status['completion_summary']['overall_completion_rate'] + }) + agno.end_span(span, 'Success') + except Exception as e: + logging.warning(f"记录AgentOps事件失败: {e}") + + # 更新进度 + processor.update_progress(success=True, result=result) + + logging.info(f"样本 {sample_index} 处理完成 (耗时: {execution_time:.2f}s, " + f"步数: {workflow_status['current_step']}, " + f"成功: {workflow_status['workflow_success']})") + + return result + + except Exception as e: + execution_time = time.time() - start_time + error_msg = f"样本 {sample_index} 处理失败: {str(e)}" + + # 记录AgentOps错误 + if span: + try: + agno.record(span, { + 'action': 'workflow_failed', + 'error': str(e), + 'execution_time': execution_time + }) + agno.end_span(span, 'Error') + except Exception: + pass + + logging.error(error_msg) + processor.update_progress(success=False, error=e, sample_index=sample_index) + + # 返回错误结果 + return { + 'sample_index': sample_index, + 'thread_id': thread_id, + 'execution_time': execution_time, + 'error': str(e), + 'processed_at': datetime.now().isoformat(), + 'success': False + } + +def print_progress_report(processor: BatchProcessor, total_samples: int): + """打印进度报告""" + stats = processor.get_progress_stats() + + print(f"\n=== 处理进度报告 ===") + print(f"已处理: {stats['processed']}/{total_samples} ({stats['processed']/total_samples:.1%})") + print(f"成功: {stats['success']} | 失败: {stats['failed']} | 成功率: {stats['success_rate']:.1%}") + print(f"耗时: {stats['elapsed_time']:.1f}s | 处理速度: {stats['samples_per_minute']:.1f} 样本/分钟") + print(f"预计剩余时间: {(total_samples - stats['processed']) / max(stats['samples_per_minute'] / 60, 0.01):.1f}s") + print("=" * 50) + +def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace, + agentops_session: Optional[Any] = None) -> Dict[str, Any]: + """执行批量工作流处理""" + total_samples = len(dataset) + logging.info(f"开始批量处理 {total_samples} 个样本,使用 {args.num_threads} 个线程") + + # 创建批处理管理器 + processor = BatchProcessor(num_threads=args.num_threads) + processor.start_time = time.time() + + # 创建输出目录 + os.makedirs(args.output_dir, exist_ok=True) + os.makedirs(args.log_dir, exist_ok=True) + + # 启动进度监控线程 + def progress_monitor(): + while processor.processed_count < total_samples: + time.sleep(args.progress_interval) + if processor.processed_count < total_samples: + print_progress_report(processor, total_samples) + + progress_thread = threading.Thread(target=progress_monitor, daemon=True) + progress_thread.start() + + try: + # 使用线程池执行批处理 + with ThreadPoolExecutor(max_workers=args.num_threads) as executor: + # 提交所有任务 + future_to_index = {} + for i, sample_data in enumerate(dataset): + sample_index = args.start_index + i + future = executor.submit( + process_single_sample, + sample_data, + sample_index, + args, + processor + ) + future_to_index[future] = sample_index + + # 等待所有任务完成 + for future in as_completed(future_to_index): + sample_index = future_to_index[future] + try: + _ = future.result() # 结果已经在process_single_sample中处理 + except Exception as e: + logging.error(f"线程执行异常 (样本 {sample_index}): {e}") + + except KeyboardInterrupt: + logging.warning("收到中断信号,正在停止处理...") + executor.shutdown(wait=False) + raise + + # 最终进度报告 + total_time = time.time() - processor.start_time + stats = processor.get_progress_stats() + + print_progress_report(processor, total_samples) + + # 构建最终结果摘要 + summary = { + 'total_samples': total_samples, + 'processed_samples': processor.processed_count, + 'successful_samples': processor.success_count, + 'failed_samples': processor.failed_count, + 'success_rate': stats['success_rate'], + 'total_execution_time': total_time, + 'average_time_per_sample': total_time / max(processor.processed_count, 1), + 'samples_per_minute': stats['samples_per_minute'], + 'failed_sample_details': processor.failed_samples, + 'processing_config': { + 'num_threads': args.num_threads, + 'model_type': args.model_type, + 'max_steps': args.max_steps, + 'dataset_range': f"[{args.start_index}, {args.start_index + len(dataset)})" + } + } + + return { + 'summary': summary, + 'results': processor.results, + 'agentops_session': agentops_session + } + +def generate_summary_report(batch_results: Dict[str, Any], + output_path: str) -> None: + """生成详细的执行摘要报告""" + summary = batch_results['summary'] + results = batch_results['results'] + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + + # 生成JSON格式的详细报告 + detailed_report = { + 'batch_execution_summary': summary, + 'sample_results': results, + 'generated_at': datetime.now().isoformat(), + 'report_version': '1.0' + } + + report_file = os.path.join(output_path, f'batch_report_{timestamp}.json') + + try: + with open(report_file, 'w', encoding='utf-8') as f: + json.dump(detailed_report, f, ensure_ascii=False, indent=2) + + logging.info(f"详细报告已保存: {report_file}") + + # 生成人类可读的摘要 + summary_file = os.path.join(output_path, f'batch_summary_{timestamp}.txt') + with open(summary_file, 'w', encoding='utf-8') as f: + f.write("AIM医疗问诊工作流批处理执行摘要\n") + f.write("=" * 50 + "\n\n") + + f.write(f"执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"总样本数: {summary['total_samples']}\n") + f.write(f"处理样本数: {summary['processed_samples']}\n") + f.write(f"成功样本数: {summary['successful_samples']}\n") + f.write(f"失败样本数: {summary['failed_samples']}\n") + f.write(f"成功率: {summary['success_rate']:.2%}\n") + f.write(f"总执行时间: {summary['total_execution_time']:.2f} 秒\n") + f.write(f"平均处理时间: {summary['average_time_per_sample']:.2f} 秒/样本\n") + f.write(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟\n\n") + + f.write("处理配置:\n") + for key, value in summary['processing_config'].items(): + f.write(f" {key}: {value}\n") + + if summary['failed_samples'] > 0: + f.write(f"\n失败样本详情:\n") + for failed in summary['failed_sample_details']: + f.write(f" 样本 {failed['sample_index']}: {failed['error']}\n") + + logging.info(f"摘要报告已保存: {summary_file}") + + except Exception as e: + logging.error(f"生成报告失败: {e}") + +def main(): + """主入口函数""" + # 解析参数 + args = parse_arguments() + + # 设置日志 + setup_logging(args.log_level) + + logging.info("=" * 60) + logging.info("AIM医疗问诊工作流批处理系统启动") + logging.info("=" * 60) + + try: + # 验证参数 + if args.num_threads <= 0: + raise ValueError("线程数必须大于0") + + if args.max_steps <= 0: + raise ValueError("最大步数必须大于0") + + # 试运行模式 + if args.dry_run: + logging.info("试运行模式:验证配置...") + dataset = load_dataset( + args.dataset_path, + args.start_index, + args.end_index, + min(args.sample_limit or 5, 5) # 试运行只验证前5个样本 + ) + logging.info(f"配置验证成功,将处理 {len(dataset)} 个样本") + return 0 + + # 加载数据集 + dataset = load_dataset( + args.dataset_path, + args.start_index, + args.end_index, + args.sample_limit + ) + + if len(dataset) == 0: + logging.warning("没有样本需要处理") + return 0 + + # 初始化AgentOps + agentops_session = initialize_agentops( + args.agentops_api_key, + args.disable_agentops + ) + + # 执行批处理 + logging.info("开始批量处理...") + batch_results = run_workflow_batch(dataset, args, agentops_session) + + # 生成报告 + generate_summary_report(batch_results, args.output_dir) + + # 关闭AgentOps会话 + if agentops_session and agno: + try: + agno.end_session('Success') + logging.info("AgentOps会话已结束") + except Exception as e: + logging.error(f"关闭AgentOps会话失败: {e}") + + # 输出最终统计 + summary = batch_results['summary'] + logging.info("=" * 60) + logging.info("批处理执行完成!") + logging.info(f"成功率: {summary['success_rate']:.2%} ({summary['successful_samples']}/{summary['total_samples']})") + logging.info(f"总耗时: {summary['total_execution_time']:.2f} 秒") + logging.info(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟") + logging.info("=" * 60) + + return 0 if summary['success_rate'] > 0.8 else 1 + + except KeyboardInterrupt: + logging.warning("程序被用户中断") + return 1 + except Exception as e: + logging.error(f"程序执行失败: {e}") + return 1 if __name__ == "__main__": - main() + exit_code = main() + sys.exit(exit_code) \ No newline at end of file