362 lines
12 KiB
Python
362 lines
12 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
Evaluate智能体评估指标分析脚本
|
|||
|
|
用于统计evaluate的所有维度分数并绘制折线图
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import numpy as np
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Dict, List
|
|||
|
|
from file_filter_utils import filter_complete_files, print_filter_summary
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_workflow_data(data_dir: str, output_dir: str = "", limit: int = 5000) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
加载工作流数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: 数据目录路径
|
|||
|
|
output_dir: 输出目录路径(用于文件过滤)
|
|||
|
|
limit: 限制加载的病例数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
工作流数据列表
|
|||
|
|
"""
|
|||
|
|
workflow_data = []
|
|||
|
|
|
|||
|
|
# 获取所有jsonl文件
|
|||
|
|
all_files = sorted(Path(data_dir).glob("*.jsonl"))
|
|||
|
|
|
|||
|
|
# 过滤出完成的文件
|
|||
|
|
if output_dir:
|
|||
|
|
all_files = [str(f) for f in all_files]
|
|||
|
|
filtered_files = filter_complete_files(all_files, output_dir)
|
|||
|
|
filtered_files = [Path(f) for f in filtered_files]
|
|||
|
|
print_filter_summary(output_dir)
|
|||
|
|
else:
|
|||
|
|
filtered_files = all_files
|
|||
|
|
|
|||
|
|
# 限制文件数量
|
|||
|
|
jsonl_files = filtered_files[:limit]
|
|||
|
|
|
|||
|
|
print(f"将处理 {len(jsonl_files)} 个完成的文件")
|
|||
|
|
|
|||
|
|
for file_path in jsonl_files:
|
|||
|
|
try:
|
|||
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
workflow = []
|
|||
|
|
for line in f:
|
|||
|
|
try:
|
|||
|
|
data = json.loads(line.strip())
|
|||
|
|
workflow.append(data)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if workflow:
|
|||
|
|
workflow_data.append(workflow)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"加载文件 {file_path} 时出错: {e}")
|
|||
|
|
|
|||
|
|
return workflow_data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_evaluate_scores(workflow: List[Dict]) -> List[Dict]:
|
|||
|
|
"""
|
|||
|
|
提取evaluate评分数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
workflow: 单个工作流数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
evaluate评分列表
|
|||
|
|
"""
|
|||
|
|
evaluate_scores = []
|
|||
|
|
|
|||
|
|
for step in workflow:
|
|||
|
|
if step.get('agent_name') == 'evaluator' and 'output_data' in step:
|
|||
|
|
output_data = step['output_data']
|
|||
|
|
# 检查是否包含评估分数
|
|||
|
|
if any(key in output_data for key in [
|
|||
|
|
'clinical_inquiry', 'communication_quality',
|
|||
|
|
'multi_round_consistency', 'overall_professionalism',
|
|||
|
|
'present_illness_similarity', 'past_history_similarity',
|
|||
|
|
'chief_complaint_similarity'
|
|||
|
|
]):
|
|||
|
|
evaluate_scores.append(output_data)
|
|||
|
|
|
|||
|
|
return evaluate_scores
|
|||
|
|
|
|||
|
|
|
|||
|
|
def calculate_metrics_by_step(workflow_data: List[List[Dict]]) -> Dict[str, List[float]]:
|
|||
|
|
"""
|
|||
|
|
计算每一步的评估指标平均值
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
workflow_data: 所有工作流数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
各维度指标按步骤分组的平均值
|
|||
|
|
"""
|
|||
|
|
# 找出最大步骤数
|
|||
|
|
max_steps = 0
|
|||
|
|
for workflow in workflow_data:
|
|||
|
|
evaluate_scores = extract_evaluate_scores(workflow)
|
|||
|
|
max_steps = max(max_steps, len(evaluate_scores))
|
|||
|
|
|
|||
|
|
# 初始化数据收集器
|
|||
|
|
metrics_data = {
|
|||
|
|
'clinical_inquiry': [[] for _ in range(max_steps)],
|
|||
|
|
'communication_quality': [[] for _ in range(max_steps)],
|
|||
|
|
'multi_round_consistency': [[] for _ in range(max_steps)],
|
|||
|
|
'overall_professionalism': [[] for _ in range(max_steps)],
|
|||
|
|
'present_illness_similarity': [[] for _ in range(max_steps)],
|
|||
|
|
'past_history_similarity': [[] for _ in range(max_steps)],
|
|||
|
|
'chief_complaint_similarity': [[] for _ in range(max_steps)]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 收集每个步骤的评分
|
|||
|
|
for workflow in workflow_data:
|
|||
|
|
evaluate_scores = extract_evaluate_scores(workflow)
|
|||
|
|
|
|||
|
|
for step_idx, score_data in enumerate(evaluate_scores):
|
|||
|
|
# 提取各维度分数
|
|||
|
|
for metric in metrics_data.keys():
|
|||
|
|
if metric in score_data and isinstance(score_data[metric], dict):
|
|||
|
|
score = score_data[metric].get('score', 0.0)
|
|||
|
|
metrics_data[metric][step_idx].append(score)
|
|||
|
|
|
|||
|
|
# 计算平均值
|
|||
|
|
result = {}
|
|||
|
|
for metric, step_data in metrics_data.items():
|
|||
|
|
result[metric] = []
|
|||
|
|
for scores in step_data:
|
|||
|
|
if scores:
|
|||
|
|
result[metric].append(np.mean(scores))
|
|||
|
|
else:
|
|||
|
|
result[metric].append(0.0)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def plot_metrics_curves(metrics_data: Dict[str, List[float]], output_dir: str):
|
|||
|
|
"""
|
|||
|
|
绘制评估指标折线图
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
metrics_data: 各维度指标数据
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
plt.figure(figsize=(16, 10))
|
|||
|
|
|
|||
|
|
steps = list(range(1, len(next(iter(metrics_data.values()))) + 1))
|
|||
|
|
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57', '#FF9FF3', '#54A0FF', '#5F27CD']
|
|||
|
|
|
|||
|
|
for idx, (metric_name, scores) in enumerate(metrics_data.items()):
|
|||
|
|
# 跳过全为0的数据
|
|||
|
|
if all(score == 0.0 for score in scores):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
plt.plot(steps, scores, marker='o', linewidth=2,
|
|||
|
|
label=metric_name.replace('_', ' ').title(),
|
|||
|
|
color=colors[idx % len(colors)])
|
|||
|
|
|
|||
|
|
plt.xlabel('Conversation Round', fontsize=12)
|
|||
|
|
plt.ylabel('Score', fontsize=12)
|
|||
|
|
plt.title('Evaluate Agent Multi-Dimensional Assessment Trends', fontsize=14, fontweight='bold')
|
|||
|
|
plt.legend(fontsize=10, bbox_to_anchor=(1.05, 1), loc='upper left')
|
|||
|
|
plt.grid(True, alpha=0.3)
|
|||
|
|
plt.ylim(0, 5.5)
|
|||
|
|
|
|||
|
|
# 添加数值标签
|
|||
|
|
for metric_name, scores in metrics_data.items():
|
|||
|
|
if not all(score == 0.0 for score in scores):
|
|||
|
|
for i, score in enumerate(scores):
|
|||
|
|
if score > 0:
|
|||
|
|
plt.annotate(f'{score:.1f}', (steps[i], score),
|
|||
|
|
textcoords="offset points",
|
|||
|
|
xytext=(0, 5), ha='center', fontsize=8)
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(os.path.join(output_dir, 'evaluate_metrics_trends.png'), dpi=300, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
|
|||
|
|
# 绘制子图
|
|||
|
|
_, axes = plt.subplots(2, 4, figsize=(20, 12))
|
|||
|
|
axes = axes.flatten()
|
|||
|
|
|
|||
|
|
for idx, (metric_name, scores) in enumerate(metrics_data.items()):
|
|||
|
|
if idx >= len(axes):
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
ax = axes[idx]
|
|||
|
|
if not all(score == 0.0 for score in scores):
|
|||
|
|
ax.plot(steps, scores, marker='o', linewidth=2, color=colors[idx])
|
|||
|
|
ax.set_title(metric_name.replace('_', ' ').title(), fontsize=12)
|
|||
|
|
ax.set_xlabel('Conversation Round')
|
|||
|
|
ax.set_ylabel('Score')
|
|||
|
|
ax.grid(True, alpha=0.3)
|
|||
|
|
ax.set_ylim(0, 5.5)
|
|||
|
|
else:
|
|||
|
|
ax.text(0.5, 0.5, 'No Data', ha='center', va='center', transform=ax.transAxes)
|
|||
|
|
|
|||
|
|
# 隐藏多余的子图
|
|||
|
|
for idx in range(len(metrics_data), len(axes)):
|
|||
|
|
axes[idx].set_visible(False)
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(os.path.join(output_dir, 'evaluate_metrics_subplots.png'), dpi=300, bbox_inches='tight')
|
|||
|
|
plt.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_metrics_data(metrics_data: Dict[str, List[float]], output_dir: str):
|
|||
|
|
"""
|
|||
|
|
保存评估指标数据到JSON文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
metrics_data: 各维度指标数据
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
# 转换为更易读的格式
|
|||
|
|
formatted_data = {
|
|||
|
|
'维度': list(metrics_data.keys()),
|
|||
|
|
'步骤': list(range(1, len(next(iter(metrics_data.values()))) + 1)),
|
|||
|
|
'各维度得分': {}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for metric, scores in metrics_data.items():
|
|||
|
|
formatted_data['各维度得分'][metric] = scores
|
|||
|
|
|
|||
|
|
with open(os.path.join(output_dir, 'evaluate_metrics_data.json'), 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(formatted_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
# 保存简化格式
|
|||
|
|
simplified_data = {
|
|||
|
|
'轮次': list(range(1, len(next(iter(metrics_data.values()))) + 1))
|
|||
|
|
}
|
|||
|
|
simplified_data.update(metrics_data)
|
|||
|
|
|
|||
|
|
with open(os.path.join(output_dir, 'evaluate_metrics_summary.json'), 'w', encoding='utf-8') as f:
|
|||
|
|
json.dump(simplified_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_report(metrics_data: Dict[str, List[float]], output_dir: str):
|
|||
|
|
"""
|
|||
|
|
生成评估报告
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
metrics_data: 各维度指标数据
|
|||
|
|
output_dir: 输出目录
|
|||
|
|
"""
|
|||
|
|
report_lines = [
|
|||
|
|
"# Evaluate Agent Assessment Report",
|
|||
|
|
"",
|
|||
|
|
"## Average Scores by Dimension",
|
|||
|
|
""
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for metric_name, scores in metrics_data.items():
|
|||
|
|
valid_scores = [s for s in scores if s > 0]
|
|||
|
|
if valid_scores:
|
|||
|
|
avg_score = np.mean(valid_scores)
|
|||
|
|
max_score = max(valid_scores)
|
|||
|
|
min_score = min(valid_scores)
|
|||
|
|
report_lines.append(
|
|||
|
|
f"- **{metric_name.replace('_', ' ').title()}**: 平均 {avg_score:.2f} (最高: {max_score:.2f}, 最低: {min_score:.2f})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
report_lines.extend([
|
|||
|
|
"",
|
|||
|
|
"",
|
|||
|
|
"## 分析",
|
|||
|
|
"",
|
|||
|
|
"### 表现良好的维度 (平均得分>4.0):"
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
good_metrics = []
|
|||
|
|
for metric_name, scores in metrics_data.items():
|
|||
|
|
valid_scores = [s for s in scores if s > 0]
|
|||
|
|
if valid_scores and np.mean(valid_scores) > 4.0:
|
|||
|
|
good_metrics.append(metric_name.replace('_', ' ').title())
|
|||
|
|
|
|||
|
|
if good_metrics:
|
|||
|
|
report_lines.extend([f"- {metric}" for metric in good_metrics])
|
|||
|
|
else:
|
|||
|
|
report_lines.append("- 无")
|
|||
|
|
|
|||
|
|
report_lines.extend([
|
|||
|
|
"",
|
|||
|
|
"### 需要改进的维度(平均得分<2.0):"
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
poor_metrics = []
|
|||
|
|
for metric_name, scores in metrics_data.items():
|
|||
|
|
valid_scores = [s for s in scores if s > 0]
|
|||
|
|
if valid_scores and np.mean(valid_scores) < 2.0:
|
|||
|
|
poor_metrics.append(metric_name.replace('_', ' ').title())
|
|||
|
|
|
|||
|
|
if poor_metrics:
|
|||
|
|
report_lines.extend([f"- {metric}" for metric in poor_metrics])
|
|||
|
|
else:
|
|||
|
|
report_lines.append("- 无")
|
|||
|
|
|
|||
|
|
with open(os.path.join(output_dir, 'evaluate_report.md'), 'w', encoding='utf-8') as f:
|
|||
|
|
f.write('\n'.join(report_lines))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
# 从命令行参数获取路径,如果没有提供则使用默认值
|
|||
|
|
if len(sys.argv) >= 3:
|
|||
|
|
data_dir = Path(sys.argv[1])
|
|||
|
|
output_dir = Path(sys.argv[2])
|
|||
|
|
else:
|
|||
|
|
base_dir = Path(__file__).parent.parent
|
|||
|
|
data_dir = base_dir / "results" / "results0902"
|
|||
|
|
output_dir = base_dir / "analysis" / "0902"
|
|||
|
|
|
|||
|
|
# 创建输出目录
|
|||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
print(f"正在加载数据从: {data_dir}")
|
|||
|
|
workflow_data = load_workflow_data(str(data_dir), str(output_dir), limit=5000)
|
|||
|
|
print(f"成功加载 {len(workflow_data)} 个病例数据")
|
|||
|
|
|
|||
|
|
if not workflow_data:
|
|||
|
|
print("未找到有效的工作流数据")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print("正在计算评估指标...")
|
|||
|
|
metrics_data = calculate_metrics_by_step(workflow_data)
|
|||
|
|
|
|||
|
|
print("评估维度统计结果:")
|
|||
|
|
for metric, scores in metrics_data.items():
|
|||
|
|
valid_scores = [s for s in scores if s > 0]
|
|||
|
|
if valid_scores:
|
|||
|
|
avg_score = np.mean(valid_scores)
|
|||
|
|
print(f" {metric}: 平均 {avg_score:.2f} (轮次: {len(valid_scores)})")
|
|||
|
|
|
|||
|
|
print("正在生成图表...")
|
|||
|
|
plot_metrics_curves(metrics_data, str(output_dir))
|
|||
|
|
|
|||
|
|
print("正在保存数据...")
|
|||
|
|
save_metrics_data(metrics_data, str(output_dir))
|
|||
|
|
|
|||
|
|
print("正在生成报告...")
|
|||
|
|
generate_report(metrics_data, str(output_dir))
|
|||
|
|
|
|||
|
|
print(f"分析完成!结果已保存到: {output_dir}")
|
|||
|
|
print("输出文件:")
|
|||
|
|
print(" - evaluate_metrics_data.json: 详细数据")
|
|||
|
|
print(" - evaluate_metrics_summary.json: 简化数据")
|
|||
|
|
print(" - evaluate_metrics_trends.png: 趋势图")
|
|||
|
|
print(" - evaluate_metrics_subplots.png: 子图")
|
|||
|
|
print(" - evaluate_report.md: 评估报告")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|