265 lines
8.8 KiB
Python
265 lines
8.8 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
分诊结果正确率分析脚本
|
|||
|
|
用于计算每一步分诊结果的一级科室分诊和二级科室分诊的正确率
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import numpy as np
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, List, Tuple
|
|||
|
|
import re
|
|||
|
|
from file_filter_utils import filter_complete_files, print_filter_summary
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_workflow_data(data_dir: str, output_dir: str = "", limit: int = 5000) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
加载工作流数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: 数据目录路径
|
|||
|
|
output_dir: 输出目录路径(用于文件过滤)
|
|||
|
|
limit: 限制加载的病例数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
工作流数据列表
|
|||
|
|
"""
|
|||
|
|
workflow_data = []
|
|||
|
|
|
|||
|
|
# 获取所有jsonl文件
|
|||
|
|
all_files = sorted(Path(data_dir).glob("*.jsonl"))
|
|||
|
|
|
|||
|
|
# 过滤出完成的文件
|
|||
|
|
if output_dir:
|
|||
|
|
all_files = [str(f) for f in all_files]
|
|||
|
|
filtered_files = filter_complete_files(all_files, output_dir)
|
|||
|
|
filtered_files = [Path(f) for f in filtered_files]
|
|||
|
|
print_filter_summary(output_dir)
|
|||
|
|
else:
|
|||
|
|
filtered_files = all_files
|
|||
|
|
|
|||
|
|
# 限制文件数量
|
|||
|
|
jsonl_files = filtered_files[:limit]
|
|||
|
|
|
|||
|
|
for file_path in jsonl_files:
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
workflow = []
|
|||
|
|
for line in f:
|
|||
|
|
try:
|
|||
|
|
data = json.loads(line.strip())
|
|||
|
|
workflow.append(data)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if workflow:
|
|||
|
|
workflow_data.append(workflow)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"加载文件 {file_path} 时出错: {e}")
|
|||
|
|
|
|||
|
|
return workflow_data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_triage_steps(workflow: List[Dict]) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
提取分诊步骤
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
workflow: 单个工作流数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
分诊步骤列表
|
|||
|
|
"""
|
|||
|
|
triage_steps = []
|
|||
|
|
|
|||
|
|
for step in workflow:
|
|||
|
|
if step.get('agent_name') == 'triager' and 'output_data' in step:
|
|||
|
|
triage_steps.append(step)
|
|||
|
|
|
|||
|
|
return triage_steps
|
|||
|
|
|
|||
|
|
|
|||
|
|
def calculate_accuracy(workflow_data: List[List[Dict]]) -> Tuple[List[float], List[float]]:
|
|||
|
|
"""
|
|||
|
|
计算每一步的一级和二级科室分诊正确率
|
|||
|
|
对于提前结束的病例,沿用最后一步的分诊结果
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
workflow_data: 所有工作流数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(一级科室正确率列表, 二级科室正确率列表)
|
|||
|
|
"""
|
|||
|
|
# 找出最大步骤数
|
|||
|
|
max_steps = 0
|
|||
|
|
for workflow in workflow_data:
|
|||
|
|
triage_steps = extract_triage_steps(workflow)
|
|||
|
|
max_steps = max(max_steps, len(triage_steps))
|
|||
|
|
|
|||
|
|
# 初始化计数器
|
|||
|
|
level1_correct = [0] * max_steps
|
|||
|
|
level2_correct = [0] * max_steps
|
|||
|
|
total_cases = [0] * max_steps
|
|||
|
|
|
|||
|
|
for workflow in workflow_data:
|
|||
|
|
triage_steps = extract_triage_steps(workflow)
|
|||
|
|
|
|||
|
|
# 获取标准答案(从case_data)
|
|||
|
|
standard_answer = None
|
|||
|
|
for step in workflow:
|
|||
|
|
if step.get('event_type') == 'workflow_start' and 'case_data' in step:
|
|||
|
|
case_data = step['case_data']
|
|||
|
|
standard_answer = {
|
|||
|
|
'一级科室': case_data.get('一级科室'),
|
|||
|
|
'二级科室': case_data.get('二级科室')
|
|||
|
|
}
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not standard_answer:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if not triage_steps:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 获取该病例的最后一步分诊结果
|
|||
|
|
final_step = triage_steps[-1]
|
|||
|
|
final_output = final_step.get('output_data', {})
|
|||
|
|
|
|||
|
|
# 计算一级科室是否正确
|
|||
|
|
level1_is_correct = final_output.get('primary_department') == standard_answer['一级科室']
|
|||
|
|
|
|||
|
|
# 计算二级科室是否正确
|
|||
|
|
level2_is_correct = final_output.get('secondary_department') == standard_answer['二级科室']
|
|||
|
|
|
|||
|
|
# 对于该病例的每一步,都使用最终的分诊结果进行计算
|
|||
|
|
for i in range(max_steps):
|
|||
|
|
# 如果该病例在步骤i+1有分诊步骤,则使用该步骤的结果
|
|||
|
|
if i < len(triage_steps):
|
|||
|
|
step_output = triage_steps[i].get('output_data', {})
|
|||
|
|
level1_is_correct = step_output.get('primary_department') == standard_answer['一级科室']
|
|||
|
|
level2_is_correct = step_output.get('secondary_department') == standard_answer['二级科室']
|
|||
|
|
|
|||
|
|
# 对于后续的步骤,沿用最后一步的结果
|
|||
|
|
level1_correct[i] += 1 if level1_is_correct else 0
|
|||
|
|
level2_correct[i] += 1 if level2_is_correct else 0
|
|||
|
|
total_cases[i] += 1
|
|||
|
|
|
|||
|
|
# 计算正确率
|
|||
|
|
level1_accuracy = []
|
|||
|
|
level2_accuracy = []
|
|||
|
|
|
|||
|
|
for i in range(max_steps):
|
|||
|
|
if total_cases[i] > 0:
|
|||
|
|
level1_accuracy.append(level1_correct[i] / total_cases[i])
|
|||
|
|
level2_accuracy.append(level2_correct[i] / total_cases[i])
|
|||
|
|
else:
|
|||
|
|
level1_accuracy.append(0.0)
|
|||
|
|
level2_accuracy.append(0.0)
|
|||
|
|
|
|||
|
|
return level1_accuracy, level2_accuracy
|
|||
|
|
|
|||
|
|
|
|||
|
|
def plot_accuracy_curves(level1_accuracy: List[float], level2_accuracy: List[float], output_dir: str):
|
|||
|
|
"""
|
|||
|
|
绘制正确率折线图
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
level1_accuracy: 一级科室正确率列表
|
|||
|
|
level2_accuracy: 二级科室正确率列表
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
plt.figure(figsize=(12, 8))
|
|||
|
|
|
|||
|
|
steps = list(range(1, len(level1_accuracy) + 1))
|
|||
|
|
|
|||
|
|
plt.plot(steps, level1_accuracy, marker='o', linewidth=2, label='Level 1 Department Accuracy', color='#2E86AB')
|
|||
|
|
plt.plot(steps, level2_accuracy, marker='s', linewidth=2, label='Level 2 Department Accuracy', color='#A23B72')
|
|||
|
|
|
|||
|
|
plt.xlabel('Triage Step', fontsize=12)
|
|||
|
|
plt.ylabel('Accuracy Rate', fontsize=12)
|
|||
|
|
plt.title('Triage Accuracy Trends Over Steps', fontsize=14, fontweight='bold')
|
|||
|
|
plt.legend(fontsize=12)
|
|||
|
|
plt.grid(True, alpha=0.3)
|
|||
|
|
plt.ylim(0, 1.1)
|
|||
|
|
|
|||
|
|
# 添加数值标签
|
|||
|
|
for i, (l1, l2) in enumerate(zip(level1_accuracy, level2_accuracy)):
|
|||
|
|
if l1 > 0: # 只显示非零值
|
|||
|
|
plt.annotate(f'{l1:.2f}', (steps[i], l1), textcoords="offset points",
|
|||
|
|
xytext=(0,10), ha='center', fontsize=9)
|
|||
|
|
if l2 > 0: # 只显示非零值
|
|||
|
|
plt.annotate(f'{l2:.2f}', (steps[i], l2), textcoords="offset points",
|
|||
|
|
xytext=(0,10), ha='center', fontsize=9)
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(os.path.join(output_dir, 'triage_accuracy_trends.png'), dpi=300, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_accuracy_data(level1_accuracy: List[float], level2_accuracy: List[float], output_dir: str):
|
|||
|
|
"""
|
|||
|
|
保存正确率数据到JSON文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
level1_accuracy: 一级科室正确率列表
|
|||
|
|
level2_accuracy: 二级科室正确率列表
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
accuracy_data = {
|
|||
|
|
'一级科室分诊正确率': level1_accuracy,
|
|||
|
|
'二级科室分诊正确率': level2_accuracy,
|
|||
|
|
'步骤': list(range(1, len(level1_accuracy) + 1))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with open(os.path.join(output_dir, 'triage_accuracy_data.json'), 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(accuracy_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
# 从命令行参数获取路径,如果没有提供则使用默认值
|
|||
|
|
if len(sys.argv) >= 3:
|
|||
|
|
data_dir = Path(sys.argv[1])
|
|||
|
|
output_dir = Path(sys.argv[2])
|
|||
|
|
else:
|
|||
|
|
base_dir = Path(__file__).parent.parent
|
|||
|
|
data_dir = base_dir / "results" / "results0902"
|
|||
|
|
output_dir = base_dir / "analysis" / "0902"
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
print(f"正在加载数据从: {data_dir}")
|
|||
|
|
workflow_data = load_workflow_data(str(data_dir), str(output_dir), limit=5000)
|
|||
|
|
print(f"成功加载 {len(workflow_data)} 个病例数据")
|
|||
|
|
|
|||
|
|
if not workflow_data:
|
|||
|
|
print("未找到有效的工作流数据")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print("正在计算分诊正确率...")
|
|||
|
|
level1_accuracy, level2_accuracy = calculate_accuracy(workflow_data)
|
|||
|
|
|
|||
|
|
print("一级科室分诊正确率:")
|
|||
|
|
for i, acc in enumerate(level1_accuracy, 1):
|
|||
|
|
print(f" 步骤 {i}: {acc:.4f}")
|
|||
|
|
|
|||
|
|
print("二级科室分诊正确率:")
|
|||
|
|
for i, acc in enumerate(level2_accuracy, 1):
|
|||
|
|
print(f" 步骤 {i}: {acc:.4f}")
|
|||
|
|
|
|||
|
|
print("正在生成图表...")
|
|||
|
|
plot_accuracy_curves(level1_accuracy, level2_accuracy, str(output_dir))
|
|||
|
|
|
|||
|
|
print("正在保存数据...")
|
|||
|
|
save_accuracy_data(level1_accuracy, level2_accuracy, str(output_dir))
|
|||
|
|
|
|||
|
|
print(f"分析完成!结果已保存到: {output_dir}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|