triage/analysis/triage_accuracy_analysis.py

265 lines
8.8 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
分诊结果正确率分析脚本
用于计算每一步分诊结果的一级科室分诊和二级科室分诊的正确率
"""
import json
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Tuple
import re
from file_filter_utils import filter_complete_files, print_filter_summary
def load_workflow_data(data_dir: str, output_dir: str = "", limit: int = 5000) -> List[Dict]:
"""
加载工作流数据
Args:
data_dir: 数据目录路径
output_dir: 输出目录路径用于文件过滤
limit: 限制加载的病例数量
Returns:
工作流数据列表
"""
workflow_data = []
# 获取所有jsonl文件
all_files = sorted(Path(data_dir).glob("*.jsonl"))
# 过滤出完成的文件
if output_dir:
all_files = [str(f) for f in all_files]
filtered_files = filter_complete_files(all_files, output_dir)
filtered_files = [Path(f) for f in filtered_files]
print_filter_summary(output_dir)
else:
filtered_files = all_files
# 限制文件数量
jsonl_files = filtered_files[:limit]
for file_path in jsonl_files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
workflow = []
for line in f:
try:
data = json.loads(line.strip())
workflow.append(data)
except json.JSONDecodeError:
continue
if workflow:
workflow_data.append(workflow)
except Exception as e:
print(f"加载文件 {file_path} 时出错: {e}")
return workflow_data
def extract_triage_steps(workflow: List[Dict]) -> List[Dict]:
"""
提取分诊步骤
Args:
workflow: 单个工作流数据
Returns:
分诊步骤列表
"""
triage_steps = []
for step in workflow:
if step.get('agent_name') == 'triager' and 'output_data' in step:
triage_steps.append(step)
return triage_steps
def calculate_accuracy(workflow_data: List[List[Dict]]) -> Tuple[List[float], List[float]]:
"""
计算每一步的一级和二级科室分诊正确率
对于提前结束的病例沿用最后一步的分诊结果
Args:
workflow_data: 所有工作流数据
Returns:
(一级科室正确率列表, 二级科室正确率列表)
"""
# 找出最大步骤数
max_steps = 0
for workflow in workflow_data:
triage_steps = extract_triage_steps(workflow)
max_steps = max(max_steps, len(triage_steps))
# 初始化计数器
level1_correct = [0] * max_steps
level2_correct = [0] * max_steps
total_cases = [0] * max_steps
for workflow in workflow_data:
triage_steps = extract_triage_steps(workflow)
# 获取标准答案从case_data
standard_answer = None
for step in workflow:
if step.get('event_type') == 'workflow_start' and 'case_data' in step:
case_data = step['case_data']
standard_answer = {
'一级科室': case_data.get('一级科室'),
'二级科室': case_data.get('二级科室')
}
break
if not standard_answer:
continue
if not triage_steps:
continue
# 获取该病例的最后一步分诊结果
final_step = triage_steps[-1]
final_output = final_step.get('output_data', {})
# 计算一级科室是否正确
level1_is_correct = final_output.get('primary_department') == standard_answer['一级科室']
# 计算二级科室是否正确
level2_is_correct = final_output.get('secondary_department') == standard_answer['二级科室']
# 对于该病例的每一步,都使用最终的分诊结果进行计算
for i in range(max_steps):
# 如果该病例在步骤i+1有分诊步骤则使用该步骤的结果
if i < len(triage_steps):
step_output = triage_steps[i].get('output_data', {})
level1_is_correct = step_output.get('primary_department') == standard_answer['一级科室']
level2_is_correct = step_output.get('secondary_department') == standard_answer['二级科室']
# 对于后续的步骤,沿用最后一步的结果
level1_correct[i] += 1 if level1_is_correct else 0
level2_correct[i] += 1 if level2_is_correct else 0
total_cases[i] += 1
# 计算正确率
level1_accuracy = []
level2_accuracy = []
for i in range(max_steps):
if total_cases[i] > 0:
level1_accuracy.append(level1_correct[i] / total_cases[i])
level2_accuracy.append(level2_correct[i] / total_cases[i])
else:
level1_accuracy.append(0.0)
level2_accuracy.append(0.0)
return level1_accuracy, level2_accuracy
def plot_accuracy_curves(level1_accuracy: List[float], level2_accuracy: List[float], output_dir: str):
"""
绘制正确率折线图
Args:
level1_accuracy: 一级科室正确率列表
level2_accuracy: 二级科室正确率列表
output_dir: 输出目录
"""
plt.figure(figsize=(12, 8))
steps = list(range(1, len(level1_accuracy) + 1))
plt.plot(steps, level1_accuracy, marker='o', linewidth=2, label='Level 1 Department Accuracy', color='#2E86AB')
plt.plot(steps, level2_accuracy, marker='s', linewidth=2, label='Level 2 Department Accuracy', color='#A23B72')
plt.xlabel('Triage Step', fontsize=12)
plt.ylabel('Accuracy Rate', fontsize=12)
plt.title('Triage Accuracy Trends Over Steps', fontsize=14, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.1)
# 添加数值标签
for i, (l1, l2) in enumerate(zip(level1_accuracy, level2_accuracy)):
if l1 > 0: # 只显示非零值
plt.annotate(f'{l1:.2f}', (steps[i], l1), textcoords="offset points",
xytext=(0,10), ha='center', fontsize=9)
if l2 > 0: # 只显示非零值
plt.annotate(f'{l2:.2f}', (steps[i], l2), textcoords="offset points",
xytext=(0,10), ha='center', fontsize=9)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'triage_accuracy_trends.png'), dpi=300, bbox_inches='tight')
plt.close()
def save_accuracy_data(level1_accuracy: List[float], level2_accuracy: List[float], output_dir: str):
"""
保存正确率数据到JSON文件
Args:
level1_accuracy: 一级科室正确率列表
level2_accuracy: 二级科室正确率列表
output_dir: 输出目录
"""
accuracy_data = {
'一级科室分诊正确率': level1_accuracy,
'二级科室分诊正确率': level2_accuracy,
'步骤': list(range(1, len(level1_accuracy) + 1))
}
with open(os.path.join(output_dir, 'triage_accuracy_data.json'), 'w', encoding='utf-8') as f:
json.dump(accuracy_data, f, ensure_ascii=False, indent=2)
def main():
"""主函数"""
import sys
# 从命令行参数获取路径,如果没有提供则使用默认值
if len(sys.argv) >= 3:
data_dir = Path(sys.argv[1])
output_dir = Path(sys.argv[2])
else:
base_dir = Path(__file__).parent.parent
data_dir = base_dir / "results" / "results0902"
output_dir = base_dir / "analysis" / "0902"
# 创建输出目录
output_dir.mkdir(parents=True, exist_ok=True)
print(f"正在加载数据从: {data_dir}")
workflow_data = load_workflow_data(str(data_dir), str(output_dir), limit=5000)
print(f"成功加载 {len(workflow_data)} 个病例数据")
if not workflow_data:
print("未找到有效的工作流数据")
return
print("正在计算分诊正确率...")
level1_accuracy, level2_accuracy = calculate_accuracy(workflow_data)
print("一级科室分诊正确率:")
for i, acc in enumerate(level1_accuracy, 1):
print(f" 步骤 {i}: {acc:.4f}")
print("二级科室分诊正确率:")
for i, acc in enumerate(level2_accuracy, 1):
print(f" 步骤 {i}: {acc:.4f}")
print("正在生成图表...")
plot_accuracy_curves(level1_accuracy, level2_accuracy, str(output_dir))
print("正在保存数据...")
save_accuracy_data(level1_accuracy, level2_accuracy, str(output_dir))
print(f"分析完成!结果已保存到: {output_dir}")
if __name__ == "__main__":
main()