Compare commits

..

10 Commits

Author SHA1 Message Date
76c04eae4a feat: 优化信息提取系统并行处理和错误重试机制
- info_extractor.py: 添加文档并行处理线程数配置参数
- papers_crawler.py: 优化默认参数配置和数据文件路径
- src/crawler.py: 精确化MIMIC-IV关键词搜索和扩大爬取范围
- src/extractor.py: 实现并行文档处理、提取重试机制和内容预处理
- src/parse.py: 小幅优化解析逻辑

主要改进:
1. 支持多线程并行处理文档,提升提取效率
2. 增加API调用重试机制,提高稳定性
3. 优化论文内容预处理,去除无关信息
4. 完善进度跟踪和错误日志记录
2025-08-26 22:19:28 +08:00
d1f7a27b1b cleanup: 移除过时实验文件并更新依赖锁定文件
- 删除experiment_runner.py和number_extraction_models.py旧实验文件
- 更新uv.lock以反映新增的langextract和httpx[socks]依赖
2025-08-25 20:51:41 +08:00
c4037325ed feat: 实现基于LangExtract框架的MIMIC论文信息提取系统
- 新增info_extractor.py主文件,支持命令行参数和测试模式
- 实现src/extractor.py核心MIMICLangExtractBuilder类
- 集成vllm API服务(OpenAI兼容格式)进行结构化信息提取
- 支持5大模块提取:数据集、模型、训练、评估、环境配置
- 实现源文本定位和交互式HTML可视化
- 添加langextract和httpx[socks]依赖
- 支持个性化论文子目录结果保存
- 清理过时的experiment_runner.py和number_extraction_models.py文件
2025-08-25 20:51:30 +08:00
1b652502d5 docs: 新增SubAgent系统完整示例和说明文档
- 添加详细的SubAgent使用指南(README.md)
- 创建完整的Pydantic模型示例(example_models.py)
- 实现基础使用示例,展示核心功能(basic_example.py)
- 构建复杂文本分析应用示例(text_analysis_example.py)
- 提供数字提取实验运行器作为参考示例
- 包含多Agent协作、批量处理、性能监控等高级功能
- 支持交互式演示和完整的错误处理机制
2025-08-25 17:33:20 +08:00
f7a06775ca feat: 实现基于Agno框架的SubAgent系统
- 新增SubAgent核心类,支持多LLM提供商
- 实现动态prompt模板构建功能
- 添加JSON结构化输出和零容错解析
- 集成配置管理和模型工厂模式
- 提供完整的错误处理和日志系统
- 支持阿里云、DeepSeek、OpenAI等主流LLM服务
2025-08-25 17:33:11 +08:00
099159dfb7 feat: 新增PDF解析功能模块
- pdf_parser.py: PDF解析主程序,支持命令行参数和并发处理
- src/parse.py: PDF解析核心模块,提供PDFParser类
  * 支持OCR API调用,将PDF转换为Markdown格式
  * 内置HTTP会话管理、连接池优化和重试机制
  * 支持并发处理和详细进度显示
  * 完善的错误处理和日志记录功能
2025-08-24 15:07:42 +08:00
8d6d217c2f fix: 优化论文爬取功能
- papers_crawler.py: 优化CSV下载参数默认值为"yes",提升用户体验
- src/crawler.py:
  * 修复摘要字段换行符处理,确保数据清洁性
  * 增强MedRxiv PDF链接获取策略,支持多种URL格式和版本号
2025-08-24 15:07:34 +08:00
367696788b config: 更新开发环境配置
- .gitignore: 添加日志文件忽略规则(**/*.log)
- .vscode/launch.json: 为PDF解析器添加调试配置,支持不同参数测试
2025-08-24 15:07:26 +08:00
41e5fd1543 feat: 实现PDF下载功能
- 新增 download_pdfs_from_csv() 方法支持从CSV文件批量下载论文PDF
- 支持ArXiv和MedRxiv两种数据源的PDF链接解析和下载
- 实现并发下载控制、失败重试机制和PDF完整性验证
- 添加实时下载进度显示和详细的错误日志记录
- 更新命令行参数支持PDF下载测试功能
- 清理临时文件和更新.gitignore规则
2025-08-23 19:42:47 +08:00
802fe4b239 config: 更新.gitignore忽略macOS系统文件
- 添加.DS_Store到忽略列表
2025-08-23 16:33:51 +08:00
23 changed files with 7109 additions and 315 deletions

4
.gitignore vendored
View File

@ -10,4 +10,6 @@ wheels/
.venv
.claude
dataset/
docs/CLAUDE*
docs/CLAUDE*
.DS_Store
**/*.log

18
.vscode/launch.json vendored
View File

@ -31,6 +31,24 @@
"console": "integratedTerminal",
"python": "${workspaceFolder}/.venv/bin/python",
"args": ["--mimic_paper_path", "dataset/mimic.csv", "--parallel", "5"]
},
{
"name": "调试 pdf_parser.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/pdf_parser.py",
"console": "integratedTerminal",
"python": "${workspaceFolder}/.venv/bin/python",
"args": []
},
{
"name": "调试 pdf_parser.py (指定参数)",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/pdf_parser.py",
"console": "integratedTerminal",
"python": "${workspaceFolder}/.venv/bin/python",
"args": ["--pdf-dir", "dataset/pdfs", "--parallel", "3", "--markdown-dir", "dataset/markdowns"]
}
]
}

View File

@ -1,283 +0,0 @@
# AI指导规范构建任务 - 深度分析
## 任务理解
用户希望为MedResearcher项目构建一套完整的AI协作指导规范核心思想是"充分讨论后再修改",以避免代码难以维护。
## 核心设计理念
1. **讨论优先**:任何修改前必须充分讨论,达成共识
2. **上下文明确**:不仅指出需要什么,更要具体到哪个文件的哪个函数
3. **渐进式实施**:通过子任务分解,逐步完成复杂需求
4. **可追溯性**:所有决策和修改都有明确记录
## 详细工作流程设计
### 阶段1需求理解与记录必须执行
**触发条件**:用户提出任何代码修改需求
**具体步骤**
1. **立即**在CLAUDE-temp.md中撰写
```markdown
## 任务理解
原始需求:[用户的原话]
我的理解:[用我的话重述一遍]
## 收集的上下文
### 相关文件和函数
- `papers_crawler.py::line_45-67::fetch_papers()` - 当前的爬取实现
- `config.py::line_12-15::RETRY_CONFIG` - 现有重试配置
### 现有代码分析
[贴出关键代码片段并分析]
### 潜在影响
- 影响文件papers_crawler.py, test_crawler.py
- 影响功能:论文爬取的稳定性
- 风险评估:可能影响爬取速度
## 任务复杂度判断
- [ ] 单一功能修改影响1-2个函数 → 简单任务
- [x] 需要修改3个以上函数或添加新模块 → 复杂任务
```
2. **等待用户反馈**
- "理解正确" → 继续阶段2
- "理解有偏差" → 修正理解,重新记录
- "补充需求" → 更新CLAUDE-temp.md
### 阶段2任务规划根据复杂度差异化
**简单任务处理**
- **判断标准**
- 修改不超过2个函数
- 不需要新建文件
- 逻辑改动在50行以内
- **严格要求**:绝对不允许拆分子任务
- **直接输出**:一个完整的执行计划
**复杂任务处理**
- **判断标准**
- 修改3个以上函数
- 需要新建模块或文件
- 涉及多个功能模块交互
- **拆分原则**
- 必须拆分为3-5个子任务不能少于3个不能多于5个
- 每个子任务可独立验证
- 子任务之间有清晰的依赖关系
- **拆分示例**
```
子任务1创建重试机制基础设施
子任务2集成重试机制到爬虫
子任务3添加重试相关配置
子任务4更新错误处理逻辑
子任务5添加重试日志记录
```
### 阶段3计划确认与正式化用户确认后执行
**创建CLAUDE-plan.md严格按照以下格式**
#### 简单任务格式:
```markdown
## 任务:[具体任务名称]
创建时间2025-08-22 10:30
### 目标
[30-50字描述要实现的功能必须具体且可验证]
### 所需上下文
- `papers_crawler.py::45-67行::fetch_papers()` - 需要添加异常处理
- `papers_crawler.py::120-135行::parse_response()` - 需要理解返回格式
- `config.py::全文` - 了解现有配置结构
### 拟修改内容
1. 修改 `papers_crawler.py::fetch_papers()` 第50-55行 - 添加try-except块
2. 修改 `papers_crawler.py::fetch_papers()` 第65行 - 添加重试逻辑
3. 修改 `config.py` 末尾 - 添加RETRY_TIMES常量
### 测试指令
```bash
# 主功能测试
uv run papers_crawler.py --keyword "machine learning" --limit 5
# 异常情况测试(模拟网络错误)
uv run papers_crawler.py --test-mode --simulate-error
```
### 验收标准
- [ ] 正常爬取功能不受影响
- [ ] 网络异常时能正确重试
- [ ] 日志正确记录重试次数
```
#### 复杂任务格式:
```markdown
## 任务:[具体任务名称]
创建时间2025-08-22 10:30
### 总体目标
[50-100字描述整体要实现的功能]
### 子任务分解
#### 子任务1[名称]
**目标**[20-30字描述]
**所需上下文**
- `papers_crawler.py::45-67行::fetch_papers()` - 理解当前实现
- `utils/__init__.py::全文` - 确认工具模块结构
**拟修改内容**
- 新建 `utils/retry.py` - 创建RetryDecorator类
- 修改 `utils/__init__.py` - 导出retry装饰器
**测试指令**
```bash
# 单元测试
python -c "from utils.retry import retry; print('导入成功')"
```
#### 子任务2[名称]
[格式同上]
### 整体验收标准
- [ ] 所有子任务独立测试通过
- [ ] 集成测试:完整流程测试通过
- [ ] 性能测试:重试不影响正常爬取速度
```
### 阶段4实施与验证严格按计划执行
**执行要求**
1. **开始前确认**
- 再次阅读CLAUDE-plan.md
- 确认所有依赖文件存在
- 确认测试环境就绪
2. **执行中记录**
- 在CLAUDE-activeContext.md实时更新进度
- 遇到计划外情况立即停止并讨论
3. **完成后验证**
- 运行所有测试指令
- 检查验收标准
- 记录任何偏差或问题
## Memory Bank系统设计
### 文件结构
```
/docs/
├── CLAUDE-temp.md # 临时讨论和分析
├── CLAUDE-plan.md # 当前任务的正式计划
├── CLAUDE-activeContext.md # 会话状态和进度跟踪
├── CLAUDE-patterns.md # 项目代码模式记录
├── CLAUDE-decisions.md # 重要决策和理由记录
├── CLAUDE-troubleshooting.md # 问题和解决方案库
└── CLAUDE-config-variables.md # 配置变量参考
```
### 使用原则
1. **docs/CLAUDE-temp.md**
- 每次新任务开始时清空或归档
- 用于快速记录和思考
- 不需要结构化
2. **docs/CLAUDE-plan.md**
- 结构化的任务计划
- 用户确认后才写入
- 作为实施的指导文档
3. **docs/CLAUDE-activeContext.md**
- 记录当前进度
- 标记已完成/进行中/待完成
- 会话恢复时的参考
### Memory Bank更新机制
**使用专门的SubAgent管理**
```
Task: memory-bank-updater
Description: "更新Memory Bank文件"
Prompt: "任务已完成请更新以下Memory Bank文件
1. CLAUDE-activeContext.md - 标记任务完成,记录最终状态
2. CLAUDE-patterns.md - 如有新的代码模式,记录下来
3. CLAUDE-decisions.md - 记录本次任务的关键决策
4. CLAUDE-troubleshooting.md - 如遇到问题,记录解决方案
5. CLAUDE-config-variables.md - 如有新配置,更新文档
具体完成内容:[任务摘要]
遇到的问题:[如有]
采用的解决方案:[如有]"
```
**调用时机**
- 每个任务完成后必须调用
- 遇到重要决策时调用
- 发现新的最佳实践时调用
## 工具使用优化原则
### 1. 批量操作原则
**场景**:需要读取多个文件或执行多个独立搜索时
**做法**
```python
# 同时执行多个工具调用
parallel_calls = [
Read("papers_crawler.py"),
Read("pdf_parser.py"),
Grep("retry", "*.py"),
LS("./utils/")
]
```
**禁止**:顺序执行可并行的操作
### 2. 上下文管理策略
**主上下文保留**
- 用户对话
- 关键决策点
- 当前任务计划
**委托给subagent**
- 大规模代码搜索:"搜索所有使用requests库的地方"
- 代码模式分析:"分析项目中的错误处理模式"
- 依赖关系梳理:"找出papers_crawler.py的所有依赖"
**subagent使用示例**
```
Task: code-searcher
Prompt: "在整个项目中搜索所有异常处理相关的代码,
重点关注papers_crawler.py和pdf_parser.py
总结当前的错误处理模式和改进建议"
```
### 3. 文件操作最佳实践
**读取顺序**
1. 先读CLAUDE-activeContext.md如果存在了解当前状态
2. 读取主文件了解整体结构
3. 读取相关依赖文件
**修改原则**
- 优先使用Edit而非Write
- 使用MultiEdit处理同文件多处修改
- 新文件创建需明确理由
## 与现有编程规范的协同
### 层次关系
1. **编程规范**已在CLAUDE.md中定义
- 定义"怎么写代码"
- 包括:命名、注释、代码风格等
2. **AI指导规范**(本规范):
- 定义"怎么理解和修改代码"
- 包括:工作流程、沟通方式、工具使用等
### 执行优先级
1. 遵守编程规范的硬性要求(如单次修改限制)
2. 按AI指导流程进行任务
3. 发生冲突时,编程规范优先
## 规范更新机制
- 每次遇到新的最佳实践记录到CLAUDE-patterns.md
- 定期回顾CLAUDE-troubleshooting.md提炼通用规则
- 用户可随时提出规范优化建议

View File

@ -14,6 +14,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
## 各个模块的主文件
1. 论文爬取主文件: papers_crawler.py
2. pdf解析主文件: pdf_parser.py
3. 信息抽取主文件: info_extractor.py
3. 实验运行主文件: experiment_runner.py
## 文件结构
@ -23,6 +24,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
│ └── mimic.csv # 存放所有需要处理的与mimic相关论文的基础信息
├── papers_crawler.py # 论文爬取主文件
├── pdf_parser.py # pdf解析主文件
├── info_extractor.py # 信息抽取主文件
├── experiment_runner.py # 实验运行主文件
├── src/ # 源代码目录
│ └── utils/ # 工具函数目录

View File

@ -0,0 +1,428 @@
# SubAgent系统使用指南
## 📚 概述
SubAgent是基于Agno框架构建的智能代理系统为MedResearcher项目提供强大的AI功能。它支持多种LLM提供商提供动态prompt构建、JSON结构化输出和零容错解析等核心功能。
## ✨ 核心特性
### 🤖 智能代理核心
- **多提供商支持**: 阿里云(qwen)、DeepSeek、OpenAI等主流LLM服务
- **动态Prompt**: 支持模板变量替换的灵活prompt构建系统
- **结构化输出**: 基于Pydantic模型的JSON格式化响应
- **零容错解析**: 多策略JSON解析确保即使不完美输出也能解析
### 🔧 配置管理
- **YAML配置**: 统一的配置文件管理,支持环境变量
- **模型工厂**: 自动化的模型实例创建和参数管理
- **灵活配置**: 支持运行时参数覆盖和动态配置
### 🛠 开发便利性
- **类型安全**: 完整的类型提示支持
- **异常处理**: 详细的错误信息和异常层级
- **调试支持**: 内置日志和调试模式
## 🚀 快速开始
### 1. 基础设置
首先确保已安装依赖:
```bash
uv add agno pydantic pyyaml
```
### 2. 配置LLM服务
`src/config/llm_config.yaml`中配置你的LLM服务
```yaml
aliyun:
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
api_key: "${DASHSCOPE_API_KEY}"
models:
qwen-max:
class: "OpenAILike"
params:
id: "qwen-max"
temperature: 0.3
```
### 3. 设置环境变量
创建`.env`文件或设置环境变量:
```bash
export DASHSCOPE_API_KEY="your_api_key_here"
```
### 4. 创建你的第一个Agent
```python
from src.agent_system import SubAgent
from pydantic import BaseModel, Field
# 定义响应模型
class TaskResult(BaseModel):
summary: str = Field(description="任务总结")
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
# 创建SubAgent
agent = SubAgent(
provider="aliyun",
model_name="qwen-max",
name="task_agent",
instructions=["你是一个专业的任务处理专家"],
prompt_template="请分析以下任务: {task_description}",
response_model=TaskResult
)
# 执行任务
result = agent.run(template_vars={"task_description": "数据分析项目"})
print(f"总结: {result.summary}")
print(f"置信度: {result.confidence}")
```
## 📖 详细使用指南
### SubAgent核心类
#### 初始化参数
```python
SubAgent(
provider: str, # LLM提供商名称
model_name: str, # 模型名称
instructions: List[str], # 指令列表(可选)
name: str, # Agent名称(可选)
description: str, # Agent描述(可选)
prompt_template: str, # 动态prompt模板(可选)
response_model: BaseModel, # Pydantic响应模型(可选)
config: Dict[str, Any], # 自定义配置(可选)
**agent_kwargs # 传递给Agno Agent的额外参数
)
```
#### 核心方法
##### 1. build_prompt() - 构建动态Prompt
```python
# 设置带变量的prompt模板
agent.update_prompt_template("""
请分析以下{data_type}数据:
数据内容: {data_content}
分析目标: {analysis_goal}
请提供详细的分析结果。
""")
# 构建具体prompt
prompt = agent.build_prompt({
"data_type": "销售",
"data_content": "Q1销售数据...",
"analysis_goal": "找出增长趋势"
})
```
##### 2. run() - 执行推理
```python
# 方式1: 使用模板变量
result = agent.run(template_vars={
"input_text": "待分析的文本内容"
})
# 方式2: 直接提供prompt
result = agent.run(prompt="请分析这段文本的情感倾向")
# 方式3: 带额外参数
result = agent.run(
template_vars={"data": "测试数据"},
temperature=0.7,
max_tokens=1000
)
```
##### 3. get_model_info() - 获取模型信息
```python
info = agent.get_model_info()
print(f"Agent名称: {info['name']}")
print(f"提供商: {info['provider']}")
print(f"模型: {info['model_name']}")
print(f"是否有prompt模板: {info['has_prompt_template']}")
```
### Pydantic响应模型
#### 基础模型定义
```python
from pydantic import BaseModel, Field
from typing import List, Optional
class AnalysisResult(BaseModel):
"""分析结果模型"""
summary: str = Field(description="分析总结")
key_points: List[str] = Field(description="关键要点列表")
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
recommendations: Optional[List[str]] = Field(default=None, description="建议列表")
class Config:
json_encoders = {
float: lambda v: round(v, 3) if v is not None else None
}
```
#### 复杂嵌套模型
```python
class DetailedItem(BaseModel):
name: str = Field(description="项目名称")
value: float = Field(description="数值")
category: str = Field(description="分类")
class ComprehensiveResult(BaseModel):
items: List[DetailedItem] = Field(description="详细项目列表")
total_count: int = Field(description="总数量", ge=0)
summary: str = Field(description="整体总结")
```
### 配置管理详解
#### LLM配置文件结构 (llm_config.yaml)
```yaml
# 阿里云配置
aliyun:
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
api_key: "${DASHSCOPE_API_KEY}"
models:
qwen-max:
class: "OpenAILike"
params:
id: "qwen-max"
temperature: 0.3
max_tokens: 4000
qwen-plus:
class: "OpenAILike"
params:
id: "qwen-plus"
temperature: 0.5
# DeepSeek配置
deepseek:
base_url: "https://api.deepseek.com/v1"
api_key: "${DEEPSEEK_API_KEY}"
models:
deepseek-v3:
class: "OpenAILike"
params:
id: "deepseek-chat"
temperature: 0.3
# OpenAI配置
openai:
api_key: "${OPENAI_API_KEY}"
models:
gpt-4o:
class: "OpenAIChat"
params:
model: "gpt-4o"
temperature: 0.3
```
#### 环境变量配置 (.env)
```bash
# 阿里云API密钥
DASHSCOPE_API_KEY=sk-your-dashscope-key
# DeepSeek API密钥
DEEPSEEK_API_KEY=sk-your-deepseek-key
# OpenAI API密钥
OPENAI_API_KEY=sk-your-openai-key
```
### 便捷函数使用
#### create_json_agent() - 快速创建JSON Agent
```python
from src.agent_system import create_json_agent
# 快速创建支持JSON输出的Agent
agent = create_json_agent(
provider="aliyun",
model_name="qwen-max",
name="json_extractor",
prompt_template="从以下文本提取信息: {text}",
response_model="MyModel", # 可以是字符串或类
instructions=["你是信息提取专家"]
)
```
## 🎯 实际应用示例
### 示例1: 情感分析Agent
```python
from pydantic import BaseModel, Field
from typing import Literal
from src.agent_system import SubAgent
class SentimentResult(BaseModel):
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
explanation: str = Field(description="分析说明")
sentiment_agent = SubAgent(
provider="aliyun",
model_name="qwen-max",
name="sentiment_analyzer",
instructions=[
"你是专业的文本情感分析专家",
"请准确识别文本的情感倾向",
"提供详细的分析依据"
],
prompt_template="""
请分析以下文本的情感倾向:
文本内容: {text}
请识别情感倾向(positive/negative/neutral)、置信度(0-1)和分析说明。
""",
response_model=SentimentResult
)
# 使用示例
result = sentiment_agent.run(template_vars={
"text": "这个产品质量很好,我非常满意!"
})
print(f"情感: {result.sentiment}")
print(f"置信度: {result.confidence}")
print(f"说明: {result.explanation}")
```
### 示例2: 数据提取Agent
```python
class DataExtraction(BaseModel):
extracted_data: Dict[str, Any] = Field(description="提取的数据")
extraction_count: int = Field(description="提取项目数量")
data_quality: Literal["high", "medium", "low"] = Field(description="数据质量评估")
extractor_agent = SubAgent(
provider="aliyun",
model_name="qwen-plus",
name="data_extractor",
instructions=[
"你是数据提取专家",
"从非结构化文本中提取结构化数据",
"确保提取的数据准确完整"
],
prompt_template="""
从以下{data_type}文档中提取关键数据:
文档内容:
{document}
提取要求:
{requirements}
请提取所有相关数据并评估数据质量。
""",
response_model=DataExtraction
)
```
## ⚠️ 注意事项与最佳实践
### 1. 配置管理
- **API密钥安全**: 始终使用环境变量存储API密钥切勿在代码中硬编码
- **配置验证**: 程序启动时验证配置文件完整性
- **环境隔离**: 开发、测试、生产环境使用不同的配置文件
### 2. Prompt设计
- **明确指令**: 提供清晰、具体的任务指令
- **示例驱动**: 在prompt中包含输入输出示例
- **结构化模板**: 使用结构化的prompt模板提高一致性
### 3. 错误处理
- **异常捕获**: 对Agent调用进行适当的异常处理
- **重试机制**: 对网络错误实现重试逻辑
- **降级策略**: 准备备用模型或简化输出格式
### 4. 性能优化
- **缓存机制**: 对相同输入实现结果缓存
- **批处理**: 将多个小任务合并为大任务处理
- **模型选择**: 根据任务复杂度选择合适的模型
## 🔧 故障排除
### 常见问题
#### 1. 配置文件不存在
```
错误: FileNotFoundError: LLM配置文件不存在
解决: 确保 src/config/llm_config.yaml 文件存在且格式正确
```
#### 2. API密钥未设置
```
错误: 环境变量 DASHSCOPE_API_KEY 未定义
解决: 设置相应的环境变量或在.env文件中配置
```
#### 3. JSON解析失败
```
错误: JSONParseError: 所有解析策略都失败了
解决: 检查prompt设计确保要求明确的JSON格式输出
```
#### 4. 模型验证失败
```
错误: Pydantic模型验证失败
解决: 检查响应模型定义与实际输出是否匹配
```
### 调试技巧
#### 启用调试模式
```python
agent = SubAgent(
provider="aliyun",
model_name="qwen-max",
debug_mode=True, # 启用调试输出
# ... 其他参数
)
```
#### 查看生成的Prompt
```python
# 构建并查看最终的prompt
prompt = agent.build_prompt({"key": "value"})
print(f"生成的prompt: {prompt}")
```
#### 捕获详细错误信息
```python
try:
result = agent.run(template_vars={"text": "测试"})
except Exception as e:
print(f"错误类型: {type(e)}")
print(f"错误信息: {e}")
import traceback
traceback.print_exc()
```
## 🚦 版本信息
- **当前版本**: 0.1.0
- **依赖要求**:
- Python >= 3.8
- agno >= 0.1.0
- pydantic >= 2.0.0
- pyyaml >= 6.0.0
## 📞 支持与反馈
如遇到问题或有功能建议请联系开发团队或提交issue。
---
*MedResearcher SubAgent系统 - 让AI更智能让开发更简单* 🎉

View File

@ -0,0 +1,403 @@
#!/usr/bin/env python3
"""
SubAgent基础使用示例
展示SubAgent系统的基本功能:
1. 创建简单的对话Agent
2. 使用动态prompt模板
3. 结构化JSON输出
4. 错误处理和调试
"""
import sys
import os
from typing import Optional
# 添加项目根路径到Python路径
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(project_root)
from src.agent_system import SubAgent
from example_models import BasicResponse, SentimentAnalysis
def create_simple_chat_agent() -> SubAgent:
"""创建简单的聊天Agent"""
print("🤖 创建简单聊天Agent...")
try:
agent = SubAgent(
provider="aliyun",
model_name="qwen-turbo",
name="simple_chat",
description="简单的聊天助手",
instructions=[
"你是一个友好的AI助手",
"请用简洁明了的语言回答问题",
"保持积极正面的态度"
]
)
print("✅ 简单聊天Agent创建成功")
return agent
except Exception as e:
print(f"❌ Agent创建失败: {e}")
raise
def create_structured_output_agent() -> SubAgent:
"""创建支持结构化输出的Agent"""
print("\n🔧 创建结构化输出Agent...")
try:
agent = SubAgent(
provider="aliyun",
model_name="qwen-max",
name="structured_responder",
description="提供结构化响应的智能助手",
instructions=[
"你是一个专业的响应助手",
"始终提供结构化的JSON格式输出",
"确保响应准确和有用"
],
response_model=BasicResponse
)
print("✅ 结构化输出Agent创建成功")
return agent
except Exception as e:
print(f"❌ 结构化输出Agent创建失败: {e}")
raise
def create_sentiment_agent() -> SubAgent:
"""创建情感分析Agent"""
print("\n💭 创建情感分析Agent...")
instructions = [
"你是专业的文本情感分析专家",
"准确识别文本的情感倾向positive积极、negative消极、neutral中性",
"提供0-1范围的置信度评分",
"给出详细的分析说明和关键词"
]
prompt_template = """
请对以下文本进行情感分析
文本内容"{text}"
分析要求
1. 识别情感倾向positive/negative/neutral
2. 评估分析置信度0-1之间的浮点数
3. 提供分析说明和依据
4. 提取影响情感判断的关键词
请严格按照指定的JSON格式返回结果
"""
try:
agent = SubAgent(
provider="aliyun",
model_name="qwen-max",
name="sentiment_analyzer",
description="专业的文本情感分析系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=SentimentAnalysis
)
print("✅ 情感分析Agent创建成功")
return agent
except Exception as e:
print(f"❌ 情感分析Agent创建失败: {e}")
raise
def demo_simple_chat():
"""演示简单对话功能"""
print("\n" + "="*50)
print("🗣️ 简单对话演示")
print("="*50)
try:
agent = create_simple_chat_agent()
# 测试问题列表
test_questions = [
"你好!请介绍一下你自己",
"什么是人工智能?",
"请给我一些学习Python的建议"
]
for i, question in enumerate(test_questions, 1):
print(f"\n问题 {i}: {question}")
try:
response = agent.run(prompt=question)
print(f"回答: {response}")
except Exception as e:
print(f"❌ 回答失败: {e}")
return True
except Exception as e:
print(f"❌ 简单对话演示失败: {e}")
return False
def demo_structured_response():
"""演示结构化响应功能"""
print("\n" + "="*50)
print("📋 结构化响应演示")
print("="*50)
try:
agent = create_structured_output_agent()
# 测试请求列表
test_requests = [
"请解释什么是机器学习",
"介绍Python编程语言的特点",
"如何提高工作效率?"
]
for i, request in enumerate(test_requests, 1):
print(f"\n请求 {i}: {request}")
try:
result = agent.run(prompt=request)
print(f"✅ 响应成功:")
print(f" 消息: {result.message}")
print(f" 成功: {result.success}")
print(f" 时间: {result.timestamp}")
except Exception as e:
print(f"❌ 响应失败: {e}")
return True
except Exception as e:
print(f"❌ 结构化响应演示失败: {e}")
return False
def demo_dynamic_prompt():
"""演示动态prompt模板功能"""
print("\n" + "="*50)
print("🎭 动态Prompt模板演示")
print("="*50)
try:
agent = create_sentiment_agent()
# 测试文本列表
test_texts = [
"这个产品质量非常好,我很满意!强烈推荐给大家。",
"服务态度差,产品质量也不行,非常失望。",
"产品功能还可以,价格也合理,算是中规中矩的选择。",
"今天天气不错,适合出去散步。",
"这部电影真是太精彩了!演员演技出色,剧情引人入胜。"
]
for i, text in enumerate(test_texts, 1):
print(f"\n文本 {i}: {text}")
try:
# 使用动态模板构建prompt
result = agent.run(template_vars={"text": text})
print(f"✅ 分析结果:")
print(f" 情感: {result.sentiment}")
print(f" 置信度: {result.confidence}")
print(f" 说明: {result.explanation}")
print(f" 关键词: {result.keywords}")
except Exception as e:
print(f"❌ 分析失败: {e}")
return True
except Exception as e:
print(f"❌ 动态prompt演示失败: {e}")
return False
def demo_error_handling():
"""演示错误处理功能"""
print("\n" + "="*50)
print("⚠️ 错误处理演示")
print("="*50)
# 测试各种错误情况
error_tests = [
{
"name": "无效的提供商",
"params": {"provider": "invalid_provider", "model_name": "test"},
"expected_error": "ValueError"
},
{
"name": "空的prompt模板变量",
"params": {"provider": "aliyun", "model_name": "qwen-turbo"},
"template_vars": {},
"prompt_template": "分析这个文本: {missing_var}",
"expected_error": "SubAgentError"
}
]
for test in error_tests:
print(f"\n测试: {test['name']}")
try:
if 'prompt_template' in test:
agent = SubAgent(**test['params'])
agent.update_prompt_template(test['prompt_template'])
agent.run(template_vars=test.get('template_vars', {}))
else:
agent = SubAgent(**test['params'])
print(f"❌ 预期错误未发生")
except Exception as e:
print(f"✅ 捕获到预期错误: {type(e).__name__}: {e}")
return True
def interactive_demo():
"""交互式演示"""
print("\n" + "="*50)
print("💬 交互式演示")
print("="*50)
print("输入文本进行情感分析,输入'quit'退出")
try:
agent = create_sentiment_agent()
while True:
user_input = input("\n请输入要分析的文本: ").strip()
if user_input.lower() == 'quit':
print("再见!")
break
if not user_input:
continue
try:
print(f"正在分析: {user_input}")
result = agent.run(template_vars={"text": user_input})
print(f"\n📊 分析结果:")
print(f"情感倾向: {result.sentiment}")
print(f"置信度: {result.confidence:.3f}")
print(f"分析说明: {result.explanation}")
if result.keywords:
print(f"关键词: {', '.join(result.keywords)}")
except Exception as e:
print(f"❌ 分析失败: {e}")
except KeyboardInterrupt:
print("\n程序已中断")
except Exception as e:
print(f"❌ 交互式演示失败: {e}")
def show_agent_info(agent: SubAgent):
"""显示Agent信息"""
info = agent.get_model_info()
print(f"\n📋 Agent信息:")
for key, value in info.items():
print(f" {key}: {value}")
def main():
"""主函数 - 运行所有演示"""
print("🚀 SubAgent基础使用示例")
print("=" * 60)
# 运行所有演示
demos = [
("简单对话", demo_simple_chat),
("结构化响应", demo_structured_response),
("动态Prompt", demo_dynamic_prompt),
("错误处理", demo_error_handling),
]
results = {}
for name, demo_func in demos:
print(f"\n开始演示: {name}")
try:
success = demo_func()
results[name] = success
print(f"{'' if success else ''} {name}演示{'成功' if success else '失败'}")
except Exception as e:
print(f"{name}演示异常: {e}")
results[name] = False
# 显示总结
print("\n" + "="*60)
print("📊 演示总结")
print("="*60)
total_demos = len(results)
successful_demos = sum(results.values())
for name, success in results.items():
status = "✅ 成功" if success else "❌ 失败"
print(f" {name}: {status}")
print(f"\n🎯 总计: {successful_demos}/{total_demos} 个演示成功")
# 询问是否运行交互式演示
print(f"\n是否运行交互式演示?(y/n): ", end="")
try:
choice = input().strip().lower()
if choice in ['y', 'yes', '']:
interactive_demo()
except (KeyboardInterrupt, EOFError):
print("\n程序结束")
return successful_demos == total_demos
def test_basic_functionality():
"""测试基础功能"""
print("正在测试SubAgent基础功能...")
try:
# 创建基本Agent
agent = SubAgent(
provider="aliyun",
model_name="qwen-turbo",
name="test_agent"
)
print(f"✅ Agent创建成功: {agent}")
# 显示Agent信息
show_agent_info(agent)
# 测试简单对话
response = agent.run(prompt="请简单介绍一下你自己")
print(f"✅ 对话测试成功,响应长度: {len(str(response))}字符")
return True
except Exception as e:
print(f"❌ 基础功能测试失败: {e}")
return False
if __name__ == "__main__":
# 可以选择运行测试或完整演示
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--test":
# 仅运行基础测试
success = test_basic_functionality()
exit(0 if success else 1)
else:
# 运行完整演示
main()

View File

@ -0,0 +1,377 @@
"""
SubAgent示例Pydantic模型定义
提供各种场景下的结构化输出模型展示SubAgent系统的灵活性
"""
from typing import List, Dict, Any, Optional, Literal, Union
from pydantic import BaseModel, Field, validator
from datetime import datetime
class BasicResponse(BaseModel):
"""基础响应模型"""
message: str = Field(description="响应消息")
success: bool = Field(description="处理是否成功", default=True)
timestamp: datetime = Field(default_factory=datetime.now, description="响应时间")
class SentimentAnalysis(BaseModel):
"""情感分析结果模型"""
sentiment: Literal["positive", "negative", "neutral"] = Field(description="情感倾向")
confidence: float = Field(description="置信度", ge=0.0, le=1.0)
explanation: str = Field(description="分析说明")
keywords: List[str] = Field(description="关键词列表", default_factory=list)
@validator('confidence')
def validate_confidence(cls, v):
"""验证置信度范围"""
if not 0.0 <= v <= 1.0:
raise ValueError('置信度必须在0.0到1.0之间')
return round(v, 3)
class KeywordExtraction(BaseModel):
"""关键词提取项"""
keyword: str = Field(description="关键词")
frequency: int = Field(description="出现频次", ge=1)
importance: float = Field(description="重要性评分", ge=0.0, le=1.0)
category: str = Field(description="关键词分类", default="general")
class Config:
json_encoders = {
float: lambda v: round(v, 3) if v is not None else None
}
class TextAnalysisResult(BaseModel):
"""文本分析完整结果"""
# 基本信息
text_length: int = Field(description="文本长度(字符数)", ge=0)
word_count: int = Field(description="词汇数量", ge=0)
language: str = Field(description="检测到的语言", default="zh")
# 分析结果
summary: str = Field(description="文本摘要")
sentiment: SentimentAnalysis = Field(description="情感分析结果")
keywords: List[KeywordExtraction] = Field(description="关键词提取结果")
# 质量评估
readability: Literal["high", "medium", "low"] = Field(description="可读性评估")
complexity: float = Field(description="复杂度评分", ge=0.0, le=1.0)
@validator('text_length')
def validate_text_length(cls, v):
"""验证文本长度"""
if v < 0:
raise ValueError('文本长度不能为负数')
return v
@validator('keywords')
def validate_keywords(cls, v):
"""验证关键词列表"""
if len(v) > 20:
# 只保留前20个最重要的关键词
v = sorted(v, key=lambda x: x.importance, reverse=True)[:20]
return v
class CategoryClassification(BaseModel):
"""分类结果项"""
category: str = Field(description="分类名称")
confidence: float = Field(description="分类置信度", ge=0.0, le=1.0)
probability: float = Field(description="分类概率", ge=0.0, le=1.0)
@validator('confidence', 'probability')
def round_float_values(cls, v):
"""保留3位小数"""
return round(v, 3)
class DocumentClassificationResult(BaseModel):
"""文档分类结果"""
primary_category: str = Field(description="主要分类")
confidence: float = Field(description="主分类置信度", ge=0.0, le=1.0)
all_categories: List[CategoryClassification] = Field(
description="所有分类结果(按置信度排序)"
)
features_used: List[str] = Field(
description="使用的特征列表",
default_factory=list
)
processing_time: Optional[float] = Field(
description="处理时间(秒)",
default=None
)
@validator('all_categories')
def sort_categories(cls, v):
"""按置信度降序排列"""
return sorted(v, key=lambda x: x.confidence, reverse=True)
class DataExtractionItem(BaseModel):
"""数据提取项"""
field_name: str = Field(description="字段名称")
field_value: Union[str, int, float, bool, None] = Field(description="字段值")
confidence: float = Field(description="提取置信度", ge=0.0, le=1.0)
source_text: str = Field(description="来源文本片段")
extraction_method: str = Field(description="提取方法", default="llm")
class StructuredDataExtraction(BaseModel):
"""结构化数据提取结果"""
extracted_data: Dict[str, Any] = Field(description="提取的结构化数据")
extraction_items: List[DataExtractionItem] = Field(description="详细提取项目")
# 质量评估
extraction_quality: Literal["excellent", "good", "fair", "poor"] = Field(
description="提取质量评估"
)
completeness: float = Field(
description="完整性评分",
ge=0.0,
le=1.0
)
accuracy: float = Field(
description="准确性评分",
ge=0.0,
le=1.0
)
# 统计信息
total_fields: int = Field(description="总字段数", ge=0)
extracted_fields: int = Field(description="成功提取字段数", ge=0)
failed_fields: int = Field(description="提取失败字段数", ge=0)
@validator('extracted_fields', 'failed_fields')
def validate_field_counts(cls, v, values):
"""验证字段计数"""
total = values.get('total_fields', 0)
if v > total:
raise ValueError('提取字段数不能超过总字段数')
return v
class TaskExecutionResult(BaseModel):
"""任务执行结果"""
# 任务信息
task_id: str = Field(description="任务ID")
task_type: str = Field(description="任务类型")
status: Literal["completed", "failed", "partial"] = Field(description="执行状态")
# 执行详情
result_data: Optional[Dict[str, Any]] = Field(description="结果数据", default=None)
error_message: Optional[str] = Field(description="错误信息", default=None)
warnings: List[str] = Field(description="警告信息", default_factory=list)
# 性能指标
execution_time: float = Field(description="执行时间(秒)", ge=0.0)
memory_usage: Optional[float] = Field(description="内存使用量MB", default=None)
# 质量评估
success_rate: float = Field(description="成功率", ge=0.0, le=1.0)
quality_score: float = Field(description="质量评分", ge=0.0, le=1.0)
@validator('success_rate', 'quality_score')
def round_scores(cls, v):
"""保留3位小数"""
return round(v, 3)
class ComprehensiveAnalysisResult(BaseModel):
"""综合分析结果(组合多个分析)"""
# 基本信息
analysis_id: str = Field(description="分析ID")
input_summary: str = Field(description="输入数据摘要")
analysis_timestamp: datetime = Field(default_factory=datetime.now)
# 各项分析结果
text_analysis: Optional[TextAnalysisResult] = Field(
description="文本分析结果",
default=None
)
classification: Optional[DocumentClassificationResult] = Field(
description="分类结果",
default=None
)
data_extraction: Optional[StructuredDataExtraction] = Field(
description="数据提取结果",
default=None
)
# 综合评估
overall_quality: Literal["excellent", "good", "fair", "poor"] = Field(
description="整体质量评估"
)
confidence_level: float = Field(
description="整体置信度",
ge=0.0,
le=1.0
)
# 处理统计
total_processing_time: float = Field(description="总处理时间(秒)", ge=0.0)
components_completed: int = Field(description="完成的组件数量", ge=0)
components_failed: int = Field(description="失败的组件数量", ge=0)
recommendations: List[str] = Field(
description="改进建议",
default_factory=list
)
@validator('confidence_level')
def validate_confidence(cls, v):
"""验证并格式化置信度"""
return round(v, 3)
# 测试模型定义
def test_models():
"""测试所有模型定义"""
print("正在测试示例模型定义...")
try:
# 测试基础响应模型
basic = BasicResponse(message="测试消息")
print(f"✅ BasicResponse模型测试成功: {basic.message}")
# 测试情感分析模型
sentiment = SentimentAnalysis(
sentiment="positive",
confidence=0.95,
explanation="积极的情感表达",
keywords=["", "满意", "推荐"]
)
print(f"✅ SentimentAnalysis模型测试成功: {sentiment.sentiment}")
# 测试关键词提取模型
keyword = KeywordExtraction(
keyword="人工智能",
frequency=5,
importance=0.8,
category="technology"
)
print(f"✅ KeywordExtraction模型测试成功: {keyword.keyword}")
# 测试文本分析结果模型
text_result = TextAnalysisResult(
text_length=150,
word_count=25,
language="zh",
summary="这是一个测试文本摘要",
sentiment=sentiment,
keywords=[keyword],
readability="high",
complexity=0.3
)
print(f"✅ TextAnalysisResult模型测试成功: {text_result.summary}")
# 测试分类结果模型
classification = DocumentClassificationResult(
primary_category="技术文档",
confidence=0.92,
all_categories=[
CategoryClassification(
category="技术文档",
confidence=0.92,
probability=0.87
)
]
)
print(f"✅ DocumentClassificationResult模型测试成功: {classification.primary_category}")
# 测试数据提取模型
extraction_item = DataExtractionItem(
field_name="标题",
field_value="SubAgent系统指南",
confidence=0.98,
source_text="# SubAgent系统指南",
extraction_method="pattern_matching"
)
data_extraction = StructuredDataExtraction(
extracted_data={"title": "SubAgent系统指南"},
extraction_items=[extraction_item],
extraction_quality="excellent",
completeness=1.0,
accuracy=0.95,
total_fields=1,
extracted_fields=1,
failed_fields=0
)
print(f"✅ StructuredDataExtraction模型测试成功: {data_extraction.extraction_quality}")
# 测试任务执行结果模型
task_result = TaskExecutionResult(
task_id="task_001",
task_type="text_analysis",
status="completed",
result_data={"status": "success"},
execution_time=2.5,
success_rate=1.0,
quality_score=0.95
)
print(f"✅ TaskExecutionResult模型测试成功: {task_result.status}")
# 测试综合分析结果模型
comprehensive = ComprehensiveAnalysisResult(
analysis_id="analysis_001",
input_summary="测试输入摘要",
text_analysis=text_result,
classification=classification,
data_extraction=data_extraction,
overall_quality="excellent",
confidence_level=0.93,
total_processing_time=5.2,
components_completed=3,
components_failed=0,
recommendations=["继续保持高质量输出"]
)
print(f"✅ ComprehensiveAnalysisResult模型测试成功: {comprehensive.overall_quality}")
# 测试JSON序列化
json_str = comprehensive.model_dump_json(indent=2)
print(f"✅ JSON序列化测试成功长度: {len(json_str)}字符")
# 列出所有模型字段
print("\n📋 模型字段信息:")
for model_name, model_class in [
("BasicResponse", BasicResponse),
("SentimentAnalysis", SentimentAnalysis),
("TextAnalysisResult", TextAnalysisResult),
("DocumentClassificationResult", DocumentClassificationResult),
("StructuredDataExtraction", StructuredDataExtraction),
("ComprehensiveAnalysisResult", ComprehensiveAnalysisResult),
]:
fields = list(model_class.model_fields.keys())
print(f" {model_name}: {len(fields)} 个字段 - {fields[:3]}{'...' if len(fields) > 3 else ''}")
return True
except Exception as e:
print(f"❌ 模型测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_models()
if success:
print("\n🎉 所有示例模型测试通过!")
else:
print("\n💥 模型测试失败,请检查定义")

View File

@ -0,0 +1,762 @@
#!/usr/bin/env python3
"""
文本分析综合示例
基于SubAgent系统的复杂应用示例展示
1. 多Agent协作系统
2. 复杂的数据处理pipeline
3. 结构化输出和错误恢复
4. 性能监控和质量评估
"""
import sys
import os
import time
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
import uuid
# 添加项目根路径到Python路径
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(project_root)
from src.agent_system import SubAgent, create_json_agent
from example_models import (
TextAnalysisResult,
DocumentClassificationResult,
StructuredDataExtraction,
ComprehensiveAnalysisResult,
SentimentAnalysis,
KeywordExtraction,
DataExtractionItem,
CategoryClassification,
TaskExecutionResult
)
class TextAnalysisEngine:
"""文本分析引擎 - 多Agent协作系统"""
def __init__(self):
"""初始化文本分析引擎"""
self.agents = {}
self.processing_stats = {
"total_processed": 0,
"successful_analyses": 0,
"failed_analyses": 0,
"average_processing_time": 0.0
}
# 初始化所有Agent
self._initialize_agents()
def _initialize_agents(self):
"""初始化所有分析Agent"""
print("🔧 初始化文本分析引擎...")
try:
# 1. 情感分析Agent
self.agents['sentiment'] = self._create_sentiment_agent()
# 2. 关键词提取Agent
self.agents['keywords'] = self._create_keyword_agent()
# 3. 文本分类Agent
self.agents['classification'] = self._create_classification_agent()
# 4. 数据提取Agent
self.agents['extraction'] = self._create_extraction_agent()
# 5. 综合分析Agent
self.agents['comprehensive'] = self._create_comprehensive_agent()
print(f"✅ 成功初始化 {len(self.agents)} 个Agent")
except Exception as e:
print(f"❌ Agent初始化失败: {e}")
raise
def _create_sentiment_agent(self) -> SubAgent:
"""创建情感分析Agent"""
instructions = [
"你是专业的文本情感分析专家",
"准确识别文本的情感倾向和情感强度",
"提供详细的分析依据和相关关键词",
"对分析结果给出可信度评估"
]
prompt_template = """
请对以下文本进行深入的情感分析
文本内容
{text}
分析要求
1. 识别主要情感倾向positive/negative/neutral
2. 评估情感强度和置信度0-1
3. 提供分析说明和判断依据
4. 提取影响情感判断的关键词和短语
5. 考虑语言的细微差别和上下文含义
请提供准确专业的分析结果
"""
return SubAgent(
provider="aliyun",
model_name="qwen-max",
name="sentiment_analyzer",
description="专业的情感分析系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=SentimentAnalysis
)
def _create_keyword_agent(self) -> SubAgent:
"""创建关键词提取Agent"""
instructions = [
"你是专业的关键词提取专家",
"从文本中识别最重要和最相关的关键词",
"评估关键词的重要性和频率",
"对关键词进行合理的分类"
]
prompt_template = """
请从以下文本中提取关键词
文本内容
{text}
提取要求
1. 识别最重要的关键词和短语
2. 统计关键词出现频率
3. 评估每个关键词的重要性0-1
4. 对关键词进行分类人物地点概念技术等
5. 排除停用词和无意义词汇
请提供结构化的关键词提取结果
"""
# 创建专门处理关键词列表的响应模型
from pydantic import BaseModel, Field
from typing import List
class KeywordExtractionResult(BaseModel):
keywords: List[KeywordExtraction] = Field(description="提取的关键词列表")
total_count: int = Field(description="关键词总数", ge=0)
text_complexity: float = Field(description="文本复杂度", ge=0.0, le=1.0)
return SubAgent(
provider="aliyun",
model_name="qwen-max",
name="keyword_extractor",
description="智能关键词提取系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=KeywordExtractionResult
)
def _create_classification_agent(self) -> SubAgent:
"""创建文档分类Agent"""
instructions = [
"你是专业的文档分类专家",
"准确识别文档的类型和主题",
"提供多级分类和置信度评估",
"考虑文档的内容、风格和用途"
]
prompt_template = """
请对以下文档进行分类
文档内容
{text}
分类体系
主要分类技术文档商业文档学术论文新闻报道个人写作法律文档医学文档等
详细分类根据具体内容进一步细分
分类要求
1. 确定主要分类和置信度
2. 提供所有可能分类的概率分布
3. 识别用于分类判断的关键特征
4. 评估分类的可信度
请提供准确的分类结果
"""
return SubAgent(
provider="aliyun",
model_name="qwen-max",
name="document_classifier",
description="智能文档分类系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=DocumentClassificationResult
)
def _create_extraction_agent(self) -> SubAgent:
"""创建数据提取Agent"""
instructions = [
"你是专业的结构化数据提取专家",
"从非结构化文本中提取有价值的信息",
"确保提取的数据准确性和完整性",
"评估提取质量和可靠性"
]
prompt_template = """
请从以下文本中提取结构化数据
文本内容
{text}
提取目标
根据文本内容自动识别可提取的数据类型可能包括
- 人名地名机构名
- 日期时间数量
- 联系方式地址
- 专业术语概念
- 关键指标统计数据
提取要求
1. 自动识别文本中的结构化信息
2. 为每个提取项提供置信度评估
3. 记录提取依据和来源文本片段
4. 评估整体提取质量和完整性
请提供详细的数据提取结果
"""
return SubAgent(
provider="aliyun",
model_name="qwen-max",
name="data_extractor",
description="智能数据提取系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=StructuredDataExtraction
)
def _create_comprehensive_agent(self) -> SubAgent:
"""创建综合分析Agent"""
instructions = [
"你是文本综合分析专家",
"整合多种分析结果提供整体评估",
"识别分析中的一致性和矛盾之处",
"提供改进建议和深度见解"
]
prompt_template = """
基于以下多维度分析结果请提供综合评估
原始文本
{original_text}
分析结果
情感分析: {sentiment_result}
关键词提取: {keyword_result}
文档分类: {classification_result}
数据提取: {extraction_result}
综合评估要求
1. 评估各项分析结果的一致性
2. 识别潜在的分析矛盾或问题
3. 提供整体质量评估
4. 给出置信度评估
5. 提出改进建议
请提供专业的综合分析报告
"""
return SubAgent(
provider="aliyun",
model_name="qwen-max",
name="comprehensive_analyzer",
description="文本综合分析系统",
instructions=instructions,
prompt_template=prompt_template,
response_model=ComprehensiveAnalysisResult
)
def analyze_text(self, text: str, analysis_id: Optional[str] = None) -> ComprehensiveAnalysisResult:
"""执行完整的文本分析"""
if analysis_id is None:
analysis_id = f"analysis_{uuid.uuid4().hex[:8]}"
start_time = time.time()
self.processing_stats["total_processed"] += 1
print(f"\n🔍 开始分析 [{analysis_id}]")
print(f"文本长度: {len(text)} 字符")
try:
# 阶段1: 基础分析
print("📊 执行基础分析...")
sentiment_result = self._analyze_sentiment(text)
keyword_result = self._extract_keywords(text)
# 阶段2: 高级分析
print("🧠 执行高级分析...")
classification_result = self._classify_document(text)
extraction_result = self._extract_data(text)
# 阶段3: 综合分析
print("🎯 执行综合分析...")
comprehensive_result = self._comprehensive_analysis(
text, sentiment_result, keyword_result,
classification_result, extraction_result, analysis_id
)
# 更新统计信息
processing_time = time.time() - start_time
self.processing_stats["successful_analyses"] += 1
self._update_processing_stats(processing_time)
print(f"✅ 分析完成 [{analysis_id}] - 耗时: {processing_time:.2f}")
return comprehensive_result
except Exception as e:
self.processing_stats["failed_analyses"] += 1
print(f"❌ 分析失败 [{analysis_id}]: {e}")
raise
def _analyze_sentiment(self, text: str) -> SentimentAnalysis:
"""执行情感分析"""
try:
result = self.agents['sentiment'].run(template_vars={"text": text})
print(f" 情感: {result.sentiment} (置信度: {result.confidence:.3f})")
return result
except Exception as e:
print(f" ⚠️ 情感分析失败: {e}")
# 返回默认结果
return SentimentAnalysis(
sentiment="neutral",
confidence=0.0,
explanation=f"分析失败: {e}",
keywords=[]
)
def _extract_keywords(self, text: str):
"""提取关键词"""
try:
result = self.agents['keywords'].run(template_vars={"text": text})
print(f" 关键词: {result.total_count}")
return result
except Exception as e:
print(f" ⚠️ 关键词提取失败: {e}")
# 返回空结果
from pydantic import BaseModel, Field
from typing import List
class KeywordExtractionResult(BaseModel):
keywords: List[KeywordExtraction] = Field(default_factory=list)
total_count: int = Field(default=0)
text_complexity: float = Field(default=0.5)
return KeywordExtractionResult()
def _classify_document(self, text: str) -> DocumentClassificationResult:
"""执行文档分类"""
try:
result = self.agents['classification'].run(template_vars={"text": text})
print(f" 分类: {result.primary_category} (置信度: {result.confidence:.3f})")
return result
except Exception as e:
print(f" ⚠️ 文档分类失败: {e}")
# 返回默认结果
return DocumentClassificationResult(
primary_category="未知",
confidence=0.0,
all_categories=[
CategoryClassification(
category="未知",
confidence=0.0,
probability=0.0
)
]
)
def _extract_data(self, text: str) -> StructuredDataExtraction:
"""提取结构化数据"""
try:
result = self.agents['extraction'].run(template_vars={"text": text})
print(f" 数据提取: {result.extracted_fields}/{result.total_fields} 字段")
return result
except Exception as e:
print(f" ⚠️ 数据提取失败: {e}")
# 返回空结果
return StructuredDataExtraction(
extracted_data={},
extraction_items=[],
extraction_quality="poor",
completeness=0.0,
accuracy=0.0,
total_fields=0,
extracted_fields=0,
failed_fields=0
)
def _comprehensive_analysis(
self,
original_text: str,
sentiment_result: SentimentAnalysis,
keyword_result,
classification_result: DocumentClassificationResult,
extraction_result: StructuredDataExtraction,
analysis_id: str
) -> ComprehensiveAnalysisResult:
"""执行综合分析"""
try:
# 准备模板变量
template_vars = {
"original_text": original_text[:500] + ("..." if len(original_text) > 500 else ""),
"sentiment_result": f"情感:{sentiment_result.sentiment}, 置信度:{sentiment_result.confidence}",
"keyword_result": f"关键词数量:{getattr(keyword_result, 'total_count', 0)}",
"classification_result": f"分类:{classification_result.primary_category}, 置信度:{classification_result.confidence}",
"extraction_result": f"提取质量:{extraction_result.extraction_quality}"
}
result = self.agents['comprehensive'].run(template_vars=template_vars)
# 补充一些字段
result.analysis_id = analysis_id
result.input_summary = f"长度:{len(original_text)}字符, 类型:{classification_result.primary_category}"
result.text_analysis = self._build_text_analysis_result(
original_text, sentiment_result, keyword_result
)
result.classification = classification_result
result.data_extraction = extraction_result
print(f" 综合评估: {result.overall_quality}")
return result
except Exception as e:
print(f" ⚠️ 综合分析失败: {e}")
# 构建基本的综合结果
return ComprehensiveAnalysisResult(
analysis_id=analysis_id,
input_summary=f"长度:{len(original_text)}字符",
overall_quality="poor",
confidence_level=0.0,
total_processing_time=0.0,
components_completed=0,
components_failed=4,
recommendations=["分析失败,请检查输入文本和系统配置"]
)
def _build_text_analysis_result(
self,
text: str,
sentiment: SentimentAnalysis,
keyword_result
) -> TextAnalysisResult:
"""构建文本分析结果"""
# 获取关键词列表
keywords = getattr(keyword_result, 'keywords', [])
return TextAnalysisResult(
text_length=len(text),
word_count=len(text.split()),
language="zh",
summary=f"文本分析摘要: 情感倾向为{sentiment.sentiment}",
sentiment=sentiment,
keywords=keywords,
readability="medium",
complexity=getattr(keyword_result, 'text_complexity', 0.5)
)
def _update_processing_stats(self, processing_time: float):
"""更新处理统计信息"""
total = self.processing_stats["total_processed"]
current_avg = self.processing_stats["average_processing_time"]
# 计算新的平均处理时间
new_avg = ((current_avg * (total - 1)) + processing_time) / total
self.processing_stats["average_processing_time"] = new_avg
def get_processing_stats(self) -> Dict[str, Any]:
"""获取处理统计信息"""
return self.processing_stats.copy()
def display_stats(self):
"""显示处理统计"""
stats = self.get_processing_stats()
print("\n📈 处理统计信息:")
print(f" 总处理数: {stats['total_processed']}")
print(f" 成功数: {stats['successful_analyses']}")
print(f" 失败数: {stats['failed_analyses']}")
if stats['total_processed'] > 0:
success_rate = stats['successful_analyses'] / stats['total_processed'] * 100
print(f" 成功率: {success_rate:.1f}%")
print(f" 平均处理时间: {stats['average_processing_time']:.2f}")
def demo_single_text_analysis():
"""演示单个文本分析"""
print("🔍 单文本分析演示")
print("="*50)
# 创建分析引擎
engine = TextAnalysisEngine()
# 测试文本
test_text = """
人工智能技术正在快速发展深度学习和机器学习算法在各个领域都取得了显著的进展
从自然语言处理到计算机视觉从推荐系统到自动驾驶AI技术正在改变我们的生活方式
然而我们也需要关注AI发展带来的挑战包括隐私保护算法偏见就业影响等问题
只有在技术发展和社会责任之间找到平衡AI才能真正造福人类社会
总的来说人工智能的未来充满希望但也需要我们谨慎对待确保技术发展的方向符合人类的长远利益
"""
try:
# 执行分析
result = engine.analyze_text(test_text)
# 显示详细结果
display_analysis_result(result)
# 显示统计信息
engine.display_stats()
return True
except Exception as e:
print(f"❌ 演示失败: {e}")
return False
def demo_batch_analysis():
"""演示批量文本分析"""
print("\n🔄 批量文本分析演示")
print("="*50)
# 创建分析引擎
engine = TextAnalysisEngine()
# 测试文本集合
test_texts = [
"今天天气真好,阳光明媚,心情特别愉快!",
"公司最新发布的季度财报显示营收同比增长15%净利润达到2.3亿元。董事会决定向股东分红每股0.5元。",
"机器学习是人工智能的一个重要分支,通过算法让计算机能够从数据中学习模式。常见的机器学习算法包括线性回归、决策树、神经网络等。",
"服务态度恶劣,产品质量很差,完全不值这个价格。强烈不推荐大家购买!",
"根据《合同法》第一百二十一条规定,当事人一方因第三人的原因造成违约的,应当向对方承担违约责任。"
]
results = []
start_time = time.time()
for i, text in enumerate(test_texts, 1):
print(f"\n处理第 {i}/{len(test_texts)} 个文本...")
try:
result = engine.analyze_text(text, f"batch_{i}")
results.append(result)
# 显示简要结果
print(f" 结果: {result.overall_quality} | 置信度: {result.confidence_level:.3f}")
except Exception as e:
print(f" ❌ 处理失败: {e}")
results.append(None)
total_time = time.time() - start_time
# 显示批量处理总结
print(f"\n📊 批量处理完成 - 总耗时: {total_time:.2f}")
engine.display_stats()
# 显示成功处理的结果统计
successful_results = [r for r in results if r is not None]
if successful_results:
print(f"\n🎯 成功处理 {len(successful_results)} 个文本:")
quality_stats = {}
for result in successful_results:
quality = result.overall_quality
quality_stats[quality] = quality_stats.get(quality, 0) + 1
for quality, count in quality_stats.items():
print(f" {quality}: {count}")
return len(successful_results) > 0
def display_analysis_result(result: ComprehensiveAnalysisResult):
"""显示详细的分析结果"""
print(f"\n📋 详细分析结果 [{result.analysis_id}]")
print("="*60)
print(f"输入摘要: {result.input_summary}")
print(f"分析时间: {result.analysis_timestamp}")
print(f"整体质量: {result.overall_quality}")
print(f"置信度: {result.confidence_level:.3f}")
print(f"处理时间: {result.total_processing_time:.2f}")
# 文本分析结果
if result.text_analysis:
ta = result.text_analysis
print(f"\n📝 文本分析:")
print(f" 长度: {ta.text_length} 字符")
print(f" 词数: {ta.word_count}")
print(f" 摘要: {ta.summary}")
print(f" 情感: {ta.sentiment.sentiment} (置信度: {ta.sentiment.confidence:.3f})")
print(f" 可读性: {ta.readability}")
if ta.keywords:
top_keywords = ta.keywords[:5]
print(f" 关键词: {[k.keyword for k in top_keywords]}")
# 分类结果
if result.classification:
cls = result.classification
print(f"\n🏷️ 文档分类:")
print(f" 主分类: {cls.primary_category}")
print(f" 置信度: {cls.confidence:.3f}")
if len(cls.all_categories) > 1:
other_cats = cls.all_categories[1:3]
print(f" 其他可能: {[c.category for c in other_cats]}")
# 数据提取结果
if result.data_extraction:
de = result.data_extraction
print(f"\n🔍 数据提取:")
print(f" 质量: {de.extraction_quality}")
print(f" 完整性: {de.completeness:.3f}")
print(f" 准确性: {de.accuracy:.3f}")
print(f" 字段统计: {de.extracted_fields}/{de.total_fields}")
if de.extraction_items:
print(" 提取项目:")
for item in de.extraction_items[:3]: # 显示前3个
print(f" {item.field_name}: {item.field_value} (置信度: {item.confidence:.3f})")
# 改进建议
if result.recommendations:
print(f"\n💡 改进建议:")
for i, rec in enumerate(result.recommendations[:3], 1):
print(f" {i}. {rec}")
def interactive_analysis():
"""交互式分析功能"""
print("\n💬 交互式文本分析")
print("="*50)
print("输入文本进行综合分析,输入'quit'退出")
try:
engine = TextAnalysisEngine()
while True:
print("\n" + "-"*30)
user_input = input("请输入要分析的文本: ").strip()
if user_input.lower() == 'quit':
print("分析结束,再见!")
break
if not user_input:
continue
if len(user_input) < 10:
print("⚠️ 文本太短请输入至少10个字符")
continue
try:
result = engine.analyze_text(user_input)
display_analysis_result(result)
except Exception as e:
print(f"❌ 分析失败: {e}")
# 显示最终统计
engine.display_stats()
except KeyboardInterrupt:
print("\n程序已中断")
except Exception as e:
print(f"❌ 交互式分析失败: {e}")
def test_engine_initialization():
"""测试引擎初始化"""
print("正在测试文本分析引擎初始化...")
try:
engine = TextAnalysisEngine()
print(f"✅ 引擎初始化成功,包含 {len(engine.agents)} 个Agent")
# 显示Agent信息
for name, agent in engine.agents.items():
info = agent.get_model_info()
print(f" {name}: {info['model_name']} ({info['provider']})")
return True
except Exception as e:
print(f"❌ 引擎初始化失败: {e}")
return False
def main():
"""主函数"""
print("🚀 文本分析综合示例")
print("="*60)
# 运行各种演示
demos = [
("引擎初始化测试", test_engine_initialization),
("单文本分析", demo_single_text_analysis),
("批量文本分析", demo_batch_analysis),
]
results = {}
for name, demo_func in demos:
print(f"\n开始: {name}")
try:
success = demo_func()
results[name] = success
print(f"{'' if success else ''} {name} {'成功' if success else '失败'}")
except Exception as e:
print(f"{name} 异常: {e}")
results[name] = False
# 显示总结
print(f"\n📊 演示总结")
print("="*60)
successful_demos = sum(results.values())
total_demos = len(results)
for name, success in results.items():
status = "✅ 成功" if success else "❌ 失败"
print(f" {name}: {status}")
print(f"\n🎯 总计: {successful_demos}/{total_demos} 个演示成功")
# 询问是否运行交互式演示
if successful_demos > 0:
try:
choice = input("\n是否运行交互式分析?(y/n): ").strip().lower()
if choice in ['y', 'yes', '']:
interactive_analysis()
except (KeyboardInterrupt, EOFError):
print("\n程序结束")
return successful_demos == total_demos
if __name__ == "__main__":
# 可以选择运行测试或完整演示
import sys
if len(sys.argv) > 1 and sys.argv[1] == "--test":
# 仅运行初始化测试
success = test_engine_initialization()
exit(0 if success else 1)
else:
# 运行完整演示
main()

147
info_extractor.py Normal file
View File

@ -0,0 +1,147 @@
#!/usr/bin/env python3
"""
基于LangExtract的MIMIC论文信息提取器
从医学论文中提取结构化的复现任务信息
作者MedResearcher项目
创建时间2025-01-25
"""
import argparse
import logging
from src.extractor import MIMICLangExtractBuilder
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def setup_args():
"""设置命令行参数解析
Returns:
argparse.Namespace: 解析后的命令行参数
"""
parser = argparse.ArgumentParser(
description='MIMIC论文信息提取工具 - 基于LangExtract从医学论文中提取结构化复现信息',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
使用示例:
%(prog)s # 使用默认参数
%(prog)s --papers_dir dataset/markdowns # 指定论文目录
%(prog)s --output_file results/dataset.json # 指定输出文件
%(prog)s --test_mode --max_papers 5 # 测试模式只处理5篇论文
'''
)
parser.add_argument(
'--papers_dir',
type=str,
default='dataset/markdowns',
help='markdown论文文件目录 (默认: dataset/markdowns)'
)
parser.add_argument(
'--output_file',
type=str,
default='dataset/reproduction_tasks/mimic_langextract_dataset.json',
help='输出数据集文件路径 (默认: dataset/reproduction_tasks/mimic_langextract_dataset.json)'
)
parser.add_argument(
'--test_mode',
action='store_true',
help='测试模式,只处理少量论文进行验证'
)
parser.add_argument(
'--max_papers',
type=int,
default=None,
help='最大处理论文数量,用于测试 (默认: 处理所有论文)'
)
parser.add_argument(
'--log_level',
type=str,
default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
help='日志级别 (默认: INFO)'
)
parser.add_argument(
'--doc_workers',
type=int,
default=50,
help='文档并行处理工作线程数 (默认: 4)'
)
return parser.parse_args()
def main():
"""主函数 - 执行MIMIC论文信息提取任务"""
try:
# 解析命令行参数
args = setup_args()
# 设置日志级别
logging.getLogger().setLevel(getattr(logging, args.log_level))
# 初始化信息提取器
builder = MIMICLangExtractBuilder(doc_workers=args.doc_workers)
print(f"=== MIMIC论文信息提取工具启动 ===")
print(f"论文目录: {args.papers_dir}")
print(f"输出文件: {args.output_file}")
print(f"测试模式: {'' if args.test_mode else ''}")
if args.max_papers:
print(f"最大论文数: {args.max_papers}")
print(f"文档并行度: {args.doc_workers} 线程")
print(f"日志级别: {args.log_level}")
print(f"========================")
# 构建复现数据集
print("\n开始构建MIMIC复现数据集...")
dataset = builder.build_reproduction_dataset(
papers_dir=args.papers_dir,
output_file=args.output_file,
max_papers=args.max_papers if args.test_mode or args.max_papers else None
)
# 统计结果
total_papers = dataset['metadata']['total_papers']
successful_extractions = sum(
1 for paper in dataset['papers'].values()
if any(module.get('extraction_count', 0) > 0
for module in paper.get('modules', {}).values())
)
print(f"\n=== 构建完成 ===")
print(f"总论文数: {total_papers}")
print(f"成功提取: {successful_extractions}/{total_papers}")
print(f"成功率: {successful_extractions/total_papers*100:.1f}%")
print(f"结果保存至: {args.output_file}")
print(f"交互式报告: {args.output_file.replace('.json', '.html')}")
print(f"===============")
return 0
except FileNotFoundError as e:
print(f"错误: 找不到指定的文件或目录 - {e}")
return 1
except ValueError as e:
print(f"错误: 参数值无效 - {e}")
return 1
except Exception as e:
print(f"错误: 程序执行异常 - {e}")
logging.exception("详细错误信息:")
return 1
if __name__ == "__main__":
exit_code = main()
exit(exit_code)

View File

@ -22,7 +22,7 @@ def setup_args():
parser.add_argument(
'--paper_website',
default=["arxiv","medrxiv"],
default=["medrxiv"],
help='论文网站 (默认: arxiv,medrxiv)',
nargs='+',
choices=["arxiv","medrxiv"]
@ -34,6 +34,20 @@ def setup_args():
default=20,
help='并行处理线程数 (默认: 20)'
)
parser.add_argument(
'--csv-download',
type=str,
default="yes",
help='指定CSV文件路径'
)
parser.add_argument(
'--pdf_download_list',
type=str,
default='dataset/mimic_papers_20250825.csv',
help='指定PDF下载目录'
)
return parser.parse_args()
@ -45,30 +59,49 @@ def main():
# 解析命令行参数
args = setup_args()
print(f"=== 论文爬取工具启动 ===")
print(f"论文数据源: {args.paper_website}")
print(f"并行处理数: {args.parallel}")
print(f"========================")
# 初始化论文爬取器
crawler = PaperCrawler(
websites=args.paper_website,
parallel=args.parallel
)
# 执行论文爬取
print("开始爬取MIMIC-4相关论文...")
papers = crawler.crawl_papers()
if papers:
# 保存到CSV文件
csv_file_path = crawler.save_to_csv(papers)
print(f"\n=== 爬取完成 ===")
print(f"成功爬取: {len(papers)} 篇论文")
print(f"保存位置: {csv_file_path}")
print(f"================")
else:
print("未找到相关论文,请检查网络连接或关键词设置")
print(f"=== 论文爬取工具启动 ===")
print(f"论文数据源: {args.paper_website}")
print(f"并行处理数: {args.parallel}")
print(f"========================")
# 执行论文爬取
if args.csv_download:
print("开始爬取MIMIC-4相关论文...")
papers = crawler.crawl_papers()
if papers:
# 保存到CSV文件
csv_file_path = crawler.save_to_csv(papers)
print(f"\n=== 爬取完成 ===")
print(f"成功爬取: {len(papers)} 篇论文")
print(f"保存位置: {csv_file_path}")
print(f"================")
else:
print("未找到相关论文,请检查网络连接或关键词设置")
# 如果指定了PDF下载测试执行测试
if args.pdf_download_list:
print(f"=== PDF下载功能测试 ===")
print(f"CSV文件: {args.pdf_download_list}")
print(f"并发数: {args.parallel}")
print(f"========================")
# 执行PDF下载
stats = crawler.download_pdfs_from_csv(args.pdf_download_list)
print(f"\n=== PDF下载测试完成 ===")
print(f"总数: {stats['total']} 篇论文")
print(f"成功: {stats['success']} 篇 ({stats['success']/stats['total']*100:.1f}%)")
print(f"失败: {stats['failed']} 篇 ({stats['failed']/stats['total']*100:.1f}%)")
print(f"========================")
return 0
except FileNotFoundError as e:
print(f"错误: 找不到指定的文件 - {e}")

90
pdf_parser.py Normal file
View File

@ -0,0 +1,90 @@
import argparse
from src.parse import PDFParser
def setup_args():
"""设置命令行参数解析
Returns:
argparse.Namespace: 解析后的命令行参数
"""
parser = argparse.ArgumentParser(
description='PDF解析工具 - 用于将PDF文件通过OCR API转换为Markdown格式',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
使用示例:
%(prog)s # 使用默认参数
%(prog)s --pdf-dir dataset/pdfs # 指定PDF目录
%(prog)s --parallel 10 # 设置并行度为10
%(prog)s --markdown-dir output/markdowns # 指定输出目录
'''
)
parser.add_argument(
'--pdf-dir',
default="dataset/pdfs",
help='PDF文件目录 (默认: dataset/pdfs)'
)
parser.add_argument(
'--parallel',
type=int,
default=5,
help='并发处理线程数 (默认: 5降低并发避免服务器过载)'
)
parser.add_argument(
'--markdown-dir',
default="dataset/markdowns",
help='Markdown输出目录 (默认: dataset/markdowns)'
)
return parser.parse_args()
def main():
"""主函数 - 执行PDF解析任务"""
try:
# 解析命令行参数
args = setup_args()
# 初始化PDF解析器
parser = PDFParser(
pdf_dir=args.pdf_dir,
parallel=args.parallel,
markdown_dir=args.markdown_dir
)
print(f"=== PDF解析工具启动 ===")
print(f"PDF目录: {args.pdf_dir}")
print(f"并发数: {args.parallel}")
print(f"输出目录: {args.markdown_dir}")
print(f"========================")
# 执行PDF解析
print("开始处理PDF文件...")
stats = parser.parse_all_pdfs()
print(f"\n=== 解析完成 ===")
print(f"总数: {stats['total']} 个文件")
print(f"成功: {stats['success']} 个 ({stats['success']/stats['total']*100:.1f}%)" if stats['total'] > 0 else "成功: 0 个")
print(f"失败: {stats['failed']} 个 ({stats['failed']/stats['total']*100:.1f}%)" if stats['total'] > 0 else "失败: 0 个")
print(f"================")
return 0
except FileNotFoundError as e:
print(f"错误: 找不到指定的目录 - {e}")
return 1
except ValueError as e:
print(f"错误: 参数值无效 - {e}")
return 1
except Exception as e:
print(f"错误: 程序执行异常 - {e}")
return 1
if __name__ == "__main__":
exit_code = main()
exit(exit_code)

View File

@ -5,5 +5,12 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"agno>=1.7.12",
"httpx[socks]>=0.28.1",
"langextract>=1.0.8",
"ollama>=0.5.3",
"openai>=1.101.0",
"pydantic",
"pyyaml>=6.0.2",
"requests>=2.32.5",
]

View File

@ -0,0 +1,38 @@
"""
Agent System Module
基于Agno框架的SubAgent系统实现
为MedResearcher项目提供智能Agent功能
"""
__version__ = "0.1.0"
# 导入核心类
from .subagent import SubAgent, create_json_agent
from .config_loader import load_llm_config, get_model_config
from .model_factory import create_agno_model, list_available_models
from .json_processor import parse_json_response, JSONProcessor
# 导入异常类
from .subagent import SubAgentError, ConfigurationError, ModelError
from .json_processor import JSONParseError
__all__ = [
# 核心类
'SubAgent',
'JSONProcessor',
# 便捷函数
'create_json_agent',
'load_llm_config',
'get_model_config',
'create_agno_model',
'list_available_models',
'parse_json_response',
# 异常类
'SubAgentError',
'ConfigurationError',
'ModelError',
'JSONParseError',
]

View File

@ -0,0 +1,292 @@
"""
配置加载模块
负责解析LLM配置文件和环境变量为SubAgent提供统一的配置接口
"""
import os
import yaml
from typing import Dict, Any, Optional
from pathlib import Path
def load_llm_config(config_path: Optional[str] = None) -> Dict[str, Any]:
"""
加载LLM配置文件和环境变量
Args:
config_path: 配置文件路径默认为 src/config/llm_config.yaml
Returns:
完整的LLM配置字典包含所有提供商和环境变量
Raises:
FileNotFoundError: 配置文件不存在
yaml.YAMLError: YAML解析错误
ValueError: 配置格式错误
"""
# 确定配置文件路径
if config_path is None:
# 获取项目根目录
current_dir = Path(__file__).parent
project_root = current_dir.parent.parent
config_path = project_root / "src" / "config" / "llm_config.yaml"
config_path = Path(config_path)
# 检查配置文件是否存在
if not config_path.exists():
raise FileNotFoundError(f"LLM配置文件不存在: {config_path}")
# 读取YAML配置文件
try:
with open(config_path, 'r', encoding='utf-8') as file:
config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise yaml.YAMLError(f"YAML配置文件解析失败: {e}")
except Exception as e:
raise ValueError(f"读取配置文件失败: {e}")
# 验证配置结构
if not isinstance(config, dict):
raise ValueError("配置文件格式错误:根元素必须是字典")
# 加载环境变量配置
env_config = load_env_config()
# 处理环境变量替换
config = _resolve_env_variables(config, env_config)
return config
def load_env_config(env_path: Optional[str] = None) -> Dict[str, str]:
"""
加载环境变量配置
Args:
env_path: .env文件路径默认为 src/config/.env
Returns:
环境变量字典
"""
# 确定环境变量文件路径
if env_path is None:
current_dir = Path(__file__).parent
project_root = current_dir.parent.parent
env_path = project_root / "src" / "config" / ".env"
env_config = {}
# 尝试加载.env文件
env_path = Path(env_path)
if env_path.exists():
try:
with open(env_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if line and not line.startswith('#') and '=' in line:
key, value = line.split('=', 1)
env_config[key.strip()] = value.strip()
except Exception as e:
print(f"警告: 读取.env文件失败: {e}")
# 同时从系统环境变量中加载
env_keys = ['DASHSCOPE_API_KEY', 'OPENAI_API_KEY', 'DEEPSEEK_API_KEY']
for key in env_keys:
if key in os.environ:
env_config[key] = os.environ[key]
return env_config
def _resolve_env_variables(config: Dict[str, Any], env_config: Dict[str, str]) -> Dict[str, Any]:
"""
解析配置中的环境变量占位符
Args:
config: 原始配置
env_config: 环境变量配置
Returns:
解析后的配置
"""
def resolve_value(value):
if isinstance(value, str) and value.startswith('${') and value.endswith('}'):
# 提取环境变量名称
env_var = value[2:-1] # 去掉 ${ 和 }
if env_var in env_config:
return env_config[env_var]
elif env_var in os.environ:
return os.environ[env_var]
else:
print(f"警告: 环境变量 {env_var} 未定义,保持原值")
return value
elif isinstance(value, dict):
return {k: resolve_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [resolve_value(item) for item in value]
else:
return value
return resolve_value(config)
def get_provider_config(config: Dict[str, Any], provider: str) -> Dict[str, Any]:
"""
获取特定提供商的配置
Args:
config: 完整的LLM配置
provider: 提供商名称 ( 'aliyun', 'deepseek', 'openai')
Returns:
提供商配置字典
Raises:
ValueError: 提供商不存在
"""
if provider not in config:
available_providers = list(config.keys())
raise ValueError(f"提供商 '{provider}' 不存在,可用提供商: {available_providers}")
return config[provider]
def get_model_config(config: Dict[str, Any], provider: str, model_name: str) -> Dict[str, Any]:
"""
获取特定模型的配置
Args:
config: 完整的LLM配置
provider: 提供商名称
model_name: 模型名称
Returns:
模型配置字典
Raises:
ValueError: 提供商或模型不存在
"""
provider_config = get_provider_config(config, provider)
if 'models' not in provider_config:
raise ValueError(f"提供商 '{provider}' 配置中缺少 models 字段")
models = provider_config['models']
if model_name not in models:
available_models = list(models.keys())
raise ValueError(f"模型 '{model_name}' 在提供商 '{provider}' 中不存在,可用模型: {available_models}")
model_config = models[model_name].copy()
# 添加提供商级别的配置
for key in ['base_url', 'api_key']:
if key in provider_config:
model_config[key] = provider_config[key]
return model_config
def validate_config(config: Dict[str, Any]) -> bool:
"""
验证配置文件的完整性
Args:
config: LLM配置字典
Returns:
验证是否通过
"""
required_providers = ['aliyun', 'deepseek', 'openai']
missing_providers = []
for provider in required_providers:
if provider not in config:
missing_providers.append(provider)
if missing_providers:
print(f"警告: 缺少提供商配置: {missing_providers}")
# 验证每个提供商的配置结构
valid = True
for provider_name, provider_config in config.items():
if not isinstance(provider_config, dict):
print(f"错误: 提供商 '{provider_name}' 配置必须是字典")
valid = False
continue
if 'models' not in provider_config:
print(f"错误: 提供商 '{provider_name}' 缺少 models 字段")
valid = False
continue
models = provider_config['models']
if not isinstance(models, dict) or not models:
print(f"错误: 提供商 '{provider_name}' 的 models 必须是非空字典")
valid = False
continue
# 验证每个模型配置
for model_name, model_config in models.items():
if not isinstance(model_config, dict):
print(f"错误: 模型 '{model_name}' 配置必须是字典")
valid = False
continue
required_fields = ['class', 'params']
for field in required_fields:
if field not in model_config:
print(f"错误: 模型 '{model_name}' 缺少必需字段: {field}")
valid = False
return valid
# 测试函数
def test_config_loading():
"""测试配置加载功能"""
try:
print("正在测试配置加载...")
# 加载配置
config = load_llm_config()
print(f"✅ 配置加载成功,找到 {len(config)} 个提供商")
# 验证配置
is_valid = validate_config(config)
print(f"✅ 配置验证: {'通过' if is_valid else '失败'}")
# 测试提供商配置获取
try:
aliyun_config = get_provider_config(config, 'aliyun')
print(f"✅ 阿里云配置获取成功,包含 {len(aliyun_config.get('models', {}))} 个模型")
except ValueError as e:
print(f"❌ 阿里云配置获取失败: {e}")
# 测试模型配置获取
try:
qwen_config = get_model_config(config, 'aliyun', 'qwen-max')
print(f"✅ qwen-max模型配置获取成功")
print(f" 模型类: {qwen_config.get('class', 'N/A')}")
print(f" 模型ID: {qwen_config.get('params', {}).get('id', 'N/A')}")
except ValueError as e:
print(f"❌ qwen-max模型配置获取失败: {e}")
return True
except Exception as e:
print(f"❌ 配置加载测试失败: {e}")
return False
if __name__ == "__main__":
test_config_loading()

View File

@ -0,0 +1,404 @@
"""
JSON处理器模块
提供强大的JSON解析和验证功能支持零容错解析和Pydantic模型集成
"""
import json
import re
from typing import Any, Dict, List, Optional, Type, Union, get_origin, get_args
from pydantic import BaseModel, ValidationError
class JSONParseError(Exception):
"""JSON解析错误"""
pass
class JSONProcessor:
"""
JSON处理器类
提供多种JSON解析策略确保即使在不完美的JSON输出下也能成功解析
"""
def __init__(self):
"""初始化JSON处理器"""
pass
def parse_json_response(
self,
response: str,
response_model: Optional[Type[BaseModel]] = None
) -> Union[Dict, List, BaseModel]:
"""
解析JSON响应
Args:
response: 响应字符串
response_model: 可选的Pydantic模型类
Returns:
解析后的数据对象或模型实例
Raises:
JSONParseError: 解析失败时抛出
"""
if not response or not response.strip():
raise JSONParseError("响应为空无法解析JSON")
# 清理响应字符串
cleaned_response = self._clean_response(response)
# 尝试多种解析策略
parsed_data = self._try_multiple_parsing_strategies(cleaned_response)
if parsed_data is None:
raise JSONParseError(f"所有解析策略都失败了,响应内容: {response[:200]}...")
# 如果指定了响应模型,尝试创建模型实例
if response_model is not None:
return self._create_model_instance(parsed_data, response_model)
return parsed_data
def _clean_response(self, response: str) -> str:
"""
清理响应字符串
Args:
response: 原始响应
Returns:
清理后的响应
"""
# 去除首尾空白
cleaned = response.strip()
# 移除可能的markdown代码块标记
if cleaned.startswith('```json'):
cleaned = cleaned[7:]
elif cleaned.startswith('```'):
cleaned = cleaned[3:]
if cleaned.endswith('```'):
cleaned = cleaned[:-3]
# 移除可能的额外引号包装
if cleaned.startswith('"') and cleaned.endswith('"'):
try:
# 尝试解析为字符串看是否是被引号包装的JSON
unquoted = json.loads(cleaned)
if isinstance(unquoted, str):
cleaned = unquoted
except json.JSONDecodeError:
pass
return cleaned.strip()
def _try_multiple_parsing_strategies(self, response: str) -> Optional[Union[Dict, List]]:
"""
尝试多种解析策略
Args:
response: 清理后的响应
Returns:
解析成功的数据或None
"""
strategies = [
self._strategy_direct_parse,
self._strategy_extract_json_block,
self._strategy_find_json_structure,
self._strategy_regex_extract,
self._strategy_fix_common_errors,
]
for strategy in strategies:
try:
result = strategy(response)
if result is not None:
return result
except Exception:
continue
return None
def _strategy_direct_parse(self, response: str) -> Optional[Union[Dict, List]]:
"""策略1: 直接解析JSON"""
try:
return json.loads(response)
except json.JSONDecodeError:
return None
def _strategy_extract_json_block(self, response: str) -> Optional[Union[Dict, List]]:
"""策略2: 提取JSON代码块"""
# 查找被代码块包围的JSON
code_block_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
matches = re.findall(code_block_pattern, response, re.IGNORECASE)
for match in matches:
try:
return json.loads(match.strip())
except json.JSONDecodeError:
continue
return None
def _strategy_find_json_structure(self, response: str) -> Optional[Union[Dict, List]]:
"""策略3: 查找JSON结构"""
# 查找第一个完整的JSON对象或数组
for start_char, end_char in [("{", "}"), ("[", "]")]:
json_str = self._extract_complete_json_structure(response, start_char, end_char)
if json_str:
try:
return json.loads(json_str)
except json.JSONDecodeError:
continue
return None
def _strategy_regex_extract(self, response: str) -> Optional[Union[Dict, List]]:
"""策略4: 正则表达式提取"""
# 使用正则表达式查找JSON模式
patterns = [
r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', # 嵌套对象
r'\[[^\[\]]*(?:\[[^\[\]]*\][^\[\]]*)*\]', # 嵌套数组
]
for pattern in patterns:
matches = re.findall(pattern, response)
for match in matches:
try:
return json.loads(match)
except json.JSONDecodeError:
continue
return None
def _strategy_fix_common_errors(self, response: str) -> Optional[Union[Dict, List]]:
"""策略5: 修复常见JSON错误"""
# 修复常见的JSON格式错误
fixed_response = response
# 修复单引号为双引号
fixed_response = re.sub(r"'([^']*)':", r'"\1":', fixed_response)
fixed_response = re.sub(r":\s*'([^']*)'", r': "\1"', fixed_response)
# 修复末尾逗号
fixed_response = re.sub(r',(\s*[}\]])', r'\1', fixed_response)
# 修复未引用的键
fixed_response = re.sub(r'(\w+):', r'"\1":', fixed_response)
try:
return json.loads(fixed_response)
except json.JSONDecodeError:
return None
def _extract_complete_json_structure(
self,
text: str,
start_char: str,
end_char: str
) -> Optional[str]:
"""
提取完整的JSON结构
Args:
text: 文本内容
start_char: 开始字符 ('{' '[')
end_char: 结束字符 ('}' ']')
Returns:
提取的JSON字符串或None
"""
start_idx = text.find(start_char)
if start_idx == -1:
return None
# 使用计数器匹配嵌套结构
count = 0
in_string = False
escape_next = False
for i, char in enumerate(text[start_idx:], start_idx):
if escape_next:
escape_next = False
continue
if char == '\\' and in_string:
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == start_char:
count += 1
elif char == end_char:
count -= 1
if count == 0:
return text[start_idx:i+1]
# 如果没有找到匹配的结束符,返回到字符串末尾
if count > 0:
return text[start_idx:]
return None
def _create_model_instance(
self,
data: Union[Dict, List],
response_model: Type[BaseModel]
) -> BaseModel:
"""
创建Pydantic模型实例
Args:
data: 解析后的数据
response_model: Pydantic模型类
Returns:
模型实例
Raises:
JSONParseError: 模型验证失败
"""
try:
# 如果数据是列表但模型期望对象,尝试包装
if isinstance(data, list) and not self._is_list_model(response_model):
# 尝试将列表作为某个字段的值
field_names = list(response_model.model_fields.keys())
if field_names:
# 使用第一个字段名作为包装器
data = {field_names[0]: data}
# 创建模型实例
return response_model(**data)
except ValidationError as e:
# 详细的验证错误信息
error_msg = f"Pydantic模型验证失败: {e}"
raise JSONParseError(error_msg)
except TypeError as e:
error_msg = f"模型实例化失败: {e}"
raise JSONParseError(error_msg)
def _is_list_model(self, model: Type[BaseModel]) -> bool:
"""
检查模型是否期望列表类型
Args:
model: Pydantic模型类
Returns:
是否为列表模型
"""
# 检查模型的字段类型
for field_info in model.model_fields.values():
annotation = field_info.annotation
origin = get_origin(annotation)
# 如果有字段是List类型认为这是一个列表模型
if origin is list or origin is List:
return True
return False
def validate_json_schema(self, data: Dict, schema: Dict) -> bool:
"""
验证JSON数据是否符合指定的schema
Args:
data: JSON数据
schema: JSON schema
Returns:
是否符合schema
"""
try:
import jsonschema
jsonschema.validate(data, schema)
return True
except ImportError:
# 如果没有安装jsonschema库跳过验证
return True
except Exception:
return False
# 便捷函数
def parse_json_response(
response: str,
response_model: Optional[Type[BaseModel]] = None
) -> Union[Dict, List, BaseModel]:
"""
便捷函数解析JSON响应
Args:
response: 响应字符串
response_model: 可选的Pydantic模型类
Returns:
解析后的数据或模型实例
"""
processor = JSONProcessor()
return processor.parse_json_response(response, response_model)
# 测试函数
def test_json_processor():
"""测试JSON处理器功能"""
print("正在测试JSON处理器...")
processor = JSONProcessor()
# 测试用例
test_cases = [
# 标准JSON
('{"name": "test", "value": 123}', None),
# 带代码块标记的JSON
('```json\n{"name": "test", "value": 123}\n```', None),
# 不完整的JSON
('{"name": "test", "value": 123', None),
# 单引号JSON
("{'name': 'test', 'value': 123}", None),
# 包含额外文本的JSON
('这是一个JSON响应: {"name": "test", "value": 123} 解析完成', None),
# 数组JSON
('[{"name": "test1"}, {"name": "test2"}]', None),
]
success_count = 0
for i, (test_input, model) in enumerate(test_cases, 1):
try:
result = processor.parse_json_response(test_input, model)
print(f"✅ 测试用例 {i}: 解析成功 - {type(result)}")
success_count += 1
except Exception as e:
print(f"❌ 测试用例 {i}: 解析失败 - {e}")
print(f"\n🎯 JSON处理器测试完成: {success_count}/{len(test_cases)} 通过")
return success_count == len(test_cases)
if __name__ == "__main__":
test_json_processor()

View File

@ -0,0 +1,312 @@
"""
模型工厂模块
基于配置创建各种Agno模型实例支持多种LLM提供商
"""
from typing import Dict, Any, Optional, Union
from agno.models.openai import OpenAILike, OpenAIChat
# 尝试导入其他可选模型
try:
from agno.models.ollama import Ollama
OLLAMA_AVAILABLE = True
except ImportError:
OLLAMA_AVAILABLE = False
Ollama = None
from .config_loader import load_llm_config, get_model_config
class ModelFactory:
"""
模型工厂类
负责根据配置创建不同类型的Agno模型实例
"""
# 支持的模型类映射
MODEL_CLASSES = {
"OpenAILike": OpenAILike,
"OpenAIChat": OpenAIChat,
}
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化模型工厂
Args:
config: LLM配置如果为None则自动加载
"""
self.config = config if config is not None else load_llm_config()
# 如果Ollama可用添加到支持列表
if OLLAMA_AVAILABLE:
self.MODEL_CLASSES["Ollama"] = Ollama
def create_model(
self,
provider: str,
model_name: str,
**override_params
) -> Union[OpenAILike, OpenAIChat]:
"""
创建模型实例
Args:
provider: 提供商名称 ( 'aliyun', 'deepseek', 'openai')
model_name: 模型名称 ( 'qwen-max', 'deepseek-v3')
**override_params: 覆盖配置的参数
Returns:
创建的模型实例
Raises:
ValueError: 配置错误或不支持的模型类型
ImportError: 缺少必需的依赖
"""
# 获取模型配置
model_config = get_model_config(self.config, provider, model_name)
# 获取模型类名
model_class_name = model_config.get('class')
if not model_class_name:
raise ValueError(f"模型配置中缺少 'class' 字段: {provider}.{model_name}")
# 检查模型类是否支持
if model_class_name not in self.MODEL_CLASSES:
available_classes = list(self.MODEL_CLASSES.keys())
raise ValueError(f"不支持的模型类 '{model_class_name}',支持的类型: {available_classes}")
# 特殊处理Ollama
if model_class_name == "Ollama" and not OLLAMA_AVAILABLE:
raise ImportError("Ollama模型需要安装ollama包: pip install ollama")
# 获取模型类
model_class = self.MODEL_CLASSES[model_class_name]
# 准备初始化参数
init_params = self._prepare_init_params(model_config, model_class_name, override_params)
# 创建模型实例
try:
model = model_class(**init_params)
return model
except Exception as e:
raise ValueError(f"创建模型实例失败 ({provider}.{model_name}): {e}")
def _prepare_init_params(
self,
model_config: Dict[str, Any],
model_class_name: str,
override_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
准备模型初始化参数
Args:
model_config: 模型配置
model_class_name: 模型类名
override_params: 覆盖参数
Returns:
初始化参数字典
"""
# 基础参数从params字段获取
init_params = model_config.get('params', {}).copy()
# 添加API相关参数
if 'api_key' in model_config:
init_params['api_key'] = model_config['api_key']
if 'base_url' in model_config:
init_params['base_url'] = model_config['base_url']
# 根据模型类型调整参数
if model_class_name == "OpenAILike":
# OpenAILike模型特殊处理
self._adjust_openai_like_params(init_params)
elif model_class_name == "OpenAIChat":
# OpenAIChat模型特殊处理
self._adjust_openai_chat_params(init_params)
elif model_class_name == "Ollama":
# Ollama模型特殊处理
self._adjust_ollama_params(init_params)
# 应用覆盖参数
init_params.update(override_params)
return init_params
def _adjust_openai_like_params(self, params: Dict[str, Any]) -> None:
"""调整OpenAILike模型参数"""
# OpenAILike使用id参数不需要model参数
# id参数已经在params中无需额外处理
pass
def _adjust_openai_chat_params(self, params: Dict[str, Any]) -> None:
"""调整OpenAIChat模型参数"""
# OpenAIChat通常需要id参数作为model参数
if 'id' in params and 'model' not in params:
params['model'] = params['id']
def _adjust_ollama_params(self, params: Dict[str, Any]) -> None:
"""调整Ollama模型参数"""
# Ollama模型通常不需要api_key和base_url
params.pop('api_key', None)
params.pop('base_url', None)
# 使用id作为model参数
if 'id' in params and 'model' not in params:
params['model'] = params['id']
def list_available_models(self) -> Dict[str, list]:
"""
列出所有可用的模型
Returns:
按提供商分组的模型列表
"""
available_models = {}
for provider, provider_config in self.config.items():
if 'models' in provider_config:
models = list(provider_config['models'].keys())
available_models[provider] = models
return available_models
def validate_model_exists(self, provider: str, model_name: str) -> bool:
"""
验证模型是否存在
Args:
provider: 提供商名称
model_name: 模型名称
Returns:
模型是否存在
"""
try:
get_model_config(self.config, provider, model_name)
return True
except ValueError:
return False
# 便捷函数
def create_agno_model(
provider: str,
model_name: str,
config: Optional[Dict[str, Any]] = None,
**override_params
) -> Union[OpenAILike, OpenAIChat]:
"""
便捷函数创建Agno模型实例
Args:
provider: 提供商名称
model_name: 模型名称
config: 可选的配置字典
**override_params: 覆盖参数
Returns:
创建的模型实例
"""
factory = ModelFactory(config)
return factory.create_model(provider, model_name, **override_params)
def list_available_models(config: Optional[Dict[str, Any]] = None) -> Dict[str, list]:
"""
便捷函数列出可用模型
Args:
config: 可选的配置字典
Returns:
按提供商分组的模型列表
"""
factory = ModelFactory(config)
return factory.list_available_models()
# 测试函数
def test_model_creation():
"""测试模型创建功能"""
print("正在测试模型创建...")
try:
# 创建模型工厂
factory = ModelFactory()
# 列出可用模型
available_models = factory.list_available_models()
print(f"✅ 发现可用模型:")
for provider, models in available_models.items():
print(f" {provider}: {models}")
# 测试阿里云qwen-max模型创建
try:
qwen_model = factory.create_model('aliyun', 'qwen-max')
print(f"✅ qwen-max模型创建成功: {type(qwen_model)}")
except Exception as e:
print(f"❌ qwen-max模型创建失败: {e}")
# 测试DeepSeek模型创建
try:
deepseek_model = factory.create_model('deepseek', 'deepseek-v3')
print(f"✅ deepseek-v3模型创建成功: {type(deepseek_model)}")
except Exception as e:
print(f"❌ deepseek-v3模型创建失败: {e}")
# 测试参数覆盖
try:
custom_model = factory.create_model(
'aliyun',
'qwen-plus',
temperature=0.5,
max_tokens=1000
)
print(f"✅ 参数覆盖测试成功: {type(custom_model)}")
except Exception as e:
print(f"❌ 参数覆盖测试失败: {e}")
return True
except Exception as e:
print(f"❌ 模型创建测试失败: {e}")
return False
def test_convenience_functions():
"""测试便捷函数"""
print("\n正在测试便捷函数...")
try:
# 测试便捷创建函数
model = create_agno_model('aliyun', 'qwen-turbo')
print(f"✅ 便捷函数创建模型成功: {type(model)}")
# 测试列出模型函数
models = list_available_models()
total_models = sum(len(model_list) for model_list in models.values())
print(f"✅ 便捷函数列出模型成功,总计 {total_models} 个模型")
return True
except Exception as e:
print(f"❌ 便捷函数测试失败: {e}")
return False
if __name__ == "__main__":
success1 = test_model_creation()
success2 = test_convenience_functions()
if success1 and success2:
print("\n🎉 所有模型工厂测试通过!")
else:
print("\n💥 部分测试失败,请检查配置")

View File

@ -0,0 +1,434 @@
"""
SubAgent核心类
基于Agno框架的智能代理实现提供动态prompt构建JSON解析模型管理等功能
"""
import logging
from typing import Any, Dict, List, Optional, Type, Union
from agno.agent import Agent, RunResponse
from pydantic import BaseModel
from .config_loader import load_llm_config
from .model_factory import ModelFactory
from .json_processor import JSONProcessor, JSONParseError
# 设置日志
logger = logging.getLogger(__name__)
class SubAgentError(Exception):
"""SubAgent相关错误的基类"""
pass
class ConfigurationError(SubAgentError):
"""配置相关错误"""
pass
class ModelError(SubAgentError):
"""模型相关错误"""
pass
class SubAgent:
"""
SubAgent核心类
基于Agno框架构建的智能代理支持:
- 动态prompt模板构建
- 多LLM提供商支持
- JSON结构化输出
- 零容错解析
- 灵活的配置管理
"""
def __init__(
self,
provider: str,
model_name: str,
instructions: Optional[List[str]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
prompt_template: Optional[str] = None,
response_model: Optional[Type[BaseModel]] = None,
config: Optional[Dict[str, Any]] = None,
**agent_kwargs
):
"""
初始化SubAgent
Args:
provider: LLM提供商名称 ( 'aliyun', 'deepseek', 'openai')
model_name: 模型名称 ( 'qwen-max', 'deepseek-v3')
instructions: 指令列表
name: Agent名称
description: Agent描述
prompt_template: 动态prompt模板
response_model: Pydantic响应模型类
config: 自定义LLM配置
**agent_kwargs: 传递给Agno Agent的额外参数
Raises:
ConfigurationError: 配置错误
ModelError: 模型创建错误
"""
# 基础属性设置
self.provider = provider
self.model_name = model_name
self.name = name or f"{provider}_{model_name}_agent"
self.description = description or f"基于{provider}{model_name}模型的智能代理"
self.instructions = instructions or []
self.prompt_template = prompt_template
self.response_model = response_model
# 初始化组件
try:
# 加载配置
self.config = config if config is not None else load_llm_config()
# 创建模型工厂和JSON处理器
self.model_factory = ModelFactory(self.config)
self.json_processor = JSONProcessor()
# 创建Agno模型
self.model = self._create_model()
# 创建Agno Agent
self.agent = self._create_agent(**agent_kwargs)
logger.info(f"SubAgent {self.name} 初始化成功")
except Exception as e:
logger.error(f"SubAgent初始化失败: {e}")
raise ConfigurationError(f"SubAgent初始化失败: {e}")
def _create_model(self):
"""创建Agno模型实例"""
try:
return self.model_factory.create_model(self.provider, self.model_name)
except Exception as e:
raise ModelError(f"模型创建失败: {e}")
def _create_agent(self, **agent_kwargs):
"""创建Agno Agent实例"""
try:
agent_params = {
"model": self.model,
"name": self.name,
"description": self.description,
"instructions": self.instructions,
"markdown": agent_kwargs.pop("markdown", True),
"debug_mode": agent_kwargs.pop("debug_mode", False),
**agent_kwargs
}
return Agent(**agent_params)
except Exception as e:
raise ConfigurationError(f"Agent创建失败: {e}")
def build_prompt(self, template_vars: Optional[Dict[str, Any]] = None) -> str:
"""
构建动态prompt
Args:
template_vars: 模板变量字典用于替换模板中的占位符
Returns:
构建完成的prompt字符串
Raises:
SubAgentError: prompt构建失败
"""
if not self.prompt_template:
raise SubAgentError("未设置prompt模板无法构建动态prompt")
if not template_vars:
template_vars = {}
try:
# 使用str.format()进行模板替换
prompt = self.prompt_template.format(**template_vars)
# 如果需要JSON输出添加JSON格式要求
if self.response_model:
json_instruction = self._build_json_instruction()
prompt = f"{prompt}\n\n{json_instruction}"
return prompt
except KeyError as e:
raise SubAgentError(f"模板变量缺失: {e}")
except Exception as e:
raise SubAgentError(f"prompt构建失败: {e}")
def _build_json_instruction(self) -> str:
"""构建JSON格式指令"""
json_instruction = """
请严格按照以下JSON格式返回结果不要添加任何额外的文字说明
"""
if self.response_model:
# 生成JSON schema示例
try:
schema = self.response_model.model_json_schema()
json_instruction += f"JSON Schema:\n{schema}\n\n"
# 生成示例
example = self._generate_model_example()
if example:
json_instruction += f"示例格式:\n{example}"
except Exception:
# 如果schema生成失败使用通用指令
json_instruction += "请返回有效的JSON格式数据。"
return json_instruction
def _generate_model_example(self) -> Optional[str]:
"""生成Pydantic模型的示例JSON"""
try:
# 创建一个示例实例
field_values = {}
for field_name, field_info in self.response_model.model_fields.items():
# 根据字段类型生成示例值
field_type = field_info.annotation
field_values[field_name] = self._generate_example_value(field_type)
example_instance = self.response_model(**field_values)
return example_instance.model_dump_json(indent=2)
except Exception:
return None
def _generate_example_value(self, field_type: Type) -> Any:
"""根据字段类型生成示例值"""
if field_type == str:
return "示例文本"
elif field_type == int:
return 0
elif field_type == float:
return 0.0
elif field_type == bool:
return True
elif hasattr(field_type, '__origin__'):
origin = field_type.__origin__
if origin == list:
return []
elif origin == dict:
return {}
return "示例值"
def run(
self,
prompt: Optional[str] = None,
template_vars: Optional[Dict[str, Any]] = None,
**run_kwargs
) -> Any:
"""
执行Agent推理
Args:
prompt: 直接prompt文本如果提供将忽略template_vars
template_vars: 模板变量用于动态构建prompt
**run_kwargs: 传递给Agent.run的额外参数
Returns:
如果指定了response_model返回解析后的Pydantic模型实例
否则返回原始响应内容
Raises:
SubAgentError: 执行失败
JSONParseError: JSON解析失败
"""
try:
# 构建最终prompt
if prompt is not None:
final_prompt = prompt
else:
final_prompt = self.build_prompt(template_vars)
logger.debug(f"Agent {self.name} 执行推理prompt长度: {len(final_prompt)}")
# 执行Agent推理
response: RunResponse = self.agent.run(final_prompt, **run_kwargs)
# 获取响应内容
content = response.content
if self.response_model:
# 如果指定了响应模型进行JSON解析
return self._parse_structured_response(content)
else:
# 返回原始内容
return content
except JSONParseError:
# JSON解析错误直接抛出
raise
except Exception as e:
logger.error(f"Agent执行失败: {e}")
raise SubAgentError(f"Agent执行失败: {e}")
def _parse_structured_response(self, content: str) -> BaseModel:
"""解析结构化响应"""
try:
return self.json_processor.parse_json_response(content, self.response_model)
except JSONParseError as e:
logger.error(f"JSON解析失败: {e}")
logger.debug(f"响应内容: {content}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
Returns:
包含模型信息的字典
"""
return {
"name": self.name,
"provider": self.provider,
"model_name": self.model_name,
"model_type": type(self.model).__name__,
"has_prompt_template": self.prompt_template is not None,
"has_response_model": self.response_model is not None,
"instructions_count": len(self.instructions),
}
def update_instructions(self, instructions: List[str]) -> None:
"""
更新指令列表
Args:
instructions: 新的指令列表
"""
self.instructions = instructions
# 注意这里不会更新已创建的Agent实例如需更新需要重新创建Agent
logger.info(f"Agent {self.name} 指令已更新")
def update_prompt_template(self, template: str) -> None:
"""
更新prompt模板
Args:
template: 新的prompt模板
"""
self.prompt_template = template
logger.info(f"Agent {self.name} prompt模板已更新")
def __str__(self) -> str:
return f"SubAgent({self.name}, {self.provider}.{self.model_name})"
def __repr__(self) -> str:
return self.__str__()
# 便捷函数
def create_json_agent(
provider: str,
model_name: str,
name: str,
prompt_template: str,
response_model: Union[str, Type[BaseModel]],
instructions: Optional[List[str]] = None,
**kwargs
) -> SubAgent:
"""
便捷函数创建支持JSON输出的SubAgent
Args:
provider: 提供商名称
model_name: 模型名称
name: Agent名称
prompt_template: prompt模板
response_model: 响应模型类或模块路径字符串
instructions: 指令列表
**kwargs: 额外参数
Returns:
配置好的SubAgent实例
"""
# 如果response_model是字符串尝试导入
if isinstance(response_model, str):
try:
# 解析模块路径和类名
if '.' in response_model:
module_path, class_name = response_model.rsplit('.', 1)
else:
# 如果没有模块路径,假设在当前模块
module_path = '__main__'
class_name = response_model
# 动态导入
import importlib
module = importlib.import_module(module_path)
response_model = getattr(module, class_name)
except Exception as e:
raise ValueError(f"无法导入响应模型 {response_model}: {e}")
return SubAgent(
provider=provider,
model_name=model_name,
name=name,
prompt_template=prompt_template,
response_model=response_model,
instructions=instructions,
**kwargs
)
# 测试函数
def test_subagent():
"""测试SubAgent功能"""
print("正在测试SubAgent...")
try:
# 创建基础SubAgent
agent = SubAgent(
provider="aliyun",
model_name="qwen-turbo",
name="test_agent",
instructions=["你是一个测试助手", "请简洁回答问题"]
)
print(f"✅ SubAgent创建成功: {agent}")
print(f" 模型信息: {agent.get_model_info()}")
# 测试基础对话
try:
response = agent.run("请简单介绍一下Python语言")
print(f"✅ 基础对话测试成功,响应长度: {len(str(response))}字符")
except Exception as e:
print(f"❌ 基础对话测试失败: {e}")
# 测试动态prompt构建
agent.update_prompt_template("请回答关于{topic}的问题:{question}")
try:
prompt = agent.build_prompt({
"topic": "编程",
"question": "什么是函数?"
})
print(f"✅ 动态prompt构建成功长度: {len(prompt)}字符")
except Exception as e:
print(f"❌ 动态prompt构建失败: {e}")
return True
except Exception as e:
print(f"❌ SubAgent测试失败: {e}")
return False
if __name__ == "__main__":
test_subagent()

1
src/config/.env Normal file
View File

@ -0,0 +1 @@
DASHSCOPE_API_KEY=sk-5c7f9dc33e0a43738d415a0432452b93

View File

@ -0,0 +1,93 @@
# LLM配置文件
# 定义所有Agent可用的LLM模型配置
# 阿里云通义千问配置
aliyun:
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
api_key: ${DASHSCOPE_API_KEY} # 从环境变量读取
models:
qwen-max:
class: "OpenAILike"
params:
id: "qwen-max"
temperature: 0.7
max_tokens: 3000
qwen-plus:
class: "OpenAILike"
params:
id: "qwen-plus"
temperature: 0.7
max_tokens: 2000
qwen-turbo:
class: "OpenAILike"
params:
id: "qwen-turbo"
temperature: 0.7
max_tokens: 1500
# DeepSeek配置
deepseek:
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
api_key: ${DASHSCOPE_API_KEY}
models:
deepseek-v3:
class: "OpenAILike"
params:
id: "deepseek-v3"
temperature: 0.7
max_tokens: 4000
deepseek-r1:
class: "OpenAILike"
params:
id: "deepseek-r1"
temperature: 0.5
max_tokens: 8000
# DeepSeek R1特有的推理模式
reasoning_effort: "high"
# OpenAI配置备用
openai:
base_url: "https://api.openai.com/v1"
api_key: ${OPENAI_API_KEY}
models:
gpt-4:
class: "OpenAIChat"
params:
id: "gpt-4"
temperature: 0.7
max_tokens: 4000
gpt-3.5-turbo:
class: "OpenAIChat"
params:
id: "gpt-3.5-turbo"
temperature: 0.7
max_tokens: 2000
# 本地Ollama配置开发测试用
ollama:
host: "127.0.0.1"
port: 11434
models:
qwen2.5:
class: "Ollama"
params:
id: "qwen2.5:latest"
temperature: 0.7
max_tokens: 2000
llama3:
class: "Ollama"
params:
id: "llama3:latest"
temperature: 0.7
max_tokens: 2000
vllm:
base_url: "http://100.82.33.121:11001/v1"
models:
gpt-oss:
class: "OpenAIChat"
params:
id: "gpt-oss-20b"
temperature: 0.7
max_tokens: 2000

View File

@ -8,19 +8,20 @@ import requests
import xml.etree.ElementTree as ET
import logging
import time
import re
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional
from pathlib import Path
from src.utils.csv_utils import write_dict_to_csv
from src.utils.csv_utils import write_dict_to_csv, read_csv_to_dict
class PaperCrawler:
"""论文爬取类 - 用于从ArXiv和MedRxiv爬取MIMIC 4相关论文"""
def __init__(self, websites: List[str], parallel: int = 20,
arxiv_max_results: int = 200, medrxiv_days_range: int = 730):
arxiv_max_results: int = 2000, medrxiv_days_range: int = 1825):
"""初始化爬虫配置
Args:
@ -34,12 +35,11 @@ class PaperCrawler:
self.arxiv_max_results = arxiv_max_results # ArXiv最大爬取数量
self.medrxiv_days_range = medrxiv_days_range # MedRxiv爬取时间范围(天)
# MIMIC关键词配置
# MIMIC-IV精确关键词配置 - 只包含明确引用MIMIC-IV数据集的论文
self.mimic_keywords = [
"MIMIC-IV", "MIMIC 4", "MIMIC IV",
"Medical Information Mart",
"intensive care", "ICU database",
"critical care database", "electronic health record"
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
"Medical Information Mart Intensive Care IV",
"MIMIC-IV dataset", "MIMIC-IV database", "MIMIC"
]
# HTTP会话配置
@ -104,8 +104,8 @@ class PaperCrawler:
papers = []
try:
# 构建关键词搜索查询
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords[:3]])
# 构建MIMIC-IV精确关键词搜索查询 - 标题和摘要都使用所有关键词
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords])
abstract_query = " OR ".join([f'abs:"{kw}"' for kw in self.mimic_keywords])
search_query = f"({keywords_query}) OR ({abstract_query})"
@ -239,7 +239,7 @@ class PaperCrawler:
return {
'title': paper_data.get('title', '').strip(),
'authors': ', '.join(paper_data.get('authors', [])),
'abstract': paper_data.get('summary', '').strip(),
'abstract': paper_data.get('summary', '').strip().replace('\n', ' ').replace('\r', ' '),
'doi': paper_data.get('doi', ''),
'published_date': paper_data.get('published', '').split('T')[0] if 'T' in paper_data.get('published', '') else paper_data.get('published', ''),
'url': paper_data.get('link', ''),
@ -250,7 +250,7 @@ class PaperCrawler:
return {
'title': paper_data.get('title', '').strip(),
'authors': paper_data.get('authors', ''),
'abstract': paper_data.get('abstract', '').strip(),
'abstract': paper_data.get('abstract', '').strip().replace('\n', ' ').replace('\r', ' '),
'doi': paper_data.get('doi', ''),
'published_date': paper_data.get('date', ''),
'url': f"https://doi.org/{paper_data.get('doi', '')}" if paper_data.get('doi') else '',
@ -418,4 +418,404 @@ class PaperCrawler:
except Exception as e:
logging.error(f"保存CSV文件时出错: {e}")
raise
raise
def download_pdfs_from_csv(self, csv_file_path: str) -> Dict[str, int]:
"""从CSV文件下载论文PDF
Args:
csv_file_path (str): 包含论文信息的CSV文件路径
Returns:
Dict[str, int]: 下载统计信息 {'success': 成功数, 'failed': 失败数, 'total': 总数}
Raises:
FileNotFoundError: CSV文件不存在
ValueError: CSV文件格式错误
"""
try:
# 读取CSV文件中的论文信息
papers_data = self._read_papers_csv(csv_file_path)
if not papers_data:
logging.warning("CSV文件中没有论文数据")
return {'success': 0, 'failed': 0, 'total': 0}
# 准备PDF存储目录
pdf_dir = self._prepare_pdf_storage()
# 初始化统计
total_papers = len(papers_data)
success_count = 0
failed_count = 0
failed_papers = []
logging.info(f"开始并发下载 {total_papers} 篇论文的PDF文件并发数: {self.parallel}")
# 使用并发执行器下载PDF
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
# 提交所有下载任务
future_to_paper = {
executor.submit(self._download_single_pdf, paper_data, pdf_dir): paper_data
for paper_data in papers_data
}
# 处理完成的任务,实时显示进度
completed_count = 0
for future in as_completed(future_to_paper):
paper_data = future_to_paper[future]
title = paper_data.get('title', 'Unknown')[:50] + '...' if len(paper_data.get('title', '')) > 50 else paper_data.get('title', 'Unknown')
try:
success = future.result()
completed_count += 1
if success:
success_count += 1
status = ""
else:
failed_count += 1
failed_papers.append({
'title': paper_data.get('title', ''),
'source': paper_data.get('source', ''),
'url': paper_data.get('url', ''),
'doi': paper_data.get('doi', '')
})
status = ""
# 显示进度
progress = (completed_count / total_papers) * 100
print(f"\r[{completed_count:3d}/{total_papers}] {progress:5.1f}% {status} {title}", end='', flush=True)
except Exception as e:
failed_count += 1
completed_count += 1
failed_papers.append({
'title': paper_data.get('title', ''),
'source': paper_data.get('source', ''),
'error': str(e)
})
progress = (completed_count / total_papers) * 100
print(f"\r[{completed_count:3d}/{total_papers}] {progress:5.1f}% ✗ {title} (Error: {str(e)[:30]})", end='', flush=True)
print() # 换行
# 记录失败详情
if failed_papers:
logging.warning(f"以下 {len(failed_papers)} 篇论文下载失败:")
for paper in failed_papers:
logging.warning(f" - {paper.get('title', 'Unknown')} [{paper.get('source', 'unknown')}]")
if 'error' in paper:
logging.warning(f" 错误: {paper['error']}")
# 生成下载报告
stats = {
'success': success_count,
'failed': failed_count,
'total': total_papers
}
logging.info(f"PDF下载完成! 成功: {success_count}/{total_papers} ({success_count/total_papers*100:.1f}%)")
if failed_count > 0:
logging.warning(f"失败: {failed_count}/{total_papers} ({failed_count/total_papers*100:.1f}%)")
return stats
except Exception as e:
logging.error(f"下载PDF文件时发生错误: {e}")
raise
def _prepare_pdf_storage(self) -> Path:
"""准备PDF存储目录
Returns:
Path: PDF存储目录路径
"""
pdf_dir = Path("dataset") / "pdfs"
pdf_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"PDF存储目录已准备: {pdf_dir}")
return pdf_dir
def _read_papers_csv(self, csv_file_path: str) -> List[Dict[str, str]]:
"""读取论文CSV文件
Args:
csv_file_path (str): CSV文件路径
Returns:
List[Dict[str, str]]: 论文数据列表
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式错误
"""
try:
papers_data = read_csv_to_dict(csv_file_path)
# 验证必要字段
required_fields = ['title', 'url', 'source', 'doi']
if papers_data:
missing_fields = [field for field in required_fields
if field not in papers_data[0]]
if missing_fields:
raise ValueError(f"CSV文件缺少必要字段: {missing_fields}")
logging.info(f"成功读取CSV文件{len(papers_data)} 篇论文")
return papers_data
except Exception as e:
logging.error(f"读取CSV文件失败: {e}")
raise
def _get_pdf_url(self, paper_data: Dict[str, str]) -> Optional[str]:
"""获取论文PDF下载链接
Args:
paper_data (Dict[str, str]): 论文数据
Returns:
Optional[str]: PDF下载链接如果无法获取返回None
"""
try:
source = paper_data.get('source', '')
if not source:
logging.warning("论文缺少source字段")
return None
source = source.lower()
url = paper_data.get('url', '')
doi = paper_data.get('doi', '')
if source == 'arxiv':
return self._get_arxiv_pdf_url(url)
elif source == 'medrxiv':
return self._get_medrxiv_pdf_url(doi, url)
else:
logging.warning(f"不支持的数据源: {source}")
return None
except Exception as e:
logging.error(f"获取PDF链接失败: {e}")
return None
def _get_arxiv_pdf_url(self, url: str) -> Optional[str]:
"""获取ArXiv论文PDF链接
Args:
url (str): ArXiv论文页面URL
Returns:
Optional[str]: PDF下载链接
"""
try:
if not url:
return None
# 从URL中提取论文ID
# 格式: http://arxiv.org/abs/2301.12345 -> 2301.12345
if '/abs/' in url:
paper_id = url.split('/abs/')[-1]
pdf_url = f"http://arxiv.org/pdf/{paper_id}.pdf"
logging.debug(f"ArXiv PDF链接: {pdf_url}")
return pdf_url
else:
logging.warning(f"无法解析ArXiv URL: {url}")
return None
except Exception as e:
logging.error(f"获取ArXiv PDF链接失败: {e}")
return None
def _get_medrxiv_pdf_url(self, doi: str, url: str) -> Optional[str]:
"""获取MedRxiv论文PDF链接 - 支持多种URL格式策略
Args:
doi (str): 论文DOI
url (str): DOI链接备用
Returns:
Optional[str]: PDF下载链接
"""
try:
if not doi:
logging.warning("MedRxiv论文缺少DOI")
return None
if not doi.startswith('10.1101/'):
logging.warning(f"不支持的MedRxiv DOI格式: {doi}")
return None
# 提取DOI后缀部分
paper_part = doi.replace('10.1101/', '')
# 策略1尝试简洁版本号格式优先级最高
# 格式https://www.medrxiv.org/content/10.1101/yyyy.mm.dd.xxxxxxxvN.full.pdf
for version in ['v1', 'v2', 'v3']: # 常见版本号
simple_url = f"https://www.medrxiv.org/content/10.1101/{paper_part}{version}.full.pdf"
logging.debug(f"尝试MedRxiv简洁格式: {simple_url}")
# 这里可以添加URL可用性检查暂时先返回第一个尝试
# 实际使用时下载函数会验证URL的有效性
if version == 'v1': # 优先返回v1版本
return simple_url
# 策略2回退到早期访问格式复杂路径
# 格式https://www.medrxiv.org/content/medrxiv/early/yyyy/mm/dd/yyyy.mm.dd.xxxxxxx.full.pdf
parts = paper_part.split('.')
if len(parts) >= 4:
year = parts[0]
month = parts[1].zfill(2) # 确保两位数
day = parts[2].zfill(2) # 确保两位数
# 构造复杂格式PDF URL
complex_url = f"https://www.medrxiv.org/content/medrxiv/early/{year}/{month}/{day}/{paper_part}.full.pdf"
logging.debug(f"回退使用MedRxiv复杂格式: {complex_url}")
return complex_url
else:
logging.warning(f"无法解析MedRxiv DOI日期格式: {doi}")
return None
except Exception as e:
logging.error(f"获取MedRxiv PDF链接失败: {e}")
return None
def _download_single_pdf(self, paper_data: Dict[str, str], pdf_dir: Path) -> bool:
"""下载单个论文PDF
Args:
paper_data (Dict[str, str]): 论文数据
pdf_dir (Path): PDF存储目录
Returns:
bool: 下载是否成功
"""
try:
# 获取PDF下载链接
pdf_url = self._get_pdf_url(paper_data)
if not pdf_url:
logging.warning(f"无法获取PDF链接: {paper_data.get('title', 'Unknown')}")
return False
# 生成安全的文件名
filename = self._generate_safe_filename(paper_data)
file_path = pdf_dir / filename
# 如果文件已存在且有效,跳过下载
if file_path.exists() and self._validate_pdf_file(file_path):
logging.info(f"PDF文件已存在且有效跳过下载: {filename}")
return True
# 下载PDF文件最多重试3次
for attempt in range(3):
try:
response = self._make_request_with_retry(pdf_url, max_retries=1)
if response.status_code == 200:
# 写入文件
with open(file_path, 'wb') as f:
f.write(response.content)
# 验证PDF完整性
if self._validate_pdf_file(file_path):
logging.info(f"成功下载PDF: {filename}")
return True
else:
logging.warning(f"PDF文件损坏删除并重试: {filename}")
file_path.unlink(missing_ok=True)
else:
logging.warning(f"PDF下载失败状态码 {response.status_code}: {pdf_url}")
except Exception as e:
logging.warning(f"PDF下载第{attempt + 1}次尝试失败: {e}")
# 重试前等待
if attempt < 2:
time.sleep(2 ** attempt)
logging.error(f"PDF下载最终失败: {paper_data.get('title', 'Unknown')}")
return False
except Exception as e:
logging.error(f"下载PDF时发生错误: {e}")
return False
def _validate_pdf_file(self, file_path: Path) -> bool:
"""验证PDF文件完整性
Args:
file_path (Path): PDF文件路径
Returns:
bool: PDF文件是否有效
"""
try:
if not file_path.exists():
return False
# 检查文件大小
if file_path.stat().st_size < 1024: # 至少1KB
logging.warning(f"PDF文件太小可能无效: {file_path.name}")
return False
# 检查PDF文件头和结构
with open(file_path, 'rb') as f:
# 读取文件头
header = f.read(8)
if not header.startswith(b'%PDF-'):
logging.warning(f"文件不是有效的PDF格式: {file_path.name}")
return False
# 检查文件尾部读取最后1KB
f.seek(-min(1024, file_path.stat().st_size), 2)
trailer = f.read()
if b'%%EOF' not in trailer and b'endobj' not in trailer:
logging.warning(f"PDF文件可能不完整: {file_path.name}")
return False
logging.debug(f"PDF文件验证通过: {file_path.name}")
return True
except Exception as e:
logging.error(f"验证PDF文件时发生错误: {e}")
return False
def _generate_safe_filename(self, paper_data: Dict[str, str]) -> str:
"""生成安全的PDF文件名
Args:
paper_data (Dict[str, str]): 论文数据
Returns:
str: 安全的文件名
"""
try:
source = paper_data.get('source', 'unknown').lower()
title = paper_data.get('title', 'untitled')
url = paper_data.get('url', '')
doi = paper_data.get('doi', '')
# 提取paper_id
paper_id = 'unknown'
if source == 'arxiv' and '/abs/' in url:
paper_id = url.split('/abs/')[-1]
elif source == 'medrxiv' and doi:
paper_id = doi.split('/')[-1] if '/' in doi else doi
# 清理标题,保留主要信息
safe_title = re.sub(r'[^\w\s-]', '', title) # 移除特殊字符
safe_title = re.sub(r'\s+', '_', safe_title.strip()) # 空格转下划线
safe_title = safe_title.lower()[:50] # 限制长度并转小写
# 构造文件名: source_paperid_title.pdf
filename = f"{source}_{paper_id}_{safe_title}.pdf"
# 确保文件名长度合理
if len(filename) > 255: # 大多数文件系统的限制
filename = f"{source}_{paper_id}.pdf"
return filename
except Exception as e:
logging.error(f"生成文件名时发生错误: {e}")
# 回退方案
timestamp = int(time.time())
return f"paper_{timestamp}.pdf"

1073
src/extractor.py Normal file

File diff suppressed because it is too large Load Diff

741
src/parse.py Normal file
View File

@ -0,0 +1,741 @@
"""PDF解析模块
该模块提供PDFParser类用于将PDF文件通过OCR API转换为Markdown格式
支持并发处理进度显示错误处理等功能
"""
import requests
import logging
import os
import time
import zipfile
import tempfile
import re
import json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional, Tuple
class PDFParser:
"""PDF解析类 - 用于将PDF文件转换为Markdown格式并按任务类型筛选
支持的任务类型
- prediction: 预测任务 (PRED_)
- classification: 分类任务 (CLAS_)
- time_series: 时间序列分析 (TIME_)
- correlation: 关联性分析 (CORR_)
"""
def __init__(self, pdf_dir: str = "dataset/pdfs", parallel: int = 3,
markdown_dir: str = "dataset/markdowns"):
"""初始化解析器配置
Args:
pdf_dir (str): PDF文件目录默认dataset/pdfs
parallel (int): 并发处理数默认3降低并发以避免服务器过载
markdown_dir (str): Markdown输出目录默认dataset/markdowns
"""
self.pdf_dir = Path(pdf_dir)
self.parallel = parallel
self.markdown_dir = Path(markdown_dir)
# OCR API配置
self.ocr_api_url = "http://100.106.4.14:7861/parse"
# AI模型API配置用于四类任务识别prediction/classification/time_series/correlation
self.ai_api_url = "http://100.82.33.121:11001/v1/chat/completions"
self.ai_model = "gpt-oss-20b"
# MIMIC-IV关键词配置用于内容筛选
self.mimic_keywords = [
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
"Medical Information Mart Intensive Care IV",
"MIMIC-IV dataset", "MIMIC-IV database"
]
# 任务类型到前缀的映射配置
self.task_type_prefixes = {
"prediction": "PRED_",
"classification": "CLAS_",
"time_series": "TIME_",
"correlation": "CORR_",
"none": None # 不符合任何类型,不标记
}
# HTTP会话配置增加连接池大小和超时时间
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
self.session = requests.Session()
self.session.headers.update({
'User-Agent': 'MedResearcher-PDFParser/1.0'
})
# 配置连接池适配器(增加连接池大小)
adapter = HTTPAdapter(
pool_connections=10, # 连接池数量
pool_maxsize=20, # 最大连接数
max_retries=0 # 禁用自动重试,使用自定义重试逻辑
)
self.session.mount('http://', adapter)
self.session.mount('https://', adapter)
# 配置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
def _scan_pdf_files(self) -> List[Path]:
"""扫描PDF文件目录获取所有PDF文件
Returns:
List[Path]: PDF文件路径列表
Raises:
FileNotFoundError: PDF目录不存在
"""
if not self.pdf_dir.exists():
raise FileNotFoundError(f"PDF目录不存在: {self.pdf_dir}")
pdf_files = []
for pdf_file in self.pdf_dir.glob("*.pdf"):
if pdf_file.is_file():
pdf_files.append(pdf_file)
logging.info(f"发现 {len(pdf_files)} 个PDF文件待处理")
return pdf_files
def _check_mimic_keywords(self, output_subdir: Path) -> bool:
"""检查Markdown文件是否包含MIMIC-IV关键词
Args:
output_subdir (Path): 包含Markdown文件的输出子目录
Returns:
bool: 是否包含MIMIC-IV关键词
"""
try:
# 查找所有.md文件
md_files = list(output_subdir.glob("*.md"))
if not md_files:
logging.warning(f"未找到Markdown文件进行MIMIC关键词检查: {output_subdir}")
return False
# 检查每个Markdown文件的内容
for md_file in md_files:
try:
with open(md_file, 'r', encoding='utf-8') as f:
content = f.read().lower() # 转换为小写进行不区分大小写匹配
# 检查是否包含任何MIMIC-IV关键词
for keyword in self.mimic_keywords:
if keyword.lower() in content:
logging.info(f"发现MIMIC-IV关键词 '{keyword}' 在文件 {md_file.name}")
return True
except Exception as e:
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
continue
logging.info(f"未发现MIMIC-IV关键词: {output_subdir.name}")
return False
except Exception as e:
logging.error(f"检查MIMIC关键词时发生错误: {output_subdir} - {e}")
return False
def _extract_introduction(self, output_subdir: Path) -> Optional[str]:
"""从Markdown文件中提取Introduction部分
Args:
output_subdir (Path): 包含Markdown文件的输出子目录
Returns:
Optional[str]: 提取的Introduction内容失败时返回None
"""
try:
# 查找所有.md文件
md_files = list(output_subdir.glob("*.md"))
if not md_files:
logging.warning(f"未找到Markdown文件进行Introduction提取: {output_subdir}")
return None
# 通常使用第一个md文件
md_file = md_files[0]
try:
with open(md_file, 'r', encoding='utf-8') as f:
content = f.read()
# 使用正则表达式提取Introduction部分
# 匹配各种可能的Introduction标题格式
patterns = [
r'(?i)#\s*Introduction\s*\n(.*?)(?=\n#|\n\n#|$)',
r'(?i)##\s*Introduction\s*\n(.*?)(?=\n##|\n\n##|$)',
r'(?i)###\s*Introduction\s*\n(.*?)(?=\n###|\n\n###|$)',
r'(?i)\*\*Introduction\*\*\s*\n(.*?)(?=\n\*\*|\n\n\*\*|$)',
r'(?i)Introduction\s*\n(.*?)(?=\n[A-Z][a-z]+\s*\n|$)'
]
for pattern in patterns:
match = re.search(pattern, content, re.DOTALL)
if match:
introduction = match.group(1).strip()
if len(introduction) > 100: # 确保有足够的内容进行分析
logging.info(f"成功提取Introduction部分 ({len(introduction)} 字符): {md_file.name}")
return introduction
# 如果没有明确的Introduction标题尝试提取前几段作为近似的introduction
paragraphs = content.split('\n\n')
introduction_candidates = []
for para in paragraphs[:5]: # 取前5段
para = para.strip()
if len(para) > 50 and not para.startswith('#'): # 过滤掉标题和过短段落
introduction_candidates.append(para)
if introduction_candidates:
introduction = '\n\n'.join(introduction_candidates[:3]) # 最多取前3段
if len(introduction) > 200:
logging.info(f"提取近似Introduction部分 ({len(introduction)} 字符): {md_file.name}")
return introduction
logging.warning(f"未能提取到有效的Introduction内容: {md_file.name}")
return None
except Exception as e:
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
return None
except Exception as e:
logging.error(f"提取Introduction时发生错误: {output_subdir} - {e}")
return None
def _analyze_research_task(self, introduction: str) -> str:
"""使用AI模型分析论文的研究任务类型
Args:
introduction (str): 论文的Introduction内容
Returns:
str: 任务类型 ('prediction', 'classification', 'time_series', 'correlation', 'none')
"""
try:
# 构造AI分析的提示词
system_prompt = """你是一个医学研究专家。请分析给定的论文Introduction部分判断该研究属于以下哪种任务类型
1. prediction - 预测任务预测未来事件结局或数值如死亡率预测住院时长预测疾病进展预测
2. classification - 分类任务将患者或病例分类到不同类别如疾病诊断分类风险等级分类药物反应分类
3. time_series - 时间序列分析分析随时间变化的医疗数据如生命体征趋势分析病情演进分析纵向队列研究
4. correlation - 关联性分析研究变量间的关系或关联如痾病与人口特征关系药物与副作用关联风险因素识别
5. none - 不属于以上任何类型
请以JSON格式回答包含任务类型和置信度
{\"task_type\": \"prediction\", \"confidence\": 0.85}
task_type必须是以下选项之一predictionclassificationtime_seriescorrelationnone
confidence为0-1之间的数值表示判断的置信度
只返回JSON不要添加其他文字"""
user_prompt = f"请分析以下论文Introduction判断属于哪种任务类型\n\n{introduction[:2000]}" # 限制长度避免token过多
# 构造API请求数据
api_data = {
"model": self.ai_model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"max_tokens": 50, # 需要返回JSON格式
"temperature": 0.1 # 降低随机性
}
# 调用AI API
response = self.session.post(
self.ai_api_url,
json=api_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
result = response.json()
ai_response = result['choices'][0]['message']['content'].strip()
try:
# 解析JSON响应
parsed_response = json.loads(ai_response)
task_type = parsed_response.get('task_type', 'none').lower()
confidence = parsed_response.get('confidence', 0.0)
# 验证任务类型是否有效
valid_types = ['prediction', 'classification', 'time_series', 'correlation', 'none']
if task_type not in valid_types:
logging.warning(f"AI返回了无效的任务类型: {task_type},使用默认值 'none'")
task_type = "none"
confidence = 0.0
# 只接受高置信度的结果
if confidence < 0.7:
logging.info(f"AI分析置信度过低 ({confidence:.2f}),归类为 'none'")
task_type = "none"
logging.info(f"AI分析结果: 任务类型={task_type}, 置信度={confidence:.2f}")
return task_type
except json.JSONDecodeError as e:
logging.error(f"解析AI JSON响应失败: {ai_response} - 错误: {e}")
return "none"
else:
logging.error(f"AI API调用失败状态码: {response.status_code}")
return "none"
except Exception as e:
logging.error(f"AI分析研究任务时发生错误: {e}")
return "none"
def _mark_valid_folder(self, output_subdir: Path, task_type: str) -> bool:
"""为通过筛选的文件夹添加任务类型前缀标记
Args:
output_subdir (Path): 需要标记的输出子目录
task_type (str): 任务类型 ('prediction', 'classification', 'time_series', 'correlation')
Returns:
bool: 标记是否成功
"""
try:
# 获取任务类型对应的前缀
prefix = self.task_type_prefixes.get(task_type)
if not prefix:
logging.info(f"任务类型 '{task_type}' 不需要标记文件夹")
return True # 不需要标记,但认为成功
# 检查文件夹是否已经有相应的任务类型前缀
if output_subdir.name.startswith(prefix):
logging.info(f"文件夹已标记为{task_type}任务: {output_subdir.name}")
return True
# 检查是否已经有其他任务类型的前缀
for existing_type, existing_prefix in self.task_type_prefixes.items():
if existing_prefix and output_subdir.name.startswith(existing_prefix):
logging.info(f"文件夹已有{existing_type}任务标记,不需要重新标记: {output_subdir.name}")
return True
# 生成新的文件夹名
new_folder_name = prefix + output_subdir.name
new_folder_path = output_subdir.parent / new_folder_name
# 重命名文件夹
output_subdir.rename(new_folder_path)
logging.info(f"文件夹标记成功: {output_subdir.name} -> {new_folder_name} (任务类型: {task_type})")
return True
except Exception as e:
logging.error(f"标记文件夹时发生错误: {output_subdir} - {e}")
return False
def _prepare_output_dir(self) -> Path:
"""准备Markdown输出目录
Returns:
Path: Markdown输出目录路径
"""
self.markdown_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Markdown输出目录已准备: {self.markdown_dir}")
return self.markdown_dir
def _call_ocr_api(self, pdf_file: Path) -> Optional[Dict]:
"""调用OCR API解析PDF文件
Args:
pdf_file (Path): PDF文件路径
Returns:
Optional[Dict]: API响应数据失败时返回None
"""
try:
with open(pdf_file, 'rb') as f:
files = {
'file': (pdf_file.name, f, 'application/pdf')
}
response = self._make_request_with_retry(
self.ocr_api_url,
files=files,
timeout=1800 # 增加到3分钟匹配服务器处理时间
)
if response.status_code == 200:
response_data = response.json()
if response_data.get('success', False):
logging.debug(f"OCR API调用成功: {pdf_file.name}")
return response_data
else:
logging.warning(f"OCR API处理失败: {pdf_file.name} - {response_data.get('message', 'Unknown error')}")
return None
else:
logging.error(f"OCR API请求失败状态码: {response.status_code} - {pdf_file.name}")
return None
except Exception as e:
logging.error(f"调用OCR API时发生错误: {pdf_file.name} - {e}")
return None
def _download_and_extract_zip(self, download_url: str, pdf_file: Path) -> bool:
"""从API响应中下载ZIP文件并解压到子文件夹
Args:
download_url (str): 完整的下载URL
pdf_file (Path): 原始PDF文件路径用于生成输出文件夹名
Returns:
bool: 下载和解压是否成功
"""
try:
# 下载ZIP文件
response = self._make_request_with_retry(download_url, timeout=60)
if response.status_code != 200:
logging.error(f"下载ZIP失败状态码: {response.status_code} - {pdf_file.name}")
return False
# 创建以PDF文件名命名的输出子文件夹
output_subdir = self.markdown_dir / pdf_file.stem
output_subdir.mkdir(parents=True, exist_ok=True)
# 使用临时文件保存ZIP内容
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip:
temp_zip.write(response.content)
temp_zip_path = temp_zip.name
try:
# 解压ZIP文件到输出子文件夹
with zipfile.ZipFile(temp_zip_path, 'r') as zip_ref:
zip_ref.extractall(output_subdir)
logging.debug(f"ZIP文件解压成功: {pdf_file.name} -> {output_subdir}")
# 清洗解压后的Markdown文件
if not self._clean_markdown_files(output_subdir):
logging.warning(f"Markdown文件清洗失败但解压成功: {pdf_file.name}")
return True
finally:
# 清理临时ZIP文件
os.unlink(temp_zip_path)
except zipfile.BadZipFile as e:
logging.error(f"ZIP文件损坏: {pdf_file.name} - {e}")
return False
except Exception as e:
logging.error(f"下载或解压ZIP时发生错误: {pdf_file.name} - {e}")
return False
def _clean_markdown_files(self, output_subdir: Path) -> bool:
"""清洗输出目录中的Markdown文件去除数字编号和空行
Args:
output_subdir (Path): 包含Markdown文件的输出子目录
Returns:
bool: 清洗是否成功
"""
try:
# 查找所有.md文件
md_files = list(output_subdir.glob("*.md"))
if not md_files:
logging.debug(f"未找到Markdown文件进行清洗: {output_subdir}")
return True
for md_file in md_files:
try:
# 读取原文件内容
with open(md_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 清洗每一行
cleaned_lines = []
for line in lines:
# 去除行尾换行符
line_content = line.rstrip('\n\r')
# 跳过纯数字行(如 "2", "30"
if re.match(r'^\d+$', line_content):
continue
# 跳过数字+空格行(如 "30 "
if re.match(r'^\d+\s*$', line_content):
continue
# 去除行首的数字+空格模式(如 "1 Title:" -> "Title:"
cleaned_line = re.sub(r'^\d+\s+', '', line_content)
# 如果清洗后行不为空,则保留
if cleaned_line.strip():
cleaned_lines.append(cleaned_line + '\n')
else:
# 保留空行以维护文档结构
cleaned_lines.append('\n')
# 写回清洗后的内容
with open(md_file, 'w', encoding='utf-8') as f:
f.writelines(cleaned_lines)
logging.debug(f"Markdown文件清洗完成: {md_file.name}")
except Exception as e:
logging.error(f"清洗Markdown文件时发生错误: {md_file.name} - {e}")
return False
logging.info(f"成功清洗 {len(md_files)} 个Markdown文件: {output_subdir}")
return True
except Exception as e:
logging.error(f"清洗Markdown文件时发生错误: {output_subdir} - {e}")
return False
def _process_single_pdf(self, pdf_file: Path) -> bool:
"""处理单个PDF文件的完整流程
Args:
pdf_file (Path): PDF文件路径
Returns:
bool: 处理是否成功
"""
try:
# 检查PDF文件是否存在且有效
if not pdf_file.exists() or pdf_file.stat().st_size == 0:
logging.warning(f"PDF文件不存在或为空: {pdf_file}")
return False
# 检查是否已存在对应的输出子文件夹
output_subdir = self.markdown_dir / pdf_file.stem
if output_subdir.exists() and any(output_subdir.iterdir()):
logging.info(f"输出文件夹已存在且非空,跳过处理: {pdf_file.stem}")
return True
# 调用OCR API
api_response = self._call_ocr_api(pdf_file)
if not api_response:
return False
# 获取下载URL并拼接完整地址
download_url = api_response.get('download_url')
if not download_url:
logging.error(f"API响应中缺少下载URL: {pdf_file.name}")
return False
# 拼接完整的下载URL
full_download_url = f"http://100.106.4.14:7861{download_url}"
logging.debug(f"完整下载URL: {full_download_url}")
# 下载并解压ZIP文件
success = self._download_and_extract_zip(full_download_url, pdf_file)
if not success:
return False
# 获取解压后的文件夹路径
output_subdir = self.markdown_dir / pdf_file.stem
# 第一层筛选检查MIMIC-IV关键词
logging.info(f"开始MIMIC-IV关键词筛选: {pdf_file.stem}")
if not self._check_mimic_keywords(output_subdir):
logging.info(f"未通过MIMIC-IV关键词筛选跳过: {pdf_file.stem}")
return True # 处理成功但未通过筛选
# 第二层筛选AI分析研究任务
logging.info(f"开始AI研究任务分析: {pdf_file.stem}")
introduction = self._extract_introduction(output_subdir)
if not introduction:
logging.warning(f"无法提取Introduction跳过AI分析: {pdf_file.stem}")
return True # 处理成功但无法进行任务分析
task_type = self._analyze_research_task(introduction)
if task_type == "none":
logging.info(f"未通过研究任务筛选 (task_type=none),跳过: {pdf_file.stem}")
return True # 处理成功但未通过筛选
# 两层筛选都通过,根据任务类型标记文件夹
logging.info(f"通过所有筛选,标记为{task_type}任务论文: {pdf_file.stem}")
if self._mark_valid_folder(output_subdir, task_type):
logging.info(f"论文筛选完成,已标记为{task_type}任务: {pdf_file.stem}")
else:
logging.warning(f"文件夹标记失败: {pdf_file.stem}")
return True
except Exception as e:
logging.error(f"处理PDF文件时发生错误: {pdf_file.name} - {e}")
return False
def _make_request_with_retry(self, url: str, files: Optional[Dict] = None,
max_retries: int = 5, timeout: int = 180) -> requests.Response:
"""带智能重试策略的HTTP请求
Args:
url (str): 请求URL
files (Optional[Dict]): 文件数据用于POST请求
max_retries (int): 最大重试次数增加到5次
timeout (int): 请求超时时间
Returns:
requests.Response: HTTP响应
Raises:
requests.RequestException: 当所有重试都失败时抛出
"""
for attempt in range(max_retries):
try:
if files:
response = self.session.post(url, files=files, timeout=timeout)
else:
response = self.session.get(url, timeout=timeout)
# 检查响应状态针对500错误进行重试
if response.status_code == 500:
if attempt == max_retries - 1:
logging.error(f"服务器内部错误,已达到最大重试次数: HTTP {response.status_code}")
return response # 返回错误响应而不是抛出异常
# 500错误使用较长的等待时间
wait_time = min(30, 10 + (attempt * 5)) # 10s, 15s, 20s, 25s, 30s
logging.warning(f"服务器内部错误,{wait_time}秒后重试 (第{attempt + 1}次)")
time.sleep(wait_time)
continue
return response
except requests.exceptions.Timeout as e:
if attempt == max_retries - 1:
logging.error(f"请求超时,已达到最大重试次数: {e}")
raise
# 超时错误使用较短的等待时间
wait_time = min(15, 5 + (attempt * 2)) # 5s, 7s, 9s, 11s, 13s
logging.warning(f"请求超时,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
time.sleep(wait_time)
except requests.exceptions.ConnectionError as e:
if attempt == max_retries - 1:
logging.error(f"连接错误,已达到最大重试次数: {e}")
raise
# 连接错误使用指数退避
wait_time = min(60, 5 * (2 ** attempt)) # 5s, 10s, 20s, 40s, 60s
logging.warning(f"连接错误,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
time.sleep(wait_time)
except requests.RequestException as e:
if attempt == max_retries - 1:
logging.error(f"请求失败,已达到最大重试次数: {e}")
raise
# 其他错误使用标准指数退避
wait_time = min(30, 3 * (2 ** attempt)) # 3s, 6s, 12s, 24s, 30s
logging.warning(f"请求失败,{wait_time}秒后重试 (第{attempt + 1}次): {e}")
time.sleep(wait_time)
def parse_all_pdfs(self) -> Dict[str, int]:
"""批量处理所有PDF文件转换为Markdown格式
Returns:
Dict[str, int]: 处理统计信息 {'success': 成功数, 'failed': 失败数, 'total': 总数}
Raises:
FileNotFoundError: PDF目录不存在
"""
try:
# 扫描PDF文件
pdf_files = self._scan_pdf_files()
if not pdf_files:
logging.warning("未找到PDF文件")
return {'success': 0, 'failed': 0, 'total': 0}
# 准备输出目录
self._prepare_output_dir()
# 初始化统计
total_files = len(pdf_files)
success_count = 0
failed_count = 0
failed_files = []
logging.info(f"开始并发处理 {total_files} 个PDF文件")
logging.info(f"并发数: {self.parallel} (降低并发数以避免服务器过载)")
logging.info(f"请求超时: 1800秒 (适配服务器处理时间)")
logging.info(f"重试次数: 5次 (智能重试策略)")
# 使用并发执行器处理PDF
with ThreadPoolExecutor(max_workers=self.parallel) as executor:
# 提交所有处理任务
future_to_pdf = {
executor.submit(self._process_single_pdf, pdf_file): pdf_file
for pdf_file in pdf_files
}
# 处理完成的任务,实时显示进度
completed_count = 0
for future in as_completed(future_to_pdf):
pdf_file = future_to_pdf[future]
filename = pdf_file.name[:50] + '...' if len(pdf_file.name) > 50 else pdf_file.name
try:
success = future.result()
completed_count += 1
if success:
success_count += 1
status = ""
else:
failed_count += 1
failed_files.append({
'filename': pdf_file.name,
'path': str(pdf_file)
})
status = ""
# 显示进度
progress = (completed_count / total_files) * 100
print(f"\r[{completed_count:3d}/{total_files}] {progress:5.1f}% {status} {filename}", end='', flush=True)
except Exception as e:
failed_count += 1
completed_count += 1
failed_files.append({
'filename': pdf_file.name,
'path': str(pdf_file),
'error': str(e)
})
progress = (completed_count / total_files) * 100
print(f"\r[{completed_count:3d}/{total_files}] {progress:5.1f}% ✗ {filename} (Error: {str(e)[:30]})", end='', flush=True)
print() # 换行
# 记录失败详情
if failed_files:
logging.warning(f"以下 {len(failed_files)} 个PDF文件处理失败:")
for file_info in failed_files:
logging.warning(f" - {file_info['filename']}")
if 'error' in file_info:
logging.warning(f" 错误: {file_info['error']}")
# 生成处理报告
stats = {
'success': success_count,
'failed': failed_count,
'total': total_files
}
logging.info(f"PDF解析完成! 成功: {success_count}/{total_files} ({success_count/total_files*100:.1f}%)")
if failed_count > 0:
logging.warning(f"失败: {failed_count}/{total_files} ({failed_count/total_files*100:.1f}%)")
return stats
except Exception as e:
logging.error(f"批量处理PDF文件时发生错误: {e}")
raise

1022
uv.lock generated

File diff suppressed because it is too large Load Diff