增强模型配置管理和评估器优化
主要改进: • 新增Qwen3-7B模型配置支持 • 完善main.py模型类型验证和配置管理 • 新增--list-models参数显示所有可用模型 • 固定Evaluator使用gpt-oss:latest模型提升评估一致性 • 优化评估器历史记录处理逻辑 • 更新默认日志目录为results0905-2 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d783229372
commit
862af60984
@ -233,6 +233,8 @@ class Evaluator(BaseAgent):
|
|||||||
history_parts = []
|
history_parts = []
|
||||||
|
|
||||||
for i, round_data in enumerate(all_rounds_data, 1):
|
for i, round_data in enumerate(all_rounds_data, 1):
|
||||||
|
if i < len(all_rounds_data):
|
||||||
|
continue
|
||||||
history_parts.append(f"### 第{i}轮对话")
|
history_parts.append(f"### 第{i}轮对话")
|
||||||
|
|
||||||
if 'patient_response' in round_data:
|
if 'patient_response' in round_data:
|
||||||
|
|||||||
@ -24,6 +24,14 @@ LLM_CONFIG = {
|
|||||||
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"Qwen3-7B": {
|
||||||
|
"class": "OpenAILike",
|
||||||
|
"params": {
|
||||||
|
"id": "qwen3",
|
||||||
|
"base_url": "http://100.82.33.121:19090/v1", # Ollama OpenAI兼容端点
|
||||||
|
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||||
|
}
|
||||||
|
},
|
||||||
"deepseek-v3": {
|
"deepseek-v3": {
|
||||||
"class": "OpenAILike",
|
"class": "OpenAILike",
|
||||||
"params": {
|
"params": {
|
||||||
|
|||||||
47
main.py
47
main.py
@ -19,6 +19,7 @@ from typing import Dict, Any, List, Optional
|
|||||||
|
|
||||||
# 导入本地模块
|
# 导入本地模块
|
||||||
from workflow import MedicalWorkflow
|
from workflow import MedicalWorkflow
|
||||||
|
from config import LLM_CONFIG
|
||||||
|
|
||||||
class BatchProcessor:
|
class BatchProcessor:
|
||||||
"""批处理管理器,负责协调多线程执行和状态管理"""
|
"""批处理管理器,负责协调多线程执行和状态管理"""
|
||||||
@ -100,7 +101,7 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--log-dir',
|
'--log-dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='results/results0904',
|
default='results/results0905-2',
|
||||||
help='日志文件保存目录'
|
help='日志文件保存目录'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,17 +144,24 @@ def parse_arguments() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
|
available_models = list(LLM_CONFIG.keys())
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model-type',
|
'--model-type',
|
||||||
type=str,
|
type=str,
|
||||||
|
choices=available_models,
|
||||||
default='gpt-oss:latest',
|
default='gpt-oss:latest',
|
||||||
help='使用的语言模型类型'
|
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--list-models',
|
||||||
|
action='store_true',
|
||||||
|
help='显示所有可用的模型配置并退出'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model-config',
|
'--model-config',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help='模型配置JSON字符串'
|
help='模型配置JSON字符串(可选,覆盖默认配置)'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -316,11 +324,19 @@ def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析模型配置
|
# 使用 LLM_CONFIG 作为基础配置
|
||||||
llm_config = {}
|
# BaseAgent 会根据 model_type 自动选择正确的模型配置
|
||||||
|
llm_config = LLM_CONFIG.copy()
|
||||||
|
|
||||||
|
# 如果用户提供了额外的模型配置,则合并到对应的模型配置中
|
||||||
if args.model_config:
|
if args.model_config:
|
||||||
try:
|
try:
|
||||||
llm_config = json.loads(args.model_config)
|
user_config = json.loads(args.model_config)
|
||||||
|
# 更新选定模型的配置
|
||||||
|
if args.model_type in llm_config:
|
||||||
|
llm_config[args.model_type]["params"].update(user_config.get("params", {}))
|
||||||
|
else:
|
||||||
|
logging.warning(f"样本 {sample_index}: 模型类型 {args.model_type} 不存在,忽略用户配置")
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误,使用默认配置")
|
logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误,使用默认配置")
|
||||||
|
|
||||||
@ -544,6 +560,18 @@ def main():
|
|||||||
# 解析参数
|
# 解析参数
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# 处理 --list-models 参数
|
||||||
|
if args.list_models:
|
||||||
|
print("可用的语言模型配置:")
|
||||||
|
print("=" * 50)
|
||||||
|
for model_name, config in LLM_CONFIG.items():
|
||||||
|
print(f"模型名称: {model_name}")
|
||||||
|
print(f" 类别: {config['class']}")
|
||||||
|
print(f" 模型ID: {config['params']['id']}")
|
||||||
|
print(f" API端点: {config['params']['base_url']}")
|
||||||
|
print("-" * 30)
|
||||||
|
return 0
|
||||||
|
|
||||||
# 设置日志
|
# 设置日志
|
||||||
setup_logging(args.log_level)
|
setup_logging(args.log_level)
|
||||||
|
|
||||||
@ -559,6 +587,13 @@ def main():
|
|||||||
if args.max_steps <= 0:
|
if args.max_steps <= 0:
|
||||||
raise ValueError("最大步数必须大于0")
|
raise ValueError("最大步数必须大于0")
|
||||||
|
|
||||||
|
# 验证模型类型
|
||||||
|
if args.model_type not in LLM_CONFIG:
|
||||||
|
available_models = ', '.join(LLM_CONFIG.keys())
|
||||||
|
raise ValueError(f"不支持的模型类型: {args.model_type},可用模型: {available_models}")
|
||||||
|
|
||||||
|
logging.info(f"使用模型: {args.model_type} ({LLM_CONFIG[args.model_type]['class']})")
|
||||||
|
|
||||||
# 试运行模式
|
# 试运行模式
|
||||||
if args.dry_run:
|
if args.dry_run:
|
||||||
logging.info("试运行模式:验证配置...")
|
logging.info("试运行模式:验证配置...")
|
||||||
|
|||||||
@ -46,8 +46,11 @@ class StepExecutor:
|
|||||||
初始化step执行器
|
初始化step执行器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: 使用的语言模型类型
|
model_type: 使用的语言模型类型(除Evaluator外的所有agent使用)
|
||||||
llm_config: 语言模型配置
|
llm_config: 语言模型配置
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Evaluator agent 固定使用 gpt-oss:latest 模型,不受 model_type 参数影响
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.llm_config = llm_config or {}
|
self.llm_config = llm_config or {}
|
||||||
@ -59,7 +62,8 @@ class StepExecutor:
|
|||||||
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config)
|
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
|
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
|
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
|
||||||
self.evaluator = Evaluator(model_type=model_type, llm_config=self.llm_config)
|
# Evaluator 固定使用 gpt-oss:latest 模型
|
||||||
|
self.evaluator = Evaluator(model_type="gpt-oss:latest", llm_config=self.llm_config)
|
||||||
|
|
||||||
def execute_step(self,
|
def execute_step(self,
|
||||||
step_num: int,
|
step_num: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user