diff --git a/.gitignore b/.gitignore index adc90da..e03e71f 100644 --- a/.gitignore +++ b/.gitignore @@ -53,4 +53,10 @@ results/ file_analyzer.py file_deleter_renumber.py -files_to_delete.txt \ No newline at end of file +files_to_delete.txt + +output0902-qwen3 +.claude + + +.nfs* diff --git a/agent_system/controller/agent.py b/agent_system/controller/agent.py index 819b3b2..36f7f6e 100644 --- a/agent_system/controller/agent.py +++ b/agent_system/controller/agent.py @@ -20,16 +20,25 @@ class TaskController(BaseAgent): Attributes: model_type (str): 使用的大语言模型类型,默认为 gpt-oss:latest llm_config (dict): LLM模型配置参数 + simple_mode (bool): 简化模式标志,True时自动选择第一个任务并返回固定指导 """ - def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None): + def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None, + simple_mode: bool = False, score_driven_mode: bool = False): """ 初始化任务控制器智能体 Args: model_type (str): 大语言模型类型,默认使用 gpt-oss:latest llm_config (dict): LLM模型的配置参数,如果为None则使用默认配置 + simple_mode (bool): 简化模式,如果为True则自动选择第一个任务并返回固定指导,默认为False + score_driven_mode (bool): 分数驱动模式,如果为True则选择当前任务组中分数最低的任务,默认为False + + Note: + score_driven_mode和simple_mode不能同时为True,如果同时为True则优先使用score_driven_mode """ + self.simple_mode = simple_mode + self.score_driven_mode = score_driven_mode super().__init__( model_type=model_type, description="医疗任务控制器,负责任务选择和预问诊询问指导", @@ -46,7 +55,8 @@ class TaskController(BaseAgent): chief_complaint: str, hpi_content: str = "", ph_content: str = "", - additional_info: str = "") -> ControllerDecision: + additional_info: str = "", + task_manager = None) -> ControllerDecision: """ 执行任务控制决策 @@ -69,6 +79,14 @@ class TaskController(BaseAgent): Exception: 当LLM调用失败时,返回包含默认信息的ControllerDecision """ try: + # 分数驱动模式:选择当前任务组中分数最低的任务 + if self.score_driven_mode and task_manager is not None: + return self._get_score_driven_result(pending_tasks, task_manager) + + # 简化模式:直接选择第一个任务并返回固定指导 + elif self.simple_mode: + return self._get_simple_mode_result(pending_tasks) + # 构建决策提示词 prompt = self._build_decision_prompt( pending_tasks, chief_complaint, hpi_content, ph_content, additional_info @@ -103,6 +121,90 @@ class TaskController(BaseAgent): # 如果类型不匹配,返回默认结果 return self._get_fallback_result([]) + def _get_score_driven_result(self, pending_tasks: List[Dict[str, str]], task_manager) -> ControllerDecision: + """ + 分数驱动模式下生成决策结果 + + 在分数驱动模式下,从当前任务组的未完成任务中选择分数最低的任务, + 并返回相应的询问指导。这是基于数值比较的算法选择,无需LLM参与。 + + Args: + pending_tasks (List[Dict[str, str]]): 待执行的任务列表,包含name和description + task_manager: 任务管理器实例,用于获取任务分数信息 + + Returns: + ControllerDecision: 包含分数驱动模式任务选择和指导的结果 + """ + if not pending_tasks: + return ControllerDecision( + selected_task="基本信息收集", + specific_guidance="当前没有待执行任务,请按照标准医疗询问流程进行患者评估。" + ) + + # 获取当前任务阶段 + current_phase = task_manager.get_current_phase() + + # 获取当前阶段的任务分数 + phase_scores = task_manager.get_task_scores(current_phase) + + # 在待执行任务中找到分数最低的任务 + lowest_score_task = None + lowest_score = float('inf') + + for task in pending_tasks: + task_name = task.get("name", "") + task_score = phase_scores.get(task_name, 0.0) + + if task_score < lowest_score: + lowest_score = task_score + lowest_score_task = task + + # 如果没有找到合适的任务,选择第一个并记录错误日志 + if lowest_score_task is None: + # 使用logger记录错误,如果没有logger则使用print作为后备 + error_msg = f"Controller-ScoreDriven警告:在阶段{current_phase.value}中未找到合适任务,使用默认第一个任务" + try: + import logging + logger = logging.getLogger(__name__) + logger.error(error_msg) + except: + print(f"[ERROR] {error_msg}") + + lowest_score_task = pending_tasks[0] + lowest_score = phase_scores.get(lowest_score_task.get("name", ""), 0.0) + + selected_task_name = lowest_score_task.get("name", "未知任务") + + # 使用和simple模式相同的固定指导 + return ControllerDecision( + selected_task=selected_task_name, + specific_guidance="请按照标准医疗询问流程进行患者评估,基于患者临床信息选择最重要的询问任务,提供针对性的、具体的、可操作的询问指导建议,确保指导内容仅限于医生可以通过询问获取的信息。" + ) + + def _get_simple_mode_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision: + """ + 简化模式下生成决策结果 + + 在简化模式下,直接选择第一个待执行任务,并返回固定的询问指导。 + + Args: + pending_tasks (List[Dict[str, str]]): 待执行的任务列表 + + Returns: + ControllerDecision: 包含简化模式任务选择和固定指导的结果 + """ + # 如果有待执行任务,选择第一个作为默认任务 + if pending_tasks: + selected_task = pending_tasks[0] + selected_task_name = selected_task.get("name", "未知任务") + else: + selected_task_name = "基本信息收集" + + return ControllerDecision( + selected_task=selected_task_name, + specific_guidance="请按照标准医疗询问流程进行患者评估,基于患者临床信息选择最重要的询问任务,提供针对性的、具体的、可操作的询问指导建议,确保指导内容仅限于医生可以通过询问获取的信息。" + ) + def _get_fallback_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision: """ 生成决策失败时的默认结果 diff --git a/config.py b/config.py index 8b3aaad..19c8a23 100755 --- a/config.py +++ b/config.py @@ -24,6 +24,14 @@ LLM_CONFIG = { "api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可 } }, + "phi4": { + "class": "OpenAILike", + "params": { + "id": "microsoft/phi-4", + "base_url": "http://127.0.0.1:8000/v1", # Ollama OpenAI兼容端点 + "api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可 + } + }, "Qwen3-7B": { "class": "OpenAILike", "params": { diff --git a/main.py b/main.py index 597b56a..ed36bce 100755 --- a/main.py +++ b/main.py @@ -101,7 +101,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--log-dir', type=str, - default='results/results0905-gemma3', + default='results/results09010-score_driven', help='日志文件保存目录' ) parser.add_argument( @@ -115,7 +115,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument( '--num-threads', type=int, - default=40, + default=45, help='并行处理线程数' ) parser.add_argument( @@ -149,7 +149,7 @@ def parse_arguments() -> argparse.Namespace: '--model-type', type=str, choices=available_models, - default='Gemma3-4b', + default='openai-mirror/gpt-oss-20b', help=f'使用的语言模型类型,可选: {", ".join(available_models)}' ) parser.add_argument( @@ -163,6 +163,13 @@ def parse_arguments() -> argparse.Namespace: default=None, help='模型配置JSON字符串(可选,覆盖默认配置)' ) + parser.add_argument( + '--controller-mode', + type=str, + choices=['normal', 'sequence', 'score_driven'], + default='score_driven', + help='任务控制器模式:normal为智能模式(需要LLM推理),sequence为顺序模式(直接选择第一个任务),score_driven为分数驱动模式(选择当前任务组中分数最低的任务)' + ) # 调试和日志 @@ -347,7 +354,8 @@ def process_single_sample(sample_data: Dict[str, Any], sample_index: int, llm_config=llm_config, max_steps=args.max_steps, log_dir=args.log_dir, - case_index=sample_index + case_index=sample_index, + controller_mode=args.controller_mode ) # 执行工作流 diff --git a/workflow/medical_workflow.py b/workflow/medical_workflow.py index 751d26f..61a3346 100755 --- a/workflow/medical_workflow.py +++ b/workflow/medical_workflow.py @@ -12,7 +12,7 @@ class MedicalWorkflow: def __init__(self, case_data: Dict[str, Any], model_type: str = "gpt-oss:latest", llm_config: Optional[Dict] = None, max_steps: int = 30, log_dir: str = "logs", - case_index: Optional[int] = None): + case_index: Optional[int] = None, controller_mode: str = "normal"): """ 初始化医疗问诊工作流 @@ -23,6 +23,7 @@ class MedicalWorkflow: max_steps: 最大执行步数,默认为30 log_dir: 日志目录,默认为"logs" case_index: 病例序号,用于日志文件命名 + controller_mode: 任务控制器模式,'normal'为智能模式,'sequence'为顺序模式,'score_driven'为分数驱动模式 """ self.case_data = case_data self.model_type = model_type @@ -31,7 +32,7 @@ class MedicalWorkflow: # 初始化核心组件 self.task_manager = TaskManager() - self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config) + self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config, controller_mode=controller_mode) self.logger = WorkflowLogger(case_data=case_data, log_dir=log_dir, case_index=case_index) # 重置历史评分,确保新的工作流从零开始 diff --git a/workflow/step_executor.py b/workflow/step_executor.py index 77a5bb6..926e335 100755 --- a/workflow/step_executor.py +++ b/workflow/step_executor.py @@ -41,25 +41,35 @@ class StepExecutor: "chief_complaint_similarity": 0.0 } - def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None): + def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None, controller_mode: str = "normal"): """ 初始化step执行器 Args: model_type: 使用的语言模型类型(除Evaluator外的所有agent使用) llm_config: 语言模型配置 + controller_mode: 任务控制器模式,'normal'为智能模式,'sequence'为顺序模式,'score_driven'为分数驱动模式 Note: Evaluator agent 固定使用 gpt-oss:latest 模型,不受 model_type 参数影响 """ self.model_type = model_type self.llm_config = llm_config or {} + self.controller_mode = controller_mode # 初始化所有agent self.recipient = RecipientAgent(model_type=model_type, llm_config=self.llm_config) self.triager = TriageAgent(model_type=model_type, llm_config=self.llm_config) self.monitor = Monitor(model_type=model_type, llm_config=self.llm_config) - self.controller = TaskController(model_type=model_type, llm_config=self.llm_config) + # 根据模式初始化TaskController + simple_mode = (controller_mode == "sequence") + score_driven_mode = (controller_mode == "score_driven") + self.controller = TaskController( + model_type=model_type, + llm_config=self.llm_config, + simple_mode=simple_mode, + score_driven_mode=score_driven_mode + ) self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config) self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config) # Evaluator 固定使用 gpt-oss:latest 模型 @@ -402,18 +412,28 @@ class StepExecutor: "pending_tasks": pending_tasks, "chief_complaint": recipient_result.chief_complaint, "hpi_content": recipient_result.updated_HPI, - "ph_content": recipient_result.updated_PH + "ph_content": recipient_result.updated_PH, + "task_manager": task_manager # 传递task_manager用于score_driven模式 } result = self.controller.run(**input_data) execution_time = time.time() - start_time + # 为日志记录创建可序列化的input_data副本(移除TaskManager对象) + log_input_data = { + "pending_tasks": input_data["pending_tasks"], + "chief_complaint": input_data["chief_complaint"], + "hpi_content": input_data["hpi_content"], + "ph_content": input_data["ph_content"] + # 不包含task_manager,因为它不能JSON序列化 + } + output_data = { "selected_task": result.selected_task, "specific_guidance": result.specific_guidance } - logger.log_agent_execution(step_num, "controller", input_data, output_data, execution_time) + logger.log_agent_execution(step_num, "controller", log_input_data, output_data, execution_time) return result