update
This commit is contained in:
parent
a4eca4897d
commit
cfa2fdf705
5
.gitignore
vendored
5
.gitignore
vendored
@ -53,4 +53,7 @@ results/
|
|||||||
|
|
||||||
file_analyzer.py
|
file_analyzer.py
|
||||||
file_deleter_renumber.py
|
file_deleter_renumber.py
|
||||||
files_to_delete.txt
|
files_to_delete.txt
|
||||||
|
|
||||||
|
output0902-qwen3
|
||||||
|
.claude
|
||||||
@ -20,16 +20,19 @@ class TaskController(BaseAgent):
|
|||||||
Attributes:
|
Attributes:
|
||||||
model_type (str): 使用的大语言模型类型,默认为 gpt-oss:latest
|
model_type (str): 使用的大语言模型类型,默认为 gpt-oss:latest
|
||||||
llm_config (dict): LLM模型配置参数
|
llm_config (dict): LLM模型配置参数
|
||||||
|
simple_mode (bool): 简化模式标志,True时自动选择第一个任务并返回固定指导
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None):
|
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None, simple_mode: bool = False):
|
||||||
"""
|
"""
|
||||||
初始化任务控制器智能体
|
初始化任务控制器智能体
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type (str): 大语言模型类型,默认使用 gpt-oss:latest
|
model_type (str): 大语言模型类型,默认使用 gpt-oss:latest
|
||||||
llm_config (dict): LLM模型的配置参数,如果为None则使用默认配置
|
llm_config (dict): LLM模型的配置参数,如果为None则使用默认配置
|
||||||
|
simple_mode (bool): 简化模式,如果为True则自动选择第一个任务并返回固定指导,默认为False
|
||||||
"""
|
"""
|
||||||
|
self.simple_mode = simple_mode
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
description="医疗任务控制器,负责任务选择和预问诊询问指导",
|
description="医疗任务控制器,负责任务选择和预问诊询问指导",
|
||||||
@ -69,6 +72,10 @@ class TaskController(BaseAgent):
|
|||||||
Exception: 当LLM调用失败时,返回包含默认信息的ControllerDecision
|
Exception: 当LLM调用失败时,返回包含默认信息的ControllerDecision
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 简化模式:直接选择第一个任务并返回固定指导
|
||||||
|
if self.simple_mode:
|
||||||
|
return self._get_simple_mode_result(pending_tasks)
|
||||||
|
|
||||||
# 构建决策提示词
|
# 构建决策提示词
|
||||||
prompt = self._build_decision_prompt(
|
prompt = self._build_decision_prompt(
|
||||||
pending_tasks, chief_complaint, hpi_content, ph_content, additional_info
|
pending_tasks, chief_complaint, hpi_content, ph_content, additional_info
|
||||||
@ -103,6 +110,30 @@ class TaskController(BaseAgent):
|
|||||||
# 如果类型不匹配,返回默认结果
|
# 如果类型不匹配,返回默认结果
|
||||||
return self._get_fallback_result([])
|
return self._get_fallback_result([])
|
||||||
|
|
||||||
|
def _get_simple_mode_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision:
|
||||||
|
"""
|
||||||
|
简化模式下生成决策结果
|
||||||
|
|
||||||
|
在简化模式下,直接选择第一个待执行任务,并返回固定的询问指导。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pending_tasks (List[Dict[str, str]]): 待执行的任务列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ControllerDecision: 包含简化模式任务选择和固定指导的结果
|
||||||
|
"""
|
||||||
|
# 如果有待执行任务,选择第一个作为默认任务
|
||||||
|
if pending_tasks:
|
||||||
|
selected_task = pending_tasks[0]
|
||||||
|
selected_task_name = selected_task.get("name", "未知任务")
|
||||||
|
else:
|
||||||
|
selected_task_name = "基本信息收集"
|
||||||
|
|
||||||
|
return ControllerDecision(
|
||||||
|
selected_task=selected_task_name,
|
||||||
|
specific_guidance="请按照标准医疗询问流程进行患者评估,基于患者临床信息选择最重要的询问任务,提供针对性的、具体的、可操作的询问指导建议,确保指导内容仅限于医生可以通过询问获取的信息。"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_fallback_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision:
|
def _get_fallback_result(self, pending_tasks: List[Dict[str, str]]) -> ControllerDecision:
|
||||||
"""
|
"""
|
||||||
生成决策失败时的默认结果
|
生成决策失败时的默认结果
|
||||||
|
|||||||
12
config.py
12
config.py
@ -19,8 +19,16 @@ LLM_CONFIG = {
|
|||||||
"gpt-oss:latest": {
|
"gpt-oss:latest": {
|
||||||
"class": "OpenAILike",
|
"class": "OpenAILike",
|
||||||
"params": {
|
"params": {
|
||||||
"id": "gpt-oss",
|
"id": "openai-mirror/gpt-oss-20b",
|
||||||
"base_url": "http://100.82.33.121:19090/v1", # Ollama OpenAI兼容端点
|
"base_url": "http://127.0.0.1:8001/v1", # Ollama OpenAI兼容端点
|
||||||
|
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"phi4": {
|
||||||
|
"class": "OpenAILike",
|
||||||
|
"params": {
|
||||||
|
"id": "microsoft/phi-4",
|
||||||
|
"base_url": "http://127.0.0.1:8000/v1", # Ollama OpenAI兼容端点
|
||||||
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
16
main.py
16
main.py
@ -101,7 +101,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--log-dir',
|
'--log-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='results/results0905-gemma3',
|
default='results/results09010',
|
||||||
help='日志文件保存目录'
|
help='日志文件保存目录'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -115,7 +115,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num-threads',
|
'--num-threads',
|
||||||
type=int,
|
type=int,
|
||||||
default=40,
|
default=85,
|
||||||
help='并行处理线程数'
|
help='并行处理线程数'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -149,7 +149,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
'--model-type',
|
'--model-type',
|
||||||
type=str,
|
type=str,
|
||||||
choices=available_models,
|
choices=available_models,
|
||||||
default='Gemma3-4b',
|
default='phi4',
|
||||||
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
|
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -163,6 +163,13 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
default=None,
|
default=None,
|
||||||
help='模型配置JSON字符串(可选,覆盖默认配置)'
|
help='模型配置JSON字符串(可选,覆盖默认配置)'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--controller-mode',
|
||||||
|
type=str,
|
||||||
|
choices=['normal', 'sequence'],
|
||||||
|
default='normal',
|
||||||
|
help='任务控制器模式:normal为智能模式(需要LLM推理),sequence为顺序模式(直接选择第一个任务)'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 调试和日志
|
# 调试和日志
|
||||||
@ -347,7 +354,8 @@ def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
|
|||||||
llm_config=llm_config,
|
llm_config=llm_config,
|
||||||
max_steps=args.max_steps,
|
max_steps=args.max_steps,
|
||||||
log_dir=args.log_dir,
|
log_dir=args.log_dir,
|
||||||
case_index=sample_index
|
case_index=sample_index,
|
||||||
|
controller_mode=args.controller_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行工作流
|
# 执行工作流
|
||||||
|
|||||||
@ -12,7 +12,7 @@ class MedicalWorkflow:
|
|||||||
|
|
||||||
def __init__(self, case_data: Dict[str, Any], model_type: str = "gpt-oss:latest",
|
def __init__(self, case_data: Dict[str, Any], model_type: str = "gpt-oss:latest",
|
||||||
llm_config: Optional[Dict] = None, max_steps: int = 30, log_dir: str = "logs",
|
llm_config: Optional[Dict] = None, max_steps: int = 30, log_dir: str = "logs",
|
||||||
case_index: Optional[int] = None):
|
case_index: Optional[int] = None, controller_mode: str = "normal"):
|
||||||
"""
|
"""
|
||||||
初始化医疗问诊工作流
|
初始化医疗问诊工作流
|
||||||
|
|
||||||
@ -23,6 +23,7 @@ class MedicalWorkflow:
|
|||||||
max_steps: 最大执行步数,默认为30
|
max_steps: 最大执行步数,默认为30
|
||||||
log_dir: 日志目录,默认为"logs"
|
log_dir: 日志目录,默认为"logs"
|
||||||
case_index: 病例序号,用于日志文件命名
|
case_index: 病例序号,用于日志文件命名
|
||||||
|
controller_mode: 任务控制器模式,'normal'为智能模式,'sequence'为顺序模式
|
||||||
"""
|
"""
|
||||||
self.case_data = case_data
|
self.case_data = case_data
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
@ -31,7 +32,7 @@ class MedicalWorkflow:
|
|||||||
|
|
||||||
# 初始化核心组件
|
# 初始化核心组件
|
||||||
self.task_manager = TaskManager()
|
self.task_manager = TaskManager()
|
||||||
self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config)
|
self.step_executor = StepExecutor(model_type=model_type, llm_config=self.llm_config, controller_mode=controller_mode)
|
||||||
self.logger = WorkflowLogger(case_data=case_data, log_dir=log_dir, case_index=case_index)
|
self.logger = WorkflowLogger(case_data=case_data, log_dir=log_dir, case_index=case_index)
|
||||||
|
|
||||||
# 重置历史评分,确保新的工作流从零开始
|
# 重置历史评分,确保新的工作流从零开始
|
||||||
|
|||||||
@ -41,25 +41,29 @@ class StepExecutor:
|
|||||||
"chief_complaint_similarity": 0.0
|
"chief_complaint_similarity": 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None):
|
def __init__(self, model_type: str = "gpt-oss:latest", llm_config: dict = None, controller_mode: str = "normal"):
|
||||||
"""
|
"""
|
||||||
初始化step执行器
|
初始化step执行器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: 使用的语言模型类型(除Evaluator外的所有agent使用)
|
model_type: 使用的语言模型类型(除Evaluator外的所有agent使用)
|
||||||
llm_config: 语言模型配置
|
llm_config: 语言模型配置
|
||||||
|
controller_mode: 任务控制器模式,'normal'为智能模式,'sequence'为顺序模式
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Evaluator agent 固定使用 gpt-oss:latest 模型,不受 model_type 参数影响
|
Evaluator agent 固定使用 gpt-oss:latest 模型,不受 model_type 参数影响
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.llm_config = llm_config or {}
|
self.llm_config = llm_config or {}
|
||||||
|
self.controller_mode = controller_mode
|
||||||
|
|
||||||
# 初始化所有agent
|
# 初始化所有agent
|
||||||
self.recipient = RecipientAgent(model_type=model_type, llm_config=self.llm_config)
|
self.recipient = RecipientAgent(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.triager = TriageAgent(model_type=model_type, llm_config=self.llm_config)
|
self.triager = TriageAgent(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.monitor = Monitor(model_type=model_type, llm_config=self.llm_config)
|
self.monitor = Monitor(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config)
|
# 根据模式初始化TaskController
|
||||||
|
simple_mode = (controller_mode == "sequence")
|
||||||
|
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config, simple_mode=simple_mode)
|
||||||
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
|
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
|
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
|
||||||
# Evaluator 固定使用 gpt-oss:latest 模型
|
# Evaluator 固定使用 gpt-oss:latest 模型
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user