#!/usr/bin/env python3 """ 提取分诊错误的病例详细信息 """ import json import os from pathlib import Path from typing import List, Dict, Tuple from file_filter_utils import filter_complete_files, print_filter_summary def load_workflow_data(data_dir: str, output_dir: str = "", limit: int = 5000) -> List[Dict]: """加载工作流数据""" workflow_data = [] # 获取所有jsonl文件 all_files = sorted(Path(data_dir).glob("*.jsonl")) # 过滤出完成的文件 if output_dir: all_files = [str(f) for f in all_files] filtered_files = filter_complete_files(all_files, output_dir) filtered_files = [Path(f) for f in filtered_files] print_filter_summary(output_dir) else: filtered_files = all_files # 限制文件数量 jsonl_files = filtered_files[:limit] for file_path in jsonl_files: try: with open(file_path, 'r', encoding='utf-8') as f: workflow = [] for line in f: try: data = json.loads(line.strip()) workflow.append(data) except json.JSONDecodeError: continue if workflow: workflow_data.append(workflow) except Exception as e: print(f"加载文件 {file_path} 时出错: {e}") return workflow_data def extract_triage_steps(workflow: List[Dict]) -> List[Dict]: """提取分诊步骤""" triage_steps = [] for step in workflow: if step.get('agent_name') == 'triager' and 'output_data' in step: triage_steps.append(step) return triage_steps def extract_error_cases(workflow_data: List[List[Dict]]) -> List[Dict]: """提取错误的病例""" error_cases = [] for index, workflow in enumerate(workflow_data): triage_steps = extract_triage_steps(workflow) if not triage_steps: continue # 获取标准答案 standard_answer = None for step in workflow: if step.get('event_type') == 'workflow_start' and 'case_data' in step: case_data = step['case_data'] standard_answer = { '一级科室': case_data.get('一级科室'), '二级科室': case_data.get('二级科室') } break if not standard_answer: continue # 获取最终分诊结果 final_step = triage_steps[-1] final_output = final_step.get('output_data', {}) predicted_level1 = final_output.get('primary_department') predicted_level2 = final_output.get('secondary_department') # 检查一级科室是否正确 level1_correct = predicted_level1 == standard_answer['一级科室'] level2_correct = predicted_level2 == standard_answer['二级科室'] if not level1_correct or not level2_correct: # 提取文件名中的病例ID case_id = None for step in workflow: if step.get('event_type') == 'workflow_start': # 从文件名提取病例编号 for file_step in workflow: if 'timestamp' in str(file_step): # 从workflow文件名提取 break break case_info = { 'case_index': index, 'case_id': f"case_{index:04d}", 'expected_level1': standard_answer['一级科室'], 'expected_level2': standard_answer['二级科室'], 'predicted_level1': predicted_level1, 'predicted_level2': predicted_level2, 'level1_correct': level1_correct, 'level2_correct': level2_correct, 'triage_reasoning': final_output.get('triage_reasoning', ''), 'case_introduction': None } # 获取病案介绍 for step in workflow: if step.get('event_type') == 'workflow_start' and 'case_data' in step: case_data = step['case_data'] if '病案介绍' in case_data: case_info['case_introduction'] = case_data['病案介绍'] break error_cases.append(case_info) return error_cases def save_error_analysis(error_cases: List[Dict], output_dir: str): """保存错误分析结果""" # 按错误类型分类 level1_errors = [case for case in error_cases if not case['level1_correct']] level2_errors = [case for case in error_cases if case['level1_correct'] and not case['level2_correct']] # 保存所有错误病例 with open(os.path.join(output_dir, 'error_cases_detailed.json'), 'w', encoding='utf-8') as f: json.dump(error_cases, f, ensure_ascii=False, indent=2) # 保存一级科室错误 with open(os.path.join(output_dir, 'level1_errors.json'), 'w', encoding='utf-8') as f: json.dump(level1_errors, f, ensure_ascii=False, indent=2) # 保存二级科室错误 with open(os.path.join(output_dir, 'level2_errors.json'), 'w', encoding='utf-8') as f: json.dump(level2_errors, f, ensure_ascii=False, indent=2) # 生成CSV格式的错误摘要 csv_lines = [ "病例索引,病例ID,期望一级科室,预测一级科室,一级是否正确,期望二级科室,预测二级科室,二级是否正确,分诊理由" ] for case in error_cases: csv_line = f"{case['case_index']},{case['case_id']},{case['expected_level1']},{case['predicted_level1']},{case['level1_correct']},{case['expected_level2']},{case['predicted_level2']},{case['level2_correct']},\"{case['triage_reasoning'][:100]}...\"" csv_lines.append(csv_line) with open(os.path.join(output_dir, 'error_cases_summary.csv'), 'w', encoding='utf-8') as f: f.write('\n'.join(csv_lines)) def main(): """主函数""" import sys # 从命令行参数获取路径,如果没有提供则使用默认值 if len(sys.argv) >= 3: data_dir = Path(sys.argv[1]) output_dir = Path(sys.argv[2]) else: base_dir = Path(__file__).parent.parent data_dir = base_dir / "results" / "results0902" output_dir = base_dir / "analysis" / "0902" output_dir.mkdir(parents=True, exist_ok=True) print(f"正在加载数据从: {data_dir}") workflow_data = load_workflow_data(str(data_dir), str(output_dir), limit=5000) print(f"成功加载 {len(workflow_data)} 个病例数据") print("正在提取错误病例...") error_cases = extract_error_cases(workflow_data) print(f"发现 {len(error_cases)} 个错误病例") # 统计错误类型 level1_errors = [case for case in error_cases if not case['level1_correct']] level2_errors = [case for case in error_cases if not case['level2_correct']] print(f"一级科室错误: {len(level1_errors)} 个") print(f"二级科室错误: {len(level2_errors)} 个") print("一级科室错误示例:") for case in level1_errors[:5]: print(f" 病例 {case['case_index']}: 期望={case['expected_level1']}, 预测={case['predicted_level1']}") print("二级科室错误示例:") for case in level2_errors[:5]: print(f" 病例 {case['case_index']}: 期望={case['expected_level2']}, 预测={case['predicted_level2']}") print("正在保存错误分析结果...") save_error_analysis(error_cases, str(output_dir)) print(f"错误分析完成!结果已保存到: {output_dir}") if __name__ == "__main__": main()