feat: 实现基于LangExtract框架的MIMIC论文信息提取系统
- 新增info_extractor.py主文件,支持命令行参数和测试模式 - 实现src/extractor.py核心MIMICLangExtractBuilder类 - 集成vllm API服务(OpenAI兼容格式)进行结构化信息提取 - 支持5大模块提取:数据集、模型、训练、评估、环境配置 - 实现源文本定位和交互式HTML可视化 - 添加langextract和httpx[socks]依赖 - 支持个性化论文子目录结果保存 - 清理过时的experiment_runner.py和number_extraction_models.py文件
This commit is contained in:
parent
1b652502d5
commit
c4037325ed
@ -14,6 +14,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
|
||||
## 各个模块的主文件
|
||||
1. 论文爬取主文件: papers_crawler.py
|
||||
2. pdf解析主文件: pdf_parser.py
|
||||
3. 信息抽取主文件: info_extractor.py
|
||||
3. 实验运行主文件: experiment_runner.py
|
||||
|
||||
## 文件结构
|
||||
@ -23,6 +24,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用
|
||||
│ └── mimic.csv # 存放所有需要处理的与mimic相关论文的基础信息
|
||||
├── papers_crawler.py # 论文爬取主文件
|
||||
├── pdf_parser.py # pdf解析主文件
|
||||
├── info_extractor.py # 信息抽取主文件
|
||||
├── experiment_runner.py # 实验运行主文件
|
||||
├── src/ # 源代码目录
|
||||
│ └── utils/ # 工具函数目录
|
||||
|
||||
139
info_extractor.py
Normal file
139
info_extractor.py
Normal file
@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于LangExtract的MIMIC论文信息提取器
|
||||
从医学论文中提取结构化的复现任务信息
|
||||
|
||||
作者:MedResearcher项目
|
||||
创建时间:2025-01-25
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from src.extractor import MIMICLangExtractBuilder
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
def setup_args():
|
||||
"""设置命令行参数解析
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: 解析后的命令行参数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='MIMIC论文信息提取工具 - 基于LangExtract从医学论文中提取结构化复现信息',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog='''
|
||||
使用示例:
|
||||
%(prog)s # 使用默认参数
|
||||
%(prog)s --papers_dir dataset/markdowns # 指定论文目录
|
||||
%(prog)s --output_file results/dataset.json # 指定输出文件
|
||||
%(prog)s --test_mode --max_papers 5 # 测试模式,只处理5篇论文
|
||||
'''
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--papers_dir',
|
||||
type=str,
|
||||
default='dataset/markdowns',
|
||||
help='markdown论文文件目录 (默认: dataset/markdowns)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output_file',
|
||||
type=str,
|
||||
default='dataset/reproduction_tasks/mimic_langextract_dataset.json',
|
||||
help='输出数据集文件路径 (默认: dataset/reproduction_tasks/mimic_langextract_dataset.json)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--test_mode',
|
||||
action='store_true',
|
||||
help='测试模式,只处理少量论文进行验证'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max_papers',
|
||||
type=int,
|
||||
default=None,
|
||||
help='最大处理论文数量,用于测试 (默认: 处理所有论文)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
default='INFO',
|
||||
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
||||
help='日志级别 (默认: INFO)'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 执行MIMIC论文信息提取任务"""
|
||||
try:
|
||||
# 解析命令行参数
|
||||
args = setup_args()
|
||||
|
||||
# 设置日志级别
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# 初始化信息提取器
|
||||
builder = MIMICLangExtractBuilder()
|
||||
|
||||
print(f"=== MIMIC论文信息提取工具启动 ===")
|
||||
print(f"论文目录: {args.papers_dir}")
|
||||
print(f"输出文件: {args.output_file}")
|
||||
print(f"测试模式: {'是' if args.test_mode else '否'}")
|
||||
if args.max_papers:
|
||||
print(f"最大论文数: {args.max_papers}")
|
||||
print(f"日志级别: {args.log_level}")
|
||||
print(f"========================")
|
||||
|
||||
# 构建复现数据集
|
||||
print("\n开始构建MIMIC复现数据集...")
|
||||
dataset = builder.build_reproduction_dataset(
|
||||
papers_dir=args.papers_dir,
|
||||
output_file=args.output_file,
|
||||
max_papers=args.max_papers if args.test_mode or args.max_papers else None
|
||||
)
|
||||
|
||||
# 统计结果
|
||||
total_papers = dataset['metadata']['total_papers']
|
||||
successful_extractions = sum(
|
||||
1 for paper in dataset['papers'].values()
|
||||
if any(module.get('extraction_count', 0) > 0
|
||||
for module in paper.get('modules', {}).values())
|
||||
)
|
||||
|
||||
print(f"\n=== 构建完成 ===")
|
||||
print(f"总论文数: {total_papers}")
|
||||
print(f"成功提取: {successful_extractions}/{total_papers}")
|
||||
print(f"成功率: {successful_extractions/total_papers*100:.1f}%")
|
||||
print(f"结果保存至: {args.output_file}")
|
||||
print(f"交互式报告: {args.output_file.replace('.json', '.html')}")
|
||||
print(f"===============")
|
||||
|
||||
return 0
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误: 找不到指定的文件或目录 - {e}")
|
||||
return 1
|
||||
except ValueError as e:
|
||||
print(f"错误: 参数值无效 - {e}")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"错误: 程序执行异常 - {e}")
|
||||
logging.exception("详细错误信息:")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = main()
|
||||
exit(exit_code)
|
||||
@ -6,6 +6,8 @@ readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"agno>=1.7.12",
|
||||
"httpx[socks]>=0.28.1",
|
||||
"langextract>=1.0.8",
|
||||
"ollama>=0.5.3",
|
||||
"openai>=1.101.0",
|
||||
"pydantic",
|
||||
|
||||
809
src/extractor.py
Normal file
809
src/extractor.py
Normal file
@ -0,0 +1,809 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
基于LangExtract的MIMIC论文信息提取器 - 核心实现
|
||||
从医学论文中提取结构化的复现任务信息
|
||||
|
||||
作者:MedResearcher项目
|
||||
创建时间:2025-01-25
|
||||
"""
|
||||
|
||||
import langextract as lx
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MIMICLangExtractBuilder:
|
||||
"""基于LangExtract的MIMIC论文信息提取器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化提取器,配置vllm API服务"""
|
||||
try:
|
||||
# 配置LangExtract使用vllm API(通过OpenAI兼容接口)
|
||||
import os
|
||||
os.environ["LANGEXTRACT_API_KEY"] = "dummy"
|
||||
|
||||
# 创建ModelConfig,强制使用OpenAI提供者访问vllm端点
|
||||
self.model_config = lx.factory.ModelConfig(
|
||||
model_id="gpt-oss-20b", # 使用vllm中实际部署的模型名称
|
||||
provider="OpenAILanguageModel", # 强制指定OpenAI提供者
|
||||
provider_kwargs={
|
||||
"base_url": "http://100.82.33.121:11001/v1", # vllm API端点
|
||||
"api_key": "dummy",
|
||||
"model_id": "gpt-oss-20b" # 确保使用正确的模型ID
|
||||
}
|
||||
)
|
||||
|
||||
# LangExtract通用配置参数
|
||||
self.extract_config = {
|
||||
"config": self.model_config,
|
||||
"max_workers": 3, # 降低并发,避免过载vllm服务
|
||||
"max_char_buffer": 6000, # 适合医学论文的上下文长度
|
||||
"extraction_passes": 1, # 单次提取,避免过多API调用
|
||||
"temperature": 0.1, # 较低温度确保一致性
|
||||
"fence_output": True, # 期望代码围栏格式输出
|
||||
"use_schema_constraints": False # vllm可能不支持严格schema
|
||||
}
|
||||
|
||||
# 加载所有模块的提取配置
|
||||
self.module_configs = {
|
||||
"data": self._load_data_config(),
|
||||
"model": self._load_model_config(),
|
||||
"training": self._load_training_config(),
|
||||
"evaluation": self._load_evaluation_config(),
|
||||
"environment": self._load_environment_config()
|
||||
}
|
||||
|
||||
logger.info("MIMICLangExtractBuilder初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def _load_data_config(self) -> Dict[str, Any]:
|
||||
"""加载数据模块的LangExtract配置"""
|
||||
return {
|
||||
"prompt": textwrap.dedent("""
|
||||
从医学论文中提取数据处理相关的具体信息。严格按照以下规则:
|
||||
|
||||
1. dataset_source: 提取明确提到的数据集名称(如"MIMIC-IV", "Stanford EHR")
|
||||
2. data_scale: 提取具体的数据规模数字(如"135,483 patients", "2015-2023")
|
||||
3. preprocessing_step: 提取数据预处理的具体步骤描述
|
||||
4. feature_type: 提取特征类型和编码方法的描述
|
||||
5. inclusion_criteria: 提取患者纳入标准的确切文本
|
||||
6. exclusion_criteria: 提取患者排除标准的确切文本
|
||||
|
||||
使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。
|
||||
"""),
|
||||
"examples": [
|
||||
lx.data.ExampleData(
|
||||
text="We analyzed 135,483 ED blood culture orders from Stanford Medicine EHR between 2015-2023. Adult patients (≥18 years) with blood culture collection in the ED were included. Patients with positive blood cultures within 14 days were excluded. Features were one-hot encoded for ML compatibility.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="dataset_source",
|
||||
extraction_text="Stanford Medicine EHR",
|
||||
attributes={
|
||||
"data_type": "electronic health records",
|
||||
"institution": "Stanford Medicine"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="data_scale",
|
||||
extraction_text="135,483 ED blood culture orders",
|
||||
attributes={
|
||||
"sample_size": "135,483",
|
||||
"time_range": "2015-2023",
|
||||
"data_unit": "blood culture orders"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="inclusion_criteria",
|
||||
extraction_text="Adult patients (≥18 years) with blood culture collection in the ED",
|
||||
attributes={
|
||||
"age_limit": "≥18 years",
|
||||
"setting": "Emergency Department",
|
||||
"requirement": "blood culture collection"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="exclusion_criteria",
|
||||
extraction_text="Patients with positive blood cultures within 14 days were excluded",
|
||||
attributes={
|
||||
"timeframe": "within 14 days",
|
||||
"condition": "positive blood cultures"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="feature_type",
|
||||
extraction_text="Features were one-hot encoded for ML compatibility",
|
||||
attributes={
|
||||
"encoding_method": "one-hot encoding",
|
||||
"purpose": "ML compatibility"
|
||||
}
|
||||
)
|
||||
]
|
||||
),
|
||||
lx.data.ExampleData(
|
||||
text="This study utilized MIMIC-IV database, including CHARTEVENTS and LABEVENTS tables. We extracted hourly vital signs and laboratory values for ICU patients. Missing values were imputed using forward-fill method. Outliers beyond 3 standard deviations were removed.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="dataset_source",
|
||||
extraction_text="MIMIC-IV database",
|
||||
attributes={
|
||||
"data_type": "public clinical database",
|
||||
"tables": "CHARTEVENTS, LABEVENTS"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="preprocessing_step",
|
||||
extraction_text="Missing values were imputed using forward-fill method",
|
||||
attributes={
|
||||
"method": "forward-fill",
|
||||
"target": "missing values"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="preprocessing_step",
|
||||
extraction_text="Outliers beyond 3 standard deviations were removed",
|
||||
attributes={
|
||||
"method": "outlier removal",
|
||||
"threshold": "3 standard deviations"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def _load_model_config(self) -> Dict[str, Any]:
|
||||
"""加载模型模块的LangExtract配置"""
|
||||
return {
|
||||
"prompt": textwrap.dedent("""
|
||||
从医学论文中提取机器学习模型的具体信息。严格按照以下规则:
|
||||
|
||||
1. model_name: 提取明确提到的模型名称(如"XGBoost", "LSTM", "GPT-4")
|
||||
2. architecture_detail: 提取架构描述的具体文本
|
||||
3. hyperparameter: 提取超参数设置的具体数值
|
||||
4. feature_processing: 提取特征处理方法的描述
|
||||
5. model_component: 提取模型组件或模块的描述
|
||||
|
||||
使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。
|
||||
"""),
|
||||
"examples": [
|
||||
lx.data.ExampleData(
|
||||
text="We employed XGBoost classifier with max depth of 4 and 30 boosting iterations. Class weights were used to handle imbalanced data. STELLA 1.5B model was used for text embeddings with attention-weighted average pooling.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="model_name",
|
||||
extraction_text="XGBoost classifier",
|
||||
attributes={
|
||||
"model_type": "gradient boosting",
|
||||
"task": "classification"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="hyperparameter",
|
||||
extraction_text="max depth of 4 and 30 boosting iterations",
|
||||
attributes={
|
||||
"max_depth": "4",
|
||||
"n_estimators": "30",
|
||||
"parameter_type": "tree_structure"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="model_name",
|
||||
extraction_text="STELLA 1.5B model",
|
||||
attributes={
|
||||
"model_type": "pretrained language model",
|
||||
"parameters": "1.5B",
|
||||
"purpose": "text embeddings"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="feature_processing",
|
||||
extraction_text="attention-weighted average pooling",
|
||||
attributes={
|
||||
"technique": "pooling",
|
||||
"method": "attention-weighted"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def _load_training_config(self) -> Dict[str, Any]:
|
||||
"""加载训练模块的LangExtract配置"""
|
||||
return {
|
||||
"prompt": textwrap.dedent("""
|
||||
从医学论文中提取模型训练相关的具体信息。严格按照以下规则:
|
||||
|
||||
1. data_split_method: 提取数据分割方法的具体描述
|
||||
2. validation_approach: 提取验证策略的具体描述
|
||||
3. hyperparameter_tuning: 提取超参数调优方法
|
||||
4. stopping_condition: 提取训练停止条件
|
||||
5. optimizer_config: 提取优化器配置信息
|
||||
|
||||
使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。
|
||||
"""),
|
||||
"examples": [
|
||||
lx.data.ExampleData(
|
||||
text="Data was split temporally: training set (2015-2022), development set (2022-2023) for hyperparameter tuning, and evaluation set (2023+). Grid search was performed on the development set to optimize AUC performance. Early stopping was applied when validation loss did not improve for 10 epochs.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="data_split_method",
|
||||
extraction_text="Data was split temporally: training set (2015-2022), development set (2022-2023), and evaluation set (2023+)",
|
||||
attributes={
|
||||
"split_type": "temporal",
|
||||
"train_period": "2015-2022",
|
||||
"dev_period": "2022-2023",
|
||||
"eval_period": "2023+"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="hyperparameter_tuning",
|
||||
extraction_text="Grid search was performed on the development set to optimize AUC performance",
|
||||
attributes={
|
||||
"method": "grid search",
|
||||
"metric": "AUC",
|
||||
"dataset": "development set"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="stopping_condition",
|
||||
extraction_text="Early stopping was applied when validation loss did not improve for 10 epochs",
|
||||
attributes={
|
||||
"method": "early stopping",
|
||||
"patience": "10 epochs",
|
||||
"monitor": "validation loss"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def _load_evaluation_config(self) -> Dict[str, Any]:
|
||||
"""加载评估模块的LangExtract配置"""
|
||||
return {
|
||||
"prompt": textwrap.dedent("""
|
||||
从医学论文中提取模型评估相关的具体信息。严格按照以下规则:
|
||||
|
||||
1. evaluation_metric: 提取具体的评估指标名称(如"AUC", "F1-score", "sensitivity")
|
||||
2. baseline_comparison: 提取基线模型或方法的描述
|
||||
3. performance_result: 提取具体的性能数值结果
|
||||
4. statistical_test: 提取统计检验方法的描述
|
||||
5. experimental_setting: 提取实验设置的具体信息
|
||||
|
||||
使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。
|
||||
"""),
|
||||
"examples": [
|
||||
lx.data.ExampleData(
|
||||
text="The model achieved ROC-AUC of 0.85 (95% CI: 0.82-0.88) on the test set. We compared against three baselines: expert framework (manual assessment), structured-only model, and LLM-automated framework. At 90% sensitivity, our model achieved 45% specificity versus 32% for the baseline.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="evaluation_metric",
|
||||
extraction_text="ROC-AUC",
|
||||
attributes={
|
||||
"metric_type": "discriminative performance",
|
||||
"range": "0-1"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="performance_result",
|
||||
extraction_text="ROC-AUC of 0.85 (95% CI: 0.82-0.88)",
|
||||
attributes={
|
||||
"metric": "ROC-AUC",
|
||||
"value": "0.85",
|
||||
"confidence_interval": "0.82-0.88",
|
||||
"confidence_level": "95%"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="baseline_comparison",
|
||||
extraction_text="expert framework (manual assessment), structured-only model, and LLM-automated framework",
|
||||
attributes={
|
||||
"baseline_count": "3",
|
||||
"comparison_type": "multiple baselines"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def _load_environment_config(self) -> Dict[str, Any]:
|
||||
"""加载环境模块的LangExtract配置"""
|
||||
return {
|
||||
"prompt": textwrap.dedent("""
|
||||
从医学论文中提取实验环境相关的具体信息。严格按照以下规则:
|
||||
|
||||
1. software_library: 提取具体的软件工具和库名称
|
||||
2. hardware_resource: 提取硬件资源需求的描述
|
||||
3. data_repository: 提取数据存储和访问的具体信息
|
||||
4. code_availability: 提取代码可用性的具体描述
|
||||
5. compliance_requirement: 提取合规性和部署要求
|
||||
|
||||
使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。
|
||||
"""),
|
||||
"examples": [
|
||||
lx.data.ExampleData(
|
||||
text="We implemented the models using Python 3.8 with scikit-learn 1.0.2 and XGBoost 1.5.0. Training was performed on NVIDIA A100 GPU with 40GB memory. Code is available at GitHub: https://github.com/HealthRex/CDSS. The study was approved by Stanford IRB.",
|
||||
extractions=[
|
||||
lx.data.Extraction(
|
||||
extraction_class="software_library",
|
||||
extraction_text="Python 3.8 with scikit-learn 1.0.2 and XGBoost 1.5.0",
|
||||
attributes={
|
||||
"language": "Python",
|
||||
"version": "3.8",
|
||||
"libraries": "scikit-learn, XGBoost"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="hardware_resource",
|
||||
extraction_text="NVIDIA A100 GPU with 40GB memory",
|
||||
attributes={
|
||||
"gpu_type": "NVIDIA A100",
|
||||
"memory": "40GB",
|
||||
"resource_type": "GPU"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="code_availability",
|
||||
extraction_text="Code is available at GitHub: https://github.com/HealthRex/CDSS",
|
||||
attributes={
|
||||
"platform": "GitHub",
|
||||
"url": "https://github.com/HealthRex/CDSS",
|
||||
"access_type": "public"
|
||||
}
|
||||
),
|
||||
lx.data.Extraction(
|
||||
extraction_class="compliance_requirement",
|
||||
extraction_text="The study was approved by Stanford IRB",
|
||||
attributes={
|
||||
"approval_type": "IRB",
|
||||
"institution": "Stanford"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
def extract_paper_modules(self, paper_content: str, paper_id: str) -> Dict[str, Any]:
|
||||
"""使用LangExtract提取论文的所有模块信息"""
|
||||
|
||||
results = {
|
||||
"paper_id": paper_id,
|
||||
"extraction_metadata": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"method": "langextract_with_source_grounding",
|
||||
"model": "gpt-oss-20b"
|
||||
},
|
||||
"modules": {}
|
||||
}
|
||||
|
||||
# 逐个提取每个模块
|
||||
for module_name, config in self.module_configs.items():
|
||||
try:
|
||||
logger.info(f" 提取{module_name}模块...")
|
||||
|
||||
# 使用LangExtract进行结构化提取
|
||||
extraction_result = lx.extract(
|
||||
text_or_documents=paper_content,
|
||||
prompt_description=config["prompt"],
|
||||
examples=config["examples"],
|
||||
**self.extract_config
|
||||
)
|
||||
|
||||
# 处理提取结果 - extraction_result是AnnotatedDocument对象
|
||||
if extraction_result and hasattr(extraction_result, 'extractions') and extraction_result.extractions:
|
||||
results["modules"][module_name] = {
|
||||
"extractions": [
|
||||
{
|
||||
"extraction_class": ext.extraction_class,
|
||||
"extraction_text": ext.extraction_text,
|
||||
"start_index": getattr(ext, 'start_index', None),
|
||||
"end_index": getattr(ext, 'end_index', None),
|
||||
"attributes": getattr(ext, 'attributes', {}),
|
||||
"confidence": getattr(ext, 'confidence', None)
|
||||
}
|
||||
for ext in extraction_result.extractions
|
||||
],
|
||||
"extraction_count": len(extraction_result.extractions),
|
||||
"quality_score": self._calculate_quality_score(extraction_result)
|
||||
}
|
||||
else:
|
||||
results["modules"][module_name] = {
|
||||
"extractions": [],
|
||||
"extraction_count": 0,
|
||||
"quality_score": 0.0,
|
||||
"error": "No valid extractions found"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" {module_name}模块提取失败: {e}")
|
||||
results["modules"][module_name] = {
|
||||
"extractions": [],
|
||||
"extraction_count": 0,
|
||||
"quality_score": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def build_reproduction_dataset(self, papers_dir: str, output_file: str, max_papers: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""构建完整的复现数据集"""
|
||||
papers = self._load_markdown_papers(papers_dir)
|
||||
|
||||
dataset = {
|
||||
"metadata": {
|
||||
"creation_date": datetime.now().isoformat(),
|
||||
"total_papers": len(papers),
|
||||
"extraction_method": "langextract_source_grounded",
|
||||
"api_endpoint": "http://100.82.33.121:11001/v1",
|
||||
"model": "gpt-oss-20b",
|
||||
"langextract_version": getattr(lx, '__version__', 'unknown')
|
||||
},
|
||||
"papers": {}
|
||||
}
|
||||
|
||||
# 如果指定了最大处理数量,限制论文数量
|
||||
if max_papers and max_papers < len(papers):
|
||||
papers_items = list(papers.items())[:max_papers]
|
||||
papers = dict(papers_items)
|
||||
dataset["metadata"]["total_papers"] = len(papers)
|
||||
dataset["metadata"]["note"] = f"测试模式: 只处理前{max_papers}篇论文"
|
||||
logger.info(f"测试模式: 只处理前 {max_papers} 篇论文")
|
||||
|
||||
logger.info(f"开始处理 {len(papers)} 篇论文...")
|
||||
|
||||
for i, (paper_id, content) in enumerate(papers.items()):
|
||||
logger.info(f"[{i+1}/{len(papers)}] 处理论文: {paper_id}")
|
||||
|
||||
paper_result = self.extract_paper_modules(content, paper_id)
|
||||
dataset["papers"][paper_id] = paper_result
|
||||
|
||||
# 为每个论文单独保存结果到其子文件夹
|
||||
self._save_individual_paper_result(papers_dir, paper_id, paper_result)
|
||||
|
||||
# 定期保存全局进度
|
||||
if (i + 1) % 10 == 0:
|
||||
self._save_progress(dataset, output_file)
|
||||
|
||||
# 保存最终结果
|
||||
self._save_dataset(dataset, output_file)
|
||||
|
||||
# 生成交互式HTML报告
|
||||
self._generate_html_report(dataset, output_file.replace('.json', '.html'))
|
||||
|
||||
return dataset
|
||||
|
||||
def _load_markdown_papers(self, papers_dir: str) -> Dict[str, str]:
|
||||
"""加载markdown论文文件"""
|
||||
papers = {}
|
||||
papers_path = Path(papers_dir)
|
||||
|
||||
if not papers_path.exists():
|
||||
raise FileNotFoundError(f"论文目录不存在: {papers_dir}")
|
||||
|
||||
# 修改加载逻辑:从子目录中读取.md文件
|
||||
markdown_files = []
|
||||
for subdir in papers_path.iterdir():
|
||||
if subdir.is_dir():
|
||||
md_files = list(subdir.glob("*.md"))
|
||||
markdown_files.extend(md_files)
|
||||
|
||||
if not markdown_files:
|
||||
raise ValueError(f"在 {papers_dir} 目录中未找到markdown文件")
|
||||
|
||||
logger.info(f"发现 {len(markdown_files)} 个markdown文件")
|
||||
|
||||
for file_path in markdown_files:
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
papers[file_path.stem] = content
|
||||
except Exception as e:
|
||||
logger.warning(f"读取文件 {file_path} 失败: {e}")
|
||||
|
||||
return papers
|
||||
|
||||
def _calculate_quality_score(self, extraction_result) -> float:
|
||||
"""计算提取质量分数"""
|
||||
if not extraction_result or not hasattr(extraction_result, 'extractions'):
|
||||
return 0.0
|
||||
|
||||
if not extraction_result.extractions:
|
||||
return 0.0
|
||||
|
||||
# 基于提取数量和属性丰富度计算质量分数
|
||||
total_score = 0.0
|
||||
for ext in extraction_result.extractions:
|
||||
score = 0.3 # 基础分数
|
||||
|
||||
# 有源文本定位加分
|
||||
if hasattr(ext, 'start_index') and ext.start_index is not None:
|
||||
score += 0.2
|
||||
|
||||
# 属性丰富度加分
|
||||
if ext.attributes and len(ext.attributes) > 0:
|
||||
score += min(0.3, len(ext.attributes) * 0.1)
|
||||
|
||||
# 置信度加分
|
||||
if hasattr(ext, 'confidence') and ext.confidence:
|
||||
score += 0.2 * ext.confidence
|
||||
|
||||
total_score += score
|
||||
|
||||
return min(1.0, total_score / len(extraction_result.extractions))
|
||||
|
||||
def _save_progress(self, dataset: Dict[str, Any], output_file: str):
|
||||
"""保存处理进度"""
|
||||
try:
|
||||
progress_file = output_file.replace('.json', '_progress.json')
|
||||
with open(progress_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"进度已保存至: {progress_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存进度失败: {e}")
|
||||
|
||||
def _save_individual_paper_result(self, papers_dir: str, paper_id: str, paper_result: Dict[str, Any]):
|
||||
"""为单个论文保存提取结果到其对应的子文件夹"""
|
||||
try:
|
||||
# 构建论文子文件夹路径
|
||||
paper_subdir = Path(papers_dir) / paper_id
|
||||
if not paper_subdir.exists():
|
||||
logger.warning(f"论文子文件夹不存在: {paper_subdir}")
|
||||
return
|
||||
|
||||
# 准备单个论文的数据集格式
|
||||
individual_dataset = {
|
||||
"metadata": {
|
||||
"creation_date": datetime.now().isoformat(),
|
||||
"total_papers": 1,
|
||||
"extraction_method": "langextract_source_grounded",
|
||||
"api_endpoint": "http://100.82.33.121:11001/v1",
|
||||
"model": "gpt-oss-20b",
|
||||
"langextract_version": getattr(lx, '__version__', 'unknown'),
|
||||
"paper_id": paper_id
|
||||
},
|
||||
"paper": paper_result # 注意:这里是单个论文,所以用"paper"而不是"papers"
|
||||
}
|
||||
|
||||
# 保存JSON文件
|
||||
json_file = paper_subdir / "mimic_langextract_dataset.json"
|
||||
with open(json_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(individual_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 生成HTML报告
|
||||
html_file = paper_subdir / "mimic_langextract_dataset.html"
|
||||
self._generate_individual_html_report(individual_dataset, html_file)
|
||||
|
||||
logger.info(f"已保存论文 {paper_id} 的结果到: {paper_subdir}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存单个论文结果失败 ({paper_id}): {e}")
|
||||
|
||||
def _save_dataset(self, dataset: Dict[str, Any], output_file: str):
|
||||
"""保存最终数据集"""
|
||||
try:
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"数据集已保存至: {output_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存数据集失败: {e}")
|
||||
raise
|
||||
|
||||
def _generate_html_report(self, dataset: Dict[str, Any], output_file: str):
|
||||
"""生成LangExtract风格的交互式HTML报告"""
|
||||
try:
|
||||
# 合并所有提取结果用于可视化
|
||||
all_extractions = []
|
||||
for paper_id, paper_data in dataset["papers"].items():
|
||||
for module_name, module_data in paper_data.get("modules", {}).items():
|
||||
all_extractions.extend(module_data.get("extractions", []))
|
||||
|
||||
# 基础HTML模板(简化版可视化)
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>MIMIC复现数据集 - LangExtract报告</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||||
.header {{ background: #f0f8ff; padding: 20px; border-radius: 5px; }}
|
||||
.stats {{ display: flex; gap: 20px; margin: 20px 0; }}
|
||||
.stat-card {{ background: #e6f3ff; padding: 15px; border-radius: 5px; }}
|
||||
.extraction {{ border: 1px solid #ddd; margin: 10px 0; padding: 15px; border-radius: 5px; }}
|
||||
.class-tag {{ background: #007acc; color: white; padding: 3px 8px; border-radius: 3px; font-size: 12px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>MIMIC复现数据集 - LangExtract提取报告</h1>
|
||||
<p>生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
<p>总论文数: {dataset['metadata']['total_papers']}</p>
|
||||
<p>提取方法: {dataset['metadata']['extraction_method']}</p>
|
||||
</div>
|
||||
|
||||
<div class="stats">
|
||||
<div class="stat-card">
|
||||
<h3>提取统计</h3>
|
||||
<p>总提取项: {len(all_extractions)}</p>
|
||||
<p>平均每篇: {len(all_extractions)/dataset['metadata']['total_papers']:.1f}</p>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<h3>成功率</h3>
|
||||
<p>处理成功: {len([p for p in dataset['papers'].values() if any(m.get('extraction_count', 0) > 0 for m in p.get('modules', {}).values())])}/{dataset['metadata']['total_papers']}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="extractions">
|
||||
<h2>提取结果示例</h2>
|
||||
"""
|
||||
|
||||
# 添加前20个提取结果作为示例
|
||||
for i, ext in enumerate(all_extractions[:20]):
|
||||
html_content += f"""
|
||||
<div class="extraction">
|
||||
<span class="class-tag">{ext.get('extraction_class', 'unknown')}</span>
|
||||
<p><strong>提取文本:</strong> "{ext.get('extraction_text', 'N/A')}"</p>
|
||||
<p><strong>属性:</strong> {ext.get('attributes', {})}</p>
|
||||
<p><strong>置信度:</strong> {ext.get('confidence', 'N/A')}</p>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += """
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"交互式报告已生成: {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTML报告生成失败: {e}")
|
||||
|
||||
def _generate_individual_html_report(self, individual_dataset: Dict[str, Any], output_file: Path):
|
||||
"""生成单个论文的LangExtract风格交互式HTML报告"""
|
||||
try:
|
||||
# 从单个论文数据中提取所有提取结果
|
||||
paper_data = individual_dataset["paper"]
|
||||
all_extractions = []
|
||||
for module_name, module_data in paper_data.get("modules", {}).items():
|
||||
for ext in module_data.get("extractions", []):
|
||||
ext["module"] = module_name # 添加模块标识
|
||||
all_extractions.append(ext)
|
||||
|
||||
# 计算统计信息
|
||||
successful_modules = len([
|
||||
module for module in paper_data.get("modules", {}).values()
|
||||
if module.get("extraction_count", 0) > 0
|
||||
])
|
||||
total_modules = len(paper_data.get("modules", {}))
|
||||
|
||||
# 生成HTML内容
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{individual_dataset['metadata']['paper_id']} - LangExtract提取报告</title>
|
||||
<style>
|
||||
body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f9f9f9; }}
|
||||
.header {{ background: #e3f2fd; padding: 20px; border-radius: 8px; margin-bottom: 20px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
|
||||
.stats {{ display: flex; gap: 20px; margin: 20px 0; }}
|
||||
.stat-card {{ background: #ffffff; padding: 15px; border-radius: 8px; flex: 1; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
|
||||
.extraction {{ border: 1px solid #e0e0e0; margin: 15px 0; padding: 15px; border-radius: 8px; background: white; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }}
|
||||
.class-tag {{ background: #1976d2; color: white; padding: 4px 10px; border-radius: 12px; font-size: 12px; margin-right: 10px; }}
|
||||
.module-tag {{ background: #388e3c; color: white; padding: 2px 8px; border-radius: 10px; font-size: 11px; margin-left: 10px; }}
|
||||
.attributes {{ background: #f5f5f5; padding: 10px; border-radius: 4px; margin-top: 10px; font-size: 13px; }}
|
||||
.no-extractions {{ text-align: center; color: #666; padding: 40px; background: #f0f0f0; border-radius: 8px; }}
|
||||
h1 {{ color: #1565c0; margin: 0; }}
|
||||
h2 {{ color: #424242; }}
|
||||
h3 {{ color: #1976d2; margin: 0; }}
|
||||
.meta-info {{ color: #666; font-size: 14px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>MIMIC论文信息提取报告</h1>
|
||||
<h2>{individual_dataset['metadata']['paper_id']}</h2>
|
||||
<div class="meta-info">
|
||||
<p><strong>生成时间:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||||
<p><strong>提取方法:</strong> {individual_dataset['metadata']['extraction_method']}</p>
|
||||
<p><strong>模型:</strong> {individual_dataset['metadata']['model']}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="stats">
|
||||
<div class="stat-card">
|
||||
<h3>提取统计</h3>
|
||||
<p><strong>总提取项:</strong> {len(all_extractions)}</p>
|
||||
<p><strong>成功模块:</strong> {successful_modules}/{total_modules}</p>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<h3>模块分布</h3>
|
||||
"""
|
||||
|
||||
# 添加每个模块的统计信息
|
||||
for module_name, module_data in paper_data.get("modules", {}).items():
|
||||
extraction_count = module_data.get("extraction_count", 0)
|
||||
html_content += f" <p><strong>{module_name}:</strong> {extraction_count} 项</p>\n"
|
||||
|
||||
html_content += """
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="extractions">
|
||||
<h2>详细提取结果</h2>
|
||||
"""
|
||||
|
||||
if all_extractions:
|
||||
# 按模块分组显示提取结果
|
||||
for module_name in ["data", "model", "training", "evaluation", "environment"]:
|
||||
module_extractions = [ext for ext in all_extractions if ext.get("module") == module_name]
|
||||
if module_extractions:
|
||||
html_content += f""" <h3>{module_name.title()} 模块 ({len(module_extractions)} 项)</h3>\n"""
|
||||
|
||||
for ext in module_extractions:
|
||||
confidence_text = f" (置信度: {ext.get('confidence', 'N/A')})" if ext.get('confidence') else ""
|
||||
html_content += f"""
|
||||
<div class="extraction">
|
||||
<span class="class-tag">{ext.get('extraction_class', 'unknown')}</span>
|
||||
<span class="module-tag">{module_name}</span>
|
||||
<p><strong>提取文本:</strong> "{ext.get('extraction_text', 'N/A')}"</p>
|
||||
"""
|
||||
# 添加属性信息
|
||||
attributes = ext.get('attributes', {})
|
||||
if attributes:
|
||||
html_content += f""" <div class="attributes">
|
||||
<strong>属性:</strong> """
|
||||
for key, value in attributes.items():
|
||||
html_content += f"<span><strong>{key}:</strong> {value}</span> "
|
||||
html_content += """
|
||||
</div>"""
|
||||
|
||||
# 添加位置信息
|
||||
if ext.get('start_index') is not None and ext.get('end_index') is not None:
|
||||
html_content += f""" <p class="meta-info">位置: {ext.get('start_index')}-{ext.get('end_index')}{confidence_text}</p>"""
|
||||
|
||||
html_content += """ </div>
|
||||
"""
|
||||
else:
|
||||
html_content += """
|
||||
<div class="no-extractions">
|
||||
<p>未找到任何提取结果</p>
|
||||
<p>可能的原因:模型无法识别相关信息,或者文本内容不包含目标信息类型</p>
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content += """
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# 写入HTML文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"单个论文HTML报告已生成: {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"单个论文HTML报告生成失败: {e}")
|
||||
Loading…
x
Reference in New Issue
Block a user