feat: 优化信息提取系统并行处理和错误重试机制

- info_extractor.py: 添加文档并行处理线程数配置参数
- papers_crawler.py: 优化默认参数配置和数据文件路径
- src/crawler.py: 精确化MIMIC-IV关键词搜索和扩大爬取范围
- src/extractor.py: 实现并行文档处理、提取重试机制和内容预处理
- src/parse.py: 小幅优化解析逻辑

主要改进:
1. 支持多线程并行处理文档,提升提取效率
2. 增加API调用重试机制,提高稳定性
3. 优化论文内容预处理,去除无关信息
4. 完善进度跟踪和错误日志记录
This commit is contained in:
iomgaa 2025-08-26 22:19:28 +08:00
parent d1f7a27b1b
commit 76c04eae4a
5 changed files with 669 additions and 108 deletions

View File

@ -72,6 +72,13 @@ def setup_args():
help='日志级别 (默认: INFO)'
)
parser.add_argument(
'--doc_workers',
type=int,
default=50,
help='文档并行处理工作线程数 (默认: 4)'
)
return parser.parse_args()
@ -85,7 +92,7 @@ def main():
logging.getLogger().setLevel(getattr(logging, args.log_level))
# 初始化信息提取器
builder = MIMICLangExtractBuilder()
builder = MIMICLangExtractBuilder(doc_workers=args.doc_workers)
print(f"=== MIMIC论文信息提取工具启动 ===")
print(f"论文目录: {args.papers_dir}")
@ -93,6 +100,7 @@ def main():
print(f"测试模式: {'' if args.test_mode else ''}")
if args.max_papers:
print(f"最大论文数: {args.max_papers}")
print(f"文档并行度: {args.doc_workers} 线程")
print(f"日志级别: {args.log_level}")
print(f"========================")

View File

@ -22,7 +22,7 @@ def setup_args():
parser.add_argument(
'--paper_website',
default=["arxiv","medrxiv"],
default=["medrxiv"],
help='论文网站 (默认: arxiv,medrxiv)',
nargs='+',
choices=["arxiv","medrxiv"]
@ -45,7 +45,7 @@ def setup_args():
parser.add_argument(
'--pdf_download_list',
type=str,
default='dataset/mimic_papers_20250823.csv',
default='dataset/mimic_papers_20250825.csv',
help='指定PDF下载目录'
)

View File

@ -21,7 +21,7 @@ class PaperCrawler:
"""论文爬取类 - 用于从ArXiv和MedRxiv爬取MIMIC 4相关论文"""
def __init__(self, websites: List[str], parallel: int = 20,
arxiv_max_results: int = 200, medrxiv_days_range: int = 730):
arxiv_max_results: int = 2000, medrxiv_days_range: int = 1825):
"""初始化爬虫配置
Args:
@ -35,12 +35,11 @@ class PaperCrawler:
self.arxiv_max_results = arxiv_max_results # ArXiv最大爬取数量
self.medrxiv_days_range = medrxiv_days_range # MedRxiv爬取时间范围(天)
# MIMIC关键词配置
# MIMIC-IV精确关键词配置 - 只包含明确引用MIMIC-IV数据集的论文
self.mimic_keywords = [
"MIMIC-IV", "MIMIC 4", "MIMIC IV",
"Medical Information Mart",
"intensive care", "ICU database",
"critical care database", "electronic health record"
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
"Medical Information Mart Intensive Care IV",
"MIMIC-IV dataset", "MIMIC-IV database", "MIMIC"
]
# HTTP会话配置
@ -105,8 +104,8 @@ class PaperCrawler:
papers = []
try:
# 构建关键词搜索查询
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords[:3]])
# 构建MIMIC-IV精确关键词搜索查询 - 标题和摘要都使用所有关键词
keywords_query = " OR ".join([f'ti:"{kw}"' for kw in self.mimic_keywords])
abstract_query = " OR ".join([f'abs:"{kw}"' for kw in self.mimic_keywords])
search_query = f"({keywords_query}) OR ({abstract_query})"

View File

@ -14,6 +14,8 @@ import json
from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
# 配置日志
logger = logging.getLogger(__name__)
@ -22,8 +24,12 @@ logger = logging.getLogger(__name__)
class MIMICLangExtractBuilder:
"""基于LangExtract的MIMIC论文信息提取器"""
def __init__(self):
"""初始化提取器配置vllm API服务"""
def __init__(self, doc_workers: int = 4):
"""初始化提取器配置vllm API服务
Args:
doc_workers: 文档并行处理工作线程数默认为4
"""
try:
# 配置LangExtract使用vllm API通过OpenAI兼容接口
import os
@ -43,12 +49,13 @@ class MIMICLangExtractBuilder:
# LangExtract通用配置参数
self.extract_config = {
"config": self.model_config,
"max_workers": 3, # 降低并发避免过载vllm服务
"max_workers": 5, # 降低并发避免过载vllm服务
"max_char_buffer": 6000, # 适合医学论文的上下文长度
"extraction_passes": 1, # 单次提取避免过多API调用
"temperature": 0.1, # 较低温度确保一致性
"fence_output": True, # 期望代码围栏格式输出
"use_schema_constraints": False # vllm可能不支持严格schema
"use_schema_constraints": False, # vllm可能不支持严格schema
"debug": False
}
# 加载所有模块的提取配置
@ -60,7 +67,11 @@ class MIMICLangExtractBuilder:
"environment": self._load_environment_config()
}
logger.info("MIMICLangExtractBuilder初始化成功")
# 文档并行处理配置
self.doc_workers = max(1, doc_workers) # 确保至少有1个工作线程
self.progress_lock = threading.Lock() # 保护进度保存操作的线程锁
logger.info(f"MIMICLangExtractBuilder初始化成功 (文档并行度: {self.doc_workers})")
except Exception as e:
logger.error(f"初始化失败: {e}")
@ -70,16 +81,16 @@ class MIMICLangExtractBuilder:
"""加载数据模块的LangExtract配置"""
return {
"prompt": textwrap.dedent("""
从医学论文中提取数据处理相关的具体信息严格按照以下规则
Extract specific data processing information from medical papers. Follow these rules strictly:
1. dataset_source: 提取明确提到的数据集名称"MIMIC-IV", "Stanford EHR"
2. data_scale: 提取具体的数据规模数字"135,483 patients", "2015-2023"
3. preprocessing_step: 提取数据预处理的具体步骤描述
4. feature_type: 提取特征类型和编码方法的描述
5. inclusion_criteria: 提取患者纳入标准的确切文本
6. exclusion_criteria: 提取患者排除标准的确切文本
1. dataset_source: Extract clearly mentioned dataset names (e.g., "MIMIC-IV", "Stanford EHR")
2. data_scale: Extract specific data scale numbers (e.g., "135,483 patients", "2015-2023")
3. preprocessing_step: Extract specific descriptions of data preprocessing steps
4. feature_type: Extract descriptions of feature types and encoding methods
5. inclusion_criteria: Extract exact text of patient inclusion criteria
6. exclusion_criteria: Extract exact text of patient exclusion criteria
使用exact text进行提取不要释义为每个提取项提供有意义的属性
Use exact text for extraction, do not paraphrase. Provide meaningful attributes for each extraction.
"""),
"examples": [
lx.data.ExampleData(
@ -165,15 +176,15 @@ class MIMICLangExtractBuilder:
"""加载模型模块的LangExtract配置"""
return {
"prompt": textwrap.dedent("""
从医学论文中提取机器学习模型的具体信息严格按照以下规则
Extract specific machine learning model information from medical papers. Follow these rules strictly:
1. model_name: 提取明确提到的模型名称"XGBoost", "LSTM", "GPT-4"
2. architecture_detail: 提取架构描述的具体文本
3. hyperparameter: 提取超参数设置的具体数值
4. feature_processing: 提取特征处理方法的描述
5. model_component: 提取模型组件或模块的描述
1. model_name: Extract clearly mentioned model names (e.g., "XGBoost", "LSTM", "GPT-4")
2. architecture_detail: Extract specific text describing architecture
3. hyperparameter: Extract specific numerical values of hyperparameter settings
4. feature_processing: Extract descriptions of feature processing methods
5. model_component: Extract descriptions of model components or modules
使用exact text进行提取不要释义为每个提取项提供有意义的属性
Use exact text for extraction, do not paraphrase. Provide meaningful attributes for each extraction.
"""),
"examples": [
lx.data.ExampleData(
@ -222,15 +233,15 @@ class MIMICLangExtractBuilder:
"""加载训练模块的LangExtract配置"""
return {
"prompt": textwrap.dedent("""
从医学论文中提取模型训练相关的具体信息严格按照以下规则
Extract specific model training information from medical papers. Follow these rules strictly:
1. data_split_method: 提取数据分割方法的具体描述
2. validation_approach: 提取验证策略的具体描述
3. hyperparameter_tuning: 提取超参数调优方法
4. stopping_condition: 提取训练停止条件
5. optimizer_config: 提取优化器配置信息
1. data_split_method: Extract specific descriptions of data splitting methods
2. validation_approach: Extract specific descriptions of validation strategies
3. hyperparameter_tuning: Extract hyperparameter tuning methods
4. stopping_condition: Extract training stopping conditions
5. optimizer_config: Extract optimizer configuration information
使用exact text进行提取不要释义为每个提取项提供有意义的属性
Use exact text for extraction, do not paraphrase. Provide meaningful attributes for each extraction.
"""),
"examples": [
lx.data.ExampleData(
@ -273,15 +284,15 @@ class MIMICLangExtractBuilder:
"""加载评估模块的LangExtract配置"""
return {
"prompt": textwrap.dedent("""
从医学论文中提取模型评估相关的具体信息严格按照以下规则
Extract specific model evaluation information from medical papers. Follow these rules strictly:
1. evaluation_metric: 提取具体的评估指标名称"AUC", "F1-score", "sensitivity"
2. baseline_comparison: 提取基线模型或方法的描述
3. performance_result: 提取具体的性能数值结果
4. statistical_test: 提取统计检验方法的描述
5. experimental_setting: 提取实验设置的具体信息
1. evaluation_metric: Extract specific evaluation metric names (e.g., "AUC", "F1-score", "sensitivity")
2. baseline_comparison: Extract descriptions of baseline models or methods
3. performance_result: Extract specific numerical performance results
4. statistical_test: Extract descriptions of statistical testing methods
5. experimental_setting: Extract specific information about experimental settings
使用exact text进行提取不要释义为每个提取项提供有意义的属性
Use exact text for extraction, do not paraphrase. Provide meaningful attributes for each extraction.
"""),
"examples": [
lx.data.ExampleData(
@ -322,15 +333,15 @@ class MIMICLangExtractBuilder:
"""加载环境模块的LangExtract配置"""
return {
"prompt": textwrap.dedent("""
从医学论文中提取实验环境相关的具体信息严格按照以下规则
Extract specific experimental environment information from medical papers. Follow these rules strictly:
1. software_library: 提取具体的软件工具和库名称
2. hardware_resource: 提取硬件资源需求的描述
3. data_repository: 提取数据存储和访问的具体信息
4. code_availability: 提取代码可用性的具体描述
5. compliance_requirement: 提取合规性和部署要求
1. software_library: Extract specific software tools and library names
2. hardware_resource: Extract descriptions of hardware resource requirements
3. data_repository: Extract specific information about data storage and access
4. code_availability: Extract specific descriptions of code availability
5. compliance_requirement: Extract compliance and deployment requirements
使用exact text进行提取不要释义为每个提取项提供有意义的属性
Use exact text for extraction, do not paraphrase. Provide meaningful attributes for each extraction.
"""),
"examples": [
lx.data.ExampleData(
@ -391,53 +402,225 @@ class MIMICLangExtractBuilder:
# 逐个提取每个模块
for module_name, config in self.module_configs.items():
try:
logger.info(f" 提取{module_name}模块...")
# 使用LangExtract进行结构化提取
extraction_result = lx.extract(
text_or_documents=paper_content,
prompt_description=config["prompt"],
examples=config["examples"],
**self.extract_config
)
# 处理提取结果 - extraction_result是AnnotatedDocument对象
if extraction_result and hasattr(extraction_result, 'extractions') and extraction_result.extractions:
results["modules"][module_name] = {
"extractions": [
{
"extraction_class": ext.extraction_class,
"extraction_text": ext.extraction_text,
"start_index": getattr(ext, 'start_index', None),
"end_index": getattr(ext, 'end_index', None),
"attributes": getattr(ext, 'attributes', {}),
"confidence": getattr(ext, 'confidence', None)
}
for ext in extraction_result.extractions
],
"extraction_count": len(extraction_result.extractions),
"quality_score": self._calculate_quality_score(extraction_result)
}
else:
results["modules"][module_name] = {
"extractions": [],
"extraction_count": 0,
"quality_score": 0.0,
"error": "No valid extractions found"
}
# 模块提取重试机制最多重试3次
max_retries = 3
extraction_result = None
retry_errors = []
for attempt in range(max_retries):
try:
if attempt == 0:
logger.info(f" 提取{module_name}模块...")
else:
logger.info(f" 重试{module_name}模块... (尝试 {attempt + 1}/{max_retries})")
except Exception as e:
logger.error(f" {module_name}模块提取失败: {e}")
# 使用LangExtract进行结构化提取
extraction_result = lx.extract(
text_or_documents=paper_content,
prompt_description=config["prompt"],
examples=config["examples"],
**self.extract_config
)
# 检查提取是否成功
if extraction_result and hasattr(extraction_result, 'extractions') and extraction_result.extractions:
logger.info(f" {module_name}模块提取成功 (尝试 {attempt + 1})")
break # 成功,跳出重试循环
else:
error_msg = f"No valid extractions found (attempt {attempt + 1})"
retry_errors.append(error_msg)
logger.warning(f" {module_name}模块提取失败: {error_msg}")
except Exception as e:
error_msg = f"API call failed (attempt {attempt + 1}): {str(e)}"
retry_errors.append(error_msg)
logger.error(f" {module_name}模块提取异常: {error_msg}")
# 如果还有重试机会,稍作等待
if attempt < max_retries - 1:
import time
time.sleep(1) # 等待1秒再重试
# 处理最终结果
if extraction_result and hasattr(extraction_result, 'extractions') and extraction_result.extractions:
results["modules"][module_name] = {
"extractions": [
{
"extraction_class": ext.extraction_class,
"extraction_text": ext.extraction_text,
"start_index": getattr(ext, 'start_index', None),
"end_index": getattr(ext, 'end_index', None),
"attributes": getattr(ext, 'attributes', {}),
"confidence": getattr(ext, 'confidence', None)
}
for ext in extraction_result.extractions
],
"extraction_count": len(extraction_result.extractions),
"quality_score": self._calculate_quality_score(extraction_result),
"retry_attempts": len([e for e in retry_errors if e]) + 1 # 记录总尝试次数
}
else:
# 所有重试都失败,使用默认值
results["modules"][module_name] = {
"extractions": [],
"extraction_count": 0,
"quality_score": 0.0,
"error": str(e)
"error": f"All {max_retries} attempts failed",
"retry_errors": retry_errors,
"retry_attempts": max_retries
}
return results
def _check_paper_already_extracted(self, papers_dir: str, paper_id: str) -> bool:
"""检查论文是否已经提取过,避免重复处理
Args:
papers_dir: 论文目录路径
paper_id: 论文ID
Returns:
bool: True表示已提取过False表示需要处理
"""
paper_subdir = Path(papers_dir) / paper_id
# 检查两个关键文件是否都存在
json_file = paper_subdir / "mimic_langextract_dataset.json"
html_file = paper_subdir / "mimic_langextract_dataset.html"
return json_file.exists() and html_file.exists()
def _preprocess_paper_content(self, content: str) -> str:
"""预处理论文内容,去除无关信息
Args:
content: 原始论文内容
Returns:
str: 处理后的论文内容
"""
import re
try:
# 1. 去除Abstract之前的内容如果没有Abstract则尝试Introduction
# 优先寻找Abstract部分
abstract_pattern = r'((?:abstract|ABSTRACT|Abstract)\s*:?\s*\n.*?)$'
abstract_match = re.search(abstract_pattern, content, re.DOTALL | re.IGNORECASE)
if abstract_match:
content = abstract_match.group(1)
logger.info("已保留Abstract及之后的内容")
else:
# 如果没有Abstract尝试寻找Introduction
intro_pattern = r'((?:introduction|INTRODUCTION|Introduction)\s*:?\s*\n.*?)$'
intro_match = re.search(intro_pattern, content, re.DOTALL | re.IGNORECASE)
if intro_match:
content = intro_match.group(1)
logger.info("已保留Introduction及之后的内容")
else:
logger.info("未找到Abstract或Introduction标识保持原内容")
# 2. 去除References部分
# 匹配References/REFERENCES/Bibliography等开始的部分到文末
ref_patterns = [
r'\n\s*(references|REFERENCES|References|bibliography|BIBLIOGRAPHY|Bibliography)\s*:?\s*\n.*$',
r'\n\s*\d+\.\s*References\s*\n.*$',
r'\n\s*参考文献\s*\n.*$'
]
original_content_length = len(content)
for pattern in ref_patterns:
content = re.sub(pattern, '', content, flags=re.DOTALL | re.IGNORECASE)
if len(content) != original_content_length: # 检查是否有修改
logger.info("已移除References部分")
# 3. 去除所有URL链接
url_patterns = [
r'https?://[^\s\]\)]+', # http/https链接
r'www\.[^\s\]\)]+', # www链接
r'doi:[^\s\]\)]+', # doi链接
r'arxiv:[^\s\]\)]+', # arxiv链接
]
original_length = len(content)
for pattern in url_patterns:
content = re.sub(pattern, '[URL_REMOVED]', content, flags=re.IGNORECASE)
if len(content) != original_length:
logger.info("已移除URL链接")
# 清理多余的空行
content = re.sub(r'\n\s*\n\s*\n+', '\n\n', content)
content = content.strip()
return content
except Exception as e:
logger.warning(f"论文内容预处理失败: {e},使用原始内容")
return content
def _process_single_paper(self, paper_item: tuple, papers_dir: str, total_papers: int) -> Dict[str, Any]:
"""处理单个论文的辅助方法,用于并行处理
Args:
paper_item: (paper_id, content) 元组
papers_dir: 论文目录路径
total_papers: 总论文数用于进度显示
Returns:
Dict[str, Any]: 包含论文ID和提取结果的字典
"""
paper_id, content = paper_item
try:
# 检查是否已经提取过,避免重复处理
if self._check_paper_already_extracted(papers_dir, paper_id):
logger.info(f"跳过已处理论文: {paper_id} (输出文件已存在)")
return {
"paper_id": paper_id,
"result": None,
"status": "skipped",
"reason": "已提取过,输出文件已存在"
}
logger.info(f"开始处理论文: {paper_id}")
# 预处理论文内容,去除无关信息
processed_content = self._preprocess_paper_content(content)
logger.info(f"论文内容预处理完成: {paper_id}")
# 提取论文模块信息
paper_result = self.extract_paper_modules(processed_content, paper_id)
# 为单个论文保存结果(这个操作应该是线程安全的,因为每个论文有独立的子目录)
self._save_individual_paper_result(papers_dir, paper_id, paper_result)
# 记录论文提取完成的进度日志
successful_modules = sum(1 for module_data in paper_result.get('modules', {}).values()
if module_data.get('extraction_count', 0) > 0)
total_modules = len(paper_result.get('modules', {}))
total_extractions = sum(module_data.get('extraction_count', 0)
for module_data in paper_result.get('modules', {}).values())
logger.info(f"✓ 论文提取完成: {paper_id} - 成功模块: {successful_modules}/{total_modules} - 总提取项: {total_extractions}")
return {
"paper_id": paper_id,
"result": paper_result,
"status": "success"
}
except Exception as e:
logger.error(f"处理论文 {paper_id} 失败: {e}")
return {
"paper_id": paper_id,
"result": None,
"status": "failed",
"error": str(e)
}
def build_reproduction_dataset(self, papers_dir: str, output_file: str, max_papers: Optional[int] = None) -> Dict[str, Any]:
"""构建完整的复现数据集"""
papers = self._load_markdown_papers(papers_dir)
@ -462,20 +645,80 @@ class MIMICLangExtractBuilder:
dataset["metadata"]["note"] = f"测试模式: 只处理前{max_papers}篇论文"
logger.info(f"测试模式: 只处理前 {max_papers} 篇论文")
logger.info(f"开始处理 {len(papers)} 篇论文...")
# 统计需要处理的论文数(排除已处理的)
papers_to_process = 0
already_processed = 0
for i, (paper_id, content) in enumerate(papers.items()):
logger.info(f"[{i+1}/{len(papers)}] 处理论文: {paper_id}")
for paper_id in papers.keys():
if self._check_paper_already_extracted(papers_dir, paper_id):
already_processed += 1
else:
papers_to_process += 1
logger.info(f"发现 {len(papers)} 篇论文,已处理 {already_processed} 篇,待处理 {papers_to_process}")
logger.info(f"开始处理论文... (并行度: {self.doc_workers})")
if papers_to_process == 0:
logger.info("所有论文都已处理完成,无需重新提取")
return dataset
# 并行处理所有论文
completed_count = 0
paper_items = list(papers.items())
with ThreadPoolExecutor(max_workers=self.doc_workers) as executor:
# 提交所有任务
future_to_paper = {
executor.submit(self._process_single_paper, paper_item, papers_dir, len(papers)): paper_item[0]
for paper_item in paper_items
}
paper_result = self.extract_paper_modules(content, paper_id)
dataset["papers"][paper_id] = paper_result
# 为每个论文单独保存结果到其子文件夹
self._save_individual_paper_result(papers_dir, paper_id, paper_result)
# 定期保存全局进度
if (i + 1) % 10 == 0:
self._save_progress(dataset, output_file)
# 处理完成的任务
for future in as_completed(future_to_paper):
completed_count += 1
paper_id = future_to_paper[future]
try:
result = future.result()
if result["status"] == "success":
dataset["papers"][paper_id] = result["result"]
logger.info(f"[{completed_count}/{len(papers)}] 完成论文: {paper_id}")
elif result["status"] == "skipped":
# 跳过的论文不计入失败,但需要记录日志
logger.info(f"[{completed_count}/{len(papers)}] 跳过论文: {paper_id} - {result.get('reason', '已处理')}")
# 跳过的论文可以选择不加入最终数据集或加入但标记为跳过
continue
else:
logger.error(f"[{completed_count}/{len(papers)}] 失败论文: {paper_id} - {result.get('error', '未知错误')}")
# 即使处理失败也要在数据集中记录
dataset["papers"][paper_id] = {
"paper_id": paper_id,
"extraction_metadata": {
"timestamp": datetime.now().isoformat(),
"method": "langextract_with_source_grounding",
"model": "gpt-oss-20b",
"error": result.get("error", "未知错误")
},
"modules": {}
}
except Exception as e:
logger.error(f"[{completed_count}/{len(papers)}] 处理论文 {paper_id} 时发生异常: {e}")
# 记录异常情况
dataset["papers"][paper_id] = {
"paper_id": paper_id,
"extraction_metadata": {
"timestamp": datetime.now().isoformat(),
"method": "langextract_with_source_grounding",
"model": "gpt-oss-20b",
"error": str(e)
},
"modules": {}
}
# 定期保存全局进度(线程安全)
if completed_count % 10 == 0:
with self.progress_lock:
self._save_progress(dataset, output_file)
# 保存最终结果
self._save_dataset(dataset, output_file)
@ -493,15 +736,36 @@ class MIMICLangExtractBuilder:
if not papers_path.exists():
raise FileNotFoundError(f"论文目录不存在: {papers_dir}")
# 修改加载逻辑:从子目录中读取.md文件
# 修改加载逻辑:从所有任务类型前缀的子目录中读取.md文件
task_prefixes = ["PRED_", "CLAS_", "TIME_", "CORR_"]
markdown_files = []
valid_subdirs = []
for subdir in papers_path.iterdir():
if subdir.is_dir():
md_files = list(subdir.glob("*.md"))
markdown_files.extend(md_files)
# 检查是否以任何任务类型前缀开头
has_task_prefix = any(subdir.name.startswith(prefix) for prefix in task_prefixes)
if has_task_prefix:
valid_subdirs.append(subdir)
md_files = list(subdir.glob("*.md"))
markdown_files.extend(md_files)
logger.info(f"发现 {len(valid_subdirs)} 个通过筛选的有效论文文件夹 (支持的任务类型前缀: {task_prefixes})")
logger.info(f"有效文件夹列表: {[d.name for d in valid_subdirs[:5]]}") # 显示前5个作为示例
# 统计各类任务的数量
task_counts = {prefix.rstrip('_').lower(): 0 for prefix in task_prefixes}
for subdir in valid_subdirs:
for prefix in task_prefixes:
if subdir.name.startswith(prefix):
task_name = prefix.rstrip('_').lower()
task_counts[task_name] += 1
break
logger.info(f"任务类型分布: {dict(task_counts)}")
if not markdown_files:
raise ValueError(f"{papers_dir} 目录中未找到markdown文件")
total_subdirs = len([d for d in papers_path.iterdir() if d.is_dir()])
raise ValueError(f"{papers_dir} 目录中未找到有效的markdown文件 (总文件夹: {total_subdirs}, 有效文件夹: {len(valid_subdirs)}, 支持的前缀: {task_prefixes})")
logger.info(f"发现 {len(markdown_files)} 个markdown文件")

View File

@ -11,13 +11,21 @@ import time
import zipfile
import tempfile
import re
import json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Optional, Tuple
class PDFParser:
"""PDF解析类 - 用于将PDF文件转换为Markdown格式"""
"""PDF解析类 - 用于将PDF文件转换为Markdown格式并按任务类型筛选
支持的任务类型
- prediction: 预测任务 (PRED_)
- classification: 分类任务 (CLAS_)
- time_series: 时间序列分析 (TIME_)
- correlation: 关联性分析 (CORR_)
"""
def __init__(self, pdf_dir: str = "dataset/pdfs", parallel: int = 3,
markdown_dir: str = "dataset/markdowns"):
@ -35,6 +43,26 @@ class PDFParser:
# OCR API配置
self.ocr_api_url = "http://100.106.4.14:7861/parse"
# AI模型API配置用于四类任务识别prediction/classification/time_series/correlation
self.ai_api_url = "http://100.82.33.121:11001/v1/chat/completions"
self.ai_model = "gpt-oss-20b"
# MIMIC-IV关键词配置用于内容筛选
self.mimic_keywords = [
"MIMIC-IV", "MIMIC 4", "MIMIC IV", "MIMIC-4",
"Medical Information Mart Intensive Care IV",
"MIMIC-IV dataset", "MIMIC-IV database"
]
# 任务类型到前缀的映射配置
self.task_type_prefixes = {
"prediction": "PRED_",
"classification": "CLAS_",
"time_series": "TIME_",
"correlation": "CORR_",
"none": None # 不符合任何类型,不标记
}
# HTTP会话配置增加连接池大小和超时时间
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
@ -77,6 +105,237 @@ class PDFParser:
logging.info(f"发现 {len(pdf_files)} 个PDF文件待处理")
return pdf_files
def _check_mimic_keywords(self, output_subdir: Path) -> bool:
"""检查Markdown文件是否包含MIMIC-IV关键词
Args:
output_subdir (Path): 包含Markdown文件的输出子目录
Returns:
bool: 是否包含MIMIC-IV关键词
"""
try:
# 查找所有.md文件
md_files = list(output_subdir.glob("*.md"))
if not md_files:
logging.warning(f"未找到Markdown文件进行MIMIC关键词检查: {output_subdir}")
return False
# 检查每个Markdown文件的内容
for md_file in md_files:
try:
with open(md_file, 'r', encoding='utf-8') as f:
content = f.read().lower() # 转换为小写进行不区分大小写匹配
# 检查是否包含任何MIMIC-IV关键词
for keyword in self.mimic_keywords:
if keyword.lower() in content:
logging.info(f"发现MIMIC-IV关键词 '{keyword}' 在文件 {md_file.name}")
return True
except Exception as e:
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
continue
logging.info(f"未发现MIMIC-IV关键词: {output_subdir.name}")
return False
except Exception as e:
logging.error(f"检查MIMIC关键词时发生错误: {output_subdir} - {e}")
return False
def _extract_introduction(self, output_subdir: Path) -> Optional[str]:
"""从Markdown文件中提取Introduction部分
Args:
output_subdir (Path): 包含Markdown文件的输出子目录
Returns:
Optional[str]: 提取的Introduction内容失败时返回None
"""
try:
# 查找所有.md文件
md_files = list(output_subdir.glob("*.md"))
if not md_files:
logging.warning(f"未找到Markdown文件进行Introduction提取: {output_subdir}")
return None
# 通常使用第一个md文件
md_file = md_files[0]
try:
with open(md_file, 'r', encoding='utf-8') as f:
content = f.read()
# 使用正则表达式提取Introduction部分
# 匹配各种可能的Introduction标题格式
patterns = [
r'(?i)#\s*Introduction\s*\n(.*?)(?=\n#|\n\n#|$)',
r'(?i)##\s*Introduction\s*\n(.*?)(?=\n##|\n\n##|$)',
r'(?i)###\s*Introduction\s*\n(.*?)(?=\n###|\n\n###|$)',
r'(?i)\*\*Introduction\*\*\s*\n(.*?)(?=\n\*\*|\n\n\*\*|$)',
r'(?i)Introduction\s*\n(.*?)(?=\n[A-Z][a-z]+\s*\n|$)'
]
for pattern in patterns:
match = re.search(pattern, content, re.DOTALL)
if match:
introduction = match.group(1).strip()
if len(introduction) > 100: # 确保有足够的内容进行分析
logging.info(f"成功提取Introduction部分 ({len(introduction)} 字符): {md_file.name}")
return introduction
# 如果没有明确的Introduction标题尝试提取前几段作为近似的introduction
paragraphs = content.split('\n\n')
introduction_candidates = []
for para in paragraphs[:5]: # 取前5段
para = para.strip()
if len(para) > 50 and not para.startswith('#'): # 过滤掉标题和过短段落
introduction_candidates.append(para)
if introduction_candidates:
introduction = '\n\n'.join(introduction_candidates[:3]) # 最多取前3段
if len(introduction) > 200:
logging.info(f"提取近似Introduction部分 ({len(introduction)} 字符): {md_file.name}")
return introduction
logging.warning(f"未能提取到有效的Introduction内容: {md_file.name}")
return None
except Exception as e:
logging.error(f"读取Markdown文件时发生错误: {md_file.name} - {e}")
return None
except Exception as e:
logging.error(f"提取Introduction时发生错误: {output_subdir} - {e}")
return None
def _analyze_research_task(self, introduction: str) -> str:
"""使用AI模型分析论文的研究任务类型
Args:
introduction (str): 论文的Introduction内容
Returns:
str: 任务类型 ('prediction', 'classification', 'time_series', 'correlation', 'none')
"""
try:
# 构造AI分析的提示词
system_prompt = """你是一个医学研究专家。请分析给定的论文Introduction部分判断该研究属于以下哪种任务类型
1. prediction - 预测任务预测未来事件结局或数值如死亡率预测住院时长预测疾病进展预测
2. classification - 分类任务将患者或病例分类到不同类别如疾病诊断分类风险等级分类药物反应分类
3. time_series - 时间序列分析分析随时间变化的医疗数据如生命体征趋势分析病情演进分析纵向队列研究
4. correlation - 关联性分析研究变量间的关系或关联如痾病与人口特征关系药物与副作用关联风险因素识别
5. none - 不属于以上任何类型
请以JSON格式回答包含任务类型和置信度
{\"task_type\": \"prediction\", \"confidence\": 0.85}
task_type必须是以下选项之一predictionclassificationtime_seriescorrelationnone
confidence为0-1之间的数值表示判断的置信度
只返回JSON不要添加其他文字"""
user_prompt = f"请分析以下论文Introduction判断属于哪种任务类型\n\n{introduction[:2000]}" # 限制长度避免token过多
# 构造API请求数据
api_data = {
"model": self.ai_model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"max_tokens": 50, # 需要返回JSON格式
"temperature": 0.1 # 降低随机性
}
# 调用AI API
response = self.session.post(
self.ai_api_url,
json=api_data,
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code == 200:
result = response.json()
ai_response = result['choices'][0]['message']['content'].strip()
try:
# 解析JSON响应
parsed_response = json.loads(ai_response)
task_type = parsed_response.get('task_type', 'none').lower()
confidence = parsed_response.get('confidence', 0.0)
# 验证任务类型是否有效
valid_types = ['prediction', 'classification', 'time_series', 'correlation', 'none']
if task_type not in valid_types:
logging.warning(f"AI返回了无效的任务类型: {task_type},使用默认值 'none'")
task_type = "none"
confidence = 0.0
# 只接受高置信度的结果
if confidence < 0.7:
logging.info(f"AI分析置信度过低 ({confidence:.2f}),归类为 'none'")
task_type = "none"
logging.info(f"AI分析结果: 任务类型={task_type}, 置信度={confidence:.2f}")
return task_type
except json.JSONDecodeError as e:
logging.error(f"解析AI JSON响应失败: {ai_response} - 错误: {e}")
return "none"
else:
logging.error(f"AI API调用失败状态码: {response.status_code}")
return "none"
except Exception as e:
logging.error(f"AI分析研究任务时发生错误: {e}")
return "none"
def _mark_valid_folder(self, output_subdir: Path, task_type: str) -> bool:
"""为通过筛选的文件夹添加任务类型前缀标记
Args:
output_subdir (Path): 需要标记的输出子目录
task_type (str): 任务类型 ('prediction', 'classification', 'time_series', 'correlation')
Returns:
bool: 标记是否成功
"""
try:
# 获取任务类型对应的前缀
prefix = self.task_type_prefixes.get(task_type)
if not prefix:
logging.info(f"任务类型 '{task_type}' 不需要标记文件夹")
return True # 不需要标记,但认为成功
# 检查文件夹是否已经有相应的任务类型前缀
if output_subdir.name.startswith(prefix):
logging.info(f"文件夹已标记为{task_type}任务: {output_subdir.name}")
return True
# 检查是否已经有其他任务类型的前缀
for existing_type, existing_prefix in self.task_type_prefixes.items():
if existing_prefix and output_subdir.name.startswith(existing_prefix):
logging.info(f"文件夹已有{existing_type}任务标记,不需要重新标记: {output_subdir.name}")
return True
# 生成新的文件夹名
new_folder_name = prefix + output_subdir.name
new_folder_path = output_subdir.parent / new_folder_name
# 重命名文件夹
output_subdir.rename(new_folder_path)
logging.info(f"文件夹标记成功: {output_subdir.name} -> {new_folder_name} (任务类型: {task_type})")
return True
except Exception as e:
logging.error(f"标记文件夹时发生错误: {output_subdir} - {e}")
return False
def _prepare_output_dir(self) -> Path:
"""准备Markdown输出目录
@ -276,7 +535,38 @@ class PDFParser:
# 下载并解压ZIP文件
success = self._download_and_extract_zip(full_download_url, pdf_file)
return success
if not success:
return False
# 获取解压后的文件夹路径
output_subdir = self.markdown_dir / pdf_file.stem
# 第一层筛选检查MIMIC-IV关键词
logging.info(f"开始MIMIC-IV关键词筛选: {pdf_file.stem}")
if not self._check_mimic_keywords(output_subdir):
logging.info(f"未通过MIMIC-IV关键词筛选跳过: {pdf_file.stem}")
return True # 处理成功但未通过筛选
# 第二层筛选AI分析研究任务
logging.info(f"开始AI研究任务分析: {pdf_file.stem}")
introduction = self._extract_introduction(output_subdir)
if not introduction:
logging.warning(f"无法提取Introduction跳过AI分析: {pdf_file.stem}")
return True # 处理成功但无法进行任务分析
task_type = self._analyze_research_task(introduction)
if task_type == "none":
logging.info(f"未通过研究任务筛选 (task_type=none),跳过: {pdf_file.stem}")
return True # 处理成功但未通过筛选
# 两层筛选都通过,根据任务类型标记文件夹
logging.info(f"通过所有筛选,标记为{task_type}任务论文: {pdf_file.stem}")
if self._mark_valid_folder(output_subdir, task_type):
logging.info(f"论文筛选完成,已标记为{task_type}任务: {pdf_file.stem}")
else:
logging.warning(f"文件夹标记失败: {pdf_file.stem}")
return True
except Exception as e:
logging.error(f"处理PDF文件时发生错误: {pdf_file.name} - {e}")