增强任务控制器智能体模式和模型配置

- 新增TaskController简化模式和分数驱动模式支持
- 添加phi4模型配置选项
- 优化主程序参数配置和默认设置
- 完善工作流和步骤执行器功能
- 更新.gitignore忽略规则

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
iomgaa 2025-09-10 23:59:41 +08:00
parent a4eca4897d
commit cc59034f1d
6 changed files with 158 additions and 13 deletions

8
.gitignore vendored
View File

@ -53,4 +53,10 @@ results/
file_analyzer.py
file_deleter_renumber.py
files_to_delete.txt
files_to_delete.txt
output0902-qwen3
.claude
.nfs*

View File

@ -20,16 +20,25 @@ class TaskController(BaseAgent):
Attributes:
model_type (str): 使用的大语言模型类型默认为 gpt-oss:latest
llm_config (dict): LLM模型配置参数
simple_mode (bool): 简化模式标志True时自动选择第一个任务并返回固定指导
"""
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None):
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None,
simple_mode: bool = False, score_driven_mode: bool = False):
"""
初始化任务控制器智能体
Args:
model_type (str): 大语言模型类型默认使用 gpt-oss:latest
llm_config (dict): LLM模型的配置参数如果为None则使用默认配置
simple_mode (bool): 简化模式如果为True则自动选择第一个任务并返回固定指导默认为False
score_driven_mode (bool): 分数驱动模式如果为True则选择当前任务组中分数最低的任务默认为False
Note:
score_driven_mode和simple_mode不能同时为True如果同时为True则优先使用score_driven_mode
"""
self.simple_mode = simple_mode
self.score_driven_mode = score_driven_mode
super().__init__(
model_type=model_type,
description="医疗任务控制器,负责任务选择和预问诊询问指导",
@ -46,7 +55,8 @@ class TaskController(BaseAgent):
chief_complaint: str,
hpi_content: str = "",
ph_content: str = "",
additional_info: str = "") -> ControllerDecision:
additional_info: str = "",
task_manager = None) -> ControllerDecision:
"""
执行任务控制决策
@ -69,6 +79,14 @@ class TaskController(BaseAgent):
Exception: 当LLM调用失败时返回包含默认信息的ControllerDecision
"""
try:
# 分数驱动模式:选择当前任务组中分数最低的任务
if self.score_driven_mode and task_manager is not None:
return self._get_score_driven_result(pending_tasks, task_manager)
# 简化模式:直接选择第一个任务并返回固定指导
elif self.simple_mode:
return self._get_simple_mode_result(pending_tasks)
# 构建决策提示词
prompt = self._build_decision_prompt(
pending_tasks, chief_complaint, hpi_content, ph_content, additional_info
@ -103,6 +121,90 @@ class TaskController(BaseAgent):
# 如果类型不匹配,返回默认结果
return self._get_fallback_result([])
def _get_score_driven_result(self, pending_tasks: List[Dict[str, str]], task_manager) -> ControllerDecision:
"""
分数驱动模式下生成决策结果
在分数驱动模式下从当前任务组的未完成任务中选择分数最低的任务
并返回相应的询问指导这是基于数值比较的算法选择无需LLM参与
Args:
pending_tasks (List[Dict[str, str]]): 待执行的任务列表包含name和description
task_manager: 任务管理器实例用于获取任务分数信息
Returns:
ControllerDecision: 包含分数驱动模式任务选择和指导的结果
"""
if not pending_tasks:
return ControllerDecision(
selected_task="基本信息收集",
specific_guidance="当前没有待执行任务,请按照标准医疗询问流程进行患者评估。"
)
# 获取当前任务阶段
current_phase = task_manager.get_current_phase()
# 获取当前阶段的任务分数
phase_scores = task_manager.get_task_scores(current_phase)
# 在待执行任务中找到分数最低的任务
lowest_score_task = None
lowest_score = float('inf')
for task in pending_tasks:
task_name = task.get("name", "")
task_score = phase_scores.get(task_name, 0.0)
if task_score < lowest_score:
lowest_score = task_score
lowest_score_task = task
# 如果没有找到合适的任务,选择第一个并记录错误日志
if lowest_score_task is None:
# 使用logger记录错误如果没有logger则使用print作为后备
error_msg = f"Controller-ScoreDriven警告在阶段{current_phase.value}中未找到合适任务,使用默认第一个任务"
try:
import logging
logger = logging.getLogger(__name__)
logger.error(error_msg)
except:
print(f"[ERROR] {error_msg}")
lowest_score_task = pending_tasks[0]
lowest_score = phase_scores.get(lowest_score_task.get("name", ""), 0.0)
selected_task_name = lowest_score_task.get("name", "未知任务")
# 使用和simple模式相同的固定指导
return ControllerDecision(
selected_task=selected_task_name,
specific_guidance="请按照标准医疗询问流程进行患者评估,基于患者临床信息选择最重要的询问任务,提供针对性的、具体的、可操作的询问指导建议,确保指导内容仅限于医生可以通过询问获取的信息。"
)
def _get_simple_mode_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision:
"""
简化模式下生成决策结果
在简化模式下直接选择第一个待执行任务并返回固定的询问指导
Args:
pending_tasks (List[Dict[str, str]]): 待执行的任务列表
Returns:
ControllerDecision: 包含简化模式任务选择和固定指导的结果
"""
# 如果有待执行任务,选择第一个作为默认任务
if pending_tasks:
selected_task = pending_tasks[0]
selected_task_name = selected_task.get("name", "未知任务")
else:
selected_task_name = "基本信息收集"
return ControllerDecision(
selected_task=selected_task_name,
specific_guidance="请按照标准医疗询问流程进行患者评估,基于患者临床信息选择最重要的询问任务,提供针对性的、具体的、可操作的询问指导建议,确保指导内容仅限于医生可以通过询问获取的信息。"
)
def _get_fallback_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision:
"""
生成决策失败时的默认结果

View File

@ -24,6 +24,14 @@ LLM_CONFIG = {
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥任意字符串即可
}
},
"phi4": {
"class": "OpenAILike",
"params": {
"id": "microsoft/phi-4",
"base_url": "http://127.0.0.1:8000/v1", # Ollama OpenAI兼容端点
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥任意字符串即可
}
},
"Qwen3-7B": {
"class": "OpenAILike",
"params": {

16
main.py
View File

@ -101,7 +101,7 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
'--log-dir',
type=str,
default='results/results0905-gemma3',
default='results/results09010-score_driven',
help='日志文件保存目录'
)
parser.add_argument(
@ -115,7 +115,7 @@ def parse_arguments() -> argparse.Namespace:
parser.add_argument(
'--num-threads',
type=int,
default=40,
default=45,
help='并行处理线程数'
)
parser.add_argument(
@ -149,7 +149,7 @@ def parse_arguments() -> argparse.Namespace:
'--model-type',
type=str,
choices=available_models,
default='Gemma3-4b',
default='openai-mirror/gpt-oss-20b',
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
)
parser.add_argument(
@ -163,6 +163,13 @@ def parse_arguments() -> argparse.Namespace:
default=None,
help='模型配置JSON字符串可选覆盖默认配置'
)
parser.add_argument(
'--controller-mode',
type=str,
choices=['normal', 'sequence', 'score_driven'],
default='score_driven',
help='任务控制器模式normal为智能模式需要LLM推理sequence为顺序模式直接选择第一个任务score_driven为分数驱动模式选择当前任务组中分数最低的任务'
)
# 调试和日志
@ -347,7 +354,8 @@ def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
llm_config=llm_config,
max_steps=args.max_steps,
log_dir=args.log_dir,
case_index=sample_index
case_index=sample_index,
controller_mode=args.controller_mode
)
# 执行工作流

View File

@ -12,7 +12,7 @@ class MedicalWorkflow:
def __init__(self, case_data: Dict[str, Any], model_type: str = "gpt-oss:latest",
llm_config: Optional[Dict] = None, max_steps: int = 30, log_dir: str = "logs",
case_index: Optional[int] = None):
case_index: Optional[int] = None, controller_mode: str = "normal"):
"""
初始化医疗问诊工作流
@ -23,6 +23,7 @@ class MedicalWorkflow:
max_steps: 最大执行步数默认为30
log_dir: 日志目录默认为"logs"
case_index: 病例序号用于日志文件命名
controller_mode: 任务控制器模式'normal'为智能模式'sequence'为顺序模式'score_driven'为分数驱动模式
"""
self.case_data = case_data
self.model_type = model_type
@ -31,7 +32,7 @@ class MedicalWorkflow:
# 初始化核心组件
self.task_manager = TaskManager()
self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config)
self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config, controller_mode=controller_mode)
self.logger = WorkflowLogger(case_data=case_data, log_dir=log_dir, case_index=case_index)
# 重置历史评分,确保新的工作流从零开始

View File

@ -41,25 +41,35 @@ class StepExecutor:
"chief_complaint_similarity": 0.0
}
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None):
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None, controller_mode: str = "normal"):
"""
初始化step执行器
Args:
model_type: 使用的语言模型类型除Evaluator外的所有agent使用
llm_config: 语言模型配置
controller_mode: 任务控制器模式'normal'为智能模式'sequence'为顺序模式'score_driven'为分数驱动模式
Note:
Evaluator agent 固定使用 gpt-oss:latest 模型不受 model_type 参数影响
"""
self.model_type = model_type
self.llm_config = llm_config or {}
self.controller_mode = controller_mode
# 初始化所有agent
self.recipient = RecipientAgent(model_type=model_type, llm_config=self.llm_config)
self.triager = TriageAgent(model_type=model_type, llm_config=self.llm_config)
self.monitor = Monitor(model_type=model_type, llm_config=self.llm_config)
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config)
# 根据模式初始化TaskController
simple_mode = (controller_mode == "sequence")
score_driven_mode = (controller_mode == "score_driven")
self.controller = TaskController(
model_type=model_type,
llm_config=self.llm_config,
simple_mode=simple_mode,
score_driven_mode=score_driven_mode
)
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
# Evaluator 固定使用 gpt-oss:latest 模型
@ -402,18 +412,28 @@ class StepExecutor:
"pending_tasks": pending_tasks,
"chief_complaint": recipient_result.chief_complaint,
"hpi_content": recipient_result.updated_HPI,
"ph_content": recipient_result.updated_PH
"ph_content": recipient_result.updated_PH,
"task_manager": task_manager # 传递task_manager用于score_driven模式
}
result = self.controller.run(**input_data)
execution_time = time.time() - start_time
# 为日志记录创建可序列化的input_data副本移除TaskManager对象
log_input_data = {
"pending_tasks": input_data["pending_tasks"],
"chief_complaint": input_data["chief_complaint"],
"hpi_content": input_data["hpi_content"],
"ph_content": input_data["ph_content"]
# 不包含task_manager因为它不能JSON序列化
}
output_data = {
"selected_task": result.selected_task,
"specific_guidance": result.specific_guidance
}
logger.log_agent_execution(step_num, "controller", input_data, output_data, execution_time)
logger.log_agent_execution(step_num, "controller", log_input_data, output_data, execution_time)
return result