triage/main.py

657 lines
24 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIM医疗问诊工作流批处理系统
使用多线程并行处理数据集中的所有病例样本
"""
import argparse
import json
import logging
import os
import sys
import time
import threading
import glob
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, Any, List, Optional
# 导入本地模块
from workflow import MedicalWorkflow
from config import LLM_CONFIG
class BatchProcessor:
"""批处理管理器,负责协调多线程执行和状态管理"""
def __init__(self, num_threads: int = 20):
self.num_threads = num_threads
self.lock = threading.Lock() # 线程安全锁
self.processed_count = 0 # 已处理样本数
self.success_count = 0 # 成功处理数
self.failed_count = 0 # 失败处理数
self.skipped_count = 0 # 跳过的样本数
self.results = [] # 结果列表
self.failed_samples = [] # 失败样本列表
self.start_time = None # 开始时间
def update_progress(self, success: bool, result: Dict[str, Any] = None,
error: Exception = None, sample_index: int = None):
"""线程安全地更新处理进度"""
with self.lock:
self.processed_count += 1
if success:
self.success_count += 1
if result:
self.results.append(result)
else:
self.failed_count += 1
if error and sample_index is not None:
self.failed_samples.append({
'sample_index': sample_index,
'error': str(error),
'timestamp': datetime.now().isoformat()
})
def update_skipped(self, sample_index: int):
"""线程安全地更新跳过样本计数"""
with self.lock:
self.skipped_count += 1
logging.info(f"样本 {sample_index} 已完成,跳过处理")
def get_progress_stats(self) -> Dict[str, Any]:
"""获取当前进度统计"""
with self.lock:
elapsed_time = time.time() - self.start_time if self.start_time else 0
return {
'processed': self.processed_count,
'success': self.success_count,
'failed': self.failed_count,
'skipped': self.skipped_count,
'success_rate': self.success_count / max(self.processed_count, 1),
'elapsed_time': elapsed_time,
'samples_per_minute': self.processed_count / max(elapsed_time / 60, 0.01)
}
def setup_logging(log_level: str = "INFO") -> None:
"""设置日志记录配置"""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f'batch_processing_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
]
)
def parse_arguments() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="AIM医疗问诊工作流批处理系统",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# 数据和输出配置
parser.add_argument(
'--dataset-path',
type=str,
default='dataset/bbb.json',
help='数据集JSON文件路径'
)
parser.add_argument(
'--log-dir',
type=str,
default='results/results09010-score_driven',
help='日志文件保存目录'
)
parser.add_argument(
'--output-dir',
type=str,
default='batch_results',
help='批处理结果保存目录'
)
# 执行参数
parser.add_argument(
'--num-threads',
type=int,
default=45,
help='并行处理线程数'
)
parser.add_argument(
'--max-steps',
type=int,
default=30,
help='每个工作流的最大执行步数'
)
parser.add_argument(
'--start-index',
type=int,
default=0,
help='开始处理的样本索引'
)
parser.add_argument(
'--end-index',
type=int,
default=5000,
help='结束处理的样本索引(不包含)'
)
parser.add_argument(
'--sample-limit',
type=int,
default=None,
help='限制处理的样本数量(用于测试)'
)
# 模型配置
available_models = list(LLM_CONFIG.keys())
parser.add_argument(
'--model-type',
type=str,
choices=available_models,
default='openai-mirror/gpt-oss-20b',
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
)
parser.add_argument(
'--list-models',
action='store_true',
help='显示所有可用的模型配置并退出'
)
parser.add_argument(
'--model-config',
type=str,
default=None,
help='模型配置JSON字符串可选覆盖默认配置'
)
parser.add_argument(
'--controller-mode',
type=str,
choices=['normal', 'sequence', 'score_driven'],
default='score_driven',
help='任务控制器模式normal为智能模式需要LLM推理sequence为顺序模式直接选择第一个任务score_driven为分数驱动模式选择当前任务组中分数最低的任务'
)
# 调试和日志
parser.add_argument(
'--log-level',
type=str,
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
default='INFO',
help='日志记录级别'
)
parser.add_argument(
'--progress-interval',
type=int,
default=10,
help='进度报告间隔(秒)'
)
parser.add_argument(
'--dry-run',
action='store_true',
help='试运行模式,只验证配置不执行处理'
)
return parser.parse_args()
def is_case_completed(log_dir: str, case_index: int) -> bool:
"""
检查指定case是否已经完成工作流
如果存在不完整的文件则删除确保每个case在目录中只出现一次
Args:
log_dir: 日志目录
case_index: case序号
Returns:
bool: 如果case已完成返回True否则返回False
"""
# 构建文件路径模式workflow_*_case_{case_index:04d}.jsonl
pattern = os.path.join(log_dir, f"workflow_*_case_{case_index:04d}.jsonl")
matching_files = glob.glob(pattern)
if not matching_files:
return False
# 应该只有一个匹配的文件
if len(matching_files) > 1:
logging.warning(f"发现多个匹配文件 case {case_index}: {matching_files}")
# 检查每个匹配的文件
for log_file in matching_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
# 读取最后一行
lines = f.readlines()
if not lines:
# 文件为空,删除
os.remove(log_file)
logging.info(f"删除空文件: {log_file}")
continue
last_line = lines[-1].strip()
if not last_line:
# 最后一行为空,删除
os.remove(log_file)
logging.info(f"删除最后一行为空的文件: {log_file}")
continue
# 解析最后一行的JSON
try:
last_entry = json.loads(last_line)
if last_entry.get("event_type") == "workflow_complete":
# 找到完整的文件
logging.info(f"发现已完成的case {case_index}: {log_file}")
return True
else:
# 文件不完整,删除
os.remove(log_file)
logging.info(f"删除不完整的文件: {log_file}")
continue
except json.JSONDecodeError:
# JSON解析失败删除文件
os.remove(log_file)
logging.info(f"删除JSON格式错误的文件: {log_file}")
continue
except Exception as e:
logging.warning(f"检查文件 {log_file} 时出错: {e}")
# 出现异常也删除文件,避免后续问题
try:
os.remove(log_file)
logging.info(f"删除异常文件: {log_file}")
except:
pass
continue
# 所有匹配的文件都被删除或没有完整的文件
return False
def load_dataset(dataset_path: str, start_index: int = 0,
end_index: Optional[int] = None,
sample_limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""加载和验证数据集"""
logging.info(f"正在加载数据集: {dataset_path}")
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"数据集文件不存在: {dataset_path}")
try:
with open(dataset_path, 'r', encoding='utf-8') as f:
full_dataset = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"数据集JSON格式错误: {e}")
except Exception as e:
raise Exception(f"加载数据集失败: {e}")
if not isinstance(full_dataset, list):
raise ValueError("数据集应该是包含病例的JSON数组")
total_samples = len(full_dataset)
logging.info(f"数据集总样本数: {total_samples}")
# 确定处理范围
if end_index is None:
end_index = total_samples
end_index = min(end_index, total_samples)
start_index = max(0, start_index)
if sample_limit:
end_index = min(start_index + sample_limit, end_index)
if start_index >= end_index:
raise ValueError(f"无效的索引范围: start_index={start_index}, end_index={end_index}")
# 提取指定范围的数据
dataset = full_dataset[start_index:end_index]
logging.info(f"将处理样本范围: [{start_index}, {end_index}), 共 {len(dataset)} 个样本")
# 验证数据格式
for i, sample in enumerate(dataset[:5]): # 只验证前5个样本
if not isinstance(sample, dict):
raise ValueError(f"样本 {start_index + i} 格式错误,应为字典类型")
required_keys = ['病案介绍']
for key in required_keys:
if key not in sample:
logging.warning(f"样本 {start_index + i} 缺少必需字段: {key}")
return dataset
def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
args: argparse.Namespace,
processor: BatchProcessor) -> Dict[str, Any]:
"""处理单个样本的工作函数"""
thread_id = threading.current_thread().ident
start_time = time.time()
try:
# 使用 LLM_CONFIG 作为基础配置
# BaseAgent 会根据 model_type 自动选择正确的模型配置
llm_config = LLM_CONFIG.copy()
# 如果用户提供了额外的模型配置,则合并到对应的模型配置中
if args.model_config:
try:
user_config = json.loads(args.model_config)
# 更新选定模型的配置
if args.model_type in llm_config:
llm_config[args.model_type]["params"].update(user_config.get("params", {}))
else:
logging.warning(f"样本 {sample_index}: 模型类型 {args.model_type} 不存在,忽略用户配置")
except json.JSONDecodeError:
logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误使用默认配置")
# 创建工作流实例
workflow = MedicalWorkflow(
case_data=sample_data,
model_type=args.model_type,
llm_config=llm_config,
max_steps=args.max_steps,
log_dir=args.log_dir,
case_index=sample_index,
controller_mode=args.controller_mode
)
# 执行工作流
logging.debug(f"线程 {thread_id}: 开始处理样本 {sample_index}")
log_file_path = workflow.run()
execution_time = time.time() - start_time
# 获取执行结果
workflow_status = workflow.get_current_status()
medical_summary = workflow.get_medical_summary()
# 构建结果
result = {
'sample_index': sample_index,
'thread_id': thread_id,
'execution_time': execution_time,
'log_file_path': log_file_path,
'workflow_status': workflow_status,
'medical_summary': medical_summary,
'processed_at': datetime.now().isoformat()
}
# 更新进度
processor.update_progress(success=True, result=result)
logging.info(f"样本 {sample_index} 处理完成 (耗时: {execution_time:.2f}s, "
f"步数: {workflow_status['current_step']}, "
f"成功: {workflow_status['workflow_success']})")
return result
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"样本 {sample_index} 处理失败: {str(e)}"
logging.error(error_msg)
processor.update_progress(success=False, error=e, sample_index=sample_index)
# 返回错误结果
return {
'sample_index': sample_index,
'thread_id': thread_id,
'execution_time': execution_time,
'error': str(e),
'processed_at': datetime.now().isoformat(),
'success': False
}
def print_progress_report(processor: BatchProcessor, total_samples: int):
"""打印进度报告"""
stats = processor.get_progress_stats()
print(f"\n=== 处理进度报告 ===")
print(f"已处理: {stats['processed']}/{total_samples} ({stats['processed']/total_samples:.1%})")
print(f"成功: {stats['success']} | 失败: {stats['failed']} | 跳过: {stats['skipped']} | 成功率: {stats['success_rate']:.1%}")
print(f"耗时: {stats['elapsed_time']:.1f}s | 处理速度: {stats['samples_per_minute']:.1f} 样本/分钟")
remaining_samples = total_samples - stats['processed'] - stats['skipped']
print(f"预计剩余时间: {remaining_samples / max(stats['samples_per_minute'] / 60, 0.01):.1f}s")
print("=" * 50)
def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]:
"""执行批量工作流处理"""
total_samples = len(dataset)
logging.info(f"开始批量处理 {total_samples} 个样本,使用 {args.num_threads} 个线程")
# 创建批处理管理器
processor = BatchProcessor(num_threads=args.num_threads)
processor.start_time = time.time()
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
# 启动进度监控线程
def progress_monitor():
while processor.processed_count + processor.skipped_count < total_samples:
time.sleep(args.progress_interval)
if processor.processed_count + processor.skipped_count < total_samples:
print_progress_report(processor, total_samples)
progress_thread = threading.Thread(target=progress_monitor, daemon=True)
progress_thread.start()
try:
# 使用线程池执行批处理
with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
# 提交所有任务
future_to_index = {}
for i, sample_data in enumerate(dataset):
sample_index = args.start_index + i
# 检查case是否已经完成
if is_case_completed(args.log_dir, sample_index):
processor.update_skipped(sample_index)
continue
future = executor.submit(
process_single_sample,
sample_data,
sample_index,
args,
processor
)
future_to_index[future] = sample_index
# 等待所有任务完成
for future in as_completed(future_to_index):
sample_index = future_to_index[future]
try:
_ = future.result() # 结果已经在process_single_sample中处理
except Exception as e:
logging.error(f"线程执行异常 (样本 {sample_index}): {e}")
except KeyboardInterrupt:
logging.warning("收到中断信号,正在停止处理...")
executor.shutdown(wait=False)
raise
# 最终进度报告
total_time = time.time() - processor.start_time
stats = processor.get_progress_stats()
print_progress_report(processor, total_samples)
# 构建最终结果摘要
summary = {
'total_samples': total_samples,
'processed_samples': processor.processed_count,
'successful_samples': processor.success_count,
'failed_samples': processor.failed_count,
'skipped_samples': processor.skipped_count,
'success_rate': stats['success_rate'],
'total_execution_time': total_time,
'average_time_per_sample': total_time / max(processor.processed_count, 1),
'samples_per_minute': stats['samples_per_minute'],
'failed_sample_details': processor.failed_samples,
'processing_config': {
'num_threads': args.num_threads,
'model_type': args.model_type,
'max_steps': args.max_steps,
'dataset_range': f"[{args.start_index}, {args.start_index + len(dataset)})"
}
}
return {
'summary': summary,
'results': processor.results
}
def generate_summary_report(batch_results: Dict[str, Any],
output_path: str) -> None:
"""生成详细的执行摘要报告"""
summary = batch_results['summary']
results = batch_results['results']
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 生成JSON格式的详细报告
detailed_report = {
'batch_execution_summary': summary,
'sample_results': results,
'generated_at': datetime.now().isoformat(),
'report_version': '1.0'
}
report_file = os.path.join(output_path, f'batch_report_{timestamp}.json')
try:
with open(report_file, 'w', encoding='utf-8') as f:
json.dump(detailed_report, f, ensure_ascii=False, indent=2)
logging.info(f"详细报告已保存: {report_file}")
# 生成人类可读的摘要
summary_file = os.path.join(output_path, f'batch_summary_{timestamp}.txt')
with open(summary_file, 'w', encoding='utf-8') as f:
f.write("AIM医疗问诊工作流批处理执行摘要\n")
f.write("=" * 50 + "\n\n")
f.write(f"执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总样本数: {summary['total_samples']}\n")
f.write(f"处理样本数: {summary['processed_samples']}\n")
f.write(f"成功样本数: {summary['successful_samples']}\n")
f.write(f"失败样本数: {summary['failed_samples']}\n")
f.write(f"跳过样本数: {summary['skipped_samples']}\n")
f.write(f"成功率: {summary['success_rate']:.2%}\n")
f.write(f"总执行时间: {summary['total_execution_time']:.2f}\n")
f.write(f"平均处理时间: {summary['average_time_per_sample']:.2f} 秒/样本\n")
f.write(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟\n\n")
f.write("处理配置:\n")
for key, value in summary['processing_config'].items():
f.write(f" {key}: {value}\n")
if summary['failed_samples'] > 0:
f.write(f"\n失败样本详情:\n")
for failed in summary['failed_sample_details']:
f.write(f" 样本 {failed['sample_index']}: {failed['error']}\n")
logging.info(f"摘要报告已保存: {summary_file}")
except Exception as e:
logging.error(f"生成报告失败: {e}")
def main():
"""主入口函数"""
# 解析参数
args = parse_arguments()
# 处理 --list-models 参数
if args.list_models:
print("可用的语言模型配置:")
print("=" * 50)
for model_name, config in LLM_CONFIG.items():
print(f"模型名称: {model_name}")
print(f" 类别: {config['class']}")
print(f" 模型ID: {config['params']['id']}")
print(f" API端点: {config['params']['base_url']}")
print("-" * 30)
return 0
# 设置日志
setup_logging(args.log_level)
logging.info("=" * 60)
logging.info("AIM医疗问诊工作流批处理系统启动")
logging.info("=" * 60)
try:
# 验证参数
if args.num_threads <= 0:
raise ValueError("线程数必须大于0")
if args.max_steps <= 0:
raise ValueError("最大步数必须大于0")
# 验证模型类型
if args.model_type not in LLM_CONFIG:
available_models = ', '.join(LLM_CONFIG.keys())
raise ValueError(f"不支持的模型类型: {args.model_type},可用模型: {available_models}")
logging.info(f"使用模型: {args.model_type} ({LLM_CONFIG[args.model_type]['class']})")
# 试运行模式
if args.dry_run:
logging.info("试运行模式:验证配置...")
dataset = load_dataset(
args.dataset_path,
args.start_index,
args.end_index,
min(args.sample_limit or 5, 5) # 试运行只验证前5个样本
)
logging.info(f"配置验证成功,将处理 {len(dataset)} 个样本")
return 0
# 加载数据集
dataset = load_dataset(
args.dataset_path,
args.start_index,
args.end_index,
args.sample_limit
)
if len(dataset) == 0:
logging.warning("没有样本需要处理")
return 0
# 执行批处理
logging.info("开始批量处理...")
batch_results = run_workflow_batch(dataset, args)
# 生成报告
generate_summary_report(batch_results, args.output_dir)
# 输出最终统计
summary = batch_results['summary']
logging.info("=" * 60)
logging.info("批处理执行完成!")
logging.info(f"成功率: {summary['success_rate']:.2%} ({summary['successful_samples']}/{summary['total_samples']})")
logging.info(f"总耗时: {summary['total_execution_time']:.2f}")
logging.info(f"处理速度: {summary['samples_per_minute']:.2f} 样本/分钟")
logging.info("=" * 60)
return 0 if summary['success_rate'] > 0.8 else 1
except KeyboardInterrupt:
logging.warning("程序被用户中断")
return 1
except Exception as e:
logging.error(f"程序执行失败: {e}")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)