Compare commits
10 Commits
27398dc890
...
76c04eae4a
| Author | SHA1 | Date | |
|---|---|---|---|
| 76c04eae4a | |||
| d1f7a27b1b | |||
| c4037325ed | |||
| 1b652502d5 | |||
| f7a06775ca | |||
| 099159dfb7 | |||
| 8d6d217c2f | |||
| 367696788b | |||
| 41e5fd1543 | |||
| 802fe4b239 |
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,3 +11,5 @@ wheels/
|
||||
.claude
|
||||
dataset/
|
||||
docs/CLAUDE*
|
||||
.DS_Store
|
||||
**/*.log
|
||||
|
||||
18
.vscode/launch.json
vendored
18
.vscode/launch.json
vendored
@ -31,6 +31,24 @@
|
||||
"console": "integratedTerminal",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"args": ["--mimic_paper_path", "dataset/mimic.csv", "--parallel", "5"]
|
||||
},
|
||||
{
|
||||
"name": "调试 pdf_parser.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/pdf_parser.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"args": []
|
||||
},
|
||||
{
|
||||
"name": "调试 pdf_parser.py (指定参数)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/pdf_parser.py",
|
||||
"console": "integratedTerminal",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
"args": ["--pdf-dir", "dataset/pdfs", "--parallel", "3", "--markdown-dir", "dataset/markdowns"]
|
||||
}
|
||||
]
|
||||
}
|
||||
283
CLAUDE-temp.md
283
CLAUDE-temp.md
@ -1,283 +0,0 @@
|
||||
# AI指导规范构建任务 - 深度分析
|
||||
|
||||
## 任务理解
|
||||
用户希望为MedResearcher项目构建一套完整的AI协作指导规范,核心思想是"充分讨论后再修改",以避免代码难以维护。
|
||||
|
||||
## 核心设计理念
|
||||
1. **讨论优先**:任何修改前必须充分讨论,达成共识
|
||||
2. **上下文明确**:不仅指出需要什么,更要具体到哪个文件的哪个函数
|
||||
3. **渐进式实施**:通过子任务分解,逐步完成复杂需求
|
||||
4. **可追溯性**:所有决策和修改都有明确记录
|
||||
|
||||
## 详细工作流程设计
|
||||
|
||||
### 阶段1:需求理解与记录(必须执行)
|
||||
**触发条件**:用户提出任何代码修改需求
|
||||
|
||||
**具体步骤**:
|
||||
1. **立即**在CLAUDE-temp.md中撰写:
|
||||
```markdown
|
||||
## 任务理解
|
||||
原始需求:[用户的原话]
|
||||
我的理解:[用我的话重述一遍]
|
||||
|
||||
## 收集的上下文
|
||||
### 相关文件和函数
|
||||
- `papers_crawler.py::line_45-67::fetch_papers()` - 当前的爬取实现
|
||||
- `config.py::line_12-15::RETRY_CONFIG` - 现有重试配置
|
||||
|
||||
### 现有代码分析
|
||||
[贴出关键代码片段并分析]
|
||||
|
||||
### 潜在影响
|
||||
- 影响文件:papers_crawler.py, test_crawler.py
|
||||
- 影响功能:论文爬取的稳定性
|
||||
- 风险评估:可能影响爬取速度
|
||||
|
||||
## 任务复杂度判断
|
||||
- [ ] 单一功能修改,影响1-2个函数 → 简单任务
|
||||
- [x] 需要修改3个以上函数或添加新模块 → 复杂任务
|
||||
```
|
||||
|
||||
2. **等待用户反馈**:
|
||||
- "理解正确" → 继续阶段2
|
||||
- "理解有偏差" → 修正理解,重新记录
|
||||
- "补充需求" → 更新CLAUDE-temp.md
|
||||
|
||||
### 阶段2:任务规划(根据复杂度差异化)
|
||||
|
||||
**简单任务处理**:
|
||||
- **判断标准**:
|
||||
- 修改不超过2个函数
|
||||
- 不需要新建文件
|
||||
- 逻辑改动在50行以内
|
||||
- **严格要求**:绝对不允许拆分子任务
|
||||
- **直接输出**:一个完整的执行计划
|
||||
|
||||
**复杂任务处理**:
|
||||
- **判断标准**:
|
||||
- 修改3个以上函数
|
||||
- 需要新建模块或文件
|
||||
- 涉及多个功能模块交互
|
||||
- **拆分原则**:
|
||||
- 必须拆分为3-5个子任务(不能少于3个,不能多于5个)
|
||||
- 每个子任务可独立验证
|
||||
- 子任务之间有清晰的依赖关系
|
||||
- **拆分示例**:
|
||||
```
|
||||
子任务1:创建重试机制基础设施
|
||||
子任务2:集成重试机制到爬虫
|
||||
子任务3:添加重试相关配置
|
||||
子任务4:更新错误处理逻辑
|
||||
子任务5:添加重试日志记录
|
||||
```
|
||||
|
||||
### 阶段3:计划确认与正式化(用户确认后执行)
|
||||
|
||||
**创建CLAUDE-plan.md,严格按照以下格式**:
|
||||
|
||||
#### 简单任务格式:
|
||||
```markdown
|
||||
## 任务:[具体任务名称]
|
||||
创建时间:2025-08-22 10:30
|
||||
|
||||
### 目标
|
||||
[30-50字描述要实现的功能,必须具体且可验证]
|
||||
|
||||
### 所需上下文
|
||||
- `papers_crawler.py::45-67行::fetch_papers()` - 需要添加异常处理
|
||||
- `papers_crawler.py::120-135行::parse_response()` - 需要理解返回格式
|
||||
- `config.py::全文` - 了解现有配置结构
|
||||
|
||||
### 拟修改内容
|
||||
1. 修改 `papers_crawler.py::fetch_papers()` 第50-55行 - 添加try-except块
|
||||
2. 修改 `papers_crawler.py::fetch_papers()` 第65行 - 添加重试逻辑
|
||||
3. 修改 `config.py` 末尾 - 添加RETRY_TIMES常量
|
||||
|
||||
### 测试指令
|
||||
```bash
|
||||
# 主功能测试
|
||||
uv run papers_crawler.py --keyword "machine learning" --limit 5
|
||||
|
||||
# 异常情况测试(模拟网络错误)
|
||||
uv run papers_crawler.py --test-mode --simulate-error
|
||||
```
|
||||
|
||||
### 验收标准
|
||||
- [ ] 正常爬取功能不受影响
|
||||
- [ ] 网络异常时能正确重试
|
||||
- [ ] 日志正确记录重试次数
|
||||
```
|
||||
|
||||
#### 复杂任务格式:
|
||||
```markdown
|
||||
## 任务:[具体任务名称]
|
||||
创建时间:2025-08-22 10:30
|
||||
|
||||
### 总体目标
|
||||
[50-100字描述整体要实现的功能]
|
||||
|
||||
### 子任务分解
|
||||
|
||||
#### 子任务1:[名称]
|
||||
**目标**:[20-30字描述]
|
||||
|
||||
**所需上下文**:
|
||||
- `papers_crawler.py::45-67行::fetch_papers()` - 理解当前实现
|
||||
- `utils/__init__.py::全文` - 确认工具模块结构
|
||||
|
||||
**拟修改内容**:
|
||||
- 新建 `utils/retry.py` - 创建RetryDecorator类
|
||||
- 修改 `utils/__init__.py` - 导出retry装饰器
|
||||
|
||||
**测试指令**:
|
||||
```bash
|
||||
# 单元测试
|
||||
python -c "from utils.retry import retry; print('导入成功')"
|
||||
```
|
||||
|
||||
#### 子任务2:[名称]
|
||||
[格式同上]
|
||||
|
||||
### 整体验收标准
|
||||
- [ ] 所有子任务独立测试通过
|
||||
- [ ] 集成测试:完整流程测试通过
|
||||
- [ ] 性能测试:重试不影响正常爬取速度
|
||||
```
|
||||
|
||||
### 阶段4:实施与验证(严格按计划执行)
|
||||
|
||||
**执行要求**:
|
||||
1. **开始前确认**:
|
||||
- 再次阅读CLAUDE-plan.md
|
||||
- 确认所有依赖文件存在
|
||||
- 确认测试环境就绪
|
||||
|
||||
2. **执行中记录**:
|
||||
- 在CLAUDE-activeContext.md实时更新进度
|
||||
- 遇到计划外情况立即停止并讨论
|
||||
|
||||
3. **完成后验证**:
|
||||
- 运行所有测试指令
|
||||
- 检查验收标准
|
||||
- 记录任何偏差或问题
|
||||
|
||||
## Memory Bank系统设计
|
||||
|
||||
### 文件结构
|
||||
```
|
||||
/docs/
|
||||
├── CLAUDE-temp.md # 临时讨论和分析
|
||||
├── CLAUDE-plan.md # 当前任务的正式计划
|
||||
├── CLAUDE-activeContext.md # 会话状态和进度跟踪
|
||||
├── CLAUDE-patterns.md # 项目代码模式记录
|
||||
├── CLAUDE-decisions.md # 重要决策和理由记录
|
||||
├── CLAUDE-troubleshooting.md # 问题和解决方案库
|
||||
└── CLAUDE-config-variables.md # 配置变量参考
|
||||
```
|
||||
|
||||
### 使用原则
|
||||
1. **docs/CLAUDE-temp.md**:
|
||||
- 每次新任务开始时清空或归档
|
||||
- 用于快速记录和思考
|
||||
- 不需要结构化
|
||||
|
||||
2. **docs/CLAUDE-plan.md**:
|
||||
- 结构化的任务计划
|
||||
- 用户确认后才写入
|
||||
- 作为实施的指导文档
|
||||
|
||||
3. **docs/CLAUDE-activeContext.md**:
|
||||
- 记录当前进度
|
||||
- 标记已完成/进行中/待完成
|
||||
- 会话恢复时的参考
|
||||
|
||||
### Memory Bank更新机制
|
||||
|
||||
**使用专门的SubAgent管理**:
|
||||
```
|
||||
Task: memory-bank-updater
|
||||
Description: "更新Memory Bank文件"
|
||||
Prompt: "任务已完成,请更新以下Memory Bank文件:
|
||||
1. CLAUDE-activeContext.md - 标记任务完成,记录最终状态
|
||||
2. CLAUDE-patterns.md - 如有新的代码模式,记录下来
|
||||
3. CLAUDE-decisions.md - 记录本次任务的关键决策
|
||||
4. CLAUDE-troubleshooting.md - 如遇到问题,记录解决方案
|
||||
5. CLAUDE-config-variables.md - 如有新配置,更新文档
|
||||
|
||||
具体完成内容:[任务摘要]
|
||||
遇到的问题:[如有]
|
||||
采用的解决方案:[如有]"
|
||||
```
|
||||
|
||||
**调用时机**:
|
||||
- 每个任务完成后必须调用
|
||||
- 遇到重要决策时调用
|
||||
- 发现新的最佳实践时调用
|
||||
|
||||
## 工具使用优化原则
|
||||
|
||||
### 1. 批量操作原则
|
||||
**场景**:需要读取多个文件或执行多个独立搜索时
|
||||
**做法**:
|
||||
```python
|
||||
# 同时执行多个工具调用
|
||||
parallel_calls = [
|
||||
Read("papers_crawler.py"),
|
||||
Read("pdf_parser.py"),
|
||||
Grep("retry", "*.py"),
|
||||
LS("./utils/")
|
||||
]
|
||||
```
|
||||
**禁止**:顺序执行可并行的操作
|
||||
|
||||
### 2. 上下文管理策略
|
||||
**主上下文保留**:
|
||||
- 用户对话
|
||||
- 关键决策点
|
||||
- 当前任务计划
|
||||
|
||||
**委托给subagent**:
|
||||
- 大规模代码搜索:"搜索所有使用requests库的地方"
|
||||
- 代码模式分析:"分析项目中的错误处理模式"
|
||||
- 依赖关系梳理:"找出papers_crawler.py的所有依赖"
|
||||
|
||||
**subagent使用示例**:
|
||||
```
|
||||
Task: code-searcher
|
||||
Prompt: "在整个项目中搜索所有异常处理相关的代码,
|
||||
重点关注papers_crawler.py和pdf_parser.py,
|
||||
总结当前的错误处理模式和改进建议"
|
||||
```
|
||||
|
||||
### 3. 文件操作最佳实践
|
||||
**读取顺序**:
|
||||
1. 先读CLAUDE-activeContext.md(如果存在)了解当前状态
|
||||
2. 读取主文件了解整体结构
|
||||
3. 读取相关依赖文件
|
||||
|
||||
**修改原则**:
|
||||
- 优先使用Edit而非Write
|
||||
- 使用MultiEdit处理同文件多处修改
|
||||
- 新文件创建需明确理由
|
||||
|
||||
## 与现有编程规范的协同
|
||||
|
||||
### 层次关系
|
||||
1. **编程规范**(已在CLAUDE.md中定义):
|
||||
- 定义"怎么写代码"
|
||||
- 包括:命名、注释、代码风格等
|
||||
|
||||
2. **AI指导规范**(本规范):
|
||||
- 定义"怎么理解和修改代码"
|
||||
- 包括:工作流程、沟通方式、工具使用等
|
||||
|
||||
### 执行优先级
|
||||
1. 遵守编程规范的硬性要求(如单次修改限制)
|
||||
2. 按AI指导流程进行任务
|
||||
3. 发生冲突时,编程规范优先
|
||||
|
||||
## 规范更新机制
|
||||
- 每次遇到新的最佳实践,记录到CLAUDE-patterns.md
|
||||
- 定期回顾CLAUDE-troubleshooting.md,提炼通用规则
|
||||
- 用户可随时提出规范优化建议
|
||||
@ -14,6 +14,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
|
||||
## 各个模块的主文件
|
||||
1. 论文爬取主文件: papers_crawler.py
|
||||
2. pdf解析主文件: pdf_parser.py
|
||||
3. 信息抽取主文件: info_extractor.py
|
||||
3. 实验运行主文件: experiment_runner.py
|
||||
|
||||
## 文件结构
|
||||
@ -23,6 +24,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
|
||||
│ └── mimic.csv # 存放所有需要处理的与mimic相关论文的基础信息
|
||||
├── papers_crawler.py # 论文爬取主文件
|
||||
├── pdf_parser.py # pdf解析主文件
|
||||
├── info_extractor.py # 信息抽取主文件
|
||||
├── experiment_runner.py # 实验运行主文件
|
||||
├── src/ # 源代码目录
|
||||
│ └── utils/ # 工具函数目录
|
||||
|
||||
428
example/subagent-example/README.md
Normal file
428
example/subagent-example/README.md
Normal file
@ -0,0 +1,428 @@
|
||||
# SubAgent系统使用指南
|
||||
|
||||
## 📚 概述
|
||||
|
||||
SubAgent是基于Agno框架构建的智能代理系统,为MedResearcher项目提供强大的AI功能。它支持多种LLM提供商,提供动态prompt构建、JSON结构化输出和零容错解析等核心功能。
|
||||
|
||||
## ✨ 核心特性
|
||||
|
||||
### 🤖 智能代理核心
|
||||
- **多提供商支持**: 阿里云(qwen)、DeepSeek、OpenAI等主流LLM服务
|
||||
- **动态Prompt**: 支持模板变量替换的灵活prompt构建系统
|
||||
- **结构化输出**: 基于Pydantic模型的JSON格式化响应
|
||||
- **零容错解析**: 多策略JSON解析,确保即使不完美输出也能解析
|
||||
|
||||
### 🔧 配置管理
|
||||
- **YAML配置**: 统一的配置文件管理,支持环境变量
|
||||
- **模型工厂**: 自动化的模型实例创建和参数管理
|
||||
- **灵活配置**: 支持运行时参数覆盖和动态配置
|
||||
|
||||
### 🛠 开发便利性
|
||||
- **类型安全**: 完整的类型提示支持
|
||||
- **异常处理**: 详细的错误信息和异常层级
|
||||
- **调试支持**: 内置日志和调试模式
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 基础设置
|
||||
|
||||
首先确保已安装依赖:
|
||||
```bash
|
||||
uv add agno pydantic pyyaml
|
||||
```
|
||||
|
||||
### 2. 配置LLM服务
|
||||
|
||||
在`src/config/llm_config.yaml`中配置你的LLM服务:
|
||||
```yaml
|
||||
aliyun:
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key: "${DASHSCOPE_API_KEY}"
|
||||
models:
|
||||
qwen-max:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-max"
|
||||
temperature: 0.3
|
||||
```
|
||||
|
||||
### 3. 设置环境变量
|
||||
|
||||
创建`.env`文件或设置环境变量:
|
||||
```bash
|
||||
export DASHSCOPE_API_KEY="your_api_key_here"
|
||||
```
|
||||
|
||||
### 4. 创建你的第一个Agent
|
||||
|
||||
```python
|
||||
from src.agent_system import SubAgent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# 定义响应模型
|
||||
class TaskResult(BaseModel):
|
||||
summary: str = Field(description="任务总结")
|
||||
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
||||
|
||||
# 创建SubAgent
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="task_agent",
|
||||
instructions=["你是一个专业的任务处理专家"],
|
||||
prompt_template="请分析以下任务: {task_description}",
|
||||
response_model=TaskResult
|
||||
)
|
||||
|
||||
# 执行任务
|
||||
result = agent.run(template_vars={"task_description": "数据分析项目"})
|
||||
print(f"总结: {result.summary}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
```
|
||||
|
||||
## 📖 详细使用指南
|
||||
|
||||
### SubAgent核心类
|
||||
|
||||
#### 初始化参数
|
||||
|
||||
```python
|
||||
SubAgent(
|
||||
provider: str, # LLM提供商名称
|
||||
model_name: str, # 模型名称
|
||||
instructions: List[str], # 指令列表(可选)
|
||||
name: str, # Agent名称(可选)
|
||||
description: str, # Agent描述(可选)
|
||||
prompt_template: str, # 动态prompt模板(可选)
|
||||
response_model: BaseModel, # Pydantic响应模型(可选)
|
||||
config: Dict[str, Any], # 自定义配置(可选)
|
||||
**agent_kwargs # 传递给Agno Agent的额外参数
|
||||
)
|
||||
```
|
||||
|
||||
#### 核心方法
|
||||
|
||||
##### 1. build_prompt() - 构建动态Prompt
|
||||
```python
|
||||
# 设置带变量的prompt模板
|
||||
agent.update_prompt_template("""
|
||||
请分析以下{data_type}数据:
|
||||
|
||||
数据内容: {data_content}
|
||||
分析目标: {analysis_goal}
|
||||
|
||||
请提供详细的分析结果。
|
||||
""")
|
||||
|
||||
# 构建具体prompt
|
||||
prompt = agent.build_prompt({
|
||||
"data_type": "销售",
|
||||
"data_content": "Q1销售数据...",
|
||||
"analysis_goal": "找出增长趋势"
|
||||
})
|
||||
```
|
||||
|
||||
##### 2. run() - 执行推理
|
||||
```python
|
||||
# 方式1: 使用模板变量
|
||||
result = agent.run(template_vars={
|
||||
"input_text": "待分析的文本内容"
|
||||
})
|
||||
|
||||
# 方式2: 直接提供prompt
|
||||
result = agent.run(prompt="请分析这段文本的情感倾向")
|
||||
|
||||
# 方式3: 带额外参数
|
||||
result = agent.run(
|
||||
template_vars={"data": "测试数据"},
|
||||
temperature=0.7,
|
||||
max_tokens=1000
|
||||
)
|
||||
```
|
||||
|
||||
##### 3. get_model_info() - 获取模型信息
|
||||
```python
|
||||
info = agent.get_model_info()
|
||||
print(f"Agent名称: {info['name']}")
|
||||
print(f"提供商: {info['provider']}")
|
||||
print(f"模型: {info['model_name']}")
|
||||
print(f"是否有prompt模板: {info['has_prompt_template']}")
|
||||
```
|
||||
|
||||
### Pydantic响应模型
|
||||
|
||||
#### 基础模型定义
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
"""分析结果模型"""
|
||||
|
||||
summary: str = Field(description="分析总结")
|
||||
key_points: List[str] = Field(description="关键要点列表")
|
||||
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
||||
recommendations: Optional[List[str]] = Field(default=None, description="建议列表")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
float: lambda v: round(v, 3) if v is not None else None
|
||||
}
|
||||
```
|
||||
|
||||
#### 复杂嵌套模型
|
||||
```python
|
||||
class DetailedItem(BaseModel):
|
||||
name: str = Field(description="项目名称")
|
||||
value: float = Field(description="数值")
|
||||
category: str = Field(description="分类")
|
||||
|
||||
class ComprehensiveResult(BaseModel):
|
||||
items: List[DetailedItem] = Field(description="详细项目列表")
|
||||
total_count: int = Field(description="总数量", ge=0)
|
||||
summary: str = Field(description="整体总结")
|
||||
```
|
||||
|
||||
### 配置管理详解
|
||||
|
||||
#### LLM配置文件结构 (llm_config.yaml)
|
||||
```yaml
|
||||
# 阿里云配置
|
||||
aliyun:
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key: "${DASHSCOPE_API_KEY}"
|
||||
models:
|
||||
qwen-max:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-max"
|
||||
temperature: 0.3
|
||||
max_tokens: 4000
|
||||
qwen-plus:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-plus"
|
||||
temperature: 0.5
|
||||
|
||||
# DeepSeek配置
|
||||
deepseek:
|
||||
base_url: "https://api.deepseek.com/v1"
|
||||
api_key: "${DEEPSEEK_API_KEY}"
|
||||
models:
|
||||
deepseek-v3:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "deepseek-chat"
|
||||
temperature: 0.3
|
||||
|
||||
# OpenAI配置
|
||||
openai:
|
||||
api_key: "${OPENAI_API_KEY}"
|
||||
models:
|
||||
gpt-4o:
|
||||
class: "OpenAIChat"
|
||||
params:
|
||||
model: "gpt-4o"
|
||||
temperature: 0.3
|
||||
```
|
||||
|
||||
#### 环境变量配置 (.env)
|
||||
```bash
|
||||
# 阿里云API密钥
|
||||
DASHSCOPE_API_KEY=sk-your-dashscope-key
|
||||
|
||||
# DeepSeek API密钥
|
||||
DEEPSEEK_API_KEY=sk-your-deepseek-key
|
||||
|
||||
# OpenAI API密钥
|
||||
OPENAI_API_KEY=sk-your-openai-key
|
||||
```
|
||||
|
||||
### 便捷函数使用
|
||||
|
||||
#### create_json_agent() - 快速创建JSON Agent
|
||||
```python
|
||||
from src.agent_system import create_json_agent
|
||||
|
||||
# 快速创建支持JSON输出的Agent
|
||||
agent = create_json_agent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="json_extractor",
|
||||
prompt_template="从以下文本提取信息: {text}",
|
||||
response_model="MyModel", # 可以是字符串或类
|
||||
instructions=["你是信息提取专家"]
|
||||
)
|
||||
```
|
||||
|
||||
## 🎯 实际应用示例
|
||||
|
||||
### 示例1: 情感分析Agent
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Literal
|
||||
from src.agent_system import SubAgent
|
||||
|
||||
class SentimentResult(BaseModel):
|
||||
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
|
||||
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
||||
explanation: str = Field(description="分析说明")
|
||||
|
||||
sentiment_agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="sentiment_analyzer",
|
||||
instructions=[
|
||||
"你是专业的文本情感分析专家",
|
||||
"请准确识别文本的情感倾向",
|
||||
"提供详细的分析依据"
|
||||
],
|
||||
prompt_template="""
|
||||
请分析以下文本的情感倾向:
|
||||
|
||||
文本内容: {text}
|
||||
|
||||
请识别情感倾向(positive/negative/neutral)、置信度(0-1)和分析说明。
|
||||
""",
|
||||
response_model=SentimentResult
|
||||
)
|
||||
|
||||
# 使用示例
|
||||
result = sentiment_agent.run(template_vars={
|
||||
"text": "这个产品质量很好,我非常满意!"
|
||||
})
|
||||
|
||||
print(f"情感: {result.sentiment}")
|
||||
print(f"置信度: {result.confidence}")
|
||||
print(f"说明: {result.explanation}")
|
||||
```
|
||||
|
||||
### 示例2: 数据提取Agent
|
||||
|
||||
```python
|
||||
class DataExtraction(BaseModel):
|
||||
extracted_data: Dict[str, Any] = Field(description="提取的数据")
|
||||
extraction_count: int = Field(description="提取项目数量")
|
||||
data_quality: Literal["high", "medium", "low"] = Field(description="数据质量评估")
|
||||
|
||||
extractor_agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-plus",
|
||||
name="data_extractor",
|
||||
instructions=[
|
||||
"你是数据提取专家",
|
||||
"从非结构化文本中提取结构化数据",
|
||||
"确保提取的数据准确完整"
|
||||
],
|
||||
prompt_template="""
|
||||
从以下{data_type}文档中提取关键数据:
|
||||
|
||||
文档内容:
|
||||
{document}
|
||||
|
||||
提取要求:
|
||||
{requirements}
|
||||
|
||||
请提取所有相关数据并评估数据质量。
|
||||
""",
|
||||
response_model=DataExtraction
|
||||
)
|
||||
```
|
||||
|
||||
## ⚠️ 注意事项与最佳实践
|
||||
|
||||
### 1. 配置管理
|
||||
- **API密钥安全**: 始终使用环境变量存储API密钥,切勿在代码中硬编码
|
||||
- **配置验证**: 程序启动时验证配置文件完整性
|
||||
- **环境隔离**: 开发、测试、生产环境使用不同的配置文件
|
||||
|
||||
### 2. Prompt设计
|
||||
- **明确指令**: 提供清晰、具体的任务指令
|
||||
- **示例驱动**: 在prompt中包含输入输出示例
|
||||
- **结构化模板**: 使用结构化的prompt模板提高一致性
|
||||
|
||||
### 3. 错误处理
|
||||
- **异常捕获**: 对Agent调用进行适当的异常处理
|
||||
- **重试机制**: 对网络错误实现重试逻辑
|
||||
- **降级策略**: 准备备用模型或简化输出格式
|
||||
|
||||
### 4. 性能优化
|
||||
- **缓存机制**: 对相同输入实现结果缓存
|
||||
- **批处理**: 将多个小任务合并为大任务处理
|
||||
- **模型选择**: 根据任务复杂度选择合适的模型
|
||||
|
||||
## 🔧 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
#### 1. 配置文件不存在
|
||||
```
|
||||
错误: FileNotFoundError: LLM配置文件不存在
|
||||
解决: 确保 src/config/llm_config.yaml 文件存在且格式正确
|
||||
```
|
||||
|
||||
#### 2. API密钥未设置
|
||||
```
|
||||
错误: 环境变量 DASHSCOPE_API_KEY 未定义
|
||||
解决: 设置相应的环境变量或在.env文件中配置
|
||||
```
|
||||
|
||||
#### 3. JSON解析失败
|
||||
```
|
||||
错误: JSONParseError: 所有解析策略都失败了
|
||||
解决: 检查prompt设计,确保要求明确的JSON格式输出
|
||||
```
|
||||
|
||||
#### 4. 模型验证失败
|
||||
```
|
||||
错误: Pydantic模型验证失败
|
||||
解决: 检查响应模型定义与实际输出是否匹配
|
||||
```
|
||||
|
||||
### 调试技巧
|
||||
|
||||
#### 启用调试模式
|
||||
```python
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
debug_mode=True, # 启用调试输出
|
||||
# ... 其他参数
|
||||
)
|
||||
```
|
||||
|
||||
#### 查看生成的Prompt
|
||||
```python
|
||||
# 构建并查看最终的prompt
|
||||
prompt = agent.build_prompt({"key": "value"})
|
||||
print(f"生成的prompt: {prompt}")
|
||||
```
|
||||
|
||||
#### 捕获详细错误信息
|
||||
```python
|
||||
try:
|
||||
result = agent.run(template_vars={"text": "测试"})
|
||||
except Exception as e:
|
||||
print(f"错误类型: {type(e)}")
|
||||
print(f"错误信息: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
```
|
||||
|
||||
## 🚦 版本信息
|
||||
|
||||
- **当前版本**: 0.1.0
|
||||
- **依赖要求**:
|
||||
- Python >= 3.8
|
||||
- agno >= 0.1.0
|
||||
- pydantic >= 2.0.0
|
||||
- pyyaml >= 6.0.0
|
||||
|
||||
## 📞 支持与反馈
|
||||
|
||||
如遇到问题或有功能建议,请联系开发团队或提交issue。
|
||||
|
||||
---
|
||||
|
||||
*MedResearcher SubAgent系统 - 让AI更智能,让开发更简单* 🎉
|
||||
403
example/subagent-example/basic_example.py
Normal file
403
example/subagent-example/basic_example.py
Normal file
@ -0,0 +1,403 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SubAgent基础使用示例
|
||||
|
||||
展示SubAgent系统的基本功能:
|
||||
1. 创建简单的对话Agent
|
||||
2. 使用动态prompt模板
|
||||
3. 结构化JSON输出
|
||||
4. 错误处理和调试
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
# 添加项目根路径到Python路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.append(project_root)
|
||||
|
||||
from src.agent_system import SubAgent
|
||||
from example_models import BasicResponse, SentimentAnalysis
|
||||
|
||||
|
||||
def create_simple_chat_agent() -> SubAgent:
|
||||
"""创建简单的聊天Agent"""
|
||||
print("🤖 创建简单聊天Agent...")
|
||||
|
||||
try:
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-turbo",
|
||||
name="simple_chat",
|
||||
description="简单的聊天助手",
|
||||
instructions=[
|
||||
"你是一个友好的AI助手",
|
||||
"请用简洁明了的语言回答问题",
|
||||
"保持积极正面的态度"
|
||||
]
|
||||
)
|
||||
|
||||
print("✅ 简单聊天Agent创建成功")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent创建失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_structured_output_agent() -> SubAgent:
|
||||
"""创建支持结构化输出的Agent"""
|
||||
print("\n🔧 创建结构化输出Agent...")
|
||||
|
||||
try:
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="structured_responder",
|
||||
description="提供结构化响应的智能助手",
|
||||
instructions=[
|
||||
"你是一个专业的响应助手",
|
||||
"始终提供结构化的JSON格式输出",
|
||||
"确保响应准确和有用"
|
||||
],
|
||||
response_model=BasicResponse
|
||||
)
|
||||
|
||||
print("✅ 结构化输出Agent创建成功")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 结构化输出Agent创建失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def create_sentiment_agent() -> SubAgent:
|
||||
"""创建情感分析Agent"""
|
||||
print("\n💭 创建情感分析Agent...")
|
||||
|
||||
instructions = [
|
||||
"你是专业的文本情感分析专家",
|
||||
"准确识别文本的情感倾向:positive(积极)、negative(消极)、neutral(中性)",
|
||||
"提供0-1范围的置信度评分",
|
||||
"给出详细的分析说明和关键词"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
请对以下文本进行情感分析:
|
||||
|
||||
文本内容:"{text}"
|
||||
|
||||
分析要求:
|
||||
1. 识别情感倾向(positive/negative/neutral)
|
||||
2. 评估分析置信度(0-1之间的浮点数)
|
||||
3. 提供分析说明和依据
|
||||
4. 提取影响情感判断的关键词
|
||||
|
||||
请严格按照指定的JSON格式返回结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="sentiment_analyzer",
|
||||
description="专业的文本情感分析系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=SentimentAnalysis
|
||||
)
|
||||
|
||||
print("✅ 情感分析Agent创建成功")
|
||||
return agent
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 情感分析Agent创建失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def demo_simple_chat():
|
||||
"""演示简单对话功能"""
|
||||
print("\n" + "="*50)
|
||||
print("🗣️ 简单对话演示")
|
||||
print("="*50)
|
||||
|
||||
try:
|
||||
agent = create_simple_chat_agent()
|
||||
|
||||
# 测试问题列表
|
||||
test_questions = [
|
||||
"你好!请介绍一下你自己",
|
||||
"什么是人工智能?",
|
||||
"请给我一些学习Python的建议"
|
||||
]
|
||||
|
||||
for i, question in enumerate(test_questions, 1):
|
||||
print(f"\n问题 {i}: {question}")
|
||||
try:
|
||||
response = agent.run(prompt=question)
|
||||
print(f"回答: {response}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 回答失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 简单对话演示失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def demo_structured_response():
|
||||
"""演示结构化响应功能"""
|
||||
print("\n" + "="*50)
|
||||
print("📋 结构化响应演示")
|
||||
print("="*50)
|
||||
|
||||
try:
|
||||
agent = create_structured_output_agent()
|
||||
|
||||
# 测试请求列表
|
||||
test_requests = [
|
||||
"请解释什么是机器学习",
|
||||
"介绍Python编程语言的特点",
|
||||
"如何提高工作效率?"
|
||||
]
|
||||
|
||||
for i, request in enumerate(test_requests, 1):
|
||||
print(f"\n请求 {i}: {request}")
|
||||
try:
|
||||
result = agent.run(prompt=request)
|
||||
print(f"✅ 响应成功:")
|
||||
print(f" 消息: {result.message}")
|
||||
print(f" 成功: {result.success}")
|
||||
print(f" 时间: {result.timestamp}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 响应失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 结构化响应演示失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def demo_dynamic_prompt():
|
||||
"""演示动态prompt模板功能"""
|
||||
print("\n" + "="*50)
|
||||
print("🎭 动态Prompt模板演示")
|
||||
print("="*50)
|
||||
|
||||
try:
|
||||
agent = create_sentiment_agent()
|
||||
|
||||
# 测试文本列表
|
||||
test_texts = [
|
||||
"这个产品质量非常好,我很满意!强烈推荐给大家。",
|
||||
"服务态度差,产品质量也不行,非常失望。",
|
||||
"产品功能还可以,价格也合理,算是中规中矩的选择。",
|
||||
"今天天气不错,适合出去散步。",
|
||||
"这部电影真是太精彩了!演员演技出色,剧情引人入胜。"
|
||||
]
|
||||
|
||||
for i, text in enumerate(test_texts, 1):
|
||||
print(f"\n文本 {i}: {text}")
|
||||
try:
|
||||
# 使用动态模板构建prompt
|
||||
result = agent.run(template_vars={"text": text})
|
||||
|
||||
print(f"✅ 分析结果:")
|
||||
print(f" 情感: {result.sentiment}")
|
||||
print(f" 置信度: {result.confidence}")
|
||||
print(f" 说明: {result.explanation}")
|
||||
print(f" 关键词: {result.keywords}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 分析失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 动态prompt演示失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def demo_error_handling():
|
||||
"""演示错误处理功能"""
|
||||
print("\n" + "="*50)
|
||||
print("⚠️ 错误处理演示")
|
||||
print("="*50)
|
||||
|
||||
# 测试各种错误情况
|
||||
error_tests = [
|
||||
{
|
||||
"name": "无效的提供商",
|
||||
"params": {"provider": "invalid_provider", "model_name": "test"},
|
||||
"expected_error": "ValueError"
|
||||
},
|
||||
{
|
||||
"name": "空的prompt模板变量",
|
||||
"params": {"provider": "aliyun", "model_name": "qwen-turbo"},
|
||||
"template_vars": {},
|
||||
"prompt_template": "分析这个文本: {missing_var}",
|
||||
"expected_error": "SubAgentError"
|
||||
}
|
||||
]
|
||||
|
||||
for test in error_tests:
|
||||
print(f"\n测试: {test['name']}")
|
||||
try:
|
||||
if 'prompt_template' in test:
|
||||
agent = SubAgent(**test['params'])
|
||||
agent.update_prompt_template(test['prompt_template'])
|
||||
agent.run(template_vars=test.get('template_vars', {}))
|
||||
else:
|
||||
agent = SubAgent(**test['params'])
|
||||
|
||||
print(f"❌ 预期错误未发生")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✅ 捕获到预期错误: {type(e).__name__}: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def interactive_demo():
|
||||
"""交互式演示"""
|
||||
print("\n" + "="*50)
|
||||
print("💬 交互式演示")
|
||||
print("="*50)
|
||||
print("输入文本进行情感分析,输入'quit'退出")
|
||||
|
||||
try:
|
||||
agent = create_sentiment_agent()
|
||||
|
||||
while True:
|
||||
user_input = input("\n请输入要分析的文本: ").strip()
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"正在分析: {user_input}")
|
||||
result = agent.run(template_vars={"text": user_input})
|
||||
|
||||
print(f"\n📊 分析结果:")
|
||||
print(f"情感倾向: {result.sentiment}")
|
||||
print(f"置信度: {result.confidence:.3f}")
|
||||
print(f"分析说明: {result.explanation}")
|
||||
if result.keywords:
|
||||
print(f"关键词: {', '.join(result.keywords)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 分析失败: {e}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n程序已中断")
|
||||
except Exception as e:
|
||||
print(f"❌ 交互式演示失败: {e}")
|
||||
|
||||
|
||||
def show_agent_info(agent: SubAgent):
|
||||
"""显示Agent信息"""
|
||||
info = agent.get_model_info()
|
||||
print(f"\n📋 Agent信息:")
|
||||
for key, value in info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 运行所有演示"""
|
||||
print("🚀 SubAgent基础使用示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行所有演示
|
||||
demos = [
|
||||
("简单对话", demo_simple_chat),
|
||||
("结构化响应", demo_structured_response),
|
||||
("动态Prompt", demo_dynamic_prompt),
|
||||
("错误处理", demo_error_handling),
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for name, demo_func in demos:
|
||||
print(f"\n开始演示: {name}")
|
||||
try:
|
||||
success = demo_func()
|
||||
results[name] = success
|
||||
print(f"{'✅' if success else '❌'} {name}演示{'成功' if success else '失败'}")
|
||||
except Exception as e:
|
||||
print(f"❌ {name}演示异常: {e}")
|
||||
results[name] = False
|
||||
|
||||
# 显示总结
|
||||
print("\n" + "="*60)
|
||||
print("📊 演示总结")
|
||||
print("="*60)
|
||||
|
||||
total_demos = len(results)
|
||||
successful_demos = sum(results.values())
|
||||
|
||||
for name, success in results.items():
|
||||
status = "✅ 成功" if success else "❌ 失败"
|
||||
print(f" {name}: {status}")
|
||||
|
||||
print(f"\n🎯 总计: {successful_demos}/{total_demos} 个演示成功")
|
||||
|
||||
# 询问是否运行交互式演示
|
||||
print(f"\n是否运行交互式演示?(y/n): ", end="")
|
||||
try:
|
||||
choice = input().strip().lower()
|
||||
if choice in ['y', 'yes', '是']:
|
||||
interactive_demo()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n程序结束")
|
||||
|
||||
return successful_demos == total_demos
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
"""测试基础功能"""
|
||||
print("正在测试SubAgent基础功能...")
|
||||
|
||||
try:
|
||||
# 创建基本Agent
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-turbo",
|
||||
name="test_agent"
|
||||
)
|
||||
|
||||
print(f"✅ Agent创建成功: {agent}")
|
||||
|
||||
# 显示Agent信息
|
||||
show_agent_info(agent)
|
||||
|
||||
# 测试简单对话
|
||||
response = agent.run(prompt="请简单介绍一下你自己")
|
||||
print(f"✅ 对话测试成功,响应长度: {len(str(response))}字符")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 基础功能测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 可以选择运行测试或完整演示
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--test":
|
||||
# 仅运行基础测试
|
||||
success = test_basic_functionality()
|
||||
exit(0 if success else 1)
|
||||
else:
|
||||
# 运行完整演示
|
||||
main()
|
||||
377
example/subagent-example/example_models.py
Normal file
377
example/subagent-example/example_models.py
Normal file
@ -0,0 +1,377 @@
|
||||
"""
|
||||
SubAgent示例Pydantic模型定义
|
||||
|
||||
提供各种场景下的结构化输出模型,展示SubAgent系统的灵活性
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Literal, Union
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class BasicResponse(BaseModel):
|
||||
"""基础响应模型"""
|
||||
|
||||
message: str = Field(description="响应消息")
|
||||
success: bool = Field(description="处理是否成功", default=True)
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="响应时间")
|
||||
|
||||
|
||||
class SentimentAnalysis(BaseModel):
|
||||
"""情感分析结果模型"""
|
||||
|
||||
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
|
||||
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
|
||||
explanation: str = Field(description="分析说明")
|
||||
keywords: List[str] = Field(description="关键词列表", default_factory=list)
|
||||
|
||||
@validator('confidence')
|
||||
def validate_confidence(cls, v):
|
||||
"""验证置信度范围"""
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError('置信度必须在0.0到1.0之间')
|
||||
return round(v, 3)
|
||||
|
||||
|
||||
class KeywordExtraction(BaseModel):
|
||||
"""关键词提取项"""
|
||||
|
||||
keyword: str = Field(description="关键词")
|
||||
frequency: int = Field(description="出现频次", ge=1)
|
||||
importance: float = Field(description="重要性评分", ge=0.0, le=1.0)
|
||||
category: str = Field(description="关键词分类", default="general")
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
float: lambda v: round(v, 3) if v is not None else None
|
||||
}
|
||||
|
||||
|
||||
class TextAnalysisResult(BaseModel):
|
||||
"""文本分析完整结果"""
|
||||
|
||||
# 基本信息
|
||||
text_length: int = Field(description="文本长度(字符数)", ge=0)
|
||||
word_count: int = Field(description="词汇数量", ge=0)
|
||||
language: str = Field(description="检测到的语言", default="zh")
|
||||
|
||||
# 分析结果
|
||||
summary: str = Field(description="文本摘要")
|
||||
sentiment: SentimentAnalysis = Field(description="情感分析结果")
|
||||
keywords: List[KeywordExtraction] = Field(description="关键词提取结果")
|
||||
|
||||
# 质量评估
|
||||
readability: Literal["high", "medium", "low"] = Field(description="可读性评估")
|
||||
complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0)
|
||||
|
||||
@validator('text_length')
|
||||
def validate_text_length(cls, v):
|
||||
"""验证文本长度"""
|
||||
if v < 0:
|
||||
raise ValueError('文本长度不能为负数')
|
||||
return v
|
||||
|
||||
@validator('keywords')
|
||||
def validate_keywords(cls, v):
|
||||
"""验证关键词列表"""
|
||||
if len(v) > 20:
|
||||
# 只保留前20个最重要的关键词
|
||||
v = sorted(v, key=lambda x: x.importance, reverse=True)[:20]
|
||||
return v
|
||||
|
||||
|
||||
class CategoryClassification(BaseModel):
|
||||
"""分类结果项"""
|
||||
|
||||
category: str = Field(description="分类名称")
|
||||
confidence: float = Field(description="分类置信度", ge=0.0, le=1.0)
|
||||
probability: float = Field(description="分类概率", ge=0.0, le=1.0)
|
||||
|
||||
@validator('confidence', 'probability')
|
||||
def round_float_values(cls, v):
|
||||
"""保留3位小数"""
|
||||
return round(v, 3)
|
||||
|
||||
|
||||
class DocumentClassificationResult(BaseModel):
|
||||
"""文档分类结果"""
|
||||
|
||||
primary_category: str = Field(description="主要分类")
|
||||
confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0)
|
||||
|
||||
all_categories: List[CategoryClassification] = Field(
|
||||
description="所有分类结果(按置信度排序)"
|
||||
)
|
||||
|
||||
features_used: List[str] = Field(
|
||||
description="使用的特征列表",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
processing_time: Optional[float] = Field(
|
||||
description="处理时间(秒)",
|
||||
default=None
|
||||
)
|
||||
|
||||
@validator('all_categories')
|
||||
def sort_categories(cls, v):
|
||||
"""按置信度降序排列"""
|
||||
return sorted(v, key=lambda x: x.confidence, reverse=True)
|
||||
|
||||
|
||||
class DataExtractionItem(BaseModel):
|
||||
"""数据提取项"""
|
||||
|
||||
field_name: str = Field(description="字段名称")
|
||||
field_value: Union[str, int, float, bool, None] = Field(description="字段值")
|
||||
confidence: float = Field(description="提取置信度", ge=0.0, le=1.0)
|
||||
source_text: str = Field(description="来源文本片段")
|
||||
extraction_method: str = Field(description="提取方法", default="llm")
|
||||
|
||||
|
||||
class StructuredDataExtraction(BaseModel):
|
||||
"""结构化数据提取结果"""
|
||||
|
||||
extracted_data: Dict[str, Any] = Field(description="提取的结构化数据")
|
||||
extraction_items: List[DataExtractionItem] = Field(description="详细提取项目")
|
||||
|
||||
# 质量评估
|
||||
extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
||||
description="提取质量评估"
|
||||
)
|
||||
completeness: float = Field(
|
||||
description="完整性评分",
|
||||
ge=0.0,
|
||||
le=1.0
|
||||
)
|
||||
accuracy: float = Field(
|
||||
description="准确性评分",
|
||||
ge=0.0,
|
||||
le=1.0
|
||||
)
|
||||
|
||||
# 统计信息
|
||||
total_fields: int = Field(description="总字段数", ge=0)
|
||||
extracted_fields: int = Field(description="成功提取字段数", ge=0)
|
||||
failed_fields: int = Field(description="提取失败字段数", ge=0)
|
||||
|
||||
@validator('extracted_fields', 'failed_fields')
|
||||
def validate_field_counts(cls, v, values):
|
||||
"""验证字段计数"""
|
||||
total = values.get('total_fields', 0)
|
||||
if v > total:
|
||||
raise ValueError('提取字段数不能超过总字段数')
|
||||
return v
|
||||
|
||||
|
||||
class TaskExecutionResult(BaseModel):
|
||||
"""任务执行结果"""
|
||||
|
||||
# 任务信息
|
||||
task_id: str = Field(description="任务ID")
|
||||
task_type: str = Field(description="任务类型")
|
||||
status: Literal["completed", "failed", "partial"] = Field(description="执行状态")
|
||||
|
||||
# 执行详情
|
||||
result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None)
|
||||
error_message: Optional[str] = Field(description="错误信息", default=None)
|
||||
warnings: List[str] = Field(description="警告信息", default_factory=list)
|
||||
|
||||
# 性能指标
|
||||
execution_time: float = Field(description="执行时间(秒)", ge=0.0)
|
||||
memory_usage: Optional[float] = Field(description="内存使用量(MB)", default=None)
|
||||
|
||||
# 质量评估
|
||||
success_rate: float = Field(description="成功率", ge=0.0, le=1.0)
|
||||
quality_score: float = Field(description="质量评分", ge=0.0, le=1.0)
|
||||
|
||||
@validator('success_rate', 'quality_score')
|
||||
def round_scores(cls, v):
|
||||
"""保留3位小数"""
|
||||
return round(v, 3)
|
||||
|
||||
|
||||
class ComprehensiveAnalysisResult(BaseModel):
|
||||
"""综合分析结果(组合多个分析)"""
|
||||
|
||||
# 基本信息
|
||||
analysis_id: str = Field(description="分析ID")
|
||||
input_summary: str = Field(description="输入数据摘要")
|
||||
analysis_timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
# 各项分析结果
|
||||
text_analysis: Optional[TextAnalysisResult] = Field(
|
||||
description="文本分析结果",
|
||||
default=None
|
||||
)
|
||||
classification: Optional[DocumentClassificationResult] = Field(
|
||||
description="分类结果",
|
||||
default=None
|
||||
)
|
||||
data_extraction: Optional[StructuredDataExtraction] = Field(
|
||||
description="数据提取结果",
|
||||
default=None
|
||||
)
|
||||
|
||||
# 综合评估
|
||||
overall_quality: Literal["excellent", "good", "fair", "poor"] = Field(
|
||||
description="整体质量评估"
|
||||
)
|
||||
confidence_level: float = Field(
|
||||
description="整体置信度",
|
||||
ge=0.0,
|
||||
le=1.0
|
||||
)
|
||||
|
||||
# 处理统计
|
||||
total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0)
|
||||
components_completed: int = Field(description="完成的组件数量", ge=0)
|
||||
components_failed: int = Field(description="失败的组件数量", ge=0)
|
||||
|
||||
recommendations: List[str] = Field(
|
||||
description="改进建议",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@validator('confidence_level')
|
||||
def validate_confidence(cls, v):
|
||||
"""验证并格式化置信度"""
|
||||
return round(v, 3)
|
||||
|
||||
|
||||
# 测试模型定义
|
||||
def test_models():
|
||||
"""测试所有模型定义"""
|
||||
print("正在测试示例模型定义...")
|
||||
|
||||
try:
|
||||
# 测试基础响应模型
|
||||
basic = BasicResponse(message="测试消息")
|
||||
print(f"✅ BasicResponse模型测试成功: {basic.message}")
|
||||
|
||||
# 测试情感分析模型
|
||||
sentiment = SentimentAnalysis(
|
||||
sentiment="positive",
|
||||
confidence=0.95,
|
||||
explanation="积极的情感表达",
|
||||
keywords=["好", "满意", "推荐"]
|
||||
)
|
||||
print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}")
|
||||
|
||||
# 测试关键词提取模型
|
||||
keyword = KeywordExtraction(
|
||||
keyword="人工智能",
|
||||
frequency=5,
|
||||
importance=0.8,
|
||||
category="technology"
|
||||
)
|
||||
print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}")
|
||||
|
||||
# 测试文本分析结果模型
|
||||
text_result = TextAnalysisResult(
|
||||
text_length=150,
|
||||
word_count=25,
|
||||
language="zh",
|
||||
summary="这是一个测试文本摘要",
|
||||
sentiment=sentiment,
|
||||
keywords=[keyword],
|
||||
readability="high",
|
||||
complexity=0.3
|
||||
)
|
||||
print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}")
|
||||
|
||||
# 测试分类结果模型
|
||||
classification = DocumentClassificationResult(
|
||||
primary_category="技术文档",
|
||||
confidence=0.92,
|
||||
all_categories=[
|
||||
CategoryClassification(
|
||||
category="技术文档",
|
||||
confidence=0.92,
|
||||
probability=0.87
|
||||
)
|
||||
]
|
||||
)
|
||||
print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}")
|
||||
|
||||
# 测试数据提取模型
|
||||
extraction_item = DataExtractionItem(
|
||||
field_name="标题",
|
||||
field_value="SubAgent系统指南",
|
||||
confidence=0.98,
|
||||
source_text="# SubAgent系统指南",
|
||||
extraction_method="pattern_matching"
|
||||
)
|
||||
|
||||
data_extraction = StructuredDataExtraction(
|
||||
extracted_data={"title": "SubAgent系统指南"},
|
||||
extraction_items=[extraction_item],
|
||||
extraction_quality="excellent",
|
||||
completeness=1.0,
|
||||
accuracy=0.95,
|
||||
total_fields=1,
|
||||
extracted_fields=1,
|
||||
failed_fields=0
|
||||
)
|
||||
print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}")
|
||||
|
||||
# 测试任务执行结果模型
|
||||
task_result = TaskExecutionResult(
|
||||
task_id="task_001",
|
||||
task_type="text_analysis",
|
||||
status="completed",
|
||||
result_data={"status": "success"},
|
||||
execution_time=2.5,
|
||||
success_rate=1.0,
|
||||
quality_score=0.95
|
||||
)
|
||||
print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}")
|
||||
|
||||
# 测试综合分析结果模型
|
||||
comprehensive = ComprehensiveAnalysisResult(
|
||||
analysis_id="analysis_001",
|
||||
input_summary="测试输入摘要",
|
||||
text_analysis=text_result,
|
||||
classification=classification,
|
||||
data_extraction=data_extraction,
|
||||
overall_quality="excellent",
|
||||
confidence_level=0.93,
|
||||
total_processing_time=5.2,
|
||||
components_completed=3,
|
||||
components_failed=0,
|
||||
recommendations=["继续保持高质量输出"]
|
||||
)
|
||||
print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}")
|
||||
|
||||
# 测试JSON序列化
|
||||
json_str = comprehensive.model_dump_json(indent=2)
|
||||
print(f"✅ JSON序列化测试成功,长度: {len(json_str)}字符")
|
||||
|
||||
# 列出所有模型字段
|
||||
print("\n📋 模型字段信息:")
|
||||
for model_name, model_class in [
|
||||
("BasicResponse", BasicResponse),
|
||||
("SentimentAnalysis", SentimentAnalysis),
|
||||
("TextAnalysisResult", TextAnalysisResult),
|
||||
("DocumentClassificationResult", DocumentClassificationResult),
|
||||
("StructuredDataExtraction", StructuredDataExtraction),
|
||||
("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult),
|
||||
]:
|
||||
fields = list(model_class.model_fields.keys())
|
||||
print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 模型测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_models()
|
||||
if success:
|
||||
print("\n🎉 所有示例模型测试通过!")
|
||||
else:
|
||||
print("\n💥 模型测试失败,请检查定义")
|
||||
762
example/subagent-example/text_analysis_example.py
Normal file
762
example/subagent-example/text_analysis_example.py
Normal file
@ -0,0 +1,762 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
文本分析综合示例
|
||||
|
||||
基于SubAgent系统的复杂应用示例,展示:
|
||||
1. 多Agent协作系统
|
||||
2. 复杂的数据处理pipeline
|
||||
3. 结构化输出和错误恢复
|
||||
4. 性能监控和质量评估
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
# 添加项目根路径到Python路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.append(project_root)
|
||||
|
||||
from src.agent_system import SubAgent, create_json_agent
|
||||
from example_models import (
|
||||
TextAnalysisResult,
|
||||
DocumentClassificationResult,
|
||||
StructuredDataExtraction,
|
||||
ComprehensiveAnalysisResult,
|
||||
SentimentAnalysis,
|
||||
KeywordExtraction,
|
||||
DataExtractionItem,
|
||||
CategoryClassification,
|
||||
TaskExecutionResult
|
||||
)
|
||||
|
||||
|
||||
class TextAnalysisEngine:
|
||||
"""文本分析引擎 - 多Agent协作系统"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化文本分析引擎"""
|
||||
self.agents = {}
|
||||
self.processing_stats = {
|
||||
"total_processed": 0,
|
||||
"successful_analyses": 0,
|
||||
"failed_analyses": 0,
|
||||
"average_processing_time": 0.0
|
||||
}
|
||||
|
||||
# 初始化所有Agent
|
||||
self._initialize_agents()
|
||||
|
||||
def _initialize_agents(self):
|
||||
"""初始化所有分析Agent"""
|
||||
print("🔧 初始化文本分析引擎...")
|
||||
|
||||
try:
|
||||
# 1. 情感分析Agent
|
||||
self.agents['sentiment'] = self._create_sentiment_agent()
|
||||
|
||||
# 2. 关键词提取Agent
|
||||
self.agents['keywords'] = self._create_keyword_agent()
|
||||
|
||||
# 3. 文本分类Agent
|
||||
self.agents['classification'] = self._create_classification_agent()
|
||||
|
||||
# 4. 数据提取Agent
|
||||
self.agents['extraction'] = self._create_extraction_agent()
|
||||
|
||||
# 5. 综合分析Agent
|
||||
self.agents['comprehensive'] = self._create_comprehensive_agent()
|
||||
|
||||
print(f"✅ 成功初始化 {len(self.agents)} 个Agent")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def _create_sentiment_agent(self) -> SubAgent:
|
||||
"""创建情感分析Agent"""
|
||||
instructions = [
|
||||
"你是专业的文本情感分析专家",
|
||||
"准确识别文本的情感倾向和情感强度",
|
||||
"提供详细的分析依据和相关关键词",
|
||||
"对分析结果给出可信度评估"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
请对以下文本进行深入的情感分析:
|
||||
|
||||
【文本内容】
|
||||
{text}
|
||||
|
||||
【分析要求】
|
||||
1. 识别主要情感倾向(positive/negative/neutral)
|
||||
2. 评估情感强度和置信度(0-1)
|
||||
3. 提供分析说明和判断依据
|
||||
4. 提取影响情感判断的关键词和短语
|
||||
5. 考虑语言的细微差别和上下文含义
|
||||
|
||||
请提供准确、专业的分析结果。
|
||||
"""
|
||||
|
||||
return SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="sentiment_analyzer",
|
||||
description="专业的情感分析系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=SentimentAnalysis
|
||||
)
|
||||
|
||||
def _create_keyword_agent(self) -> SubAgent:
|
||||
"""创建关键词提取Agent"""
|
||||
instructions = [
|
||||
"你是专业的关键词提取专家",
|
||||
"从文本中识别最重要和最相关的关键词",
|
||||
"评估关键词的重要性和频率",
|
||||
"对关键词进行合理的分类"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
请从以下文本中提取关键词:
|
||||
|
||||
【文本内容】
|
||||
{text}
|
||||
|
||||
【提取要求】
|
||||
1. 识别最重要的关键词和短语
|
||||
2. 统计关键词出现频率
|
||||
3. 评估每个关键词的重要性(0-1)
|
||||
4. 对关键词进行分类(如:人物、地点、概念、技术等)
|
||||
5. 排除停用词和无意义词汇
|
||||
|
||||
请提供结构化的关键词提取结果。
|
||||
"""
|
||||
|
||||
# 创建专门处理关键词列表的响应模型
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
class KeywordExtractionResult(BaseModel):
|
||||
keywords: List[KeywordExtraction] = Field(description="提取的关键词列表")
|
||||
total_count: int = Field(description="关键词总数", ge=0)
|
||||
text_complexity: float = Field(description="文本复杂度", ge=0.0, le=1.0)
|
||||
|
||||
return SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="keyword_extractor",
|
||||
description="智能关键词提取系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=KeywordExtractionResult
|
||||
)
|
||||
|
||||
def _create_classification_agent(self) -> SubAgent:
|
||||
"""创建文档分类Agent"""
|
||||
instructions = [
|
||||
"你是专业的文档分类专家",
|
||||
"准确识别文档的类型和主题",
|
||||
"提供多级分类和置信度评估",
|
||||
"考虑文档的内容、风格和用途"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
请对以下文档进行分类:
|
||||
|
||||
【文档内容】
|
||||
{text}
|
||||
|
||||
【分类体系】
|
||||
主要分类:技术文档、商业文档、学术论文、新闻报道、个人写作、法律文档、医学文档等
|
||||
详细分类:根据具体内容进一步细分
|
||||
|
||||
【分类要求】
|
||||
1. 确定主要分类和置信度
|
||||
2. 提供所有可能分类的概率分布
|
||||
3. 识别用于分类判断的关键特征
|
||||
4. 评估分类的可信度
|
||||
|
||||
请提供准确的分类结果。
|
||||
"""
|
||||
|
||||
return SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="document_classifier",
|
||||
description="智能文档分类系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=DocumentClassificationResult
|
||||
)
|
||||
|
||||
def _create_extraction_agent(self) -> SubAgent:
|
||||
"""创建数据提取Agent"""
|
||||
instructions = [
|
||||
"你是专业的结构化数据提取专家",
|
||||
"从非结构化文本中提取有价值的信息",
|
||||
"确保提取的数据准确性和完整性",
|
||||
"评估提取质量和可靠性"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
请从以下文本中提取结构化数据:
|
||||
|
||||
【文本内容】
|
||||
{text}
|
||||
|
||||
【提取目标】
|
||||
根据文本内容自动识别可提取的数据类型,可能包括:
|
||||
- 人名、地名、机构名
|
||||
- 日期、时间、数量
|
||||
- 联系方式、地址
|
||||
- 专业术语、概念
|
||||
- 关键指标、统计数据
|
||||
|
||||
【提取要求】
|
||||
1. 自动识别文本中的结构化信息
|
||||
2. 为每个提取项提供置信度评估
|
||||
3. 记录提取依据和来源文本片段
|
||||
4. 评估整体提取质量和完整性
|
||||
|
||||
请提供详细的数据提取结果。
|
||||
"""
|
||||
|
||||
return SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="data_extractor",
|
||||
description="智能数据提取系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=StructuredDataExtraction
|
||||
)
|
||||
|
||||
def _create_comprehensive_agent(self) -> SubAgent:
|
||||
"""创建综合分析Agent"""
|
||||
instructions = [
|
||||
"你是文本综合分析专家",
|
||||
"整合多种分析结果提供整体评估",
|
||||
"识别分析中的一致性和矛盾之处",
|
||||
"提供改进建议和深度见解"
|
||||
]
|
||||
|
||||
prompt_template = """
|
||||
基于以下多维度分析结果,请提供综合评估:
|
||||
|
||||
【原始文本】
|
||||
{original_text}
|
||||
|
||||
【分析结果】
|
||||
情感分析: {sentiment_result}
|
||||
关键词提取: {keyword_result}
|
||||
文档分类: {classification_result}
|
||||
数据提取: {extraction_result}
|
||||
|
||||
【综合评估要求】
|
||||
1. 评估各项分析结果的一致性
|
||||
2. 识别潜在的分析矛盾或问题
|
||||
3. 提供整体质量评估
|
||||
4. 给出置信度评估
|
||||
5. 提出改进建议
|
||||
|
||||
请提供专业的综合分析报告。
|
||||
"""
|
||||
|
||||
return SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-max",
|
||||
name="comprehensive_analyzer",
|
||||
description="文本综合分析系统",
|
||||
instructions=instructions,
|
||||
prompt_template=prompt_template,
|
||||
response_model=ComprehensiveAnalysisResult
|
||||
)
|
||||
|
||||
def analyze_text(self, text: str, analysis_id: Optional[str] = None) -> ComprehensiveAnalysisResult:
|
||||
"""执行完整的文本分析"""
|
||||
if analysis_id is None:
|
||||
analysis_id = f"analysis_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
start_time = time.time()
|
||||
self.processing_stats["total_processed"] += 1
|
||||
|
||||
print(f"\n🔍 开始分析 [{analysis_id}]")
|
||||
print(f"文本长度: {len(text)} 字符")
|
||||
|
||||
try:
|
||||
# 阶段1: 基础分析
|
||||
print("📊 执行基础分析...")
|
||||
sentiment_result = self._analyze_sentiment(text)
|
||||
keyword_result = self._extract_keywords(text)
|
||||
|
||||
# 阶段2: 高级分析
|
||||
print("🧠 执行高级分析...")
|
||||
classification_result = self._classify_document(text)
|
||||
extraction_result = self._extract_data(text)
|
||||
|
||||
# 阶段3: 综合分析
|
||||
print("🎯 执行综合分析...")
|
||||
comprehensive_result = self._comprehensive_analysis(
|
||||
text, sentiment_result, keyword_result,
|
||||
classification_result, extraction_result, analysis_id
|
||||
)
|
||||
|
||||
# 更新统计信息
|
||||
processing_time = time.time() - start_time
|
||||
self.processing_stats["successful_analyses"] += 1
|
||||
self._update_processing_stats(processing_time)
|
||||
|
||||
print(f"✅ 分析完成 [{analysis_id}] - 耗时: {processing_time:.2f}秒")
|
||||
return comprehensive_result
|
||||
|
||||
except Exception as e:
|
||||
self.processing_stats["failed_analyses"] += 1
|
||||
print(f"❌ 分析失败 [{analysis_id}]: {e}")
|
||||
raise
|
||||
|
||||
def _analyze_sentiment(self, text: str) -> SentimentAnalysis:
|
||||
"""执行情感分析"""
|
||||
try:
|
||||
result = self.agents['sentiment'].run(template_vars={"text": text})
|
||||
print(f" 情感: {result.sentiment} (置信度: {result.confidence:.3f})")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 情感分析失败: {e}")
|
||||
# 返回默认结果
|
||||
return SentimentAnalysis(
|
||||
sentiment="neutral",
|
||||
confidence=0.0,
|
||||
explanation=f"分析失败: {e}",
|
||||
keywords=[]
|
||||
)
|
||||
|
||||
def _extract_keywords(self, text: str):
|
||||
"""提取关键词"""
|
||||
try:
|
||||
result = self.agents['keywords'].run(template_vars={"text": text})
|
||||
print(f" 关键词: {result.total_count} 个")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 关键词提取失败: {e}")
|
||||
# 返回空结果
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
class KeywordExtractionResult(BaseModel):
|
||||
keywords: List[KeywordExtraction] = Field(default_factory=list)
|
||||
total_count: int = Field(default=0)
|
||||
text_complexity: float = Field(default=0.5)
|
||||
|
||||
return KeywordExtractionResult()
|
||||
|
||||
def _classify_document(self, text: str) -> DocumentClassificationResult:
|
||||
"""执行文档分类"""
|
||||
try:
|
||||
result = self.agents['classification'].run(template_vars={"text": text})
|
||||
print(f" 分类: {result.primary_category} (置信度: {result.confidence:.3f})")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 文档分类失败: {e}")
|
||||
# 返回默认结果
|
||||
return DocumentClassificationResult(
|
||||
primary_category="未知",
|
||||
confidence=0.0,
|
||||
all_categories=[
|
||||
CategoryClassification(
|
||||
category="未知",
|
||||
confidence=0.0,
|
||||
probability=0.0
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _extract_data(self, text: str) -> StructuredDataExtraction:
|
||||
"""提取结构化数据"""
|
||||
try:
|
||||
result = self.agents['extraction'].run(template_vars={"text": text})
|
||||
print(f" 数据提取: {result.extracted_fields}/{result.total_fields} 字段")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 数据提取失败: {e}")
|
||||
# 返回空结果
|
||||
return StructuredDataExtraction(
|
||||
extracted_data={},
|
||||
extraction_items=[],
|
||||
extraction_quality="poor",
|
||||
completeness=0.0,
|
||||
accuracy=0.0,
|
||||
total_fields=0,
|
||||
extracted_fields=0,
|
||||
failed_fields=0
|
||||
)
|
||||
|
||||
def _comprehensive_analysis(
|
||||
self,
|
||||
original_text: str,
|
||||
sentiment_result: SentimentAnalysis,
|
||||
keyword_result,
|
||||
classification_result: DocumentClassificationResult,
|
||||
extraction_result: StructuredDataExtraction,
|
||||
analysis_id: str
|
||||
) -> ComprehensiveAnalysisResult:
|
||||
"""执行综合分析"""
|
||||
try:
|
||||
# 准备模板变量
|
||||
template_vars = {
|
||||
"original_text": original_text[:500] + ("..." if len(original_text) > 500 else ""),
|
||||
"sentiment_result": f"情感:{sentiment_result.sentiment}, 置信度:{sentiment_result.confidence}",
|
||||
"keyword_result": f"关键词数量:{getattr(keyword_result, 'total_count', 0)}",
|
||||
"classification_result": f"分类:{classification_result.primary_category}, 置信度:{classification_result.confidence}",
|
||||
"extraction_result": f"提取质量:{extraction_result.extraction_quality}"
|
||||
}
|
||||
|
||||
result = self.agents['comprehensive'].run(template_vars=template_vars)
|
||||
|
||||
# 补充一些字段
|
||||
result.analysis_id = analysis_id
|
||||
result.input_summary = f"长度:{len(original_text)}字符, 类型:{classification_result.primary_category}"
|
||||
result.text_analysis = self._build_text_analysis_result(
|
||||
original_text, sentiment_result, keyword_result
|
||||
)
|
||||
result.classification = classification_result
|
||||
result.data_extraction = extraction_result
|
||||
|
||||
print(f" 综合评估: {result.overall_quality}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ 综合分析失败: {e}")
|
||||
# 构建基本的综合结果
|
||||
return ComprehensiveAnalysisResult(
|
||||
analysis_id=analysis_id,
|
||||
input_summary=f"长度:{len(original_text)}字符",
|
||||
overall_quality="poor",
|
||||
confidence_level=0.0,
|
||||
total_processing_time=0.0,
|
||||
components_completed=0,
|
||||
components_failed=4,
|
||||
recommendations=["分析失败,请检查输入文本和系统配置"]
|
||||
)
|
||||
|
||||
def _build_text_analysis_result(
|
||||
self,
|
||||
text: str,
|
||||
sentiment: SentimentAnalysis,
|
||||
keyword_result
|
||||
) -> TextAnalysisResult:
|
||||
"""构建文本分析结果"""
|
||||
|
||||
# 获取关键词列表
|
||||
keywords = getattr(keyword_result, 'keywords', [])
|
||||
|
||||
return TextAnalysisResult(
|
||||
text_length=len(text),
|
||||
word_count=len(text.split()),
|
||||
language="zh",
|
||||
summary=f"文本分析摘要: 情感倾向为{sentiment.sentiment}",
|
||||
sentiment=sentiment,
|
||||
keywords=keywords,
|
||||
readability="medium",
|
||||
complexity=getattr(keyword_result, 'text_complexity', 0.5)
|
||||
)
|
||||
|
||||
def _update_processing_stats(self, processing_time: float):
|
||||
"""更新处理统计信息"""
|
||||
total = self.processing_stats["total_processed"]
|
||||
current_avg = self.processing_stats["average_processing_time"]
|
||||
|
||||
# 计算新的平均处理时间
|
||||
new_avg = ((current_avg * (total - 1)) + processing_time) / total
|
||||
self.processing_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_processing_stats(self) -> Dict[str, Any]:
|
||||
"""获取处理统计信息"""
|
||||
return self.processing_stats.copy()
|
||||
|
||||
def display_stats(self):
|
||||
"""显示处理统计"""
|
||||
stats = self.get_processing_stats()
|
||||
print("\n📈 处理统计信息:")
|
||||
print(f" 总处理数: {stats['total_processed']}")
|
||||
print(f" 成功数: {stats['successful_analyses']}")
|
||||
print(f" 失败数: {stats['failed_analyses']}")
|
||||
if stats['total_processed'] > 0:
|
||||
success_rate = stats['successful_analyses'] / stats['total_processed'] * 100
|
||||
print(f" 成功率: {success_rate:.1f}%")
|
||||
print(f" 平均处理时间: {stats['average_processing_time']:.2f}秒")
|
||||
|
||||
|
||||
def demo_single_text_analysis():
|
||||
"""演示单个文本分析"""
|
||||
print("🔍 单文本分析演示")
|
||||
print("="*50)
|
||||
|
||||
# 创建分析引擎
|
||||
engine = TextAnalysisEngine()
|
||||
|
||||
# 测试文本
|
||||
test_text = """
|
||||
人工智能技术正在快速发展,深度学习和机器学习算法在各个领域都取得了显著的进展。
|
||||
从自然语言处理到计算机视觉,从推荐系统到自动驾驶,AI技术正在改变我们的生活方式。
|
||||
|
||||
然而,我们也需要关注AI发展带来的挑战,包括隐私保护、算法偏见、就业影响等问题。
|
||||
只有在技术发展和社会责任之间找到平衡,AI才能真正造福人类社会。
|
||||
|
||||
总的来说,人工智能的未来充满希望,但也需要我们谨慎对待,确保技术发展的方向符合人类的长远利益。
|
||||
"""
|
||||
|
||||
try:
|
||||
# 执行分析
|
||||
result = engine.analyze_text(test_text)
|
||||
|
||||
# 显示详细结果
|
||||
display_analysis_result(result)
|
||||
|
||||
# 显示统计信息
|
||||
engine.display_stats()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 演示失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def demo_batch_analysis():
|
||||
"""演示批量文本分析"""
|
||||
print("\n🔄 批量文本分析演示")
|
||||
print("="*50)
|
||||
|
||||
# 创建分析引擎
|
||||
engine = TextAnalysisEngine()
|
||||
|
||||
# 测试文本集合
|
||||
test_texts = [
|
||||
"今天天气真好,阳光明媚,心情特别愉快!",
|
||||
|
||||
"公司最新发布的季度财报显示,营收同比增长15%,净利润达到2.3亿元。董事会决定向股东分红每股0.5元。",
|
||||
|
||||
"机器学习是人工智能的一个重要分支,通过算法让计算机能够从数据中学习模式。常见的机器学习算法包括线性回归、决策树、神经网络等。",
|
||||
|
||||
"服务态度恶劣,产品质量很差,完全不值这个价格。强烈不推荐大家购买!",
|
||||
|
||||
"根据《合同法》第一百二十一条规定,当事人一方因第三人的原因造成违约的,应当向对方承担违约责任。"
|
||||
]
|
||||
|
||||
results = []
|
||||
start_time = time.time()
|
||||
|
||||
for i, text in enumerate(test_texts, 1):
|
||||
print(f"\n处理第 {i}/{len(test_texts)} 个文本...")
|
||||
try:
|
||||
result = engine.analyze_text(text, f"batch_{i}")
|
||||
results.append(result)
|
||||
|
||||
# 显示简要结果
|
||||
print(f" 结果: {result.overall_quality} | 置信度: {result.confidence_level:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 处理失败: {e}")
|
||||
results.append(None)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# 显示批量处理总结
|
||||
print(f"\n📊 批量处理完成 - 总耗时: {total_time:.2f}秒")
|
||||
engine.display_stats()
|
||||
|
||||
# 显示成功处理的结果统计
|
||||
successful_results = [r for r in results if r is not None]
|
||||
if successful_results:
|
||||
print(f"\n🎯 成功处理 {len(successful_results)} 个文本:")
|
||||
|
||||
quality_stats = {}
|
||||
for result in successful_results:
|
||||
quality = result.overall_quality
|
||||
quality_stats[quality] = quality_stats.get(quality, 0) + 1
|
||||
|
||||
for quality, count in quality_stats.items():
|
||||
print(f" {quality}: {count} 个")
|
||||
|
||||
return len(successful_results) > 0
|
||||
|
||||
|
||||
def display_analysis_result(result: ComprehensiveAnalysisResult):
|
||||
"""显示详细的分析结果"""
|
||||
print(f"\n📋 详细分析结果 [{result.analysis_id}]")
|
||||
print("="*60)
|
||||
|
||||
print(f"输入摘要: {result.input_summary}")
|
||||
print(f"分析时间: {result.analysis_timestamp}")
|
||||
print(f"整体质量: {result.overall_quality}")
|
||||
print(f"置信度: {result.confidence_level:.3f}")
|
||||
print(f"处理时间: {result.total_processing_time:.2f}秒")
|
||||
|
||||
# 文本分析结果
|
||||
if result.text_analysis:
|
||||
ta = result.text_analysis
|
||||
print(f"\n📝 文本分析:")
|
||||
print(f" 长度: {ta.text_length} 字符")
|
||||
print(f" 词数: {ta.word_count}")
|
||||
print(f" 摘要: {ta.summary}")
|
||||
print(f" 情感: {ta.sentiment.sentiment} (置信度: {ta.sentiment.confidence:.3f})")
|
||||
print(f" 可读性: {ta.readability}")
|
||||
if ta.keywords:
|
||||
top_keywords = ta.keywords[:5]
|
||||
print(f" 关键词: {[k.keyword for k in top_keywords]}")
|
||||
|
||||
# 分类结果
|
||||
if result.classification:
|
||||
cls = result.classification
|
||||
print(f"\n🏷️ 文档分类:")
|
||||
print(f" 主分类: {cls.primary_category}")
|
||||
print(f" 置信度: {cls.confidence:.3f}")
|
||||
if len(cls.all_categories) > 1:
|
||||
other_cats = cls.all_categories[1:3]
|
||||
print(f" 其他可能: {[c.category for c in other_cats]}")
|
||||
|
||||
# 数据提取结果
|
||||
if result.data_extraction:
|
||||
de = result.data_extraction
|
||||
print(f"\n🔍 数据提取:")
|
||||
print(f" 质量: {de.extraction_quality}")
|
||||
print(f" 完整性: {de.completeness:.3f}")
|
||||
print(f" 准确性: {de.accuracy:.3f}")
|
||||
print(f" 字段统计: {de.extracted_fields}/{de.total_fields}")
|
||||
|
||||
if de.extraction_items:
|
||||
print(" 提取项目:")
|
||||
for item in de.extraction_items[:3]: # 显示前3个
|
||||
print(f" {item.field_name}: {item.field_value} (置信度: {item.confidence:.3f})")
|
||||
|
||||
# 改进建议
|
||||
if result.recommendations:
|
||||
print(f"\n💡 改进建议:")
|
||||
for i, rec in enumerate(result.recommendations[:3], 1):
|
||||
print(f" {i}. {rec}")
|
||||
|
||||
|
||||
def interactive_analysis():
|
||||
"""交互式分析功能"""
|
||||
print("\n💬 交互式文本分析")
|
||||
print("="*50)
|
||||
print("输入文本进行综合分析,输入'quit'退出")
|
||||
|
||||
try:
|
||||
engine = TextAnalysisEngine()
|
||||
|
||||
while True:
|
||||
print("\n" + "-"*30)
|
||||
user_input = input("请输入要分析的文本: ").strip()
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("分析结束,再见!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if len(user_input) < 10:
|
||||
print("⚠️ 文本太短,请输入至少10个字符")
|
||||
continue
|
||||
|
||||
try:
|
||||
result = engine.analyze_text(user_input)
|
||||
display_analysis_result(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 分析失败: {e}")
|
||||
|
||||
# 显示最终统计
|
||||
engine.display_stats()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n程序已中断")
|
||||
except Exception as e:
|
||||
print(f"❌ 交互式分析失败: {e}")
|
||||
|
||||
|
||||
def test_engine_initialization():
|
||||
"""测试引擎初始化"""
|
||||
print("正在测试文本分析引擎初始化...")
|
||||
|
||||
try:
|
||||
engine = TextAnalysisEngine()
|
||||
print(f"✅ 引擎初始化成功,包含 {len(engine.agents)} 个Agent")
|
||||
|
||||
# 显示Agent信息
|
||||
for name, agent in engine.agents.items():
|
||||
info = agent.get_model_info()
|
||||
print(f" {name}: {info['model_name']} ({info['provider']})")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 引擎初始化失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🚀 文本分析综合示例")
|
||||
print("="*60)
|
||||
|
||||
# 运行各种演示
|
||||
demos = [
|
||||
("引擎初始化测试", test_engine_initialization),
|
||||
("单文本分析", demo_single_text_analysis),
|
||||
("批量文本分析", demo_batch_analysis),
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for name, demo_func in demos:
|
||||
print(f"\n开始: {name}")
|
||||
try:
|
||||
success = demo_func()
|
||||
results[name] = success
|
||||
print(f"{'✅' if success else '❌'} {name} {'成功' if success else '失败'}")
|
||||
except Exception as e:
|
||||
print(f"❌ {name} 异常: {e}")
|
||||
results[name] = False
|
||||
|
||||
# 显示总结
|
||||
print(f"\n📊 演示总结")
|
||||
print("="*60)
|
||||
|
||||
successful_demos = sum(results.values())
|
||||
total_demos = len(results)
|
||||
|
||||
for name, success in results.items():
|
||||
status = "✅ 成功" if success else "❌ 失败"
|
||||
print(f" {name}: {status}")
|
||||
|
||||
print(f"\n🎯 总计: {successful_demos}/{total_demos} 个演示成功")
|
||||
|
||||
# 询问是否运行交互式演示
|
||||
if successful_demos > 0:
|
||||
try:
|
||||
choice = input("\n是否运行交互式分析?(y/n): ").strip().lower()
|
||||
if choice in ['y', 'yes', '是']:
|
||||
interactive_analysis()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n程序结束")
|
||||
|
||||
return successful_demos == total_demos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 可以选择运行测试或完整演示
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--test":
|
||||
# 仅运行初始化测试
|
||||
success = test_engine_initialization()
|
||||
exit(0 if success else 1)
|
||||
else:
|
||||
# 运行完整演示
|
||||
main()
|
||||
147
info_extractor.py
Normal file
147
info_extractor.py
Normal file
@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于LangExtract的MIMIC论文信息提取器
|
||||
从医学论文中提取结构化的复现任务信息
|
||||
|
||||
作者:MedResearcher项目
|
||||
创建时间:2025-01-25
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from src.extractor import MIMICLangExtractBuilder
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def setup_args():
|
||||
"""设置命令行参数解析
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 解析后的命令行参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MIMIC论文信息提取工具 - 基于LangExtract从医学论文中提取结构化复现信息',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog='''
|
||||
使用示例:
|
||||
%(prog)s # 使用默认参数
|
||||
%(prog)s --papers_dir dataset/markdowns # 指定论文目录
|
||||
%(prog)s --output_file results/dataset.json # 指定输出文件
|
||||
%(prog)s --test_mode --max_papers 5 # 测试模式,只处理5篇论文
|
||||
'''
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--papers_dir',
|
||||
type=str,
|
||||
default='dataset/markdowns',
|
||||
help='markdown论文文件目录 (默认: dataset/markdowns)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_file',
|
||||
type=str,
|
||||
default='dataset/reproduction_tasks/mimic_langextract_dataset.json',
|
||||
help='输出数据集文件路径 (默认: dataset/reproduction_tasks/mimic_langextract_dataset.json)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--test_mode',
|
||||
action='store_true',
|
||||
help='测试模式,只处理少量论文进行验证'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max_papers',
|
||||
type=int,
|
||||
default=None,
|
||||
help='最大处理论文数量,用于测试 (默认: 处理所有论文)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
default='INFO',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
||||
help='日志级别 (默认: INFO)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--doc_workers',
|
||||
type=int,
|
||||
default=50,
|
||||
help='文档并行处理工作线程数 (默认: 4)'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 执行MIMIC论文信息提取任务"""
|
||||
try:
|
||||
# 解析命令行参数
|
||||
args = setup_args()
|
||||
|
||||
# 设置日志级别
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# 初始化信息提取器
|
||||
builder = MIMICLangExtractBuilder(doc_workers=args.doc_workers)
|
||||
|
||||
print(f"=== MIMIC论文信息提取工具启动 ===")
|
||||
print(f"论文目录: {args.papers_dir}")
|
||||
print(f"输出文件: {args.output_file}")
|
||||
print(f"测试模式: {'是' if args.test_mode else '否'}")
|
||||
if args.max_papers:
|
||||
print(f"最大论文数: {args.max_papers}")
|
||||
print(f"文档并行度: {args.doc_workers} 线程")
|
||||
print(f"日志级别: {args.log_level}")
|
||||
print(f"========================")
|
||||
|
||||
# 构建复现数据集
|
||||
print("\n开始构建MIMIC复现数据集...")
|
||||
dataset = builder.build_reproduction_dataset(
|
||||
papers_dir=args.papers_dir,
|
||||
output_file=args.output_file,
|
||||
max_papers=args.max_papers if args.test_mode or args.max_papers else None
|
||||
)
|
||||
|
||||
# 统计结果
|
||||
total_papers = dataset['metadata']['total_papers']
|
||||
successful_extractions = sum(
|
||||
1 for paper in dataset['papers'].values()
|
||||
if any(module.get('extraction_count', 0) > 0
|
||||
for module in paper.get('modules', {}).values())
|
||||
)
|
||||
|
||||
print(f"\n=== 构建完成 ===")
|
||||
print(f"总论文数: {total_papers}")
|
||||
print(f"成功提取: {successful_extractions}/{total_papers}")
|
||||
print(f"成功率: {successful_extractions/total_papers*100:.1f}%")
|
||||
print(f"结果保存至: {args.output_file}")
|
||||
print(f"交互式报告: {args.output_file.replace('.json', '.html')}")
|
||||
print(f"===============")
|
||||
|
||||
return 0
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 找不到指定的文件或目录 - {e}")
|
||||
return 1
|
||||
except ValueError as e:
|
||||
print(f"错误: 参数值无效 - {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"错误: 程序执行异常 - {e}")
|
||||
logging.exception("详细错误信息:")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
exit(exit_code)
|
||||
@ -22,7 +22,7 @@ def setup_args():
|
||||
|
||||
parser.add_argument(
|
||||
'--paper_website',
|
||||
default=["arxiv","medrxiv"],
|
||||
default=["medrxiv"],
|
||||
help='论文网站 (默认: arxiv,medrxiv)',
|
||||
nargs='+',
|
||||
choices=["arxiv","medrxiv"]
|
||||
@ -35,6 +35,20 @@ def setup_args():
|
||||
help='并行处理线程数 (默认: 20)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--csv-download',
|
||||
type=str,
|
||||
default="yes",
|
||||
help='指定CSV文件路径'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pdf_download_list',
|
||||
type=str,
|
||||
default='dataset/mimic_papers_20250825.csv',
|
||||
help='指定PDF下载目录'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -45,18 +59,20 @@ def main():
|
||||
# 解析命令行参数
|
||||
args = setup_args()
|
||||
|
||||
print(f"=== 论文爬取工具启动 ===")
|
||||
print(f"论文数据源: {args.paper_website}")
|
||||
print(f"并行处理数: {args.parallel}")
|
||||
print(f"========================")
|
||||
|
||||
# 初始化论文爬取器
|
||||
crawler = PaperCrawler(
|
||||
websites=args.paper_website,
|
||||
parallel=args.parallel
|
||||
)
|
||||
|
||||
|
||||
print(f"=== 论文爬取工具启动 ===")
|
||||
print(f"论文数据源: {args.paper_website}")
|
||||
print(f"并行处理数: {args.parallel}")
|
||||
print(f"========================")
|
||||
|
||||
# 执行论文爬取
|
||||
if args.csv_download:
|
||||
print("开始爬取MIMIC-4相关论文...")
|
||||
papers = crawler.crawl_papers()
|
||||
|
||||
@ -70,6 +86,23 @@ def main():
|
||||
else:
|
||||
print("未找到相关论文,请检查网络连接或关键词设置")
|
||||
|
||||
# 如果指定了PDF下载测试,执行测试
|
||||
if args.pdf_download_list:
|
||||
print(f"=== PDF下载功能测试 ===")
|
||||
print(f"CSV文件: {args.pdf_download_list}")
|
||||
print(f"并发数: {args.parallel}")
|
||||
print(f"========================")
|
||||
|
||||
# 执行PDF下载
|
||||
stats = crawler.download_pdfs_from_csv(args.pdf_download_list)
|
||||
|
||||
print(f"\n=== PDF下载测试完成 ===")
|
||||
print(f"总数: {stats['total']} 篇论文")
|
||||
print(f"成功: {stats['success']} 篇 ({stats['success']/stats['total']*100:.1f}%)")
|
||||
print(f"失败: {stats['failed']} 篇 ({stats['failed']/stats['total']*100:.1f}%)")
|
||||
print(f"========================")
|
||||
return 0
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 找不到指定的文件 - {e}")
|
||||
return 1
|
||||
|
||||
90
pdf_parser.py
Normal file
90
pdf_parser.py
Normal file
@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
|
||||
from src.parse import PDFParser
|
||||
|
||||
|
||||
def setup_args():
|
||||
"""设置命令行参数解析
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 解析后的命令行参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='PDF解析工具 - 用于将PDF文件通过OCR API转换为Markdown格式',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog='''
|
||||
使用示例:
|
||||
%(prog)s # 使用默认参数
|
||||
%(prog)s --pdf-dir dataset/pdfs # 指定PDF目录
|
||||
%(prog)s --parallel 10 # 设置并行度为10
|
||||
%(prog)s --markdown-dir output/markdowns # 指定输出目录
|
||||
'''
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pdf-dir',
|
||||
default="dataset/pdfs",
|
||||
help='PDF文件目录 (默认: dataset/pdfs)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--parallel',
|
||||
type=int,
|
||||
default=5,
|
||||
help='并发处理线程数 (默认: 5,降低并发避免服务器过载)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--markdown-dir',
|
||||
default="dataset/markdowns",
|
||||
help='Markdown输出目录 (默认: dataset/markdowns)'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 执行PDF解析任务"""
|
||||
try:
|
||||
# 解析命令行参数
|
||||
args = setup_args()
|
||||
|
||||
# 初始化PDF解析器
|
||||
parser = PDFParser(
|
||||
pdf_dir=args.pdf_dir,
|
||||
parallel=args.parallel,
|
||||
markdown_dir=args.markdown_dir
|
||||
)
|
||||
|
||||
print(f"=== PDF解析工具启动 ===")
|
||||
print(f"PDF目录: {args.pdf_dir}")
|
||||
print(f"并发数: {args.parallel}")
|
||||
print(f"输出目录: {args.markdown_dir}")
|
||||
print(f"========================")
|
||||
|
||||
# 执行PDF解析
|
||||
print("开始处理PDF文件...")
|
||||
stats = parser.parse_all_pdfs()
|
||||
|
||||
print(f"\n=== 解析完成 ===")
|
||||
print(f"总数: {stats['total']} 个文件")
|
||||
print(f"成功: {stats['success']} 个 ({stats['success']/stats['total']*100:.1f}%)" if stats['total'] > 0 else "成功: 0 个")
|
||||
print(f"失败: {stats['failed']} 个 ({stats['failed']/stats['total']*100:.1f}%)" if stats['total'] > 0 else "失败: 0 个")
|
||||
print(f"================")
|
||||
|
||||
return 0
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 找不到指定的目录 - {e}")
|
||||
return 1
|
||||
except ValueError as e:
|
||||
print(f"错误: 参数值无效 - {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"错误: 程序执行异常 - {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
exit(exit_code)
|
||||
@ -5,5 +5,12 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"agno>=1.7.12",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"langextract>=1.0.8",
|
||||
"ollama>=0.5.3",
|
||||
"openai>=1.101.0",
|
||||
"pydantic",
|
||||
"pyyaml>=6.0.2",
|
||||
"requests>=2.32.5",
|
||||
]
|
||||
|
||||
38
src/agent_system/__init__.py
Normal file
38
src/agent_system/__init__.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
Agent System Module
|
||||
|
||||
基于Agno框架的SubAgent系统实现
|
||||
为MedResearcher项目提供智能Agent功能
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# 导入核心类
|
||||
from .subagent import SubAgent, create_json_agent
|
||||
from .config_loader import load_llm_config, get_model_config
|
||||
from .model_factory import create_agno_model, list_available_models
|
||||
from .json_processor import parse_json_response, JSONProcessor
|
||||
|
||||
# 导入异常类
|
||||
from .subagent import SubAgentError, ConfigurationError, ModelError
|
||||
from .json_processor import JSONParseError
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
'SubAgent',
|
||||
'JSONProcessor',
|
||||
|
||||
# 便捷函数
|
||||
'create_json_agent',
|
||||
'load_llm_config',
|
||||
'get_model_config',
|
||||
'create_agno_model',
|
||||
'list_available_models',
|
||||
'parse_json_response',
|
||||
|
||||
# 异常类
|
||||
'SubAgentError',
|
||||
'ConfigurationError',
|
||||
'ModelError',
|
||||
'JSONParseError',
|
||||
]
|
||||
292
src/agent_system/config_loader.py
Normal file
292
src/agent_system/config_loader.py
Normal file
@ -0,0 +1,292 @@
|
||||
"""
|
||||
配置加载模块
|
||||
|
||||
负责解析LLM配置文件和环境变量,为SubAgent提供统一的配置接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_llm_config(config_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
加载LLM配置文件和环境变量
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为 src/config/llm_config.yaml
|
||||
|
||||
Returns:
|
||||
完整的LLM配置字典,包含所有提供商和环境变量
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 配置文件不存在
|
||||
yaml.YAMLError: YAML解析错误
|
||||
ValueError: 配置格式错误
|
||||
"""
|
||||
|
||||
# 确定配置文件路径
|
||||
if config_path is None:
|
||||
# 获取项目根目录
|
||||
current_dir = Path(__file__).parent
|
||||
project_root = current_dir.parent.parent
|
||||
config_path = project_root / "src" / "config" / "llm_config.yaml"
|
||||
|
||||
config_path = Path(config_path)
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"LLM配置文件不存在: {config_path}")
|
||||
|
||||
# 读取YAML配置文件
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as file:
|
||||
config = yaml.safe_load(file)
|
||||
except yaml.YAMLError as e:
|
||||
raise yaml.YAMLError(f"YAML配置文件解析失败: {e}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"读取配置文件失败: {e}")
|
||||
|
||||
# 验证配置结构
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("配置文件格式错误:根元素必须是字典")
|
||||
|
||||
# 加载环境变量配置
|
||||
env_config = load_env_config()
|
||||
|
||||
# 处理环境变量替换
|
||||
config = _resolve_env_variables(config, env_config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_env_config(env_path: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
加载环境变量配置
|
||||
|
||||
Args:
|
||||
env_path: .env文件路径,默认为 src/config/.env
|
||||
|
||||
Returns:
|
||||
环境变量字典
|
||||
"""
|
||||
|
||||
# 确定环境变量文件路径
|
||||
if env_path is None:
|
||||
current_dir = Path(__file__).parent
|
||||
project_root = current_dir.parent.parent
|
||||
env_path = project_root / "src" / "config" / ".env"
|
||||
|
||||
env_config = {}
|
||||
|
||||
# 尝试加载.env文件
|
||||
env_path = Path(env_path)
|
||||
if env_path.exists():
|
||||
try:
|
||||
with open(env_path, 'r', encoding='utf-8') as file:
|
||||
for line in file:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
env_config[key.strip()] = value.strip()
|
||||
except Exception as e:
|
||||
print(f"警告: 读取.env文件失败: {e}")
|
||||
|
||||
# 同时从系统环境变量中加载
|
||||
env_keys = ['DASHSCOPE_API_KEY', 'OPENAI_API_KEY', 'DEEPSEEK_API_KEY']
|
||||
for key in env_keys:
|
||||
if key in os.environ:
|
||||
env_config[key] = os.environ[key]
|
||||
|
||||
return env_config
|
||||
|
||||
|
||||
def _resolve_env_variables(config: Dict[str, Any], env_config: Dict[str, str]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析配置中的环境变量占位符
|
||||
|
||||
Args:
|
||||
config: 原始配置
|
||||
env_config: 环境变量配置
|
||||
|
||||
Returns:
|
||||
解析后的配置
|
||||
"""
|
||||
|
||||
def resolve_value(value):
|
||||
if isinstance(value, str) and value.startswith('${') and value.endswith('}'):
|
||||
# 提取环境变量名称
|
||||
env_var = value[2:-1] # 去掉 ${ 和 }
|
||||
|
||||
if env_var in env_config:
|
||||
return env_config[env_var]
|
||||
elif env_var in os.environ:
|
||||
return os.environ[env_var]
|
||||
else:
|
||||
print(f"警告: 环境变量 {env_var} 未定义,保持原值")
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: resolve_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [resolve_value(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
return resolve_value(config)
|
||||
|
||||
|
||||
def get_provider_config(config: Dict[str, Any], provider: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取特定提供商的配置
|
||||
|
||||
Args:
|
||||
config: 完整的LLM配置
|
||||
provider: 提供商名称 (如 'aliyun', 'deepseek', 'openai')
|
||||
|
||||
Returns:
|
||||
提供商配置字典
|
||||
|
||||
Raises:
|
||||
ValueError: 提供商不存在
|
||||
"""
|
||||
|
||||
if provider not in config:
|
||||
available_providers = list(config.keys())
|
||||
raise ValueError(f"提供商 '{provider}' 不存在,可用提供商: {available_providers}")
|
||||
|
||||
return config[provider]
|
||||
|
||||
|
||||
def get_model_config(config: Dict[str, Any], provider: str, model_name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取特定模型的配置
|
||||
|
||||
Args:
|
||||
config: 完整的LLM配置
|
||||
provider: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
模型配置字典
|
||||
|
||||
Raises:
|
||||
ValueError: 提供商或模型不存在
|
||||
"""
|
||||
|
||||
provider_config = get_provider_config(config, provider)
|
||||
|
||||
if 'models' not in provider_config:
|
||||
raise ValueError(f"提供商 '{provider}' 配置中缺少 models 字段")
|
||||
|
||||
models = provider_config['models']
|
||||
|
||||
if model_name not in models:
|
||||
available_models = list(models.keys())
|
||||
raise ValueError(f"模型 '{model_name}' 在提供商 '{provider}' 中不存在,可用模型: {available_models}")
|
||||
|
||||
model_config = models[model_name].copy()
|
||||
|
||||
# 添加提供商级别的配置
|
||||
for key in ['base_url', 'api_key']:
|
||||
if key in provider_config:
|
||||
model_config[key] = provider_config[key]
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def validate_config(config: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
验证配置文件的完整性
|
||||
|
||||
Args:
|
||||
config: LLM配置字典
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
|
||||
required_providers = ['aliyun', 'deepseek', 'openai']
|
||||
missing_providers = []
|
||||
|
||||
for provider in required_providers:
|
||||
if provider not in config:
|
||||
missing_providers.append(provider)
|
||||
|
||||
if missing_providers:
|
||||
print(f"警告: 缺少提供商配置: {missing_providers}")
|
||||
|
||||
# 验证每个提供商的配置结构
|
||||
valid = True
|
||||
for provider_name, provider_config in config.items():
|
||||
if not isinstance(provider_config, dict):
|
||||
print(f"错误: 提供商 '{provider_name}' 配置必须是字典")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
if 'models' not in provider_config:
|
||||
print(f"错误: 提供商 '{provider_name}' 缺少 models 字段")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
models = provider_config['models']
|
||||
if not isinstance(models, dict) or not models:
|
||||
print(f"错误: 提供商 '{provider_name}' 的 models 必须是非空字典")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
# 验证每个模型配置
|
||||
for model_name, model_config in models.items():
|
||||
if not isinstance(model_config, dict):
|
||||
print(f"错误: 模型 '{model_name}' 配置必须是字典")
|
||||
valid = False
|
||||
continue
|
||||
|
||||
required_fields = ['class', 'params']
|
||||
for field in required_fields:
|
||||
if field not in model_config:
|
||||
print(f"错误: 模型 '{model_name}' 缺少必需字段: {field}")
|
||||
valid = False
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
# 测试函数
|
||||
def test_config_loading():
|
||||
"""测试配置加载功能"""
|
||||
try:
|
||||
print("正在测试配置加载...")
|
||||
|
||||
# 加载配置
|
||||
config = load_llm_config()
|
||||
print(f"✅ 配置加载成功,找到 {len(config)} 个提供商")
|
||||
|
||||
# 验证配置
|
||||
is_valid = validate_config(config)
|
||||
print(f"✅ 配置验证: {'通过' if is_valid else '失败'}")
|
||||
|
||||
# 测试提供商配置获取
|
||||
try:
|
||||
aliyun_config = get_provider_config(config, 'aliyun')
|
||||
print(f"✅ 阿里云配置获取成功,包含 {len(aliyun_config.get('models', {}))} 个模型")
|
||||
except ValueError as e:
|
||||
print(f"❌ 阿里云配置获取失败: {e}")
|
||||
|
||||
# 测试模型配置获取
|
||||
try:
|
||||
qwen_config = get_model_config(config, 'aliyun', 'qwen-max')
|
||||
print(f"✅ qwen-max模型配置获取成功")
|
||||
print(f" 模型类: {qwen_config.get('class', 'N/A')}")
|
||||
print(f" 模型ID: {qwen_config.get('params', {}).get('id', 'N/A')}")
|
||||
except ValueError as e:
|
||||
print(f"❌ qwen-max模型配置获取失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 配置加载测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_config_loading()
|
||||
404
src/agent_system/json_processor.py
Normal file
404
src/agent_system/json_processor.py
Normal file
@ -0,0 +1,404 @@
|
||||
"""
|
||||
JSON处理器模块
|
||||
|
||||
提供强大的JSON解析和验证功能,支持零容错解析和Pydantic模型集成
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Type, Union, get_origin, get_args
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
|
||||
class JSONParseError(Exception):
|
||||
"""JSON解析错误"""
|
||||
pass
|
||||
|
||||
|
||||
class JSONProcessor:
|
||||
"""
|
||||
JSON处理器类
|
||||
|
||||
提供多种JSON解析策略,确保即使在不完美的JSON输出下也能成功解析
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化JSON处理器"""
|
||||
pass
|
||||
|
||||
def parse_json_response(
|
||||
self,
|
||||
response: str,
|
||||
response_model: Optional[Type[BaseModel]] = None
|
||||
) -> Union[Dict, List, BaseModel]:
|
||||
"""
|
||||
解析JSON响应
|
||||
|
||||
Args:
|
||||
response: 响应字符串
|
||||
response_model: 可选的Pydantic模型类
|
||||
|
||||
Returns:
|
||||
解析后的数据对象或模型实例
|
||||
|
||||
Raises:
|
||||
JSONParseError: 解析失败时抛出
|
||||
"""
|
||||
|
||||
if not response or not response.strip():
|
||||
raise JSONParseError("响应为空,无法解析JSON")
|
||||
|
||||
# 清理响应字符串
|
||||
cleaned_response = self._clean_response(response)
|
||||
|
||||
# 尝试多种解析策略
|
||||
parsed_data = self._try_multiple_parsing_strategies(cleaned_response)
|
||||
|
||||
if parsed_data is None:
|
||||
raise JSONParseError(f"所有解析策略都失败了,响应内容: {response[:200]}...")
|
||||
|
||||
# 如果指定了响应模型,尝试创建模型实例
|
||||
if response_model is not None:
|
||||
return self._create_model_instance(parsed_data, response_model)
|
||||
|
||||
return parsed_data
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""
|
||||
清理响应字符串
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
清理后的响应
|
||||
"""
|
||||
|
||||
# 去除首尾空白
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
elif cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
|
||||
# 移除可能的额外引号包装
|
||||
if cleaned.startswith('"') and cleaned.endswith('"'):
|
||||
try:
|
||||
# 尝试解析为字符串,看是否是被引号包装的JSON
|
||||
unquoted = json.loads(cleaned)
|
||||
if isinstance(unquoted, str):
|
||||
cleaned = unquoted
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
def _try_multiple_parsing_strategies(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""
|
||||
尝试多种解析策略
|
||||
|
||||
Args:
|
||||
response: 清理后的响应
|
||||
|
||||
Returns:
|
||||
解析成功的数据,或None
|
||||
"""
|
||||
|
||||
strategies = [
|
||||
self._strategy_direct_parse,
|
||||
self._strategy_extract_json_block,
|
||||
self._strategy_find_json_structure,
|
||||
self._strategy_regex_extract,
|
||||
self._strategy_fix_common_errors,
|
||||
]
|
||||
|
||||
for strategy in strategies:
|
||||
try:
|
||||
result = strategy(response)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _strategy_direct_parse(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""策略1: 直接解析JSON"""
|
||||
try:
|
||||
return json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _strategy_extract_json_block(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""策略2: 提取JSON代码块"""
|
||||
# 查找被代码块包围的JSON
|
||||
code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
||||
matches = re.findall(code_block_pattern, response, re.IGNORECASE)
|
||||
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match.strip())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _strategy_find_json_structure(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""策略3: 查找JSON结构"""
|
||||
# 查找第一个完整的JSON对象或数组
|
||||
for start_char, end_char in [("{", "}"), ("[", "]")]:
|
||||
json_str = self._extract_complete_json_structure(response, start_char, end_char)
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _strategy_regex_extract(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""策略4: 正则表达式提取"""
|
||||
# 使用正则表达式查找JSON模式
|
||||
patterns = [
|
||||
r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', # 嵌套对象
|
||||
r'\[[^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*\]', # 嵌套数组
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, response)
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _strategy_fix_common_errors(self, response: str) -> Optional[Union[Dict, List]]:
|
||||
"""策略5: 修复常见JSON错误"""
|
||||
# 修复常见的JSON格式错误
|
||||
fixed_response = response
|
||||
|
||||
# 修复单引号为双引号
|
||||
fixed_response = re.sub(r"'([^']*)':", r'"\1":', fixed_response)
|
||||
fixed_response = re.sub(r":\s*'([^']*)'", r': "\1"', fixed_response)
|
||||
|
||||
# 修复末尾逗号
|
||||
fixed_response = re.sub(r',(\s*[}\]])', r'\1', fixed_response)
|
||||
|
||||
# 修复未引用的键
|
||||
fixed_response = re.sub(r'(\w+):', r'"\1":', fixed_response)
|
||||
|
||||
try:
|
||||
return json.loads(fixed_response)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _extract_complete_json_structure(
|
||||
self,
|
||||
text: str,
|
||||
start_char: str,
|
||||
end_char: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
提取完整的JSON结构
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
start_char: 开始字符 ('{' 或 '[')
|
||||
end_char: 结束字符 ('}' 或 ']')
|
||||
|
||||
Returns:
|
||||
提取的JSON字符串或None
|
||||
"""
|
||||
|
||||
start_idx = text.find(start_char)
|
||||
if start_idx == -1:
|
||||
return None
|
||||
|
||||
# 使用计数器匹配嵌套结构
|
||||
count = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
for i, char in enumerate(text[start_idx:], start_idx):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == '\\' and in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if not in_string:
|
||||
if char == start_char:
|
||||
count += 1
|
||||
elif char == end_char:
|
||||
count -= 1
|
||||
|
||||
if count == 0:
|
||||
return text[start_idx:i+1]
|
||||
|
||||
# 如果没有找到匹配的结束符,返回到字符串末尾
|
||||
if count > 0:
|
||||
return text[start_idx:]
|
||||
|
||||
return None
|
||||
|
||||
def _create_model_instance(
|
||||
self,
|
||||
data: Union[Dict, List],
|
||||
response_model: Type[BaseModel]
|
||||
) -> BaseModel:
|
||||
"""
|
||||
创建Pydantic模型实例
|
||||
|
||||
Args:
|
||||
data: 解析后的数据
|
||||
response_model: Pydantic模型类
|
||||
|
||||
Returns:
|
||||
模型实例
|
||||
|
||||
Raises:
|
||||
JSONParseError: 模型验证失败
|
||||
"""
|
||||
|
||||
try:
|
||||
# 如果数据是列表但模型期望对象,尝试包装
|
||||
if isinstance(data, list) and not self._is_list_model(response_model):
|
||||
# 尝试将列表作为某个字段的值
|
||||
field_names = list(response_model.model_fields.keys())
|
||||
if field_names:
|
||||
# 使用第一个字段名作为包装器
|
||||
data = {field_names[0]: data}
|
||||
|
||||
# 创建模型实例
|
||||
return response_model(**data)
|
||||
|
||||
except ValidationError as e:
|
||||
# 详细的验证错误信息
|
||||
error_msg = f"Pydantic模型验证失败: {e}"
|
||||
raise JSONParseError(error_msg)
|
||||
except TypeError as e:
|
||||
error_msg = f"模型实例化失败: {e}"
|
||||
raise JSONParseError(error_msg)
|
||||
|
||||
def _is_list_model(self, model: Type[BaseModel]) -> bool:
|
||||
"""
|
||||
检查模型是否期望列表类型
|
||||
|
||||
Args:
|
||||
model: Pydantic模型类
|
||||
|
||||
Returns:
|
||||
是否为列表模型
|
||||
"""
|
||||
|
||||
# 检查模型的字段类型
|
||||
for field_info in model.model_fields.values():
|
||||
annotation = field_info.annotation
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# 如果有字段是List类型,认为这是一个列表模型
|
||||
if origin is list or origin is List:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate_json_schema(self, data: Dict, schema: Dict) -> bool:
|
||||
"""
|
||||
验证JSON数据是否符合指定的schema
|
||||
|
||||
Args:
|
||||
data: JSON数据
|
||||
schema: JSON schema
|
||||
|
||||
Returns:
|
||||
是否符合schema
|
||||
"""
|
||||
|
||||
try:
|
||||
import jsonschema
|
||||
jsonschema.validate(data, schema)
|
||||
return True
|
||||
except ImportError:
|
||||
# 如果没有安装jsonschema库,跳过验证
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def parse_json_response(
|
||||
response: str,
|
||||
response_model: Optional[Type[BaseModel]] = None
|
||||
) -> Union[Dict, List, BaseModel]:
|
||||
"""
|
||||
便捷函数:解析JSON响应
|
||||
|
||||
Args:
|
||||
response: 响应字符串
|
||||
response_model: 可选的Pydantic模型类
|
||||
|
||||
Returns:
|
||||
解析后的数据或模型实例
|
||||
"""
|
||||
processor = JSONProcessor()
|
||||
return processor.parse_json_response(response, response_model)
|
||||
|
||||
|
||||
# 测试函数
|
||||
def test_json_processor():
|
||||
"""测试JSON处理器功能"""
|
||||
print("正在测试JSON处理器...")
|
||||
|
||||
processor = JSONProcessor()
|
||||
|
||||
# 测试用例
|
||||
test_cases = [
|
||||
# 标准JSON
|
||||
('{"name": "test", "value": 123}', None),
|
||||
|
||||
# 带代码块标记的JSON
|
||||
('```json\n{"name": "test", "value": 123}\n```', None),
|
||||
|
||||
# 不完整的JSON
|
||||
('{"name": "test", "value": 123', None),
|
||||
|
||||
# 单引号JSON
|
||||
("{'name': 'test', 'value': 123}", None),
|
||||
|
||||
# 包含额外文本的JSON
|
||||
('这是一个JSON响应: {"name": "test", "value": 123} 解析完成', None),
|
||||
|
||||
# 数组JSON
|
||||
('[{"name": "test1"}, {"name": "test2"}]', None),
|
||||
]
|
||||
|
||||
success_count = 0
|
||||
|
||||
for i, (test_input, model) in enumerate(test_cases, 1):
|
||||
try:
|
||||
result = processor.parse_json_response(test_input, model)
|
||||
print(f"✅ 测试用例 {i}: 解析成功 - {type(result)}")
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
print(f"❌ 测试用例 {i}: 解析失败 - {e}")
|
||||
|
||||
print(f"\n🎯 JSON处理器测试完成: {success_count}/{len(test_cases)} 通过")
|
||||
|
||||
return success_count == len(test_cases)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_json_processor()
|
||||
312
src/agent_system/model_factory.py
Normal file
312
src/agent_system/model_factory.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""
|
||||
模型工厂模块
|
||||
|
||||
基于配置创建各种Agno模型实例,支持多种LLM提供商
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from agno.models.openai import OpenAILike, OpenAIChat
|
||||
|
||||
# 尝试导入其他可选模型
|
||||
try:
|
||||
from agno.models.ollama import Ollama
|
||||
OLLAMA_AVAILABLE = True
|
||||
except ImportError:
|
||||
OLLAMA_AVAILABLE = False
|
||||
Ollama = None
|
||||
|
||||
from .config_loader import load_llm_config, get_model_config
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""
|
||||
模型工厂类
|
||||
|
||||
负责根据配置创建不同类型的Agno模型实例
|
||||
"""
|
||||
|
||||
# 支持的模型类映射
|
||||
MODEL_CLASSES = {
|
||||
"OpenAILike": OpenAILike,
|
||||
"OpenAIChat": OpenAIChat,
|
||||
}
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
初始化模型工厂
|
||||
|
||||
Args:
|
||||
config: LLM配置,如果为None则自动加载
|
||||
"""
|
||||
self.config = config if config is not None else load_llm_config()
|
||||
|
||||
# 如果Ollama可用,添加到支持列表
|
||||
if OLLAMA_AVAILABLE:
|
||||
self.MODEL_CLASSES["Ollama"] = Ollama
|
||||
|
||||
def create_model(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
**override_params
|
||||
) -> Union[OpenAILike, OpenAIChat]:
|
||||
"""
|
||||
创建模型实例
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 (如 'aliyun', 'deepseek', 'openai')
|
||||
model_name: 模型名称 (如 'qwen-max', 'deepseek-v3')
|
||||
**override_params: 覆盖配置的参数
|
||||
|
||||
Returns:
|
||||
创建的模型实例
|
||||
|
||||
Raises:
|
||||
ValueError: 配置错误或不支持的模型类型
|
||||
ImportError: 缺少必需的依赖
|
||||
"""
|
||||
|
||||
# 获取模型配置
|
||||
model_config = get_model_config(self.config, provider, model_name)
|
||||
|
||||
# 获取模型类名
|
||||
model_class_name = model_config.get('class')
|
||||
if not model_class_name:
|
||||
raise ValueError(f"模型配置中缺少 'class' 字段: {provider}.{model_name}")
|
||||
|
||||
# 检查模型类是否支持
|
||||
if model_class_name not in self.MODEL_CLASSES:
|
||||
available_classes = list(self.MODEL_CLASSES.keys())
|
||||
raise ValueError(f"不支持的模型类 '{model_class_name}',支持的类型: {available_classes}")
|
||||
|
||||
# 特殊处理Ollama
|
||||
if model_class_name == "Ollama" and not OLLAMA_AVAILABLE:
|
||||
raise ImportError("Ollama模型需要安装ollama包: pip install ollama")
|
||||
|
||||
# 获取模型类
|
||||
model_class = self.MODEL_CLASSES[model_class_name]
|
||||
|
||||
# 准备初始化参数
|
||||
init_params = self._prepare_init_params(model_config, model_class_name, override_params)
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
model = model_class(**init_params)
|
||||
return model
|
||||
except Exception as e:
|
||||
raise ValueError(f"创建模型实例失败 ({provider}.{model_name}): {e}")
|
||||
|
||||
def _prepare_init_params(
|
||||
self,
|
||||
model_config: Dict[str, Any],
|
||||
model_class_name: str,
|
||||
override_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
准备模型初始化参数
|
||||
|
||||
Args:
|
||||
model_config: 模型配置
|
||||
model_class_name: 模型类名
|
||||
override_params: 覆盖参数
|
||||
|
||||
Returns:
|
||||
初始化参数字典
|
||||
"""
|
||||
|
||||
# 基础参数从params字段获取
|
||||
init_params = model_config.get('params', {}).copy()
|
||||
|
||||
# 添加API相关参数
|
||||
if 'api_key' in model_config:
|
||||
init_params['api_key'] = model_config['api_key']
|
||||
|
||||
if 'base_url' in model_config:
|
||||
init_params['base_url'] = model_config['base_url']
|
||||
|
||||
# 根据模型类型调整参数
|
||||
if model_class_name == "OpenAILike":
|
||||
# OpenAILike模型特殊处理
|
||||
self._adjust_openai_like_params(init_params)
|
||||
elif model_class_name == "OpenAIChat":
|
||||
# OpenAIChat模型特殊处理
|
||||
self._adjust_openai_chat_params(init_params)
|
||||
elif model_class_name == "Ollama":
|
||||
# Ollama模型特殊处理
|
||||
self._adjust_ollama_params(init_params)
|
||||
|
||||
# 应用覆盖参数
|
||||
init_params.update(override_params)
|
||||
|
||||
return init_params
|
||||
|
||||
def _adjust_openai_like_params(self, params: Dict[str, Any]) -> None:
|
||||
"""调整OpenAILike模型参数"""
|
||||
# OpenAILike使用id参数,不需要model参数
|
||||
# id参数已经在params中,无需额外处理
|
||||
pass
|
||||
|
||||
def _adjust_openai_chat_params(self, params: Dict[str, Any]) -> None:
|
||||
"""调整OpenAIChat模型参数"""
|
||||
# OpenAIChat通常需要id参数作为model参数
|
||||
if 'id' in params and 'model' not in params:
|
||||
params['model'] = params['id']
|
||||
|
||||
def _adjust_ollama_params(self, params: Dict[str, Any]) -> None:
|
||||
"""调整Ollama模型参数"""
|
||||
# Ollama模型通常不需要api_key和base_url
|
||||
params.pop('api_key', None)
|
||||
params.pop('base_url', None)
|
||||
|
||||
# 使用id作为model参数
|
||||
if 'id' in params and 'model' not in params:
|
||||
params['model'] = params['id']
|
||||
|
||||
def list_available_models(self) -> Dict[str, list]:
|
||||
"""
|
||||
列出所有可用的模型
|
||||
|
||||
Returns:
|
||||
按提供商分组的模型列表
|
||||
"""
|
||||
available_models = {}
|
||||
|
||||
for provider, provider_config in self.config.items():
|
||||
if 'models' in provider_config:
|
||||
models = list(provider_config['models'].keys())
|
||||
available_models[provider] = models
|
||||
|
||||
return available_models
|
||||
|
||||
def validate_model_exists(self, provider: str, model_name: str) -> bool:
|
||||
"""
|
||||
验证模型是否存在
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
模型是否存在
|
||||
"""
|
||||
try:
|
||||
get_model_config(self.config, provider, model_name)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_agno_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
**override_params
|
||||
) -> Union[OpenAILike, OpenAIChat]:
|
||||
"""
|
||||
便捷函数:创建Agno模型实例
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model_name: 模型名称
|
||||
config: 可选的配置字典
|
||||
**override_params: 覆盖参数
|
||||
|
||||
Returns:
|
||||
创建的模型实例
|
||||
"""
|
||||
factory = ModelFactory(config)
|
||||
return factory.create_model(provider, model_name, **override_params)
|
||||
|
||||
|
||||
def list_available_models(config: Optional[Dict[str, Any]] = None) -> Dict[str, list]:
|
||||
"""
|
||||
便捷函数:列出可用模型
|
||||
|
||||
Args:
|
||||
config: 可选的配置字典
|
||||
|
||||
Returns:
|
||||
按提供商分组的模型列表
|
||||
"""
|
||||
factory = ModelFactory(config)
|
||||
return factory.list_available_models()
|
||||
|
||||
|
||||
# 测试函数
|
||||
def test_model_creation():
|
||||
"""测试模型创建功能"""
|
||||
print("正在测试模型创建...")
|
||||
|
||||
try:
|
||||
# 创建模型工厂
|
||||
factory = ModelFactory()
|
||||
|
||||
# 列出可用模型
|
||||
available_models = factory.list_available_models()
|
||||
print(f"✅ 发现可用模型:")
|
||||
for provider, models in available_models.items():
|
||||
print(f" {provider}: {models}")
|
||||
|
||||
# 测试阿里云qwen-max模型创建
|
||||
try:
|
||||
qwen_model = factory.create_model('aliyun', 'qwen-max')
|
||||
print(f"✅ qwen-max模型创建成功: {type(qwen_model)}")
|
||||
except Exception as e:
|
||||
print(f"❌ qwen-max模型创建失败: {e}")
|
||||
|
||||
# 测试DeepSeek模型创建
|
||||
try:
|
||||
deepseek_model = factory.create_model('deepseek', 'deepseek-v3')
|
||||
print(f"✅ deepseek-v3模型创建成功: {type(deepseek_model)}")
|
||||
except Exception as e:
|
||||
print(f"❌ deepseek-v3模型创建失败: {e}")
|
||||
|
||||
# 测试参数覆盖
|
||||
try:
|
||||
custom_model = factory.create_model(
|
||||
'aliyun',
|
||||
'qwen-plus',
|
||||
temperature=0.5,
|
||||
max_tokens=1000
|
||||
)
|
||||
print(f"✅ 参数覆盖测试成功: {type(custom_model)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 参数覆盖测试失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 模型创建测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_convenience_functions():
|
||||
"""测试便捷函数"""
|
||||
print("\n正在测试便捷函数...")
|
||||
|
||||
try:
|
||||
# 测试便捷创建函数
|
||||
model = create_agno_model('aliyun', 'qwen-turbo')
|
||||
print(f"✅ 便捷函数创建模型成功: {type(model)}")
|
||||
|
||||
# 测试列出模型函数
|
||||
models = list_available_models()
|
||||
total_models = sum(len(model_list) for model_list in models.values())
|
||||
print(f"✅ 便捷函数列出模型成功,总计 {total_models} 个模型")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 便捷函数测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success1 = test_model_creation()
|
||||
success2 = test_convenience_functions()
|
||||
|
||||
if success1 and success2:
|
||||
print("\n🎉 所有模型工厂测试通过!")
|
||||
else:
|
||||
print("\n💥 部分测试失败,请检查配置")
|
||||
434
src/agent_system/subagent.py
Normal file
434
src/agent_system/subagent.py
Normal file
@ -0,0 +1,434 @@
|
||||
"""
|
||||
SubAgent核心类
|
||||
|
||||
基于Agno框架的智能代理实现,提供动态prompt构建、JSON解析、模型管理等功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
from agno.agent import Agent, RunResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config_loader import load_llm_config
|
||||
from .model_factory import ModelFactory
|
||||
from .json_processor import JSONProcessor, JSONParseError
|
||||
|
||||
# 设置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubAgentError(Exception):
|
||||
"""SubAgent相关错误的基类"""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(SubAgentError):
|
||||
"""配置相关错误"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelError(SubAgentError):
|
||||
"""模型相关错误"""
|
||||
pass
|
||||
|
||||
|
||||
class SubAgent:
|
||||
"""
|
||||
SubAgent核心类
|
||||
|
||||
基于Agno框架构建的智能代理,支持:
|
||||
- 动态prompt模板构建
|
||||
- 多LLM提供商支持
|
||||
- JSON结构化输出
|
||||
- 零容错解析
|
||||
- 灵活的配置管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
instructions: Optional[List[str]] = None,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None,
|
||||
response_model: Optional[Type[BaseModel]] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
**agent_kwargs
|
||||
):
|
||||
"""
|
||||
初始化SubAgent
|
||||
|
||||
Args:
|
||||
provider: LLM提供商名称 (如 'aliyun', 'deepseek', 'openai')
|
||||
model_name: 模型名称 (如 'qwen-max', 'deepseek-v3')
|
||||
instructions: 指令列表
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
prompt_template: 动态prompt模板
|
||||
response_model: Pydantic响应模型类
|
||||
config: 自定义LLM配置
|
||||
**agent_kwargs: 传递给Agno Agent的额外参数
|
||||
|
||||
Raises:
|
||||
ConfigurationError: 配置错误
|
||||
ModelError: 模型创建错误
|
||||
"""
|
||||
|
||||
# 基础属性设置
|
||||
self.provider = provider
|
||||
self.model_name = model_name
|
||||
self.name = name or f"{provider}_{model_name}_agent"
|
||||
self.description = description or f"基于{provider}的{model_name}模型的智能代理"
|
||||
self.instructions = instructions or []
|
||||
self.prompt_template = prompt_template
|
||||
self.response_model = response_model
|
||||
|
||||
# 初始化组件
|
||||
try:
|
||||
# 加载配置
|
||||
self.config = config if config is not None else load_llm_config()
|
||||
|
||||
# 创建模型工厂和JSON处理器
|
||||
self.model_factory = ModelFactory(self.config)
|
||||
self.json_processor = JSONProcessor()
|
||||
|
||||
# 创建Agno模型
|
||||
self.model = self._create_model()
|
||||
|
||||
# 创建Agno Agent
|
||||
self.agent = self._create_agent(**agent_kwargs)
|
||||
|
||||
logger.info(f"SubAgent {self.name} 初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SubAgent初始化失败: {e}")
|
||||
raise ConfigurationError(f"SubAgent初始化失败: {e}")
|
||||
|
||||
def _create_model(self):
|
||||
"""创建Agno模型实例"""
|
||||
try:
|
||||
return self.model_factory.create_model(self.provider, self.model_name)
|
||||
except Exception as e:
|
||||
raise ModelError(f"模型创建失败: {e}")
|
||||
|
||||
def _create_agent(self, **agent_kwargs):
|
||||
"""创建Agno Agent实例"""
|
||||
try:
|
||||
agent_params = {
|
||||
"model": self.model,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"instructions": self.instructions,
|
||||
"markdown": agent_kwargs.pop("markdown", True),
|
||||
"debug_mode": agent_kwargs.pop("debug_mode", False),
|
||||
**agent_kwargs
|
||||
}
|
||||
|
||||
return Agent(**agent_params)
|
||||
|
||||
except Exception as e:
|
||||
raise ConfigurationError(f"Agent创建失败: {e}")
|
||||
|
||||
def build_prompt(self, template_vars: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
构建动态prompt
|
||||
|
||||
Args:
|
||||
template_vars: 模板变量字典,用于替换模板中的占位符
|
||||
|
||||
Returns:
|
||||
构建完成的prompt字符串
|
||||
|
||||
Raises:
|
||||
SubAgentError: prompt构建失败
|
||||
"""
|
||||
|
||||
if not self.prompt_template:
|
||||
raise SubAgentError("未设置prompt模板,无法构建动态prompt")
|
||||
|
||||
if not template_vars:
|
||||
template_vars = {}
|
||||
|
||||
try:
|
||||
# 使用str.format()进行模板替换
|
||||
prompt = self.prompt_template.format(**template_vars)
|
||||
|
||||
# 如果需要JSON输出,添加JSON格式要求
|
||||
if self.response_model:
|
||||
json_instruction = self._build_json_instruction()
|
||||
prompt = f"{prompt}\n\n{json_instruction}"
|
||||
|
||||
return prompt
|
||||
|
||||
except KeyError as e:
|
||||
raise SubAgentError(f"模板变量缺失: {e}")
|
||||
except Exception as e:
|
||||
raise SubAgentError(f"prompt构建失败: {e}")
|
||||
|
||||
def _build_json_instruction(self) -> str:
|
||||
"""构建JSON格式指令"""
|
||||
json_instruction = """
|
||||
请严格按照以下JSON格式返回结果,不要添加任何额外的文字说明:
|
||||
|
||||
"""
|
||||
|
||||
if self.response_model:
|
||||
# 生成JSON schema示例
|
||||
try:
|
||||
schema = self.response_model.model_json_schema()
|
||||
json_instruction += f"JSON Schema:\n{schema}\n\n"
|
||||
|
||||
# 生成示例
|
||||
example = self._generate_model_example()
|
||||
if example:
|
||||
json_instruction += f"示例格式:\n{example}"
|
||||
|
||||
except Exception:
|
||||
# 如果schema生成失败,使用通用指令
|
||||
json_instruction += "请返回有效的JSON格式数据。"
|
||||
|
||||
return json_instruction
|
||||
|
||||
def _generate_model_example(self) -> Optional[str]:
|
||||
"""生成Pydantic模型的示例JSON"""
|
||||
try:
|
||||
# 创建一个示例实例
|
||||
field_values = {}
|
||||
for field_name, field_info in self.response_model.model_fields.items():
|
||||
# 根据字段类型生成示例值
|
||||
field_type = field_info.annotation
|
||||
field_values[field_name] = self._generate_example_value(field_type)
|
||||
|
||||
example_instance = self.response_model(**field_values)
|
||||
return example_instance.model_dump_json(indent=2)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _generate_example_value(self, field_type: Type) -> Any:
|
||||
"""根据字段类型生成示例值"""
|
||||
if field_type == str:
|
||||
return "示例文本"
|
||||
elif field_type == int:
|
||||
return 0
|
||||
elif field_type == float:
|
||||
return 0.0
|
||||
elif field_type == bool:
|
||||
return True
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
origin = field_type.__origin__
|
||||
if origin == list:
|
||||
return []
|
||||
elif origin == dict:
|
||||
return {}
|
||||
|
||||
return "示例值"
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
template_vars: Optional[Dict[str, Any]] = None,
|
||||
**run_kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
执行Agent推理
|
||||
|
||||
Args:
|
||||
prompt: 直接prompt文本(如果提供,将忽略template_vars)
|
||||
template_vars: 模板变量(用于动态构建prompt)
|
||||
**run_kwargs: 传递给Agent.run的额外参数
|
||||
|
||||
Returns:
|
||||
如果指定了response_model,返回解析后的Pydantic模型实例
|
||||
否则返回原始响应内容
|
||||
|
||||
Raises:
|
||||
SubAgentError: 执行失败
|
||||
JSONParseError: JSON解析失败
|
||||
"""
|
||||
|
||||
try:
|
||||
# 构建最终prompt
|
||||
if prompt is not None:
|
||||
final_prompt = prompt
|
||||
else:
|
||||
final_prompt = self.build_prompt(template_vars)
|
||||
|
||||
logger.debug(f"Agent {self.name} 执行推理,prompt长度: {len(final_prompt)}")
|
||||
|
||||
# 执行Agent推理
|
||||
response: RunResponse = self.agent.run(final_prompt, **run_kwargs)
|
||||
|
||||
# 获取响应内容
|
||||
content = response.content
|
||||
|
||||
if self.response_model:
|
||||
# 如果指定了响应模型,进行JSON解析
|
||||
return self._parse_structured_response(content)
|
||||
else:
|
||||
# 返回原始内容
|
||||
return content
|
||||
|
||||
except JSONParseError:
|
||||
# JSON解析错误直接抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
raise SubAgentError(f"Agent执行失败: {e}")
|
||||
|
||||
def _parse_structured_response(self, content: str) -> BaseModel:
|
||||
"""解析结构化响应"""
|
||||
try:
|
||||
return self.json_processor.parse_json_response(content, self.response_model)
|
||||
except JSONParseError as e:
|
||||
logger.error(f"JSON解析失败: {e}")
|
||||
logger.debug(f"响应内容: {content}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取模型信息
|
||||
|
||||
Returns:
|
||||
包含模型信息的字典
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"provider": self.provider,
|
||||
"model_name": self.model_name,
|
||||
"model_type": type(self.model).__name__,
|
||||
"has_prompt_template": self.prompt_template is not None,
|
||||
"has_response_model": self.response_model is not None,
|
||||
"instructions_count": len(self.instructions),
|
||||
}
|
||||
|
||||
def update_instructions(self, instructions: List[str]) -> None:
|
||||
"""
|
||||
更新指令列表
|
||||
|
||||
Args:
|
||||
instructions: 新的指令列表
|
||||
"""
|
||||
self.instructions = instructions
|
||||
# 注意:这里不会更新已创建的Agent实例,如需更新需要重新创建Agent
|
||||
logger.info(f"Agent {self.name} 指令已更新")
|
||||
|
||||
def update_prompt_template(self, template: str) -> None:
|
||||
"""
|
||||
更新prompt模板
|
||||
|
||||
Args:
|
||||
template: 新的prompt模板
|
||||
"""
|
||||
self.prompt_template = template
|
||||
logger.info(f"Agent {self.name} prompt模板已更新")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SubAgent({self.name}, {self.provider}.{self.model_name})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_json_agent(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
name: str,
|
||||
prompt_template: str,
|
||||
response_model: Union[str, Type[BaseModel]],
|
||||
instructions: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SubAgent:
|
||||
"""
|
||||
便捷函数:创建支持JSON输出的SubAgent
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model_name: 模型名称
|
||||
name: Agent名称
|
||||
prompt_template: prompt模板
|
||||
response_model: 响应模型类或模块路径字符串
|
||||
instructions: 指令列表
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
配置好的SubAgent实例
|
||||
"""
|
||||
|
||||
# 如果response_model是字符串,尝试导入
|
||||
if isinstance(response_model, str):
|
||||
try:
|
||||
# 解析模块路径和类名
|
||||
if '.' in response_model:
|
||||
module_path, class_name = response_model.rsplit('.', 1)
|
||||
else:
|
||||
# 如果没有模块路径,假设在当前模块
|
||||
module_path = '__main__'
|
||||
class_name = response_model
|
||||
|
||||
# 动态导入
|
||||
import importlib
|
||||
module = importlib.import_module(module_path)
|
||||
response_model = getattr(module, class_name)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法导入响应模型 {response_model}: {e}")
|
||||
|
||||
return SubAgent(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
name=name,
|
||||
prompt_template=prompt_template,
|
||||
response_model=response_model,
|
||||
instructions=instructions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
# 测试函数
|
||||
def test_subagent():
|
||||
"""测试SubAgent功能"""
|
||||
print("正在测试SubAgent...")
|
||||
|
||||
try:
|
||||
# 创建基础SubAgent
|
||||
agent = SubAgent(
|
||||
provider="aliyun",
|
||||
model_name="qwen-turbo",
|
||||
name="test_agent",
|
||||
instructions=["你是一个测试助手", "请简洁回答问题"]
|
||||
)
|
||||
|
||||
print(f"✅ SubAgent创建成功: {agent}")
|
||||
print(f" 模型信息: {agent.get_model_info()}")
|
||||
|
||||
# 测试基础对话
|
||||
try:
|
||||
response = agent.run("请简单介绍一下Python语言")
|
||||
print(f"✅ 基础对话测试成功,响应长度: {len(str(response))}字符")
|
||||
except Exception as e:
|
||||
print(f"❌ 基础对话测试失败: {e}")
|
||||
|
||||
# 测试动态prompt构建
|
||||
agent.update_prompt_template("请回答关于{topic}的问题:{question}")
|
||||
|
||||
try:
|
||||
prompt = agent.build_prompt({
|
||||
"topic": "编程",
|
||||
"question": "什么是函数?"
|
||||
})
|
||||
print(f"✅ 动态prompt构建成功,长度: {len(prompt)}字符")
|
||||
except Exception as e:
|
||||
print(f"❌ 动态prompt构建失败: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ SubAgent测试失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_subagent()
|
||||
1
src/config/.env
Normal file
1
src/config/.env
Normal file
@ -0,0 +1 @@
|
||||
DASHSCOPE_API_KEY=sk-5c7f9dc33e0a43738d415a0432452b93
|
||||
93
src/config/llm_config.yaml
Normal file
93
src/config/llm_config.yaml
Normal file
@ -0,0 +1,93 @@
|
||||
# LLM配置文件
|
||||
# 定义所有Agent可用的LLM模型配置
|
||||
|
||||
|
||||
# 阿里云通义千问配置
|
||||
aliyun:
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key: ${DASHSCOPE_API_KEY} # 从环境变量读取
|
||||
models:
|
||||
qwen-max:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-max"
|
||||
temperature: 0.7
|
||||
max_tokens: 3000
|
||||
qwen-plus:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-plus"
|
||||
temperature: 0.7
|
||||
max_tokens: 2000
|
||||
qwen-turbo:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "qwen-turbo"
|
||||
temperature: 0.7
|
||||
max_tokens: 1500
|
||||
|
||||
# DeepSeek配置
|
||||
deepseek:
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key: ${DASHSCOPE_API_KEY}
|
||||
models:
|
||||
deepseek-v3:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "deepseek-v3"
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
deepseek-r1:
|
||||
class: "OpenAILike"
|
||||
params:
|
||||
id: "deepseek-r1"
|
||||
temperature: 0.5
|
||||
max_tokens: 8000
|
||||
# DeepSeek R1特有的推理模式
|
||||
reasoning_effort: "high"
|
||||
|
||||
# OpenAI配置(备用)
|
||||
openai:
|
||||
base_url: "https://api.openai.com/v1"
|
||||
api_key: ${OPENAI_API_KEY}
|
||||
models:
|
||||
gpt-4:
|
||||
class: "OpenAIChat"
|
||||
params:
|
||||
id: "gpt-4"
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
gpt-3.5-turbo:
|
||||
class: "OpenAIChat"
|
||||
params:
|
||||
id: "gpt-3.5-turbo"
|
||||
temperature: 0.7
|
||||
max_tokens: 2000
|
||||
|
||||
# 本地Ollama配置(开发测试用)
|
||||
ollama:
|
||||
host: "127.0.0.1"
|
||||
port: 11434
|
||||
models:
|
||||
qwen2.5:
|
||||
class: "Ollama"
|
||||
params:
|
||||
id: "qwen2.5:latest"
|
||||
temperature: 0.7
|
||||
max_tokens: 2000
|
||||
llama3:
|
||||
class: "Ollama"
|
||||
params:
|
||||
id: "llama3:latest"
|
||||
temperature: 0.7
|
||||
max_tokens: 2000
|
||||
|
||||
vllm:
|
||||
base_url: "http://100.82.33.121:11001/v1"
|
||||
models:
|
||||
gpt-oss:
|
||||
class: "OpenAIChat"
|
||||
params:
|
||||
id: "gpt-oss-20b"
|
||||
temperature: 0.7
|
||||
max_tokens: 2000
|
||||
422
src/crawler.py
422
src/crawler.py
@ -8,19 +8,20 @@ import requests
|
||||
import xml.etree.ElementTree as ET
|
||||
import logging
|
||||
import time
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.utils.csv_utils import write_dict_to_csv
|
||||
from src.utils.csv_utils import write_dict_to_csv, read_csv_to_dict
|
||||
|
||||
|
||||
class PaperCrawler:
|
||||
"""论文爬取类 - 用于从ArXiv和MedRxiv爬取MIMIC 4相关论文"""
|
||||
|
||||
def __init__(self, websites: List[str], parallel: int = 20,
|
||||
arxiv_max_results: int = 200, medrxiv_days_range: int = 730):
|
||||
arxiv_max_results: int = 2000, medrxiv_days_range: int = 1825):
|
||||
"""初始化爬虫配置
|
||||
|
||||
Args:
|
||||
@ -34,12 +35,11 @@ class PaperCrawler:
|
||||
self.arxiv_max_results = arxiv_max_results # ArXiv最大爬取数量
|
||||
self.medrxiv_days_range = medrxiv_days_range # MedRxiv爬取时间范围(天)
|
||||
|
||||
# MIMIC关键词配置
|
||||
# MIMIC-IV精确关键词配置 - 只包含明确引用MIMIC-IV数据集的论文
|
||||
self.mimic_keywords = [
|
||||
"MIMIC-IV", "MIMIC 4", "MIMIC IV",
|
||||
"Medical Information Mart",
|
||||
"intensive care", "ICU database",
|
||||
"critical care database", "electronic health record"
|
||||
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
|
||||
"Medical Information Mart Intensive Care IV",
|
||||
"MIMIC-IV dataset", "MIMIC-IV database", "MIMIC"
|
||||
]
|
||||
|
||||
# HTTP会话配置
|
||||
@ -104,8 +104,8 @@ class PaperCrawler:
|
||||
papers = []
|
||||
|
||||
try:
|
||||
# 构建关键词搜索查询
|
||||
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords[:3]])
|
||||
# 构建MIMIC-IV精确关键词搜索查询 - 标题和摘要都使用所有关键词
|
||||
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords])
|
||||
abstract_query = " OR ".join([f'abs:"{kw}"' for kw in self.mimic_keywords])
|
||||
search_query = f"({keywords_query}) OR ({abstract_query})"
|
||||
|
||||
@ -239,7 +239,7 @@ class PaperCrawler:
|
||||
return {
|
||||
'title': paper_data.get('title', '').strip(),
|
||||
'authors': ', '.join(paper_data.get('authors', [])),
|
||||
'abstract': paper_data.get('summary', '').strip(),
|
||||
'abstract': paper_data.get('summary', '').strip().replace('\n', ' ').replace('\r', ' '),
|
||||
'doi': paper_data.get('doi', ''),
|
||||
'published_date': paper_data.get('published', '').split('T')[0] if 'T' in paper_data.get('published', '') else paper_data.get('published', ''),
|
||||
'url': paper_data.get('link', ''),
|
||||
@ -250,7 +250,7 @@ class PaperCrawler:
|
||||
return {
|
||||
'title': paper_data.get('title', '').strip(),
|
||||
'authors': paper_data.get('authors', ''),
|
||||
'abstract': paper_data.get('abstract', '').strip(),
|
||||
'abstract': paper_data.get('abstract', '').strip().replace('\n', ' ').replace('\r', ' '),
|
||||
'doi': paper_data.get('doi', ''),
|
||||
'published_date': paper_data.get('date', ''),
|
||||
'url': f"https://doi.org/{paper_data.get('doi', '')}" if paper_data.get('doi') else '',
|
||||
@ -419,3 +419,403 @@ class PaperCrawler:
|
||||
except Exception as e:
|
||||
logging.error(f"保存CSV文件时出错: {e}")
|
||||
raise
|
||||
|
||||
def download_pdfs_from_csv(self, csv_file_path: str) -> Dict[str, int]:
|
||||
"""从CSV文件下载论文PDF
|
||||
|
||||
Args:
|
||||
csv_file_path (str): 包含论文信息的CSV文件路径
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 下载统计信息 {'success': 成功数, 'failed': 失败数, 'total': 总数}
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: CSV文件不存在
|
||||
ValueError: CSV文件格式错误
|
||||
"""
|
||||
try:
|
||||
# 读取CSV文件中的论文信息
|
||||
papers_data = self._read_papers_csv(csv_file_path)
|
||||
if not papers_data:
|
||||
logging.warning("CSV文件中没有论文数据")
|
||||
return {'success': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
# 准备PDF存储目录
|
||||
pdf_dir = self._prepare_pdf_storage()
|
||||
|
||||
# 初始化统计
|
||||
total_papers = len(papers_data)
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
failed_papers = []
|
||||
|
||||
logging.info(f"开始并发下载 {total_papers} 篇论文的PDF文件,并发数: {self.parallel}")
|
||||
|
||||
# 使用并发执行器下载PDF
|
||||
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
|
||||
# 提交所有下载任务
|
||||
future_to_paper = {
|
||||
executor.submit(self._download_single_pdf, paper_data, pdf_dir): paper_data
|
||||
for paper_data in papers_data
|
||||
}
|
||||
|
||||
# 处理完成的任务,实时显示进度
|
||||
completed_count = 0
|
||||
for future in as_completed(future_to_paper):
|
||||
paper_data = future_to_paper[future]
|
||||
title = paper_data.get('title', 'Unknown')[:50] + '...' if len(paper_data.get('title', '')) > 50 else paper_data.get('title', 'Unknown')
|
||||
|
||||
try:
|
||||
success = future.result()
|
||||
completed_count += 1
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
status = "✓"
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_papers.append({
|
||||
'title': paper_data.get('title', ''),
|
||||
'source': paper_data.get('source', ''),
|
||||
'url': paper_data.get('url', ''),
|
||||
'doi': paper_data.get('doi', '')
|
||||
})
|
||||
status = "✗"
|
||||
|
||||
# 显示进度
|
||||
progress = (completed_count / total_papers) * 100
|
||||
print(f"\r[{completed_count:3d}/{total_papers}] {progress:5.1f}% {status} {title}", end='', flush=True)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
completed_count += 1
|
||||
failed_papers.append({
|
||||
'title': paper_data.get('title', ''),
|
||||
'source': paper_data.get('source', ''),
|
||||
'error': str(e)
|
||||
})
|
||||
progress = (completed_count / total_papers) * 100
|
||||
print(f"\r[{completed_count:3d}/{total_papers}] {progress:5.1f}% ✗ {title} (Error: {str(e)[:30]})", end='', flush=True)
|
||||
|
||||
print() # 换行
|
||||
|
||||
# 记录失败详情
|
||||
if failed_papers:
|
||||
logging.warning(f"以下 {len(failed_papers)} 篇论文下载失败:")
|
||||
for paper in failed_papers:
|
||||
logging.warning(f" - {paper.get('title', 'Unknown')} [{paper.get('source', 'unknown')}]")
|
||||
if 'error' in paper:
|
||||
logging.warning(f" 错误: {paper['error']}")
|
||||
|
||||
# 生成下载报告
|
||||
stats = {
|
||||
'success': success_count,
|
||||
'failed': failed_count,
|
||||
'total': total_papers
|
||||
}
|
||||
|
||||
logging.info(f"PDF下载完成! 成功: {success_count}/{total_papers} ({success_count/total_papers*100:.1f}%)")
|
||||
if failed_count > 0:
|
||||
logging.warning(f"失败: {failed_count}/{total_papers} ({failed_count/total_papers*100:.1f}%)")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"下载PDF文件时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_pdf_storage(self) -> Path:
|
||||
"""准备PDF存储目录
|
||||
|
||||
Returns:
|
||||
Path: PDF存储目录路径
|
||||
"""
|
||||
pdf_dir = Path("dataset") / "pdfs"
|
||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"PDF存储目录已准备: {pdf_dir}")
|
||||
return pdf_dir
|
||||
|
||||
def _read_papers_csv(self, csv_file_path: str) -> List[Dict[str, str]]:
|
||||
"""读取论文CSV文件
|
||||
|
||||
Args:
|
||||
csv_file_path (str): CSV文件路径
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 论文数据列表
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件格式错误
|
||||
"""
|
||||
try:
|
||||
papers_data = read_csv_to_dict(csv_file_path)
|
||||
|
||||
# 验证必要字段
|
||||
required_fields = ['title', 'url', 'source', 'doi']
|
||||
if papers_data:
|
||||
missing_fields = [field for field in required_fields
|
||||
if field not in papers_data[0]]
|
||||
if missing_fields:
|
||||
raise ValueError(f"CSV文件缺少必要字段: {missing_fields}")
|
||||
|
||||
logging.info(f"成功读取CSV文件,共 {len(papers_data)} 篇论文")
|
||||
return papers_data
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"读取CSV文件失败: {e}")
|
||||
raise
|
||||
|
||||
def _get_pdf_url(self, paper_data: Dict[str, str]) -> Optional[str]:
|
||||
"""获取论文PDF下载链接
|
||||
|
||||
Args:
|
||||
paper_data (Dict[str, str]): 论文数据
|
||||
|
||||
Returns:
|
||||
Optional[str]: PDF下载链接,如果无法获取返回None
|
||||
"""
|
||||
try:
|
||||
source = paper_data.get('source', '')
|
||||
if not source:
|
||||
logging.warning("论文缺少source字段")
|
||||
return None
|
||||
source = source.lower()
|
||||
url = paper_data.get('url', '')
|
||||
doi = paper_data.get('doi', '')
|
||||
|
||||
if source == 'arxiv':
|
||||
return self._get_arxiv_pdf_url(url)
|
||||
elif source == 'medrxiv':
|
||||
return self._get_medrxiv_pdf_url(doi, url)
|
||||
else:
|
||||
logging.warning(f"不支持的数据源: {source}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取PDF链接失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_arxiv_pdf_url(self, url: str) -> Optional[str]:
|
||||
"""获取ArXiv论文PDF链接
|
||||
|
||||
Args:
|
||||
url (str): ArXiv论文页面URL
|
||||
|
||||
Returns:
|
||||
Optional[str]: PDF下载链接
|
||||
"""
|
||||
try:
|
||||
if not url:
|
||||
return None
|
||||
|
||||
# 从URL中提取论文ID
|
||||
# 格式: http://arxiv.org/abs/2301.12345 -> 2301.12345
|
||||
if '/abs/' in url:
|
||||
paper_id = url.split('/abs/')[-1]
|
||||
pdf_url = f"http://arxiv.org/pdf/{paper_id}.pdf"
|
||||
logging.debug(f"ArXiv PDF链接: {pdf_url}")
|
||||
return pdf_url
|
||||
else:
|
||||
logging.warning(f"无法解析ArXiv URL: {url}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取ArXiv PDF链接失败: {e}")
|
||||
return None
|
||||
|
||||
def _get_medrxiv_pdf_url(self, doi: str, url: str) -> Optional[str]:
|
||||
"""获取MedRxiv论文PDF链接 - 支持多种URL格式策略
|
||||
|
||||
Args:
|
||||
doi (str): 论文DOI
|
||||
url (str): DOI链接(备用)
|
||||
|
||||
Returns:
|
||||
Optional[str]: PDF下载链接
|
||||
"""
|
||||
try:
|
||||
if not doi:
|
||||
logging.warning("MedRxiv论文缺少DOI")
|
||||
return None
|
||||
|
||||
if not doi.startswith('10.1101/'):
|
||||
logging.warning(f"不支持的MedRxiv DOI格式: {doi}")
|
||||
return None
|
||||
|
||||
# 提取DOI后缀部分
|
||||
paper_part = doi.replace('10.1101/', '')
|
||||
|
||||
# 策略1:尝试简洁版本号格式(优先级最高)
|
||||
# 格式:https://www.medrxiv.org/content/10.1101/yyyy.mm.dd.xxxxxxxvN.full.pdf
|
||||
for version in ['v1', 'v2', 'v3']: # 常见版本号
|
||||
simple_url = f"https://www.medrxiv.org/content/10.1101/{paper_part}{version}.full.pdf"
|
||||
logging.debug(f"尝试MedRxiv简洁格式: {simple_url}")
|
||||
|
||||
# 这里可以添加URL可用性检查,暂时先返回第一个尝试
|
||||
# 实际使用时,下载函数会验证URL的有效性
|
||||
if version == 'v1': # 优先返回v1版本
|
||||
return simple_url
|
||||
|
||||
# 策略2:回退到早期访问格式(复杂路径)
|
||||
# 格式:https://www.medrxiv.org/content/medrxiv/early/yyyy/mm/dd/yyyy.mm.dd.xxxxxxx.full.pdf
|
||||
parts = paper_part.split('.')
|
||||
if len(parts) >= 4:
|
||||
year = parts[0]
|
||||
month = parts[1].zfill(2) # 确保两位数
|
||||
day = parts[2].zfill(2) # 确保两位数
|
||||
|
||||
# 构造复杂格式PDF URL
|
||||
complex_url = f"https://www.medrxiv.org/content/medrxiv/early/{year}/{month}/{day}/{paper_part}.full.pdf"
|
||||
logging.debug(f"回退使用MedRxiv复杂格式: {complex_url}")
|
||||
return complex_url
|
||||
else:
|
||||
logging.warning(f"无法解析MedRxiv DOI日期格式: {doi}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取MedRxiv PDF链接失败: {e}")
|
||||
return None
|
||||
|
||||
def _download_single_pdf(self, paper_data: Dict[str, str], pdf_dir: Path) -> bool:
|
||||
"""下载单个论文PDF
|
||||
|
||||
Args:
|
||||
paper_data (Dict[str, str]): 论文数据
|
||||
pdf_dir (Path): PDF存储目录
|
||||
|
||||
Returns:
|
||||
bool: 下载是否成功
|
||||
"""
|
||||
try:
|
||||
# 获取PDF下载链接
|
||||
pdf_url = self._get_pdf_url(paper_data)
|
||||
if not pdf_url:
|
||||
logging.warning(f"无法获取PDF链接: {paper_data.get('title', 'Unknown')}")
|
||||
return False
|
||||
|
||||
# 生成安全的文件名
|
||||
filename = self._generate_safe_filename(paper_data)
|
||||
file_path = pdf_dir / filename
|
||||
|
||||
# 如果文件已存在且有效,跳过下载
|
||||
if file_path.exists() and self._validate_pdf_file(file_path):
|
||||
logging.info(f"PDF文件已存在且有效,跳过下载: {filename}")
|
||||
return True
|
||||
|
||||
# 下载PDF文件,最多重试3次
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = self._make_request_with_retry(pdf_url, max_retries=1)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 写入文件
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(response.content)
|
||||
|
||||
# 验证PDF完整性
|
||||
if self._validate_pdf_file(file_path):
|
||||
logging.info(f"成功下载PDF: {filename}")
|
||||
return True
|
||||
else:
|
||||
logging.warning(f"PDF文件损坏,删除并重试: {filename}")
|
||||
file_path.unlink(missing_ok=True)
|
||||
|
||||
else:
|
||||
logging.warning(f"PDF下载失败,状态码 {response.status_code}: {pdf_url}")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"PDF下载第{attempt + 1}次尝试失败: {e}")
|
||||
|
||||
# 重试前等待
|
||||
if attempt < 2:
|
||||
time.sleep(2 ** attempt)
|
||||
|
||||
logging.error(f"PDF下载最终失败: {paper_data.get('title', 'Unknown')}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"下载PDF时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def _validate_pdf_file(self, file_path: Path) -> bool:
|
||||
"""验证PDF文件完整性
|
||||
|
||||
Args:
|
||||
file_path (Path): PDF文件路径
|
||||
|
||||
Returns:
|
||||
bool: PDF文件是否有效
|
||||
"""
|
||||
try:
|
||||
if not file_path.exists():
|
||||
return False
|
||||
|
||||
# 检查文件大小
|
||||
if file_path.stat().st_size < 1024: # 至少1KB
|
||||
logging.warning(f"PDF文件太小,可能无效: {file_path.name}")
|
||||
return False
|
||||
|
||||
# 检查PDF文件头和结构
|
||||
with open(file_path, 'rb') as f:
|
||||
# 读取文件头
|
||||
header = f.read(8)
|
||||
if not header.startswith(b'%PDF-'):
|
||||
logging.warning(f"文件不是有效的PDF格式: {file_path.name}")
|
||||
return False
|
||||
|
||||
# 检查文件尾部(读取最后1KB)
|
||||
f.seek(-min(1024, file_path.stat().st_size), 2)
|
||||
trailer = f.read()
|
||||
if b'%%EOF' not in trailer and b'endobj' not in trailer:
|
||||
logging.warning(f"PDF文件可能不完整: {file_path.name}")
|
||||
return False
|
||||
|
||||
logging.debug(f"PDF文件验证通过: {file_path.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"验证PDF文件时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def _generate_safe_filename(self, paper_data: Dict[str, str]) -> str:
|
||||
"""生成安全的PDF文件名
|
||||
|
||||
Args:
|
||||
paper_data (Dict[str, str]): 论文数据
|
||||
|
||||
Returns:
|
||||
str: 安全的文件名
|
||||
"""
|
||||
try:
|
||||
source = paper_data.get('source', 'unknown').lower()
|
||||
title = paper_data.get('title', 'untitled')
|
||||
url = paper_data.get('url', '')
|
||||
doi = paper_data.get('doi', '')
|
||||
|
||||
# 提取paper_id
|
||||
paper_id = 'unknown'
|
||||
if source == 'arxiv' and '/abs/' in url:
|
||||
paper_id = url.split('/abs/')[-1]
|
||||
elif source == 'medrxiv' and doi:
|
||||
paper_id = doi.split('/')[-1] if '/' in doi else doi
|
||||
|
||||
# 清理标题,保留主要信息
|
||||
safe_title = re.sub(r'[^\w\s-]', '', title) # 移除特殊字符
|
||||
safe_title = re.sub(r'\s+', '_', safe_title.strip()) # 空格转下划线
|
||||
safe_title = safe_title.lower()[:50] # 限制长度并转小写
|
||||
|
||||
# 构造文件名: source_paperid_title.pdf
|
||||
filename = f"{source}_{paper_id}_{safe_title}.pdf"
|
||||
|
||||
# 确保文件名长度合理
|
||||
if len(filename) > 255: # 大多数文件系统的限制
|
||||
filename = f"{source}_{paper_id}.pdf"
|
||||
|
||||
return filename
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"生成文件名时发生错误: {e}")
|
||||
# 回退方案
|
||||
timestamp = int(time.time())
|
||||
return f"paper_{timestamp}.pdf"
|
||||
1073
src/extractor.py
Normal file
1073
src/extractor.py
Normal file
File diff suppressed because it is too large
Load Diff
741
src/parse.py
Normal file
741
src/parse.py
Normal file
@ -0,0 +1,741 @@
|
||||
"""PDF解析模块
|
||||
|
||||
该模块提供PDFParser类,用于将PDF文件通过OCR API转换为Markdown格式。
|
||||
支持并发处理、进度显示、错误处理等功能。
|
||||
"""
|
||||
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
import tempfile
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
|
||||
class PDFParser:
|
||||
"""PDF解析类 - 用于将PDF文件转换为Markdown格式并按任务类型筛选
|
||||
|
||||
支持的任务类型:
|
||||
- prediction: 预测任务 (PRED_)
|
||||
- classification: 分类任务 (CLAS_)
|
||||
- time_series: 时间序列分析 (TIME_)
|
||||
- correlation: 关联性分析 (CORR_)
|
||||
"""
|
||||
|
||||
def __init__(self, pdf_dir: str = "dataset/pdfs", parallel: int = 3,
|
||||
markdown_dir: str = "dataset/markdowns"):
|
||||
"""初始化解析器配置
|
||||
|
||||
Args:
|
||||
pdf_dir (str): PDF文件目录,默认dataset/pdfs
|
||||
parallel (int): 并发处理数,默认3(降低并发以避免服务器过载)
|
||||
markdown_dir (str): Markdown输出目录,默认dataset/markdowns
|
||||
"""
|
||||
self.pdf_dir = Path(pdf_dir)
|
||||
self.parallel = parallel
|
||||
self.markdown_dir = Path(markdown_dir)
|
||||
|
||||
# OCR API配置
|
||||
self.ocr_api_url = "http://100.106.4.14:7861/parse"
|
||||
|
||||
# AI模型API配置(用于四类任务识别:prediction/classification/time_series/correlation)
|
||||
self.ai_api_url = "http://100.82.33.121:11001/v1/chat/completions"
|
||||
self.ai_model = "gpt-oss-20b"
|
||||
|
||||
# MIMIC-IV关键词配置(用于内容筛选)
|
||||
self.mimic_keywords = [
|
||||
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
|
||||
"Medical Information Mart Intensive Care IV",
|
||||
"MIMIC-IV dataset", "MIMIC-IV database"
|
||||
]
|
||||
|
||||
# 任务类型到前缀的映射配置
|
||||
self.task_type_prefixes = {
|
||||
"prediction": "PRED_",
|
||||
"classification": "CLAS_",
|
||||
"time_series": "TIME_",
|
||||
"correlation": "CORR_",
|
||||
"none": None # 不符合任何类型,不标记
|
||||
}
|
||||
|
||||
# HTTP会话配置(增加连接池大小和超时时间)
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': 'MedResearcher-PDFParser/1.0'
|
||||
})
|
||||
|
||||
# 配置连接池适配器(增加连接池大小)
|
||||
adapter = HTTPAdapter(
|
||||
pool_connections=10, # 连接池数量
|
||||
pool_maxsize=20, # 最大连接数
|
||||
max_retries=0 # 禁用自动重试,使用自定义重试逻辑
|
||||
)
|
||||
self.session.mount('http://', adapter)
|
||||
self.session.mount('https://', adapter)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def _scan_pdf_files(self) -> List[Path]:
|
||||
"""扫描PDF文件目录,获取所有PDF文件
|
||||
|
||||
Returns:
|
||||
List[Path]: PDF文件路径列表
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: PDF目录不存在
|
||||
"""
|
||||
if not self.pdf_dir.exists():
|
||||
raise FileNotFoundError(f"PDF目录不存在: {self.pdf_dir}")
|
||||
|
||||
pdf_files = []
|
||||
for pdf_file in self.pdf_dir.glob("*.pdf"):
|
||||
if pdf_file.is_file():
|
||||
pdf_files.append(pdf_file)
|
||||
|
||||
logging.info(f"发现 {len(pdf_files)} 个PDF文件待处理")
|
||||
return pdf_files
|
||||
|
||||
def _check_mimic_keywords(self, output_subdir: Path) -> bool:
|
||||
"""检查Markdown文件是否包含MIMIC-IV关键词
|
||||
|
||||
Args:
|
||||
output_subdir (Path): 包含Markdown文件的输出子目录
|
||||
|
||||
Returns:
|
||||
bool: 是否包含MIMIC-IV关键词
|
||||
"""
|
||||
try:
|
||||
# 查找所有.md文件
|
||||
md_files = list(output_subdir.glob("*.md"))
|
||||
if not md_files:
|
||||
logging.warning(f"未找到Markdown文件进行MIMIC关键词检查: {output_subdir}")
|
||||
return False
|
||||
|
||||
# 检查每个Markdown文件的内容
|
||||
for md_file in md_files:
|
||||
try:
|
||||
with open(md_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().lower() # 转换为小写进行不区分大小写匹配
|
||||
|
||||
# 检查是否包含任何MIMIC-IV关键词
|
||||
for keyword in self.mimic_keywords:
|
||||
if keyword.lower() in content:
|
||||
logging.info(f"发现MIMIC-IV关键词 '{keyword}' 在文件 {md_file.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
|
||||
continue
|
||||
|
||||
logging.info(f"未发现MIMIC-IV关键词: {output_subdir.name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"检查MIMIC关键词时发生错误: {output_subdir} - {e}")
|
||||
return False
|
||||
|
||||
def _extract_introduction(self, output_subdir: Path) -> Optional[str]:
|
||||
"""从Markdown文件中提取Introduction部分
|
||||
|
||||
Args:
|
||||
output_subdir (Path): 包含Markdown文件的输出子目录
|
||||
|
||||
Returns:
|
||||
Optional[str]: 提取的Introduction内容,失败时返回None
|
||||
"""
|
||||
try:
|
||||
# 查找所有.md文件
|
||||
md_files = list(output_subdir.glob("*.md"))
|
||||
if not md_files:
|
||||
logging.warning(f"未找到Markdown文件进行Introduction提取: {output_subdir}")
|
||||
return None
|
||||
|
||||
# 通常使用第一个md文件
|
||||
md_file = md_files[0]
|
||||
|
||||
try:
|
||||
with open(md_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 使用正则表达式提取Introduction部分
|
||||
# 匹配各种可能的Introduction标题格式
|
||||
patterns = [
|
||||
r'(?i)#\s*Introduction\s*\n(.*?)(?=\n#|\n\n#|$)',
|
||||
r'(?i)##\s*Introduction\s*\n(.*?)(?=\n##|\n\n##|$)',
|
||||
r'(?i)###\s*Introduction\s*\n(.*?)(?=\n###|\n\n###|$)',
|
||||
r'(?i)\*\*Introduction\*\*\s*\n(.*?)(?=\n\*\*|\n\n\*\*|$)',
|
||||
r'(?i)Introduction\s*\n(.*?)(?=\n[A-Z][a-z]+\s*\n|$)'
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, content, re.DOTALL)
|
||||
if match:
|
||||
introduction = match.group(1).strip()
|
||||
if len(introduction) > 100: # 确保有足够的内容进行分析
|
||||
logging.info(f"成功提取Introduction部分 ({len(introduction)} 字符): {md_file.name}")
|
||||
return introduction
|
||||
|
||||
# 如果没有明确的Introduction标题,尝试提取前几段作为近似的introduction
|
||||
paragraphs = content.split('\n\n')
|
||||
introduction_candidates = []
|
||||
for para in paragraphs[:5]: # 取前5段
|
||||
para = para.strip()
|
||||
if len(para) > 50 and not para.startswith('#'): # 过滤掉标题和过短段落
|
||||
introduction_candidates.append(para)
|
||||
|
||||
if introduction_candidates:
|
||||
introduction = '\n\n'.join(introduction_candidates[:3]) # 最多取前3段
|
||||
if len(introduction) > 200:
|
||||
logging.info(f"提取近似Introduction部分 ({len(introduction)} 字符): {md_file.name}")
|
||||
return introduction
|
||||
|
||||
logging.warning(f"未能提取到有效的Introduction内容: {md_file.name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"提取Introduction时发生错误: {output_subdir} - {e}")
|
||||
return None
|
||||
|
||||
def _analyze_research_task(self, introduction: str) -> str:
|
||||
"""使用AI模型分析论文的研究任务类型
|
||||
|
||||
Args:
|
||||
introduction (str): 论文的Introduction内容
|
||||
|
||||
Returns:
|
||||
str: 任务类型 ('prediction', 'classification', 'time_series', 'correlation', 'none')
|
||||
"""
|
||||
try:
|
||||
# 构造AI分析的提示词
|
||||
system_prompt = """你是一个医学研究专家。请分析给定的论文Introduction部分,判断该研究属于以下哪种任务类型:
|
||||
|
||||
1. prediction - 预测任务:预测未来事件、结局或数值(如死亡率预测、住院时长预测、疾病进展预测)
|
||||
2. classification - 分类任务:将患者或病例分类到不同类别(如疾病诊断分类、风险等级分类、药物反应分类)
|
||||
3. time_series - 时间序列分析:分析随时间变化的医疗数据(如生命体征趋势分析、病情演进分析、纵向队列研究)
|
||||
4. correlation - 关联性分析:研究变量间的关系或关联(如痾病与人口特征关系、药物与副作用关联、风险因素识别)
|
||||
5. none - 不属于以上任何类型
|
||||
|
||||
请以JSON格式回答,包含任务类型和置信度:
|
||||
{\"task_type\": \"prediction\", \"confidence\": 0.85}
|
||||
|
||||
task_type必须是以下选项之一:prediction、classification、time_series、correlation、none
|
||||
confidence为0-1之间的数值,表示判断的置信度。
|
||||
只返回JSON,不要添加其他文字。"""
|
||||
|
||||
user_prompt = f"请分析以下论文Introduction,判断属于哪种任务类型:\n\n{introduction[:2000]}" # 限制长度避免token过多
|
||||
|
||||
# 构造API请求数据
|
||||
api_data = {
|
||||
"model": self.ai_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
],
|
||||
"max_tokens": 50, # 需要返回JSON格式
|
||||
"temperature": 0.1 # 降低随机性
|
||||
}
|
||||
|
||||
# 调用AI API
|
||||
response = self.session.post(
|
||||
self.ai_api_url,
|
||||
json=api_data,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
ai_response = result['choices'][0]['message']['content'].strip()
|
||||
|
||||
try:
|
||||
# 解析JSON响应
|
||||
parsed_response = json.loads(ai_response)
|
||||
task_type = parsed_response.get('task_type', 'none').lower()
|
||||
confidence = parsed_response.get('confidence', 0.0)
|
||||
|
||||
# 验证任务类型是否有效
|
||||
valid_types = ['prediction', 'classification', 'time_series', 'correlation', 'none']
|
||||
if task_type not in valid_types:
|
||||
logging.warning(f"AI返回了无效的任务类型: {task_type},使用默认值 'none'")
|
||||
task_type = "none"
|
||||
confidence = 0.0
|
||||
|
||||
# 只接受高置信度的结果
|
||||
if confidence < 0.7:
|
||||
logging.info(f"AI分析置信度过低 ({confidence:.2f}),归类为 'none'")
|
||||
task_type = "none"
|
||||
|
||||
logging.info(f"AI分析结果: 任务类型={task_type}, 置信度={confidence:.2f}")
|
||||
return task_type
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"解析AI JSON响应失败: {ai_response} - 错误: {e}")
|
||||
return "none"
|
||||
|
||||
else:
|
||||
logging.error(f"AI API调用失败,状态码: {response.status_code}")
|
||||
return "none"
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"AI分析研究任务时发生错误: {e}")
|
||||
return "none"
|
||||
|
||||
def _mark_valid_folder(self, output_subdir: Path, task_type: str) -> bool:
|
||||
"""为通过筛选的文件夹添加任务类型前缀标记
|
||||
|
||||
Args:
|
||||
output_subdir (Path): 需要标记的输出子目录
|
||||
task_type (str): 任务类型 ('prediction', 'classification', 'time_series', 'correlation')
|
||||
|
||||
Returns:
|
||||
bool: 标记是否成功
|
||||
"""
|
||||
try:
|
||||
# 获取任务类型对应的前缀
|
||||
prefix = self.task_type_prefixes.get(task_type)
|
||||
if not prefix:
|
||||
logging.info(f"任务类型 '{task_type}' 不需要标记文件夹")
|
||||
return True # 不需要标记,但认为成功
|
||||
|
||||
# 检查文件夹是否已经有相应的任务类型前缀
|
||||
if output_subdir.name.startswith(prefix):
|
||||
logging.info(f"文件夹已标记为{task_type}任务: {output_subdir.name}")
|
||||
return True
|
||||
|
||||
# 检查是否已经有其他任务类型的前缀
|
||||
for existing_type, existing_prefix in self.task_type_prefixes.items():
|
||||
if existing_prefix and output_subdir.name.startswith(existing_prefix):
|
||||
logging.info(f"文件夹已有{existing_type}任务标记,不需要重新标记: {output_subdir.name}")
|
||||
return True
|
||||
|
||||
# 生成新的文件夹名
|
||||
new_folder_name = prefix + output_subdir.name
|
||||
new_folder_path = output_subdir.parent / new_folder_name
|
||||
|
||||
# 重命名文件夹
|
||||
output_subdir.rename(new_folder_path)
|
||||
logging.info(f"文件夹标记成功: {output_subdir.name} -> {new_folder_name} (任务类型: {task_type})")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"标记文件夹时发生错误: {output_subdir} - {e}")
|
||||
return False
|
||||
|
||||
def _prepare_output_dir(self) -> Path:
|
||||
"""准备Markdown输出目录
|
||||
|
||||
Returns:
|
||||
Path: Markdown输出目录路径
|
||||
"""
|
||||
self.markdown_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Markdown输出目录已准备: {self.markdown_dir}")
|
||||
return self.markdown_dir
|
||||
|
||||
def _call_ocr_api(self, pdf_file: Path) -> Optional[Dict]:
|
||||
"""调用OCR API解析PDF文件
|
||||
|
||||
Args:
|
||||
pdf_file (Path): PDF文件路径
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: API响应数据,失败时返回None
|
||||
"""
|
||||
try:
|
||||
with open(pdf_file, 'rb') as f:
|
||||
files = {
|
||||
'file': (pdf_file.name, f, 'application/pdf')
|
||||
}
|
||||
|
||||
response = self._make_request_with_retry(
|
||||
self.ocr_api_url,
|
||||
files=files,
|
||||
timeout=1800 # 增加到3分钟,匹配服务器处理时间
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data.get('success', False):
|
||||
logging.debug(f"OCR API调用成功: {pdf_file.name}")
|
||||
return response_data
|
||||
else:
|
||||
logging.warning(f"OCR API处理失败: {pdf_file.name} - {response_data.get('message', 'Unknown error')}")
|
||||
return None
|
||||
else:
|
||||
logging.error(f"OCR API请求失败,状态码: {response.status_code} - {pdf_file.name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"调用OCR API时发生错误: {pdf_file.name} - {e}")
|
||||
return None
|
||||
|
||||
def _download_and_extract_zip(self, download_url: str, pdf_file: Path) -> bool:
|
||||
"""从API响应中下载ZIP文件并解压到子文件夹
|
||||
|
||||
Args:
|
||||
download_url (str): 完整的下载URL
|
||||
pdf_file (Path): 原始PDF文件路径(用于生成输出文件夹名)
|
||||
|
||||
Returns:
|
||||
bool: 下载和解压是否成功
|
||||
"""
|
||||
try:
|
||||
# 下载ZIP文件
|
||||
response = self._make_request_with_retry(download_url, timeout=60)
|
||||
|
||||
if response.status_code != 200:
|
||||
logging.error(f"下载ZIP失败,状态码: {response.status_code} - {pdf_file.name}")
|
||||
return False
|
||||
|
||||
# 创建以PDF文件名命名的输出子文件夹
|
||||
output_subdir = self.markdown_dir / pdf_file.stem
|
||||
output_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用临时文件保存ZIP内容
|
||||
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
|
||||
temp_zip.write(response.content)
|
||||
temp_zip_path = temp_zip.name
|
||||
|
||||
try:
|
||||
# 解压ZIP文件到输出子文件夹
|
||||
with zipfile.ZipFile(temp_zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(output_subdir)
|
||||
|
||||
logging.debug(f"ZIP文件解压成功: {pdf_file.name} -> {output_subdir}")
|
||||
|
||||
# 清洗解压后的Markdown文件
|
||||
if not self._clean_markdown_files(output_subdir):
|
||||
logging.warning(f"Markdown文件清洗失败,但解压成功: {pdf_file.name}")
|
||||
|
||||
return True
|
||||
|
||||
finally:
|
||||
# 清理临时ZIP文件
|
||||
os.unlink(temp_zip_path)
|
||||
|
||||
except zipfile.BadZipFile as e:
|
||||
logging.error(f"ZIP文件损坏: {pdf_file.name} - {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"下载或解压ZIP时发生错误: {pdf_file.name} - {e}")
|
||||
return False
|
||||
|
||||
def _clean_markdown_files(self, output_subdir: Path) -> bool:
|
||||
"""清洗输出目录中的Markdown文件,去除数字编号和空行
|
||||
|
||||
Args:
|
||||
output_subdir (Path): 包含Markdown文件的输出子目录
|
||||
|
||||
Returns:
|
||||
bool: 清洗是否成功
|
||||
"""
|
||||
try:
|
||||
# 查找所有.md文件
|
||||
md_files = list(output_subdir.glob("*.md"))
|
||||
if not md_files:
|
||||
logging.debug(f"未找到Markdown文件进行清洗: {output_subdir}")
|
||||
return True
|
||||
|
||||
for md_file in md_files:
|
||||
try:
|
||||
# 读取原文件内容
|
||||
with open(md_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 清洗每一行
|
||||
cleaned_lines = []
|
||||
for line in lines:
|
||||
# 去除行尾换行符
|
||||
line_content = line.rstrip('\n\r')
|
||||
|
||||
# 跳过纯数字行(如 "2", "30")
|
||||
if re.match(r'^\d+$', line_content):
|
||||
continue
|
||||
|
||||
# 跳过数字+空格行(如 "30 ")
|
||||
if re.match(r'^\d+\s*$', line_content):
|
||||
continue
|
||||
|
||||
# 去除行首的数字+空格模式(如 "1 Title:" -> "Title:")
|
||||
cleaned_line = re.sub(r'^\d+\s+', '', line_content)
|
||||
|
||||
# 如果清洗后行不为空,则保留
|
||||
if cleaned_line.strip():
|
||||
cleaned_lines.append(cleaned_line + '\n')
|
||||
else:
|
||||
# 保留空行以维护文档结构
|
||||
cleaned_lines.append('\n')
|
||||
|
||||
# 写回清洗后的内容
|
||||
with open(md_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(cleaned_lines)
|
||||
|
||||
logging.debug(f"Markdown文件清洗完成: {md_file.name}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"清洗Markdown文件时发生错误: {md_file.name} - {e}")
|
||||
return False
|
||||
|
||||
logging.info(f"成功清洗 {len(md_files)} 个Markdown文件: {output_subdir}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"清洗Markdown文件时发生错误: {output_subdir} - {e}")
|
||||
return False
|
||||
|
||||
def _process_single_pdf(self, pdf_file: Path) -> bool:
|
||||
"""处理单个PDF文件的完整流程
|
||||
|
||||
Args:
|
||||
pdf_file (Path): PDF文件路径
|
||||
|
||||
Returns:
|
||||
bool: 处理是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查PDF文件是否存在且有效
|
||||
if not pdf_file.exists() or pdf_file.stat().st_size == 0:
|
||||
logging.warning(f"PDF文件不存在或为空: {pdf_file}")
|
||||
return False
|
||||
|
||||
# 检查是否已存在对应的输出子文件夹
|
||||
output_subdir = self.markdown_dir / pdf_file.stem
|
||||
if output_subdir.exists() and any(output_subdir.iterdir()):
|
||||
logging.info(f"输出文件夹已存在且非空,跳过处理: {pdf_file.stem}")
|
||||
return True
|
||||
|
||||
# 调用OCR API
|
||||
api_response = self._call_ocr_api(pdf_file)
|
||||
if not api_response:
|
||||
return False
|
||||
|
||||
# 获取下载URL并拼接完整地址
|
||||
download_url = api_response.get('download_url')
|
||||
if not download_url:
|
||||
logging.error(f"API响应中缺少下载URL: {pdf_file.name}")
|
||||
return False
|
||||
|
||||
# 拼接完整的下载URL
|
||||
full_download_url = f"http://100.106.4.14:7861{download_url}"
|
||||
logging.debug(f"完整下载URL: {full_download_url}")
|
||||
|
||||
# 下载并解压ZIP文件
|
||||
success = self._download_and_extract_zip(full_download_url, pdf_file)
|
||||
if not success:
|
||||
return False
|
||||
|
||||
# 获取解压后的文件夹路径
|
||||
output_subdir = self.markdown_dir / pdf_file.stem
|
||||
|
||||
# 第一层筛选:检查MIMIC-IV关键词
|
||||
logging.info(f"开始MIMIC-IV关键词筛选: {pdf_file.stem}")
|
||||
if not self._check_mimic_keywords(output_subdir):
|
||||
logging.info(f"未通过MIMIC-IV关键词筛选,跳过: {pdf_file.stem}")
|
||||
return True # 处理成功但未通过筛选
|
||||
|
||||
# 第二层筛选:AI分析研究任务
|
||||
logging.info(f"开始AI研究任务分析: {pdf_file.stem}")
|
||||
introduction = self._extract_introduction(output_subdir)
|
||||
if not introduction:
|
||||
logging.warning(f"无法提取Introduction,跳过AI分析: {pdf_file.stem}")
|
||||
return True # 处理成功但无法进行任务分析
|
||||
|
||||
task_type = self._analyze_research_task(introduction)
|
||||
if task_type == "none":
|
||||
logging.info(f"未通过研究任务筛选 (task_type=none),跳过: {pdf_file.stem}")
|
||||
return True # 处理成功但未通过筛选
|
||||
|
||||
# 两层筛选都通过,根据任务类型标记文件夹
|
||||
logging.info(f"通过所有筛选,标记为{task_type}任务论文: {pdf_file.stem}")
|
||||
if self._mark_valid_folder(output_subdir, task_type):
|
||||
logging.info(f"论文筛选完成,已标记为{task_type}任务: {pdf_file.stem}")
|
||||
else:
|
||||
logging.warning(f"文件夹标记失败: {pdf_file.stem}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"处理PDF文件时发生错误: {pdf_file.name} - {e}")
|
||||
return False
|
||||
|
||||
def _make_request_with_retry(self, url: str, files: Optional[Dict] = None,
|
||||
max_retries: int = 5, timeout: int = 180) -> requests.Response:
|
||||
"""带智能重试策略的HTTP请求
|
||||
|
||||
Args:
|
||||
url (str): 请求URL
|
||||
files (Optional[Dict]): 文件数据(用于POST请求)
|
||||
max_retries (int): 最大重试次数,增加到5次
|
||||
timeout (int): 请求超时时间(秒)
|
||||
|
||||
Returns:
|
||||
requests.Response: HTTP响应
|
||||
|
||||
Raises:
|
||||
requests.RequestException: 当所有重试都失败时抛出
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
if files:
|
||||
response = self.session.post(url, files=files, timeout=timeout)
|
||||
else:
|
||||
response = self.session.get(url, timeout=timeout)
|
||||
|
||||
# 检查响应状态,针对500错误进行重试
|
||||
if response.status_code == 500:
|
||||
if attempt == max_retries - 1:
|
||||
logging.error(f"服务器内部错误,已达到最大重试次数: HTTP {response.status_code}")
|
||||
return response # 返回错误响应而不是抛出异常
|
||||
|
||||
# 500错误使用较长的等待时间
|
||||
wait_time = min(30, 10 + (attempt * 5)) # 10s, 15s, 20s, 25s, 30s
|
||||
logging.warning(f"服务器内部错误,{wait_time}秒后重试 (第{attempt + 1}次)")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
return response
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
if attempt == max_retries - 1:
|
||||
logging.error(f"请求超时,已达到最大重试次数: {e}")
|
||||
raise
|
||||
|
||||
# 超时错误使用较短的等待时间
|
||||
wait_time = min(15, 5 + (attempt * 2)) # 5s, 7s, 9s, 11s, 13s
|
||||
logging.warning(f"请求超时,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
if attempt == max_retries - 1:
|
||||
logging.error(f"连接错误,已达到最大重试次数: {e}")
|
||||
raise
|
||||
|
||||
# 连接错误使用指数退避
|
||||
wait_time = min(60, 5 * (2 ** attempt)) # 5s, 10s, 20s, 40s, 60s
|
||||
logging.warning(f"连接错误,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
except requests.RequestException as e:
|
||||
if attempt == max_retries - 1:
|
||||
logging.error(f"请求失败,已达到最大重试次数: {e}")
|
||||
raise
|
||||
|
||||
# 其他错误使用标准指数退避
|
||||
wait_time = min(30, 3 * (2 ** attempt)) # 3s, 6s, 12s, 24s, 30s
|
||||
logging.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
def parse_all_pdfs(self) -> Dict[str, int]:
|
||||
"""批量处理所有PDF文件,转换为Markdown格式
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 处理统计信息 {'success': 成功数, 'failed': 失败数, 'total': 总数}
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: PDF目录不存在
|
||||
"""
|
||||
try:
|
||||
# 扫描PDF文件
|
||||
pdf_files = self._scan_pdf_files()
|
||||
if not pdf_files:
|
||||
logging.warning("未找到PDF文件")
|
||||
return {'success': 0, 'failed': 0, 'total': 0}
|
||||
|
||||
# 准备输出目录
|
||||
self._prepare_output_dir()
|
||||
|
||||
# 初始化统计
|
||||
total_files = len(pdf_files)
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
failed_files = []
|
||||
|
||||
logging.info(f"开始并发处理 {total_files} 个PDF文件")
|
||||
logging.info(f"并发数: {self.parallel} (降低并发数以避免服务器过载)")
|
||||
logging.info(f"请求超时: 1800秒 (适配服务器处理时间)")
|
||||
logging.info(f"重试次数: 5次 (智能重试策略)")
|
||||
|
||||
# 使用并发执行器处理PDF
|
||||
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
|
||||
# 提交所有处理任务
|
||||
future_to_pdf = {
|
||||
executor.submit(self._process_single_pdf, pdf_file): pdf_file
|
||||
for pdf_file in pdf_files
|
||||
}
|
||||
|
||||
# 处理完成的任务,实时显示进度
|
||||
completed_count = 0
|
||||
for future in as_completed(future_to_pdf):
|
||||
pdf_file = future_to_pdf[future]
|
||||
filename = pdf_file.name[:50] + '...' if len(pdf_file.name) > 50 else pdf_file.name
|
||||
|
||||
try:
|
||||
success = future.result()
|
||||
completed_count += 1
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
status = "✓"
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_files.append({
|
||||
'filename': pdf_file.name,
|
||||
'path': str(pdf_file)
|
||||
})
|
||||
status = "✗"
|
||||
|
||||
# 显示进度
|
||||
progress = (completed_count / total_files) * 100
|
||||
print(f"\r[{completed_count:3d}/{total_files}] {progress:5.1f}% {status} {filename}", end='', flush=True)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
completed_count += 1
|
||||
failed_files.append({
|
||||
'filename': pdf_file.name,
|
||||
'path': str(pdf_file),
|
||||
'error': str(e)
|
||||
})
|
||||
progress = (completed_count / total_files) * 100
|
||||
print(f"\r[{completed_count:3d}/{total_files}] {progress:5.1f}% ✗ {filename} (Error: {str(e)[:30]})", end='', flush=True)
|
||||
|
||||
print() # 换行
|
||||
|
||||
# 记录失败详情
|
||||
if failed_files:
|
||||
logging.warning(f"以下 {len(failed_files)} 个PDF文件处理失败:")
|
||||
for file_info in failed_files:
|
||||
logging.warning(f" - {file_info['filename']}")
|
||||
if 'error' in file_info:
|
||||
logging.warning(f" 错误: {file_info['error']}")
|
||||
|
||||
# 生成处理报告
|
||||
stats = {
|
||||
'success': success_count,
|
||||
'failed': failed_count,
|
||||
'total': total_files
|
||||
}
|
||||
|
||||
logging.info(f"PDF解析完成! 成功: {success_count}/{total_files} ({success_count/total_files*100:.1f}%)")
|
||||
if failed_count > 0:
|
||||
logging.warning(f"失败: {failed_count}/{total_files} ({failed_count/total_files*100:.1f}%)")
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"批量处理PDF文件时发生错误: {e}")
|
||||
raise
|
||||
Loading…
x
Reference in New Issue
Block a user