#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 医疗工作流数据分析脚本 用于分析病例完成triage、hpi、ph三个阶段所需的step数量 """ import json import os from collections import defaultdict import matplotlib.pyplot as plt from typing import Dict, List from file_filter_utils import filter_complete_files, print_filter_summary class MedicalWorkflowAnalyzer: """医疗工作流数据分析器""" def __init__(self, results_dir: str = "results", output_dir: str = "analysis/0902"): """ 初始化分析器 Args: results_dir: 结果文件目录路径(包含输入数据) output_dir: 图片输出目录路径 """ self.results_dir = results_dir self.output_dir = output_dir self.workflow_data = [] self.step_statistics = defaultdict(int) def load_workflow_data(self) -> None: """加载所有工作流数据文件""" if not os.path.exists(self.results_dir): print(f"结果目录不存在: {self.results_dir}") return # 获取所有jsonl文件 all_files = [os.path.join(self.results_dir, f) for f in os.listdir(self.results_dir) if f.endswith('.jsonl')] # 过滤出完成的文件 filtered_files = filter_complete_files(all_files, self.output_dir) print_filter_summary(self.output_dir) print(f"找到 {len(all_files)} 个数据文件,将处理 {len(filtered_files)} 个完成的文件") for filepath in sorted(filtered_files): filename = os.path.basename(filepath) filepath = os.path.join(self.results_dir, filename) try: with open(filepath, 'r', encoding='utf-8') as f: case_data = [] for line_num, line in enumerate(f, 1): line = line.strip() if line: try: data = json.loads(line) case_data.append(data) except json.JSONDecodeError as e: print(f"文件 {filename} 第{line_num}行解析失败: {e}") continue if case_data: self.workflow_data.append({ 'filename': filename, 'data': case_data }) except Exception as e: print(f"读取文件 {filename} 失败: {e}") print(f"成功加载 {len(self.workflow_data)} 个病例的数据") def analyze_workflow_steps(self) -> Dict[str, List[int]]: """ 分析每个病例完成triage、hpi、ph三个阶段所需的step数量 Returns: Dict包含每个阶段所需的step数量列表 """ stage_steps = { 'triage': [], 'hpi': [], 'ph': [], 'final_step': [] } case_count = 0 for case_info in self.workflow_data: case_data = case_info['data'] # 按阶段分组step triage_steps = set() hpi_steps = set() ph_steps = set() all_steps = set() for entry in case_data: if entry.get('event_type') == 'step_start' and 'current_phase' in entry: step_num = entry.get('step_number', 0) phase = entry.get('current_phase', '').lower() all_steps.add(step_num) if phase == 'triage': triage_steps.add(step_num) elif phase == 'hpi': hpi_steps.add(step_num) elif phase == 'ph': ph_steps.add(step_num) # 计算每个阶段的step数量 triage_count = len(triage_steps) hpi_count = len(hpi_steps) ph_count = len(ph_steps) final_step = max(all_steps) if all_steps else 0 # 只添加有数据的阶段 if triage_count > 0: stage_steps['triage'].append(triage_count) if hpi_count > 0: stage_steps['hpi'].append(hpi_count) if ph_count > 0: stage_steps['ph'].append(ph_count) if final_step > 0: stage_steps['final_step'].append(final_step) case_count += 1 print(f"成功分析 {case_count} 个病例") return stage_steps def generate_stage_statistics(self, stage_steps: Dict[str, List[int]]) -> Dict[str, Dict[int, int]]: """ 为每个阶段生成step数量统计 Args: stage_steps: 各阶段的step数量 Returns: Dict: 每个阶段的step数量统计 """ stage_stats = {} for stage, steps in stage_steps.items(): if steps: stats = defaultdict(int) for step_count in steps: stats[step_count] += 1 stage_stats[stage] = dict(stats) return stage_stats def plot_step_distribution_subplots(self, stage_stats: Dict[str, Dict[int, int]], output_file: str = "step_distribution_subplots.png") -> None: """ 绘制四个子图的step数量分布柱形图 Args: stage_stats: 各阶段的step数量统计数据 output_file: 输出图片文件名 """ if not stage_stats: print("没有数据可供绘制") return # 设置英文显示 plt.rcParams['font.family'] = 'DejaVu Sans' plt.rcParams['axes.unicode_minus'] = False # 创建四个子图 fig, axes = plt.subplots(2, 2, figsize=(16, 12)) fig.suptitle('Medical Workflow Step Distribution Analysis', fontsize=16, fontweight='bold') # 子图标题映射 subplot_titles = { 'triage': 'TRIAGE Phase', 'hpi': 'HPI Phase', 'ph': 'PH Phase', 'final_step': 'Total Steps' } # 绘制每个阶段的子图 positions = [(0, 0), (0, 1), (1, 0), (1, 1)] stages_order = ['triage', 'hpi', 'ph', 'final_step'] for stage, (row, col) in zip(stages_order, positions): ax = axes[row, col] if stage in stage_stats and stage_stats[stage]: steps = sorted(stage_stats[stage].keys()) counts = [stage_stats[stage][step] for step in steps] # 绘制柱形图 bars = ax.bar(steps, counts, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'][stages_order.index(stage) % 4], alpha=0.7, edgecolor='black', linewidth=0.5) # 在柱形上标注数值 for bar, count in zip(bars, counts): height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2., height + max(counts)*0.01, f'{count}', ha='center', va='bottom', fontsize=9, fontweight='bold') # 设置子图属性 ax.set_title(f'{subplot_titles[stage]}\n(n={sum(counts)})', fontsize=12, fontweight='bold') ax.set_xlabel('Number of Steps', fontsize=10) ax.set_ylabel('Number of Cases', fontsize=10) ax.grid(True, alpha=0.3, linestyle='--') # 设置x轴刻度 if steps: ax.set_xticks(steps) ax.set_xticklabels(steps, rotation=45) # 添加统计信息文本 if counts: mean_val = sum(s*c for s, c in zip(steps, counts)) / sum(counts) max_val = max(steps) min_val = min(steps) stats_text = f'Mean: {mean_val:.1f}\nRange: {min_val}-{max_val}' ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, fontsize=9, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) else: ax.text(0.5, 0.5, 'No Data Available', ha='center', va='center', transform=ax.transAxes, fontsize=12) ax.set_title(f'{subplot_titles[stage]}\n(n=0)', fontsize=12, fontweight='bold') # 调整布局 plt.tight_layout() # 确保输出目录存在 os.makedirs(self.output_dir, exist_ok=True) # 保存图形 output_path = os.path.join(self.output_dir, output_file) plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white') plt.close() print(f"Four-subplot chart saved to: {output_path}") def print_statistics_summary(self, stage_steps: Dict[str, List[int]]) -> None: """打印统计摘要""" print("\n=== Medical Workflow Step Statistics Summary ===") # 英文阶段名称映射 stage_names = { 'triage': 'TRIAGE Phase', 'hpi': 'HPI Phase', 'ph': 'PH Phase', 'final_step': 'Total Steps' } for stage, steps in stage_steps.items(): stage_name = stage_names.get(stage, stage.upper()) if steps: print(f"\n{stage_name}:") print(f" Total Cases: {len(steps)}") print(f" Mean Steps: {sum(steps)/len(steps):.2f}") print(f" Min Steps: {min(steps)}") print(f" Max Steps: {max(steps)}") print(f" Step Distribution: {dict(sorted({s: steps.count(s) for s in set(steps)}.items()))}") else: print(f"\n{stage_name}: No Data") def run_analysis(self) -> None: """运行完整的数据分析流程""" print("Starting medical workflow data analysis...") # 1. Load data self.load_workflow_data() if not self.workflow_data: print("No data available for analysis") return # 2. Analyze step counts stage_steps = self.analyze_workflow_steps() # 3. Generate stage statistics stage_stats = self.generate_stage_statistics(stage_steps) # 4. Print summary self.print_statistics_summary(stage_steps) # 5. Generate subplots self.plot_step_distribution_subplots(stage_stats) print("Data analysis completed successfully!") def main(): """主函数""" import sys # 从命令行参数获取路径,如果没有提供则使用默认值 if len(sys.argv) >= 3: results_dir = sys.argv[1] output_dir = sys.argv[2] else: results_dir = "results/results0902" output_dir = "analysis/0902" # 创建分析器实例 analyzer = MedicalWorkflowAnalyzer(results_dir=results_dir, output_dir=output_dir) # 运行分析 analyzer.run_analysis() if __name__ == "__main__": main()