增强模型配置管理和评估器优化
主要改进: • 新增Qwen3-7B模型配置支持 • 完善main.py模型类型验证和配置管理 • 新增--list-models参数显示所有可用模型 • 固定Evaluator使用gpt-oss:latest模型提升评估一致性 • 优化评估器历史记录处理逻辑 • 更新默认日志目录为results0905-2 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d783229372
commit
862af60984
@ -233,6 +233,8 @@ class Evaluator(BaseAgent):
|
||||
history_parts = []
|
||||
|
||||
for i, round_data in enumerate(all_rounds_data, 1):
|
||||
if i < len(all_rounds_data):
|
||||
continue
|
||||
history_parts.append(f"### 第{i}轮对话")
|
||||
|
||||
if 'patient_response' in round_data:
|
||||
|
||||
@ -24,6 +24,14 @@ LLM_CONFIG = {
|
||||
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||
}
|
||||
},
|
||||
"Qwen3-7B": {
|
||||
"class": "OpenAILike",
|
||||
"params": {
|
||||
"id": "qwen3",
|
||||
"base_url": "http://100.82.33.121:19090/v1", # Ollama OpenAI兼容端点
|
||||
"api_key": "gpustack_d402860477878812_9ec494a501497d25b565987754f4db8c" # Ollama不需要真实API密钥,任意字符串即可
|
||||
}
|
||||
},
|
||||
"deepseek-v3": {
|
||||
"class": "OpenAILike",
|
||||
"params": {
|
||||
|
||||
47
main.py
47
main.py
@ -19,6 +19,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
# 导入本地模块
|
||||
from workflow import MedicalWorkflow
|
||||
from config import LLM_CONFIG
|
||||
|
||||
class BatchProcessor:
|
||||
"""批处理管理器,负责协调多线程执行和状态管理"""
|
||||
@ -100,7 +101,7 @@ def parse_arguments() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
'--log-dir',
|
||||
type=str,
|
||||
default='results/results0904',
|
||||
default='results/results0905-2',
|
||||
help='日志文件保存目录'
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -143,17 +144,24 @@ def parse_arguments() -> argparse.Namespace:
|
||||
)
|
||||
|
||||
# 模型配置
|
||||
available_models = list(LLM_CONFIG.keys())
|
||||
parser.add_argument(
|
||||
'--model-type',
|
||||
type=str,
|
||||
choices=available_models,
|
||||
default='gpt-oss:latest',
|
||||
help='使用的语言模型类型'
|
||||
help=f'使用的语言模型类型,可选: {", ".join(available_models)}'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--list-models',
|
||||
action='store_true',
|
||||
help='显示所有可用的模型配置并退出'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model-config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='模型配置JSON字符串'
|
||||
help='模型配置JSON字符串(可选,覆盖默认配置)'
|
||||
)
|
||||
|
||||
|
||||
@ -316,11 +324,19 @@ def process_single_sample(sample_data: Dict[str, Any], sample_index: int,
|
||||
|
||||
|
||||
try:
|
||||
# 解析模型配置
|
||||
llm_config = {}
|
||||
# 使用 LLM_CONFIG 作为基础配置
|
||||
# BaseAgent 会根据 model_type 自动选择正确的模型配置
|
||||
llm_config = LLM_CONFIG.copy()
|
||||
|
||||
# 如果用户提供了额外的模型配置,则合并到对应的模型配置中
|
||||
if args.model_config:
|
||||
try:
|
||||
llm_config = json.loads(args.model_config)
|
||||
user_config = json.loads(args.model_config)
|
||||
# 更新选定模型的配置
|
||||
if args.model_type in llm_config:
|
||||
llm_config[args.model_type]["params"].update(user_config.get("params", {}))
|
||||
else:
|
||||
logging.warning(f"样本 {sample_index}: 模型类型 {args.model_type} 不存在,忽略用户配置")
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(f"样本 {sample_index}: 模型配置JSON格式错误,使用默认配置")
|
||||
|
||||
@ -544,6 +560,18 @@ def main():
|
||||
# 解析参数
|
||||
args = parse_arguments()
|
||||
|
||||
# 处理 --list-models 参数
|
||||
if args.list_models:
|
||||
print("可用的语言模型配置:")
|
||||
print("=" * 50)
|
||||
for model_name, config in LLM_CONFIG.items():
|
||||
print(f"模型名称: {model_name}")
|
||||
print(f" 类别: {config['class']}")
|
||||
print(f" 模型ID: {config['params']['id']}")
|
||||
print(f" API端点: {config['params']['base_url']}")
|
||||
print("-" * 30)
|
||||
return 0
|
||||
|
||||
# 设置日志
|
||||
setup_logging(args.log_level)
|
||||
|
||||
@ -559,6 +587,13 @@ def main():
|
||||
if args.max_steps <= 0:
|
||||
raise ValueError("最大步数必须大于0")
|
||||
|
||||
# 验证模型类型
|
||||
if args.model_type not in LLM_CONFIG:
|
||||
available_models = ', '.join(LLM_CONFIG.keys())
|
||||
raise ValueError(f"不支持的模型类型: {args.model_type},可用模型: {available_models}")
|
||||
|
||||
logging.info(f"使用模型: {args.model_type} ({LLM_CONFIG[args.model_type]['class']})")
|
||||
|
||||
# 试运行模式
|
||||
if args.dry_run:
|
||||
logging.info("试运行模式:验证配置...")
|
||||
|
||||
@ -46,8 +46,11 @@ class StepExecutor:
|
||||
初始化step执行器
|
||||
|
||||
Args:
|
||||
model_type: 使用的语言模型类型
|
||||
model_type: 使用的语言模型类型(除Evaluator外的所有agent使用)
|
||||
llm_config: 语言模型配置
|
||||
|
||||
Note:
|
||||
Evaluator agent 固定使用 gpt-oss:latest 模型,不受 model_type 参数影响
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.llm_config = llm_config or {}
|
||||
@ -59,7 +62,8 @@ class StepExecutor:
|
||||
self.controller = TaskController(model_type=model_type, llm_config=self.llm_config)
|
||||
self.prompter = Prompter(model_type=model_type, llm_config=self.llm_config)
|
||||
self.virtual_patient = VirtualPatientAgent(model_type=model_type, llm_config=self.llm_config)
|
||||
self.evaluator = Evaluator(model_type=model_type, llm_config=self.llm_config)
|
||||
# Evaluator 固定使用 gpt-oss:latest 模型
|
||||
self.evaluator = Evaluator(model_type="gpt-oss:latest", llm_config=self.llm_config)
|
||||
|
||||
def execute_step(self,
|
||||
step_num: int,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user