diff --git a/config.py b/config.py index 0e9a32d..022bf96 100755 --- a/config.py +++ b/config.py @@ -19,9 +19,9 @@ LLM_CONFIG = { "gpt-oss:latest": { "class": "OpenAILike", "params": { - "id": "gpt-oss-20b", - "base_url": "http://100.82.33.121:11001/v1", # Ollama OpenAI兼容端点 - "api_key": "ollama" # Ollama不需要真实API密钥,任意字符串即可 + "id": "gpt-oss", + "base_url": "http://100.82.33.121:19090/v1", # Ollama OpenAI兼容端点 + "api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可 } }, "deepseek-v3": { diff --git a/main.py b/main.py index 9027427..56c84e2 100755 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ import os import sys import time import threading +import glob from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from typing import Dict, Any, List, Optional @@ -28,6 +29,7 @@ class BatchProcessor: self.processed_count = 0 # 已处理样本数 self.success_count = 0 # 成功处理数 self.failed_count = 0 # 失败处理数 + self.skipped_count = 0 # 跳过的样本数 self.results = [] # 结果列表 self.failed_samples = [] # 失败样本列表 self.start_time = None # 开始时间 @@ -49,6 +51,12 @@ class BatchProcessor: 'error': str(error), 'timestamp': datetime.now().isoformat() }) + + def update_skipped(self, sample_index: int): + """线程安全地更新跳过样本计数""" + with self.lock: + self.skipped_count += 1 + logging.info(f"样本 {sample_index} 已完成,跳过处理") def get_progress_stats(self) -> Dict[str, Any]: """获取当前进度统计""" @@ -58,6 +66,7 @@ class BatchProcessor: 'processed': self.processed_count, 'success': self.success_count, 'failed': self.failed_count, + 'skipped': self.skipped_count, 'success_rate': self.success_count / max(self.processed_count, 1), 'elapsed_time': elapsed_time, 'samples_per_minute': self.processed_count / max(elapsed_time / 60, 0.01) @@ -91,7 +100,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--log-dir', type=str, - default='results/results0902', + default='results/results0904', help='日志文件保存目录' ) parser.add_argument( @@ -105,7 +114,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--num-threads', type=int, - default=40, + default=60, help='并行处理线程数' ) parser.add_argument( @@ -123,7 +132,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--end-index', type=int, - default=120, + default=5000, help='结束处理的样本索引(不包含)' ) parser.add_argument( @@ -170,6 +179,80 @@ def parse_arguments() -> argparse.Namespace: return parser.parse_args() +def is_case_completed(log_dir: str, case_index: int) -> bool: + """ + 检查指定case是否已经完成工作流 + 如果存在不完整的文件则删除,确保每个case在目录中只出现一次 + + Args: + log_dir: 日志目录 + case_index: case序号 + + Returns: + bool: 如果case已完成返回True,否则返回False + """ + # 构建文件路径模式:workflow_*_case_{case_index:04d}.jsonl + pattern = os.path.join(log_dir, f"workflow_*_case_{case_index:04d}.jsonl") + matching_files = glob.glob(pattern) + + if not matching_files: + return False + + # 应该只有一个匹配的文件 + if len(matching_files) > 1: + logging.warning(f"发现多个匹配文件 case {case_index}: {matching_files}") + + # 检查每个匹配的文件 + for log_file in matching_files: + try: + with open(log_file, 'r', encoding='utf-8') as f: + # 读取最后一行 + lines = f.readlines() + if not lines: + # 文件为空,删除 + os.remove(log_file) + logging.info(f"删除空文件: {log_file}") + continue + + last_line = lines[-1].strip() + if not last_line: + # 最后一行为空,删除 + os.remove(log_file) + logging.info(f"删除最后一行为空的文件: {log_file}") + continue + + # 解析最后一行的JSON + try: + last_entry = json.loads(last_line) + if last_entry.get("event_type") == "workflow_complete": + # 找到完整的文件 + logging.info(f"发现已完成的case {case_index}: {log_file}") + return True + else: + # 文件不完整,删除 + os.remove(log_file) + logging.info(f"删除不完整的文件: {log_file}") + continue + + except json.JSONDecodeError: + # JSON解析失败,删除文件 + os.remove(log_file) + logging.info(f"删除JSON格式错误的文件: {log_file}") + continue + + except Exception as e: + logging.warning(f"检查文件 {log_file} 时出错: {e}") + # 出现异常也删除文件,避免后续问题 + try: + os.remove(log_file) + logging.info(f"删除异常文件: {log_file}") + except: + pass + continue + + # 所有匹配的文件都被删除或没有完整的文件 + return False + def load_dataset(dataset_path: str, start_index: int = 0, end_index: Optional[int] = None, sample_limit: Optional[int] = None) -> List[Dict[str, Any]]: @@ -306,9 +389,10 @@ def print_progress_report(processor: BatchProcessor, total_samples: int): print(f"\n=== 处理进度报告 ===") print(f"已处理: {stats['processed']}/{total_samples} ({stats['processed']/total_samples:.1%})") - print(f"成功: {stats['success']} | 失败: {stats['failed']} | 成功率: {stats['success_rate']:.1%}") + print(f"成功: {stats['success']} | 失败: {stats['failed']} | 跳过: {stats['skipped']} | 成功率: {stats['success_rate']:.1%}") print(f"耗时: {stats['elapsed_time']:.1f}s | 处理速度: {stats['samples_per_minute']:.1f} 样本/分钟") - print(f"预计剩余时间: {(total_samples - stats['processed']) / max(stats['samples_per_minute'] / 60, 0.01):.1f}s") + remaining_samples = total_samples - stats['processed'] - stats['skipped'] + print(f"预计剩余时间: {remaining_samples / max(stats['samples_per_minute'] / 60, 0.01):.1f}s") print("=" * 50) def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]: @@ -326,9 +410,9 @@ def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace) # 启动进度监控线程 def progress_monitor(): - while processor.processed_count < total_samples: + while processor.processed_count + processor.skipped_count < total_samples: time.sleep(args.progress_interval) - if processor.processed_count < total_samples: + if processor.processed_count + processor.skipped_count < total_samples: print_progress_report(processor, total_samples) progress_thread = threading.Thread(target=progress_monitor, daemon=True) @@ -341,6 +425,12 @@ def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace) future_to_index = {} for i, sample_data in enumerate(dataset): sample_index = args.start_index + i + + # 检查case是否已经完成 + if is_case_completed(args.log_dir, sample_index): + processor.update_skipped(sample_index) + continue + future = executor.submit( process_single_sample, sample_data, @@ -375,6 +465,7 @@ def run_workflow_batch(dataset: List[Dict[str, Any]], args: argparse.Namespace) 'processed_samples': processor.processed_count, 'successful_samples': processor.success_count, 'failed_samples': processor.failed_count, + 'skipped_samples': processor.skipped_count, 'success_rate': stats['success_rate'], 'total_execution_time': total_time, 'average_time_per_sample': total_time / max(processor.processed_count, 1), @@ -428,6 +519,7 @@ def generate_summary_report(batch_results: Dict[str, Any], f.write(f"处理样本数: {summary['processed_samples']}\n") f.write(f"成功样本数: {summary['successful_samples']}\n") f.write(f"失败样本数: {summary['failed_samples']}\n") + f.write(f"跳过样本数: {summary['skipped_samples']}\n") f.write(f"成功率: {summary['success_rate']:.2%}\n") f.write(f"总执行时间: {summary['total_execution_time']:.2f} 秒\n") f.write(f"平均处理时间: {summary['average_time_per_sample']:.2f} 秒/样本\n") diff --git a/workflow/step_executor.py b/workflow/step_executor.py index 95977a6..0098d6e 100755 --- a/workflow/step_executor.py +++ b/workflow/step_executor.py @@ -21,7 +21,7 @@ class StepExecutor: _global_historical_scores = { "clinical_inquiry": 0.0, "communication_quality": 0.0, - "multi_round_consistency": 0.0, + "information_completeness": 0.0, "overall_professionalism": 0.0, "present_illness_similarity": 0.0, "past_history_similarity": 0.0, @@ -34,7 +34,7 @@ class StepExecutor: cls._global_historical_scores = { "clinical_inquiry": 0.0, "communication_quality": 0.0, - "multi_round_consistency": 0.0, + "information_completeness": 0.0, "overall_professionalism": 0.0, "present_illness_similarity": 0.0, "past_history_similarity": 0.0, @@ -545,7 +545,7 @@ class StepExecutor: round_data["evaluation_scores"] = { "clinical_inquiry": 0.0, "communication_quality": 0.0, - "multi_round_consistency": 0.0, + "information_completeness": 0.0, "overall_professionalism": 0.0, "present_illness_similarity": 0.0, "past_history_similarity": 0.0, @@ -571,9 +571,9 @@ class StepExecutor: "score": result.communication_quality.score, "comment": result.communication_quality.comment }, - "multi_round_consistency": { - "score": result.multi_round_consistency.score, - "comment": result.multi_round_consistency.comment + "information_completeness": { + "score": result.information_completeness.score, + "comment": result.information_completeness.comment }, "overall_professionalism": { "score": result.overall_professionalism.score, @@ -601,7 +601,7 @@ class StepExecutor: self._global_historical_scores = { "clinical_inquiry": result.clinical_inquiry.score, "communication_quality": result.communication_quality.score, - "multi_round_consistency": result.multi_round_consistency.score, + "information_completeness": result.information_completeness.score, "overall_professionalism": result.overall_professionalism.score, "present_illness_similarity": result.present_illness_similarity.score, "past_history_similarity": result.past_history_similarity.score, @@ -620,7 +620,7 @@ class StepExecutor: return EvaluatorResult( clinical_inquiry=default_dimension, communication_quality=default_dimension, - multi_round_consistency=default_dimension, + information_completeness=default_dimension, overall_professionalism=default_dimension, present_illness_similarity=default_dimension, past_history_similarity=default_dimension,