211 lines
7.7 KiB
Python
211 lines
7.7 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
提取分诊错误的病例详细信息
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import List, Dict, Tuple
|
||
|
|
from file_filter_utils import filter_complete_files, print_filter_summary
|
||
|
|
|
||
|
|
|
||
|
|
def load_workflow_data(data_dir: str, output_dir: str = "", limit: int = 5000) -> List[Dict]:
|
||
|
|
"""加载工作流数据"""
|
||
|
|
workflow_data = []
|
||
|
|
|
||
|
|
# 获取所有jsonl文件
|
||
|
|
all_files = sorted(Path(data_dir).glob("*.jsonl"))
|
||
|
|
|
||
|
|
# 过滤出完成的文件
|
||
|
|
if output_dir:
|
||
|
|
all_files = [str(f) for f in all_files]
|
||
|
|
filtered_files = filter_complete_files(all_files, output_dir)
|
||
|
|
filtered_files = [Path(f) for f in filtered_files]
|
||
|
|
print_filter_summary(output_dir)
|
||
|
|
else:
|
||
|
|
filtered_files = all_files
|
||
|
|
|
||
|
|
# 限制文件数量
|
||
|
|
jsonl_files = filtered_files[:limit]
|
||
|
|
|
||
|
|
for file_path in jsonl_files:
|
||
|
|
try:
|
||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||
|
|
workflow = []
|
||
|
|
for line in f:
|
||
|
|
try:
|
||
|
|
data = json.loads(line.strip())
|
||
|
|
workflow.append(data)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if workflow:
|
||
|
|
workflow_data.append(workflow)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"加载文件 {file_path} 时出错: {e}")
|
||
|
|
|
||
|
|
return workflow_data
|
||
|
|
|
||
|
|
|
||
|
|
def extract_triage_steps(workflow: List[Dict]) -> List[Dict]:
|
||
|
|
"""提取分诊步骤"""
|
||
|
|
triage_steps = []
|
||
|
|
for step in workflow:
|
||
|
|
if step.get('agent_name') == 'triager' and 'output_data' in step:
|
||
|
|
triage_steps.append(step)
|
||
|
|
return triage_steps
|
||
|
|
|
||
|
|
|
||
|
|
def extract_error_cases(workflow_data: List[List[Dict]]) -> List[Dict]:
|
||
|
|
"""提取错误的病例"""
|
||
|
|
error_cases = []
|
||
|
|
|
||
|
|
for index, workflow in enumerate(workflow_data):
|
||
|
|
triage_steps = extract_triage_steps(workflow)
|
||
|
|
|
||
|
|
if not triage_steps:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# 获取标准答案
|
||
|
|
standard_answer = None
|
||
|
|
for step in workflow:
|
||
|
|
if step.get('event_type') == 'workflow_start' and 'case_data' in step:
|
||
|
|
case_data = step['case_data']
|
||
|
|
standard_answer = {
|
||
|
|
'一级科室': case_data.get('一级科室'),
|
||
|
|
'二级科室': case_data.get('二级科室')
|
||
|
|
}
|
||
|
|
break
|
||
|
|
|
||
|
|
if not standard_answer:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# 获取最终分诊结果
|
||
|
|
final_step = triage_steps[-1]
|
||
|
|
final_output = final_step.get('output_data', {})
|
||
|
|
|
||
|
|
predicted_level1 = final_output.get('primary_department')
|
||
|
|
predicted_level2 = final_output.get('secondary_department')
|
||
|
|
|
||
|
|
# 检查一级科室是否正确
|
||
|
|
level1_correct = predicted_level1 == standard_answer['一级科室']
|
||
|
|
level2_correct = predicted_level2 == standard_answer['二级科室']
|
||
|
|
|
||
|
|
if not level1_correct or not level2_correct:
|
||
|
|
# 提取文件名中的病例ID
|
||
|
|
case_id = None
|
||
|
|
for step in workflow:
|
||
|
|
if step.get('event_type') == 'workflow_start':
|
||
|
|
# 从文件名提取病例编号
|
||
|
|
for file_step in workflow:
|
||
|
|
if 'timestamp' in str(file_step):
|
||
|
|
# 从workflow文件名提取
|
||
|
|
break
|
||
|
|
break
|
||
|
|
|
||
|
|
case_info = {
|
||
|
|
'case_index': index,
|
||
|
|
'case_id': f"case_{index:04d}",
|
||
|
|
'expected_level1': standard_answer['一级科室'],
|
||
|
|
'expected_level2': standard_answer['二级科室'],
|
||
|
|
'predicted_level1': predicted_level1,
|
||
|
|
'predicted_level2': predicted_level2,
|
||
|
|
'level1_correct': level1_correct,
|
||
|
|
'level2_correct': level2_correct,
|
||
|
|
'triage_reasoning': final_output.get('triage_reasoning', ''),
|
||
|
|
'case_introduction': None
|
||
|
|
}
|
||
|
|
|
||
|
|
# 获取病案介绍
|
||
|
|
for step in workflow:
|
||
|
|
if step.get('event_type') == 'workflow_start' and 'case_data' in step:
|
||
|
|
case_data = step['case_data']
|
||
|
|
if '病案介绍' in case_data:
|
||
|
|
case_info['case_introduction'] = case_data['病案介绍']
|
||
|
|
break
|
||
|
|
|
||
|
|
error_cases.append(case_info)
|
||
|
|
|
||
|
|
return error_cases
|
||
|
|
|
||
|
|
|
||
|
|
def save_error_analysis(error_cases: List[Dict], output_dir: str):
|
||
|
|
"""保存错误分析结果"""
|
||
|
|
|
||
|
|
# 按错误类型分类
|
||
|
|
level1_errors = [case for case in error_cases if not case['level1_correct']]
|
||
|
|
level2_errors = [case for case in error_cases if case['level1_correct'] and not case['level2_correct']]
|
||
|
|
|
||
|
|
# 保存所有错误病例
|
||
|
|
with open(os.path.join(output_dir, 'error_cases_detailed.json'), 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(error_cases, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
# 保存一级科室错误
|
||
|
|
with open(os.path.join(output_dir, 'level1_errors.json'), 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(level1_errors, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
# 保存二级科室错误
|
||
|
|
with open(os.path.join(output_dir, 'level2_errors.json'), 'w', encoding='utf-8') as f:
|
||
|
|
json.dump(level2_errors, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
# 生成CSV格式的错误摘要
|
||
|
|
csv_lines = [
|
||
|
|
"病例索引,病例ID,期望一级科室,预测一级科室,一级是否正确,期望二级科室,预测二级科室,二级是否正确,分诊理由"
|
||
|
|
]
|
||
|
|
|
||
|
|
for case in error_cases:
|
||
|
|
csv_line = f"{case['case_index']},{case['case_id']},{case['expected_level1']},{case['predicted_level1']},{case['level1_correct']},{case['expected_level2']},{case['predicted_level2']},{case['level2_correct']},\"{case['triage_reasoning'][:100]}...\""
|
||
|
|
csv_lines.append(csv_line)
|
||
|
|
|
||
|
|
with open(os.path.join(output_dir, 'error_cases_summary.csv'), 'w', encoding='utf-8') as f:
|
||
|
|
f.write('\n'.join(csv_lines))
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""主函数"""
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# 从命令行参数获取路径,如果没有提供则使用默认值
|
||
|
|
if len(sys.argv) >= 3:
|
||
|
|
data_dir = Path(sys.argv[1])
|
||
|
|
output_dir = Path(sys.argv[2])
|
||
|
|
else:
|
||
|
|
base_dir = Path(__file__).parent.parent
|
||
|
|
data_dir = base_dir / "results" / "results0902"
|
||
|
|
output_dir = base_dir / "analysis" / "0902"
|
||
|
|
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
print(f"正在加载数据从: {data_dir}")
|
||
|
|
workflow_data = load_workflow_data(str(data_dir), str(output_dir), limit=5000)
|
||
|
|
print(f"成功加载 {len(workflow_data)} 个病例数据")
|
||
|
|
|
||
|
|
print("正在提取错误病例...")
|
||
|
|
error_cases = extract_error_cases(workflow_data)
|
||
|
|
|
||
|
|
print(f"发现 {len(error_cases)} 个错误病例")
|
||
|
|
|
||
|
|
# 统计错误类型
|
||
|
|
level1_errors = [case for case in error_cases if not case['level1_correct']]
|
||
|
|
level2_errors = [case for case in error_cases if not case['level2_correct']]
|
||
|
|
|
||
|
|
print(f"一级科室错误: {len(level1_errors)} 个")
|
||
|
|
print(f"二级科室错误: {len(level2_errors)} 个")
|
||
|
|
|
||
|
|
print("一级科室错误示例:")
|
||
|
|
for case in level1_errors[:5]:
|
||
|
|
print(f" 病例 {case['case_index']}: 期望={case['expected_level1']}, 预测={case['predicted_level1']}")
|
||
|
|
|
||
|
|
print("二级科室错误示例:")
|
||
|
|
for case in level2_errors[:5]:
|
||
|
|
print(f" 病例 {case['case_index']}: 期望={case['expected_level2']}, 预测={case['predicted_level2']}")
|
||
|
|
|
||
|
|
print("正在保存错误分析结果...")
|
||
|
|
save_error_analysis(error_cases, str(output_dir))
|
||
|
|
|
||
|
|
print(f"错误分析完成!结果已保存到: {output_dir}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|