diff --git a/agent_system/evaluetor/prompt.py b/agent_system/evaluetor/prompt.py index 76db867..46996c2 100755 --- a/agent_system/evaluetor/prompt.py +++ b/agent_system/evaluetor/prompt.py @@ -47,83 +47,36 @@ class EvaluatorPrompt(BasePrompt): "- **5分:优秀** - 表现突出,超出基本预期", "", "**各维度具体标准**:", - "", - "### 评分参考背景案例", - "**患者背景**: 55岁男性,主诉'胸痛3天'", - "**真实现病史**: 3天前无明显诱因出现胸骨后疼痛,呈压榨性,持续性,向左肩放射,伴出汗、恶心,活动时加重,休息后稍缓解", - "**真实既往史**: 高血压10年,糖尿病5年,吸烟20年每天1包,父亲心梗病史", - "**真实主述**: 胸痛3天", - "", "### 临床问诊能力 (clinical_inquiry)", "- **5分**: 问题设计科学系统,问诊逻辑清晰,信息收集全面深入", - " - *问诊例子*: '这个胸痛是什么性质的?压榨性还是刺痛?有没有向肩膀、手臂放射?什么情况下加重?休息能缓解吗?伴随有出汗、恶心吗?您有高血压糖尿病病史吗?家族有心脏病史吗?'", - " - *评分理由*: 系统询问疼痛PQRST要素,主动询问相关既往史和家族史,体现完整临床思维", "- **4分**: 问题针对性强,问诊思路合理,能有效收集关键信息", - " - *问诊例子*: '胸痛是压榨性的吗?有放射痛吗?活动时会加重吗?有没有出汗恶心?您有高血压病史吗?'", - " - *评分理由*: 重点询问胸痛典型特征和重要危险因素,针对性强", "- **3分**: 能提出基本相关问题,问诊方向基本正确,能收集必要信息", - " - *问诊例子*: '胸痛什么时候开始的?疼痛严重吗?在哪个部位?有其他不舒服吗?'", - " - *评分理由*: 问题相关且方向正确,但缺乏针对性和深度", "- **2分**: 能提出问题并收集基本信息,方向基本正确", - " - *问诊例子*: '胸痛多长时间了?现在还痛吗?严重吗?'", - " - *评分理由*: 能问基本信息但不够深入,遗漏重要诊断要素", "- **1分**: 能完成基本问诊任务,收集基础信息", - " - *问诊例子*: '哪里不舒服?什么时候开始的?'", - " - *评分理由*: 过于简单,缺乏针对性和专业性", "- **0分**: 无法判断问诊质量", "", "### 沟通表达能力 (communication_quality)", "- **5分**: 语言通俗易懂,避免过度专业术语,患者完全理解,沟通亲和温暖", - " - *沟通例子*: '您好,我想了解一下您胸口疼痛的情况。这种疼痛是像被什么东西压着的感觉,还是像针扎一样刺痛?疼痛会不会传到肩膀或者胳膊上?'", - " - *评分理由*: 用通俗比喻解释医学概念,语言亲切自然,患者容易理解", "- **4分**: 用词恰当亲民,适度使用通俗解释,患者较易理解", - " - *沟通例子*: '您的胸痛是压榨性的吗?就是感觉胸口被挤压的那种?有放射到其他地方吗?'", - " - *评分理由*: 使用医学术语但有通俗解释,表达清晰易懂", "- **3分**: 表达基本清晰,偶有专业术语但有解释,患者基本能理解", - " - *沟通例子*: '胸痛的性质是怎样的?是压榨性疼痛吗?有放射痛吗?就是疼痛向别的地方传播。'", - " - *评分理由*: 有专业术语但能解释,基本清晰,患者能理解", "- **2分**: 表达清楚但专业性较强,患者需要一定努力才能理解", - " - *沟通例子*: '请描述疼痛的性质,是压榨性还是刺痛性?有无放射痛?诱发因素是什么?'", - " - *评分理由*: 用词较专业但表达清楚,患者需要思考才能理解", "- **1分**: 过度使用专业术语,患者理解困难,缺乏亲和力", - " - *沟通例子*: '疼痛性质如何?PQRST分析?有无心绞痛典型症状?危险分层如何?'", - " - *评分理由*: 专业术语过多,缺乏解释,患者难以理解", "- **0分**: 无法评价沟通质量", "", "### 信息收集全面性 (information_completeness)", "- **5分**: 信息收集系统全面,涵盖现病史、既往史、危险因素等关键要素", - " - *全面性例子*: '请详细说说胸痛的性质、部位、放射情况?有伴随症状吗?您有高血压、糖尿病病史吗?吸烟史如何?家族有心脏病史吗?平时活动耐量怎样?'", - " - *评分理由*: 系统询问现病史PQRST要素、相关既往史、危险因素,信息收集全面", "- **4分**: 信息收集较为全面,涵盖主要诊断要素", - " - *全面性例子*: '胸痛的性质和部位如何?有放射痛吗?您有高血压糖尿病吗?抽烟吗?'", - " - *评分理由*: 收集了症状主要特征和重要危险因素,比较全面", "- **3分**: 信息收集基本全面,涵盖必要要素", - " - *全面性例子*: '胸痛什么性质?在哪个部位?您有什么基础病吗?'", - " - *评分理由*: 收集了基本症状信息和既往史,基本全面但深度不足", "- **2分**: 信息收集不够全面,遗漏部分重要信息", - " - *全面性例子*: '胸痛怎样?什么时候开始的?严重吗?'", - " - *评分理由*: 只询问了症状基本信息,未涉及危险因素和既往史", "- **1分**: 信息收集很不全面,仅收集表面信息", - " - *全面性例子*: '哪里不舒服?什么时候开始的?'", - " - *评分理由*: 信息收集过于简单,缺乏系统性", "- **0分**: 第一轮或信息不足,无法评价全面性", "", "### 整体专业性 (overall_professionalism)", "- **5分**: 医学思维出色,风险识别准确,问诊逻辑严谨", - " - *专业性例子*: '根据您的症状,这可能涉及心血管问题,我需要了解:您有高血压糖尿病吗?家族有心脏病史吗?平时抽烟吗?我们需要排除心绞痛或心肌梗死的可能。'", - " - *评分理由*: 体现出色临床思维,准确识别高危因素,具备风险分层概念", "- **4分**: 医学思维良好,能抓住重点,问诊方向准确", - " - *专业性例子*: '胸痛伴出汗恶心需要警惕心脏问题,您有高血压病史吗?抽烟吗?我们需要进一步检查。'", - " - *评分理由*: 能识别重要症状组合,重点询问危险因素,方向正确", "- **3分**: 具备医学思维,问诊方向基本正确,体现专业性", - " - *专业性例子*: '胸痛需要了解具体情况,您有什么基础疾病吗?平时身体怎么样?'", - " - *评分理由*: 有基本医学概念,知道询问基础疾病,但缺乏针对性", "- **2分**: 医学思维基本合理,问诊方向基本正确", - " - *专业性例子*: '胸痛可能是心脏问题,您身体有什么毛病吗?'", - " - *评分理由*: 有基本判断但表达不够专业,思维简单", "- **1分**: 具备基本医学思维,能完成基本问诊", - " - *专业性例子*: '胸痛要检查一下,您还有哪里不舒服?'", - " - *评分理由*: 缺乏专业分析,问诊过于简单", "- **0分**: 无法评价专业水平", "", "### 相似度评价标准 (各维度通用)", diff --git a/analysis/workflow_file_cleaner.py b/analysis/workflow_file_cleaner.py index 7a46d32..626982a 100644 --- a/analysis/workflow_file_cleaner.py +++ b/analysis/workflow_file_cleaner.py @@ -1,44 +1,263 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -工作流文件清理器 -检测指定目录中的所有JSONL文件,删除不完整的工作流记录文件 +智能工作流文件清理器 +基于质量评估的智能清理策略: +- 不完整项目:保留10%最优质的,删除90% +- 完整项目:删除20%质量最差的,保留80% """ import json import os import glob +import re +import shutil from pathlib import Path -from typing import Dict, Any, List +from typing import Dict, Any, List, Optional, Set import argparse import logging +from dataclasses import dataclass +from datetime import datetime # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -class WorkflowFileCleaner: - """工作流文件清理器""" +@dataclass +class QualityScore: + """质量评分数据类""" + professional_penalty: float # 专业指标惩罚分 + triage_penalty: float # 分诊错误惩罚分 + total_penalty: float # 总惩罚分 + is_complete: bool # 是否完整 + file_path: str # 文件路径 + + +class IntelligentWorkflowCleaner: + """基于质量评估的智能工作流文件清理器""" - def __init__(self, directory: str, dry_run: bool = False): + def __init__(self, directory: str, dry_run: bool = False, + keep_incomplete_ratio: float = 0.1, + remove_complete_ratio: float = 0.2): """ - 初始化清理器 + 初始化智能清理器 Args: directory: 要检查的目录路径 dry_run: 是否为试运行模式(不实际删除文件) + keep_incomplete_ratio: 不完整项目保留比例(默认10%) + remove_complete_ratio: 完整项目删除比例(默认20%) """ self.directory = Path(directory) self.dry_run = dry_run + self.keep_incomplete_ratio = keep_incomplete_ratio + self.remove_complete_ratio = remove_complete_ratio + + # 质量评估相关的评估指标映射(四个核心指标) + self.quality_indicators = { + 'clinical_inquiry': 'clinical_inquiry', + 'communication_quality': 'communication_quality', + 'information_completeness': 'information_completeness', # 修正为正确的字段名 + 'overall_professional': 'overall_professionalism' + } + + # Dataset路径 + self.dataset_path = Path('dataset/bbb.json') + self.stats = { 'total_files': 0, 'complete_files': 0, 'incomplete_files': 0, - 'deleted_files': [], - 'error_files': [] + 'kept_incomplete_files': [], + 'deleted_incomplete_files': [], + 'kept_complete_files': [], + 'deleted_complete_files': [], + 'error_files': [], + 'deleted_case_indices': [], # 记录被删除的case索引 + 'deleted_cases_info': [], # 记录被删除的case详细信息 + 'dataset_backup_path': '', # 备份文件路径 + 'quality_analysis': { + 'incomplete': {'avg_penalty': 0.0, 'score_range': (0.0, 0.0)}, + 'complete': {'avg_penalty': 0.0, 'score_range': (0.0, 0.0)} + } } + def calculate_professional_penalty(self, evaluation_data_by_round: Dict[int, Dict[str, Any]]) -> float: + """ + 计算专业指标惩罚分数 + + 公式: Σ(round_i * Σ(四个指标的惩罚分)) + + Args: + evaluation_data_by_round: 按轮次组织的评估数据字典 + + Returns: + float: 专业指标惩罚分数 + """ + penalty = 0.0 + + # 遍历所有轮次 + for round_num, round_data in evaluation_data_by_round.items(): + # 计算该轮次四个指标的惩罚分总和 + round_penalty_sum = 0.0 + + for indicator_key in self.quality_indicators.values(): + if indicator_key in round_data: + indicator_data = round_data[indicator_key] + + # 处理嵌套的score结构 + if isinstance(indicator_data, dict): + score = indicator_data.get('score', 3.0) + else: + # 兼容直接存储score的情况 + score = float(indicator_data) if isinstance(indicator_data, (int, float)) else 3.0 + + # 只有分数低于3.0才计算惩罚 + if score < 3.0: + round_penalty_sum += (3.0 - score) + + # 轮次惩罚 = 轮次编号 × 该轮次四个指标惩罚分之和 + penalty += round_num * round_penalty_sum + + return penalty + + def calculate_triage_penalty(self, jsonl_file: str, case_data: Dict[str, Any]) -> float: + """ + 计算分诊错误惩罚分数 + + 如果第一轮的一级和二级都正确,才开始计算。后续错几轮就是几分 + + Args: + jsonl_file: JSONL文件路径 + case_data: 案例数据 + + Returns: + float: 分诊错误惩罚分数 + """ + try: + correct_primary = case_data.get('一级科室', '') + correct_secondary = case_data.get('二级科室', '') + + # 提取所有triager agent的分诊结果 + triage_steps = [] + with open(jsonl_file, 'r', encoding='utf-8') as f: + for line in f: + try: + event = json.loads(line.strip()) + if (event.get('event_type') == 'agent_execution' and + event.get('agent_name') == 'triager'): + + output_data = event.get('output_data', {}) + step_number = event.get('step_number', 0) + + predicted_primary = output_data.get('primary_department', '') + predicted_secondary = output_data.get('secondary_department', '') + + triage_steps.append({ + 'step_number': step_number, + 'primary_department': predicted_primary, + 'secondary_department': predicted_secondary, + 'primary_correct': predicted_primary == correct_primary, + 'secondary_correct': predicted_secondary == correct_secondary + }) + + except (json.JSONDecodeError, KeyError): + continue + + if not triage_steps: + return 0.0 + + # 按步骤号排序 + triage_steps.sort(key=lambda x: x['step_number']) + + # 检查第一轮是否完全正确(一级和二级都正确) + first_round = triage_steps[0] + if not (first_round['primary_correct'] and first_round['secondary_correct']): + # 第一轮不完全正确,不计算惩罚分 + return 0.0 + + # 计算后续轮次的错误数 + error_rounds = 0 + for step in triage_steps[1:]: # 从第二轮开始 + # 只要一级或二级有一个错误,就算这轮错误 + if not (step['primary_correct'] and step['secondary_correct']): + error_rounds += 1 + + return float(error_rounds) + + except Exception as e: + logger.warning(f"计算分诊惩罚分时出错 {jsonl_file}: {e}") + + return 0.0 + + def calculate_quality_score(self, jsonl_file: str) -> Optional[QualityScore]: + """ + 计算文件的质量分数 + + Returns: + QualityScore: 质量评分对象,如果无法计算则返回None + """ + try: + with open(jsonl_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + if not lines: + return None + + # 检查是否完整 + is_complete = self.check_workflow_completion(jsonl_file) + + # 获取案例数据 + case_data = {} + evaluation_data_by_round = {} # 按轮次组织评估数据 + + for line in lines: + try: + event = json.loads(line.strip()) + + # 获取案例数据 + if event.get('event_type') == 'workflow_start': + case_data = event.get('case_data', {}) + + # 获取评估数据,按轮次组织 + elif (event.get('event_type') == 'agent_execution' and + event.get('agent_name') == 'evaluator'): + output_data = event.get('output_data', {}) + + # 从execution_metadata中获取轮次信息 + execution_metadata = event.get('execution_metadata', {}) + round_num = execution_metadata.get('round', 1) # 默认第1轮 + + # 按轮次存储评估数据 + if round_num not in evaluation_data_by_round: + evaluation_data_by_round[round_num] = {} + evaluation_data_by_round[round_num].update(output_data) + + except (json.JSONDecodeError, KeyError): + continue + + # 计算专业指标惩罚分 + professional_penalty = self.calculate_professional_penalty(evaluation_data_by_round) + + # 计算分诊惩罚分 + triage_penalty = self.calculate_triage_penalty(jsonl_file, case_data) + + # 计算总惩罚分 + total_penalty = professional_penalty + 5 * triage_penalty + + return QualityScore( + professional_penalty=professional_penalty, + triage_penalty=triage_penalty, + total_penalty=total_penalty, + is_complete=is_complete, + file_path=jsonl_file + ) + + except Exception as e: + logger.error(f"计算质量分数时出错 {jsonl_file}: {e}") + return None + def check_workflow_completion(self, jsonl_file: str) -> bool: """ 检查工作流是否完整 @@ -95,8 +314,264 @@ class WorkflowFileCleaner: logger.error(f"检查文件时发生错误 {jsonl_file}: {e}") return False - def scan_and_clean_files(self) -> None: - """扫描目录中的所有JSONL文件并清理不完整的文件""" + def extract_case_index_from_filename(self, filename: str) -> Optional[int]: + """ + 从工作流文件名中提取case索引 + + Args: + filename: 工作流文件名 (如: workflow_20250819_001717_case_0000.jsonl) + + Returns: + int: case索引号,如果无法提取则返回None + """ + try: + # 匹配模式: workflow_*_case_*.jsonl + match = re.search(r'workflow_.*_case_(\d+)\.jsonl$', filename) + if match: + return int(match.group(1)) + return None + except Exception as e: + logger.warning(f"无法从文件名提取case索引 {filename}: {e}") + return None + + def backup_dataset(self) -> bool: + """ + 备份dataset文件 + + Returns: + bool: 备份成功返回True,失败返回False + """ + try: + if not self.dataset_path.exists(): + logger.warning(f"Dataset文件不存在: {self.dataset_path}") + return False + + # 生成带时间戳的备份文件名 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_filename = f"bbb_backup_{timestamp}.json" + backup_path = self.dataset_path.parent / backup_filename + + # 执行备份 + shutil.copy2(self.dataset_path, backup_path) + self.stats['dataset_backup_path'] = str(backup_path) + logger.info(f"Dataset已备份到: {backup_path}") + return True + + except Exception as e: + logger.error(f"备份dataset失败: {e}") + return False + + def load_dataset(self) -> Optional[List[Dict[str, Any]]]: + """ + 加载dataset数据 + + Returns: + List: dataset数据列表,失败返回None + """ + try: + if not self.dataset_path.exists(): + logger.error(f"Dataset文件不存在: {self.dataset_path}") + return None + + with open(self.dataset_path, 'r', encoding='utf-8') as f: + dataset = json.load(f) + + logger.info(f"成功加载dataset,包含{len(dataset)}个case") + return dataset + + except Exception as e: + logger.error(f"加载dataset失败: {e}") + return None + + def save_dataset(self, dataset: List[Dict[str, Any]]) -> bool: + """ + 保存更新后的dataset + + Args: + dataset: 更新后的dataset数据 + + Returns: + bool: 保存成功返回True,失败返回False + """ + try: + with open(self.dataset_path, 'w', encoding='utf-8') as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + + logger.info(f"成功保存更新后的dataset,包含{len(dataset)}个case") + return True + + except Exception as e: + logger.error(f"保存dataset失败: {e}") + return False + + def collect_case_info(self, jsonl_file: str, case_index: int, + dataset: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + 收集被删除case的详细信息 + + Args: + jsonl_file: 工作流文件路径 + case_index: case索引号 + dataset: dataset数据 + + Returns: + Dict: case详细信息 + """ + case_info = { + 'case_index': case_index, + 'jsonl_file': jsonl_file, + 'case_data': None, + 'primary_department': '', + 'secondary_department': '' + } + + try: + # 从dataset获取case数据 + if 0 <= case_index < len(dataset): + case_info['case_data'] = dataset[case_index] + case_info['primary_department'] = dataset[case_index].get('一级科室', '') + case_info['secondary_department'] = dataset[case_index].get('二级科室', '') + else: + logger.warning(f"Case索引超出范围: {case_index}") + + except Exception as e: + logger.error(f"收集case信息时出错 {jsonl_file}: {e}") + + return case_info + + def sync_delete_dataset_cases(self, deleted_case_indices: Set[int]) -> bool: + """ + 同步删除dataset中的case数据 + + Args: + deleted_case_indices: 要删除的case索引集合 + + Returns: + bool: 删除成功返回True,失败返回False + """ + try: + # 加载dataset + dataset = self.load_dataset() + if dataset is None: + return False + + # 备份dataset + if not self.backup_dataset(): + logger.error("无法备份dataset,取消删除操作") + return False + + # 按索引降序排列,避免删除时索引偏移 + sorted_indices = sorted(deleted_case_indices, reverse=True) + original_count = len(dataset) + + # 删除对应的case + for case_index in sorted_indices: + if 0 <= case_index < len(dataset): + removed_case = dataset.pop(case_index) + logger.info(f"从dataset删除case {case_index}: {removed_case.get('一级科室', '')}-{removed_case.get('二级科室', '')}") + else: + logger.warning(f"无效的case索引: {case_index}") + + # 保存更新后的dataset + if self.save_dataset(dataset): + logger.info(f"成功从dataset删除{original_count - len(dataset)}个case") + return True + else: + logger.error("保存更新后的dataset失败") + return False + + except Exception as e: + logger.error(f"同步删除dataset中的case时出错: {e}") + return False + + def validate_data_consistency(self) -> Dict[str, Any]: + """ + 验证工作流文件与dataset的数据一致性 + + Returns: + Dict: 验证结果 + """ + validation_results = { + 'total_jsonl_files': 0, + 'valid_case_mappings': 0, + 'invalid_case_mappings': [], + 'missing_case_indices': [], + 'dataset_size': 0, + 'max_case_index': -1, + 'consistency_rate': 0.0, + 'validation_passed': False + } + + try: + # 加载dataset + dataset = self.load_dataset() + if dataset is None: + validation_results['error'] = "无法加载dataset" + return validation_results + + validation_results['dataset_size'] = len(dataset) + + # 查找所有JSONL文件 + jsonl_pattern = str(self.directory / "**" / "*.jsonl") + jsonl_files = glob.glob(jsonl_pattern, recursive=True) + validation_results['total_jsonl_files'] = len(jsonl_files) + + # 验证每个文件的case索引 + for jsonl_file in jsonl_files: + filename = os.path.basename(jsonl_file) + case_index = self.extract_case_index_from_filename(filename) + + if case_index is not None: + validation_results['max_case_index'] = max(validation_results['max_case_index'], case_index) + + if 0 <= case_index < len(dataset): + validation_results['valid_case_mappings'] += 1 + else: + validation_results['invalid_case_mappings'].append({ + 'file': jsonl_file, + 'case_index': case_index, + 'reason': '索引超出dataset范围' + }) + else: + validation_results['invalid_case_mappings'].append({ + 'file': jsonl_file, + 'case_index': None, + 'reason': '无法从文件名提取case索引' + }) + + # 检查缺失的case索引 + if validation_results['max_case_index'] >= 0: + existing_indices = set() + for jsonl_file in jsonl_files: + filename = os.path.basename(jsonl_file) + case_index = self.extract_case_index_from_filename(filename) + if case_index is not None: + existing_indices.add(case_index) + + expected_indices = set(range(validation_results['max_case_index'] + 1)) + missing_indices = expected_indices - existing_indices + validation_results['missing_case_indices'] = sorted(missing_indices) + + # 计算一致性率 + if validation_results['total_jsonl_files'] > 0: + validation_results['consistency_rate'] = validation_results['valid_case_mappings'] / validation_results['total_jsonl_files'] + + # 判断验证是否通过 + validation_results['validation_passed'] = ( + validation_results['consistency_rate'] >= 0.95 and + len(validation_results['missing_case_indices']) == 0 + ) + + logger.info(f"数据一致性验证完成: 一致性率 {validation_results['consistency_rate']:.2%}") + + except Exception as e: + logger.error(f"数据一致性验证时出错: {e}") + validation_results['error'] = str(e) + + return validation_results + + def analyze_and_clean_files(self) -> None: + """基于质量评估扫描并智能清理文件""" if not self.directory.exists(): logger.error(f"目录不存在: {self.directory}") return @@ -108,50 +583,292 @@ class WorkflowFileCleaner: self.stats['total_files'] = len(jsonl_files) logger.info(f"找到 {len(jsonl_files)} 个JSONL文件") + # 预加载dataset以供后续使用 + dataset = self.load_dataset() + if dataset is None: + logger.warning("无法加载dataset,将跳过dataset同步删除") + + # 计算所有文件的质量分数 + logger.info("正在计算质量分数...") + complete_files = [] + incomplete_files = [] + for jsonl_file in jsonl_files: try: - is_complete = self.check_workflow_completion(jsonl_file) - - if is_complete: + quality_score = self.calculate_quality_score(jsonl_file) + if quality_score is None: + self.stats['error_files'].append(jsonl_file) + continue + + if quality_score.is_complete: + complete_files.append(quality_score) self.stats['complete_files'] += 1 else: + incomplete_files.append(quality_score) self.stats['incomplete_files'] += 1 - if self.dry_run: - logger.info(f"[试运行] 将删除不完整文件: {jsonl_file}") - self.stats['deleted_files'].append(jsonl_file) - else: - os.remove(jsonl_file) - logger.info(f"已删除不完整文件: {jsonl_file}") - self.stats['deleted_files'].append(jsonl_file) - except Exception as e: logger.error(f"处理文件时发生错误 {jsonl_file}: {e}") self.stats['error_files'].append(jsonl_file) + + # 智能清理逻辑(增强版,包含dataset同步删除) + self._smart_cleanup_with_sync(complete_files, incomplete_files, dataset) + + def _smart_cleanup_with_sync(self, complete_files: List[QualityScore], + incomplete_files: List[QualityScore], + dataset: Optional[List[Dict[str, Any]]]) -> None: + """ + 执行智能清理逻辑,包含dataset同步删除功能 + + Args: + complete_files: 完整文件的质量评分列表 + incomplete_files: 不完整文件的质量评分列表 + dataset: dataset数据,用于收集case信息和同步删除 + """ + deleted_case_indices = set() # 收集所有要删除的case索引 + + # 处理不完整文件:保留10%最优质的 + if incomplete_files: + # 按总惩罚分排序(分数越低质量越好) + incomplete_files.sort(key=lambda x: x.total_penalty) + + keep_count = max(1, int(len(incomplete_files) * self.keep_incomplete_ratio)) + keep_files = incomplete_files[:keep_count] + delete_files = incomplete_files[keep_count:] + + self.stats['kept_incomplete_files'] = [f.file_path for f in keep_files] + + # 记录质量分析 + if incomplete_files: + penalties = [f.total_penalty for f in incomplete_files] + self.stats['quality_analysis']['incomplete'] = { + 'avg_penalty': sum(penalties) / len(penalties), + 'score_range': (min(penalties), max(penalties)) + } + + logger.info(f"不完整文件: 总数 {len(incomplete_files)}, 保留 {len(keep_files)}, 删除 {len(delete_files)}") + + # 删除不完整文件并收集case信息 + for quality_score in delete_files: + self._delete_file_with_case_tracking(quality_score, "低质量不完整文件", dataset, deleted_case_indices) + self.stats['deleted_incomplete_files'].append(quality_score.file_path) + + # 处理完整文件:删除20%质量最差的 + if complete_files: + # 按总惩罚分排序(分数越高质量越差) + complete_files.sort(key=lambda x: x.total_penalty, reverse=True) + + delete_count = int(len(complete_files) * self.remove_complete_ratio) + delete_files = complete_files[:delete_count] + keep_files = complete_files[delete_count:] + + self.stats['kept_complete_files'] = [f.file_path for f in keep_files] + + # 记录质量分析 + if complete_files: + penalties = [f.total_penalty for f in complete_files] + self.stats['quality_analysis']['complete'] = { + 'avg_penalty': sum(penalties) / len(penalties), + 'score_range': (min(penalties), max(penalties)) + } + + logger.info(f"完整文件: 总数 {len(complete_files)}, 保留 {len(keep_files)}, 删除 {len(delete_files)}") + + # 删除低质量完整文件并收集case信息 + for quality_score in delete_files: + self._delete_file_with_case_tracking(quality_score, "低质量完整文件", dataset, deleted_case_indices) + self.stats['deleted_complete_files'].append(quality_score.file_path) + + # 同步删除dataset中的对应case + if deleted_case_indices and dataset is not None: + logger.info(f"准备从dataset中删除 {len(deleted_case_indices)} 个case: {sorted(deleted_case_indices)}") + if self.sync_delete_dataset_cases(deleted_case_indices): + logger.info("Dataset同步删除完成") + else: + logger.error("Dataset同步删除失败") + elif deleted_case_indices: + logger.warning(f"检测到 {len(deleted_case_indices)} 个case需要删除,但dataset不可用") + + # 记录删除的case索引 + self.stats['deleted_case_indices'] = sorted(deleted_case_indices) + + def _delete_file_with_case_tracking(self, quality_score: QualityScore, reason: str, + dataset: Optional[List[Dict[str, Any]]], + deleted_case_indices: Set[int]) -> None: + """ + 删除文件并跟踪相关的case信息 + + Args: + quality_score: 质量评分对象 + reason: 删除原因 + dataset: dataset数据 + deleted_case_indices: 用于收集被删除case索引的集合 + """ + file_path = quality_score.file_path + + # 从文件名提取case索引 + filename = os.path.basename(file_path) + case_index = self.extract_case_index_from_filename(filename) + + if case_index is not None and dataset is not None: + # 收集case信息 + case_info = self.collect_case_info(file_path, case_index, dataset) + self.stats['deleted_cases_info'].append(case_info) + deleted_case_indices.add(case_index) + + logger.info(f"准备删除{reason}: {file_path} (case_{case_index}: {case_info['primary_department']}-{case_info['secondary_department']})") + else: + logger.info(f"准备删除{reason}: {file_path} (无法提取case索引)") + + # 执行文件删除 + if self.dry_run: + logger.info(f"[试运行] 将删除{reason}: {file_path}") + else: + try: + os.remove(file_path) + logger.info(f"已删除{reason}: {file_path}") + except Exception as e: + logger.error(f"删除文件失败 {file_path}: {e}") + self.stats['error_files'].append(file_path) + + def _delete_file(self, file_path: str, reason: str) -> None: + """ + 删除文件(兼容性方法) + + Args: + file_path: 文件路径 + reason: 删除原因 + """ + if self.dry_run: + logger.info(f"[试运行] 将删除{reason}: {file_path}") + else: + try: + os.remove(file_path) + logger.info(f"已删除{reason}: {file_path}") + except Exception as e: + logger.error(f"删除文件失败 {file_path}: {e}") + self.stats['error_files'].append(file_path) def print_summary(self) -> None: - """打印统计摘要""" - print("\n" + "="*60) - print("工作流文件清理摘要") - print("="*60) - print(f"总文件数: {self.stats['total_files']}") - print(f"完整文件数: {self.stats['complete_files']}") - print(f"不完整文件数: {self.stats['incomplete_files']}") - print(f"删除文件数: {len(self.stats['deleted_files'])}") - print(f"错误文件数: {len(self.stats['error_files'])}") + """打印详细的统计摘要""" + print("\n" + "="*80) + print("🧠 智能工作流文件清理摘要") + print("="*80) - if self.stats['deleted_files']: - print("\n已删除的文件:") - for file in self.stats['deleted_files']: - print(f" - {file}") + # 基本统计 + print(f"📊 基本统计:") + print(f" 总文件数: {self.stats['total_files']}") + print(f" 完整文件数: {self.stats['complete_files']}") + print(f" 不完整文件数: {self.stats['incomplete_files']}") + print(f" 错误文件数: {len(self.stats['error_files'])}") + # 清理策略统计 + print(f"\n🎯 清理策略统计:") + print(f" 不完整文件保留比例: {self.keep_incomplete_ratio*100:.1f}%") + print(f" 完整文件删除比例: {self.remove_complete_ratio*100:.1f}%") + + # 不完整文件处理结果 + if self.stats['incomplete_files'] > 0: + kept_incomplete = len(self.stats['kept_incomplete_files']) + deleted_incomplete = len(self.stats['deleted_incomplete_files']) + print(f"\n📋 不完整文件处理:") + print(f" 保留数量: {kept_incomplete} ({kept_incomplete/self.stats['incomplete_files']*100:.1f}%)") + print(f" 删除数量: {deleted_incomplete} ({deleted_incomplete/self.stats['incomplete_files']*100:.1f}%)") + + qa = self.stats['quality_analysis']['incomplete'] + if qa['avg_penalty'] > 0: + print(f" 平均惩罚分: {qa['avg_penalty']:.2f}") + print(f" 分数范围: {qa['score_range'][0]:.2f} - {qa['score_range'][1]:.2f}") + + # 完整文件处理结果 + if self.stats['complete_files'] > 0: + kept_complete = len(self.stats['kept_complete_files']) + deleted_complete = len(self.stats['deleted_complete_files']) + print(f"\n✅ 完整文件处理:") + print(f" 保留数量: {kept_complete} ({kept_complete/self.stats['complete_files']*100:.1f}%)") + print(f" 删除数量: {deleted_complete} ({deleted_complete/self.stats['complete_files']*100:.1f}%)") + + qa = self.stats['quality_analysis']['complete'] + if qa['avg_penalty'] > 0: + print(f" 平均惩罚分: {qa['avg_penalty']:.2f}") + print(f" 分数范围: {qa['score_range'][0]:.2f} - {qa['score_range'][1]:.2f}") + + # 总删除统计 + total_deleted = len(self.stats['deleted_incomplete_files']) + len(self.stats['deleted_complete_files']) + if total_deleted > 0: + print(f"\n🗑️ 总删除统计:") + print(f" 删除的不完整文件: {len(self.stats['deleted_incomplete_files'])}") + print(f" 删除的完整文件: {len(self.stats['deleted_complete_files'])}") + print(f" 总删除数量: {total_deleted}") + + # 删除的case信息统计 + if self.stats['deleted_case_indices']: + print(f"\n📋 删除的Case统计:") + print(f" 删除的case数量: {len(self.stats['deleted_case_indices'])}") + print(f" 删除的case索引: {self.stats['deleted_case_indices'][:10]}{'...' if len(self.stats['deleted_case_indices']) > 10 else ''}") + + # 按科室统计删除的case + if self.stats['deleted_cases_info']: + dept_stats = {} + for case_info in self.stats['deleted_cases_info']: + dept_key = f"{case_info['primary_department']}-{case_info['secondary_department']}" + dept_stats[dept_key] = dept_stats.get(dept_key, 0) + 1 + + print(f"\n 按科室统计删除的case:") + for dept, count in sorted(dept_stats.items(), key=lambda x: x[1], reverse=True)[:10]: + print(f" {dept}: {count}个") + if len(dept_stats) > 10: + print(f" ... 以及其他 {len(dept_stats) - 10} 个科室") + + # Dataset备份信息 + if self.stats['dataset_backup_path']: + print(f"\n💾 Dataset备份:") + print(f" 备份文件: {self.stats['dataset_backup_path']}") + + # 错误文件 if self.stats['error_files']: - print("\n处理错误的文件:") - for file in self.stats['error_files']: - print(f" - {file}") + print(f"\n⚠️ 处理错误的文件 ({len(self.stats['error_files'])})个:") + for file in self.stats['error_files'][:5]: # 只显示前5个 + print(f" - {file}") + if len(self.stats['error_files']) > 5: + print(f" ... 以及其他 {len(self.stats['error_files'])-5} 个文件") - if self.dry_run and self.stats['deleted_files']: - print(f"\n注意: 这是试运行模式,实际上没有删除任何文件") + # 数据一致性验证结果 + if 'validation_results' in self.stats: + validation = self.stats['validation_results'] + print(f"\n🔍 数据一致性验证:") + print(f" Dataset大小: {validation.get('dataset_size', 0)}") + print(f" JSONL文件数: {validation.get('total_jsonl_files', 0)}") + print(f" 有效映射数: {validation.get('valid_case_mappings', 0)}") + print(f" 一致性率: {validation.get('consistency_rate', 0):.2%}") + print(f" 验证状态: {'✅ 通过' if validation.get('validation_passed', False) else '❌ 未通过'}") + + if validation.get('missing_case_indices'): + missing_count = len(validation['missing_case_indices']) + print(f" 缺失索引: {missing_count}个 {validation['missing_case_indices'][:5]}{'...' if missing_count > 5 else ''}") + + if validation.get('invalid_case_mappings'): + invalid_count = len(validation['invalid_case_mappings']) + print(f" 无效映射: {invalid_count}个") + + if self.dry_run: + print(f"\n💡 注意: 这是试运行模式,实际上没有删除任何文件") + + # 质量分析建议 + print(f"\n🔍 质量分析建议:") + incomplete_avg = self.stats['quality_analysis']['incomplete']['avg_penalty'] + complete_avg = self.stats['quality_analysis']['complete']['avg_penalty'] + + if incomplete_avg > complete_avg: + print(f" - 不完整文件的平均质量较低,建议优化工作流执行") + else: + print(f" - 完整文件中仍有质量问题,建议加强质量控制") + + if incomplete_avg > 3.0: + print(f" - 不完整文件质量分数偏高,建议检查中断原因") + + if complete_avg > 2.0: + print(f" - 完整文件质量有待提升,建议优化评估标准") def run(self) -> Dict[str, Any]: """ @@ -160,11 +877,26 @@ class WorkflowFileCleaner: Returns: Dict: 包含统计信息的字典 """ - logger.info(f"开始检查目录: {self.directory}") + logger.info(f"🚀 开始智能分析目录: {self.directory}") + logger.info(f"📋 清理策略: 保留{self.keep_incomplete_ratio*100:.0f}%最优不完整文件,删除{self.remove_complete_ratio*100:.0f}%最差完整文件") if self.dry_run: - logger.info("运行在试运行模式") + logger.info("🧪 运行在试运行模式") - self.scan_and_clean_files() + # 执行数据一致性验证 + logger.info("🔍 执行数据一致性验证...") + validation_results = self.validate_data_consistency() + self.stats['validation_results'] = validation_results + + if not validation_results.get('validation_passed', False): + logger.warning(f"⚠️ 数据一致性验证未通过: 一致性率 {validation_results.get('consistency_rate', 0):.2%}") + if validation_results.get('missing_case_indices'): + logger.warning(f" 缺失的case索引: {validation_results['missing_case_indices'][:10]}{'...' if len(validation_results['missing_case_indices']) > 10 else ''}") + if validation_results.get('invalid_case_mappings'): + logger.warning(f" 无效的case映射: {len(validation_results['invalid_case_mappings'])} 个") + else: + logger.info("✅ 数据一致性验证通过") + + self.analyze_and_clean_files() self.print_summary() return self.stats @@ -172,15 +904,33 @@ class WorkflowFileCleaner: def main(): """主函数""" - parser = argparse.ArgumentParser(description='工作流文件清理器') - parser.add_argument('directory', nargs='?', default='results/results0903', - help='要检查的目录路径 (默认: results)') + parser = argparse.ArgumentParser(description='基于质量评估的智能工作流文件清理器') + parser.add_argument('directory', nargs='?', default='results/results0905-2', + help='要检查的目录路径 (默认: results/results0903)') parser.add_argument('--dry-run', action='store_true', help='试运行模式,不实际删除文件') + parser.add_argument('--keep-incomplete', type=float, default=0.1, + help='不完整文件保留比例 (默认: 0.1, 即10%%)') + parser.add_argument('--remove-complete', type=float, default=0.2, + help='完整文件删除比例 (默认: 0.2, 即20%%)') args = parser.parse_args() - cleaner = WorkflowFileCleaner(args.directory, args.dry_run) + # 参数验证 + if not (0.0 <= args.keep_incomplete <= 1.0): + logger.error("--keep-incomplete 参数必须在 0.0 到 1.0 之间") + return + + if not (0.0 <= args.remove_complete <= 1.0): + logger.error("--remove-complete 参数必须在 0.0 到 1.0 之间") + return + + cleaner = IntelligentWorkflowCleaner( + directory=args.directory, + dry_run=args.dry_run, + keep_incomplete_ratio=args.keep_incomplete, + remove_complete_ratio=args.remove_complete + ) cleaner.run() diff --git a/config.py b/config.py index 1a9d2e2..8b3aaad 100755 --- a/config.py +++ b/config.py @@ -32,6 +32,14 @@ LLM_CONFIG = { "api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可 } }, + "Gemma3-4b": { + "class": "OpenAILike", + "params": { + "id": "gemma-3-4b-it", + "base_url": "http://100.82.33.121:19090/v1", # Ollama OpenAI兼容端点 + "api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可 + } + }, "deepseek-v3": { "class": "OpenAILike", "params": { diff --git a/main.py b/main.py index 75a5c03..cf06117 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--log-dir', type=str, - default='results/results0905-2', + default='results/results0905-2-qwen3', help='日志文件保存目录' ) parser.add_argument( @@ -149,7 +149,7 @@ def parse_arguments() -> argparse.Namespace: '--model-type', type=str, choices=available_models, - default='gpt-oss:latest', + default='Qwen3-7B', help=f'使用的语言模型类型,可选: {", ".join(available_models)}' ) parser.add_argument(