diff --git a/CLAUDE.md b/CLAUDE.md index e838baa..80d4a8e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,6 +14,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用 ## 各个模块的主文件 1. 论文爬取主文件: papers_crawler.py 2. pdf解析主文件: pdf_parser.py +3. 信息抽取主文件: info_extractor.py 3. 实验运行主文件: experiment_runner.py ## 文件结构 @@ -23,6 +24,7 @@ MedResearcher 是一个给予用户输入的自动实验平台,其会给予用 │ └── mimic.csv # 存放所有需要处理的与mimic相关论文的基础信息 ├── papers_crawler.py # 论文爬取主文件 ├── pdf_parser.py # pdf解析主文件 +├── info_extractor.py # 信息抽取主文件 ├── experiment_runner.py # 实验运行主文件 ├── src/ # 源代码目录 │ └── utils/ # 工具函数目录 diff --git a/info_extractor.py b/info_extractor.py new file mode 100644 index 0000000..c74ac72 --- /dev/null +++ b/info_extractor.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +基于LangExtract的MIMIC论文信息提取器 +从医学论文中提取结构化的复现任务信息 + +作者:MedResearcher项目 +创建时间:2025-01-25 +""" + +import argparse +import logging + +from src.extractor import MIMICLangExtractBuilder + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + + +def setup_args(): + """设置命令行参数解析 + + Returns: + argparse.Namespace: 解析后的命令行参数 + """ + parser = argparse.ArgumentParser( + description='MIMIC论文信息提取工具 - 基于LangExtract从医学论文中提取结构化复现信息', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=''' +使用示例: + %(prog)s # 使用默认参数 + %(prog)s --papers_dir dataset/markdowns # 指定论文目录 + %(prog)s --output_file results/dataset.json # 指定输出文件 + %(prog)s --test_mode --max_papers 5 # 测试模式,只处理5篇论文 + ''' + ) + + parser.add_argument( + '--papers_dir', + type=str, + default='dataset/markdowns', + help='markdown论文文件目录 (默认: dataset/markdowns)' + ) + + parser.add_argument( + '--output_file', + type=str, + default='dataset/reproduction_tasks/mimic_langextract_dataset.json', + help='输出数据集文件路径 (默认: dataset/reproduction_tasks/mimic_langextract_dataset.json)' + ) + + parser.add_argument( + '--test_mode', + action='store_true', + help='测试模式,只处理少量论文进行验证' + ) + + parser.add_argument( + '--max_papers', + type=int, + default=None, + help='最大处理论文数量,用于测试 (默认: 处理所有论文)' + ) + + parser.add_argument( + '--log_level', + type=str, + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + help='日志级别 (默认: INFO)' + ) + + return parser.parse_args() + + +def main(): + """主函数 - 执行MIMIC论文信息提取任务""" + try: + # 解析命令行参数 + args = setup_args() + + # 设置日志级别 + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + # 初始化信息提取器 + builder = MIMICLangExtractBuilder() + + print(f"=== MIMIC论文信息提取工具启动 ===") + print(f"论文目录: {args.papers_dir}") + print(f"输出文件: {args.output_file}") + print(f"测试模式: {'是' if args.test_mode else '否'}") + if args.max_papers: + print(f"最大论文数: {args.max_papers}") + print(f"日志级别: {args.log_level}") + print(f"========================") + + # 构建复现数据集 + print("\n开始构建MIMIC复现数据集...") + dataset = builder.build_reproduction_dataset( + papers_dir=args.papers_dir, + output_file=args.output_file, + max_papers=args.max_papers if args.test_mode or args.max_papers else None + ) + + # 统计结果 + total_papers = dataset['metadata']['total_papers'] + successful_extractions = sum( + 1 for paper in dataset['papers'].values() + if any(module.get('extraction_count', 0) > 0 + for module in paper.get('modules', {}).values()) + ) + + print(f"\n=== 构建完成 ===") + print(f"总论文数: {total_papers}") + print(f"成功提取: {successful_extractions}/{total_papers}") + print(f"成功率: {successful_extractions/total_papers*100:.1f}%") + print(f"结果保存至: {args.output_file}") + print(f"交互式报告: {args.output_file.replace('.json', '.html')}") + print(f"===============") + + return 0 + + except FileNotFoundError as e: + print(f"错误: 找不到指定的文件或目录 - {e}") + return 1 + except ValueError as e: + print(f"错误: 参数值无效 - {e}") + return 1 + except Exception as e: + print(f"错误: 程序执行异常 - {e}") + logging.exception("详细错误信息:") + return 1 + + +if __name__ == "__main__": + exit_code = main() + exit(exit_code) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4937299..e458a57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ readme = "README.md" requires-python = ">=3.13" dependencies = [ "agno>=1.7.12", + "httpx[socks]>=0.28.1", + "langextract>=1.0.8", "ollama>=0.5.3", "openai>=1.101.0", "pydantic", diff --git a/src/extractor.py b/src/extractor.py new file mode 100644 index 0000000..8932151 --- /dev/null +++ b/src/extractor.py @@ -0,0 +1,809 @@ +#!/usr/bin/env python3 +""" +基于LangExtract的MIMIC论文信息提取器 - 核心实现 +从医学论文中提取结构化的复现任务信息 + +作者:MedResearcher项目 +创建时间:2025-01-25 +""" + +import langextract as lx +import textwrap +from pathlib import Path +import json +from datetime import datetime +from typing import List, Dict, Any, Optional +import logging + +# 配置日志 +logger = logging.getLogger(__name__) + + +class MIMICLangExtractBuilder: + """基于LangExtract的MIMIC论文信息提取器""" + + def __init__(self): + """初始化提取器,配置vllm API服务""" + try: + # 配置LangExtract使用vllm API(通过OpenAI兼容接口) + import os + os.environ["LANGEXTRACT_API_KEY"] = "dummy" + + # 创建ModelConfig,强制使用OpenAI提供者访问vllm端点 + self.model_config = lx.factory.ModelConfig( + model_id="gpt-oss-20b", # 使用vllm中实际部署的模型名称 + provider="OpenAILanguageModel", # 强制指定OpenAI提供者 + provider_kwargs={ + "base_url": "http://100.82.33.121:11001/v1", # vllm API端点 + "api_key": "dummy", + "model_id": "gpt-oss-20b" # 确保使用正确的模型ID + } + ) + + # LangExtract通用配置参数 + self.extract_config = { + "config": self.model_config, + "max_workers": 3, # 降低并发,避免过载vllm服务 + "max_char_buffer": 6000, # 适合医学论文的上下文长度 + "extraction_passes": 1, # 单次提取,避免过多API调用 + "temperature": 0.1, # 较低温度确保一致性 + "fence_output": True, # 期望代码围栏格式输出 + "use_schema_constraints": False # vllm可能不支持严格schema + } + + # 加载所有模块的提取配置 + self.module_configs = { + "data": self._load_data_config(), + "model": self._load_model_config(), + "training": self._load_training_config(), + "evaluation": self._load_evaluation_config(), + "environment": self._load_environment_config() + } + + logger.info("MIMICLangExtractBuilder初始化成功") + + except Exception as e: + logger.error(f"初始化失败: {e}") + raise + + def _load_data_config(self) -> Dict[str, Any]: + """加载数据模块的LangExtract配置""" + return { + "prompt": textwrap.dedent(""" + 从医学论文中提取数据处理相关的具体信息。严格按照以下规则: + + 1. dataset_source: 提取明确提到的数据集名称(如"MIMIC-IV", "Stanford EHR") + 2. data_scale: 提取具体的数据规模数字(如"135,483 patients", "2015-2023") + 3. preprocessing_step: 提取数据预处理的具体步骤描述 + 4. feature_type: 提取特征类型和编码方法的描述 + 5. inclusion_criteria: 提取患者纳入标准的确切文本 + 6. exclusion_criteria: 提取患者排除标准的确切文本 + + 使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。 + """), + "examples": [ + lx.data.ExampleData( + text="We analyzed 135,483 ED blood culture orders from Stanford Medicine EHR between 2015-2023. Adult patients (≥18 years) with blood culture collection in the ED were included. Patients with positive blood cultures within 14 days were excluded. Features were one-hot encoded for ML compatibility.", + extractions=[ + lx.data.Extraction( + extraction_class="dataset_source", + extraction_text="Stanford Medicine EHR", + attributes={ + "data_type": "electronic health records", + "institution": "Stanford Medicine" + } + ), + lx.data.Extraction( + extraction_class="data_scale", + extraction_text="135,483 ED blood culture orders", + attributes={ + "sample_size": "135,483", + "time_range": "2015-2023", + "data_unit": "blood culture orders" + } + ), + lx.data.Extraction( + extraction_class="inclusion_criteria", + extraction_text="Adult patients (≥18 years) with blood culture collection in the ED", + attributes={ + "age_limit": "≥18 years", + "setting": "Emergency Department", + "requirement": "blood culture collection" + } + ), + lx.data.Extraction( + extraction_class="exclusion_criteria", + extraction_text="Patients with positive blood cultures within 14 days were excluded", + attributes={ + "timeframe": "within 14 days", + "condition": "positive blood cultures" + } + ), + lx.data.Extraction( + extraction_class="feature_type", + extraction_text="Features were one-hot encoded for ML compatibility", + attributes={ + "encoding_method": "one-hot encoding", + "purpose": "ML compatibility" + } + ) + ] + ), + lx.data.ExampleData( + text="This study utilized MIMIC-IV database, including CHARTEVENTS and LABEVENTS tables. We extracted hourly vital signs and laboratory values for ICU patients. Missing values were imputed using forward-fill method. Outliers beyond 3 standard deviations were removed.", + extractions=[ + lx.data.Extraction( + extraction_class="dataset_source", + extraction_text="MIMIC-IV database", + attributes={ + "data_type": "public clinical database", + "tables": "CHARTEVENTS, LABEVENTS" + } + ), + lx.data.Extraction( + extraction_class="preprocessing_step", + extraction_text="Missing values were imputed using forward-fill method", + attributes={ + "method": "forward-fill", + "target": "missing values" + } + ), + lx.data.Extraction( + extraction_class="preprocessing_step", + extraction_text="Outliers beyond 3 standard deviations were removed", + attributes={ + "method": "outlier removal", + "threshold": "3 standard deviations" + } + ) + ] + ) + ] + } + + def _load_model_config(self) -> Dict[str, Any]: + """加载模型模块的LangExtract配置""" + return { + "prompt": textwrap.dedent(""" + 从医学论文中提取机器学习模型的具体信息。严格按照以下规则: + + 1. model_name: 提取明确提到的模型名称(如"XGBoost", "LSTM", "GPT-4") + 2. architecture_detail: 提取架构描述的具体文本 + 3. hyperparameter: 提取超参数设置的具体数值 + 4. feature_processing: 提取特征处理方法的描述 + 5. model_component: 提取模型组件或模块的描述 + + 使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。 + """), + "examples": [ + lx.data.ExampleData( + text="We employed XGBoost classifier with max depth of 4 and 30 boosting iterations. Class weights were used to handle imbalanced data. STELLA 1.5B model was used for text embeddings with attention-weighted average pooling.", + extractions=[ + lx.data.Extraction( + extraction_class="model_name", + extraction_text="XGBoost classifier", + attributes={ + "model_type": "gradient boosting", + "task": "classification" + } + ), + lx.data.Extraction( + extraction_class="hyperparameter", + extraction_text="max depth of 4 and 30 boosting iterations", + attributes={ + "max_depth": "4", + "n_estimators": "30", + "parameter_type": "tree_structure" + } + ), + lx.data.Extraction( + extraction_class="model_name", + extraction_text="STELLA 1.5B model", + attributes={ + "model_type": "pretrained language model", + "parameters": "1.5B", + "purpose": "text embeddings" + } + ), + lx.data.Extraction( + extraction_class="feature_processing", + extraction_text="attention-weighted average pooling", + attributes={ + "technique": "pooling", + "method": "attention-weighted" + } + ) + ] + ) + ] + } + + def _load_training_config(self) -> Dict[str, Any]: + """加载训练模块的LangExtract配置""" + return { + "prompt": textwrap.dedent(""" + 从医学论文中提取模型训练相关的具体信息。严格按照以下规则: + + 1. data_split_method: 提取数据分割方法的具体描述 + 2. validation_approach: 提取验证策略的具体描述 + 3. hyperparameter_tuning: 提取超参数调优方法 + 4. stopping_condition: 提取训练停止条件 + 5. optimizer_config: 提取优化器配置信息 + + 使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。 + """), + "examples": [ + lx.data.ExampleData( + text="Data was split temporally: training set (2015-2022), development set (2022-2023) for hyperparameter tuning, and evaluation set (2023+). Grid search was performed on the development set to optimize AUC performance. Early stopping was applied when validation loss did not improve for 10 epochs.", + extractions=[ + lx.data.Extraction( + extraction_class="data_split_method", + extraction_text="Data was split temporally: training set (2015-2022), development set (2022-2023), and evaluation set (2023+)", + attributes={ + "split_type": "temporal", + "train_period": "2015-2022", + "dev_period": "2022-2023", + "eval_period": "2023+" + } + ), + lx.data.Extraction( + extraction_class="hyperparameter_tuning", + extraction_text="Grid search was performed on the development set to optimize AUC performance", + attributes={ + "method": "grid search", + "metric": "AUC", + "dataset": "development set" + } + ), + lx.data.Extraction( + extraction_class="stopping_condition", + extraction_text="Early stopping was applied when validation loss did not improve for 10 epochs", + attributes={ + "method": "early stopping", + "patience": "10 epochs", + "monitor": "validation loss" + } + ) + ] + ) + ] + } + + def _load_evaluation_config(self) -> Dict[str, Any]: + """加载评估模块的LangExtract配置""" + return { + "prompt": textwrap.dedent(""" + 从医学论文中提取模型评估相关的具体信息。严格按照以下规则: + + 1. evaluation_metric: 提取具体的评估指标名称(如"AUC", "F1-score", "sensitivity") + 2. baseline_comparison: 提取基线模型或方法的描述 + 3. performance_result: 提取具体的性能数值结果 + 4. statistical_test: 提取统计检验方法的描述 + 5. experimental_setting: 提取实验设置的具体信息 + + 使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。 + """), + "examples": [ + lx.data.ExampleData( + text="The model achieved ROC-AUC of 0.85 (95% CI: 0.82-0.88) on the test set. We compared against three baselines: expert framework (manual assessment), structured-only model, and LLM-automated framework. At 90% sensitivity, our model achieved 45% specificity versus 32% for the baseline.", + extractions=[ + lx.data.Extraction( + extraction_class="evaluation_metric", + extraction_text="ROC-AUC", + attributes={ + "metric_type": "discriminative performance", + "range": "0-1" + } + ), + lx.data.Extraction( + extraction_class="performance_result", + extraction_text="ROC-AUC of 0.85 (95% CI: 0.82-0.88)", + attributes={ + "metric": "ROC-AUC", + "value": "0.85", + "confidence_interval": "0.82-0.88", + "confidence_level": "95%" + } + ), + lx.data.Extraction( + extraction_class="baseline_comparison", + extraction_text="expert framework (manual assessment), structured-only model, and LLM-automated framework", + attributes={ + "baseline_count": "3", + "comparison_type": "multiple baselines" + } + ) + ] + ) + ] + } + + def _load_environment_config(self) -> Dict[str, Any]: + """加载环境模块的LangExtract配置""" + return { + "prompt": textwrap.dedent(""" + 从医学论文中提取实验环境相关的具体信息。严格按照以下规则: + + 1. software_library: 提取具体的软件工具和库名称 + 2. hardware_resource: 提取硬件资源需求的描述 + 3. data_repository: 提取数据存储和访问的具体信息 + 4. code_availability: 提取代码可用性的具体描述 + 5. compliance_requirement: 提取合规性和部署要求 + + 使用exact text进行提取,不要释义。为每个提取项提供有意义的属性。 + """), + "examples": [ + lx.data.ExampleData( + text="We implemented the models using Python 3.8 with scikit-learn 1.0.2 and XGBoost 1.5.0. Training was performed on NVIDIA A100 GPU with 40GB memory. Code is available at GitHub: https://github.com/HealthRex/CDSS. The study was approved by Stanford IRB.", + extractions=[ + lx.data.Extraction( + extraction_class="software_library", + extraction_text="Python 3.8 with scikit-learn 1.0.2 and XGBoost 1.5.0", + attributes={ + "language": "Python", + "version": "3.8", + "libraries": "scikit-learn, XGBoost" + } + ), + lx.data.Extraction( + extraction_class="hardware_resource", + extraction_text="NVIDIA A100 GPU with 40GB memory", + attributes={ + "gpu_type": "NVIDIA A100", + "memory": "40GB", + "resource_type": "GPU" + } + ), + lx.data.Extraction( + extraction_class="code_availability", + extraction_text="Code is available at GitHub: https://github.com/HealthRex/CDSS", + attributes={ + "platform": "GitHub", + "url": "https://github.com/HealthRex/CDSS", + "access_type": "public" + } + ), + lx.data.Extraction( + extraction_class="compliance_requirement", + extraction_text="The study was approved by Stanford IRB", + attributes={ + "approval_type": "IRB", + "institution": "Stanford" + } + ) + ] + ) + ] + } + + def extract_paper_modules(self, paper_content: str, paper_id: str) -> Dict[str, Any]: + """使用LangExtract提取论文的所有模块信息""" + + results = { + "paper_id": paper_id, + "extraction_metadata": { + "timestamp": datetime.now().isoformat(), + "method": "langextract_with_source_grounding", + "model": "gpt-oss-20b" + }, + "modules": {} + } + + # 逐个提取每个模块 + for module_name, config in self.module_configs.items(): + try: + logger.info(f" 提取{module_name}模块...") + + # 使用LangExtract进行结构化提取 + extraction_result = lx.extract( + text_or_documents=paper_content, + prompt_description=config["prompt"], + examples=config["examples"], + **self.extract_config + ) + + # 处理提取结果 - extraction_result是AnnotatedDocument对象 + if extraction_result and hasattr(extraction_result, 'extractions') and extraction_result.extractions: + results["modules"][module_name] = { + "extractions": [ + { + "extraction_class": ext.extraction_class, + "extraction_text": ext.extraction_text, + "start_index": getattr(ext, 'start_index', None), + "end_index": getattr(ext, 'end_index', None), + "attributes": getattr(ext, 'attributes', {}), + "confidence": getattr(ext, 'confidence', None) + } + for ext in extraction_result.extractions + ], + "extraction_count": len(extraction_result.extractions), + "quality_score": self._calculate_quality_score(extraction_result) + } + else: + results["modules"][module_name] = { + "extractions": [], + "extraction_count": 0, + "quality_score": 0.0, + "error": "No valid extractions found" + } + + except Exception as e: + logger.error(f" {module_name}模块提取失败: {e}") + results["modules"][module_name] = { + "extractions": [], + "extraction_count": 0, + "quality_score": 0.0, + "error": str(e) + } + + return results + + def build_reproduction_dataset(self, papers_dir: str, output_file: str, max_papers: Optional[int] = None) -> Dict[str, Any]: + """构建完整的复现数据集""" + papers = self._load_markdown_papers(papers_dir) + + dataset = { + "metadata": { + "creation_date": datetime.now().isoformat(), + "total_papers": len(papers), + "extraction_method": "langextract_source_grounded", + "api_endpoint": "http://100.82.33.121:11001/v1", + "model": "gpt-oss-20b", + "langextract_version": getattr(lx, '__version__', 'unknown') + }, + "papers": {} + } + + # 如果指定了最大处理数量,限制论文数量 + if max_papers and max_papers < len(papers): + papers_items = list(papers.items())[:max_papers] + papers = dict(papers_items) + dataset["metadata"]["total_papers"] = len(papers) + dataset["metadata"]["note"] = f"测试模式: 只处理前{max_papers}篇论文" + logger.info(f"测试模式: 只处理前 {max_papers} 篇论文") + + logger.info(f"开始处理 {len(papers)} 篇论文...") + + for i, (paper_id, content) in enumerate(papers.items()): + logger.info(f"[{i+1}/{len(papers)}] 处理论文: {paper_id}") + + paper_result = self.extract_paper_modules(content, paper_id) + dataset["papers"][paper_id] = paper_result + + # 为每个论文单独保存结果到其子文件夹 + self._save_individual_paper_result(papers_dir, paper_id, paper_result) + + # 定期保存全局进度 + if (i + 1) % 10 == 0: + self._save_progress(dataset, output_file) + + # 保存最终结果 + self._save_dataset(dataset, output_file) + + # 生成交互式HTML报告 + self._generate_html_report(dataset, output_file.replace('.json', '.html')) + + return dataset + + def _load_markdown_papers(self, papers_dir: str) -> Dict[str, str]: + """加载markdown论文文件""" + papers = {} + papers_path = Path(papers_dir) + + if not papers_path.exists(): + raise FileNotFoundError(f"论文目录不存在: {papers_dir}") + + # 修改加载逻辑:从子目录中读取.md文件 + markdown_files = [] + for subdir in papers_path.iterdir(): + if subdir.is_dir(): + md_files = list(subdir.glob("*.md")) + markdown_files.extend(md_files) + + if not markdown_files: + raise ValueError(f"在 {papers_dir} 目录中未找到markdown文件") + + logger.info(f"发现 {len(markdown_files)} 个markdown文件") + + for file_path in markdown_files: + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + papers[file_path.stem] = content + except Exception as e: + logger.warning(f"读取文件 {file_path} 失败: {e}") + + return papers + + def _calculate_quality_score(self, extraction_result) -> float: + """计算提取质量分数""" + if not extraction_result or not hasattr(extraction_result, 'extractions'): + return 0.0 + + if not extraction_result.extractions: + return 0.0 + + # 基于提取数量和属性丰富度计算质量分数 + total_score = 0.0 + for ext in extraction_result.extractions: + score = 0.3 # 基础分数 + + # 有源文本定位加分 + if hasattr(ext, 'start_index') and ext.start_index is not None: + score += 0.2 + + # 属性丰富度加分 + if ext.attributes and len(ext.attributes) > 0: + score += min(0.3, len(ext.attributes) * 0.1) + + # 置信度加分 + if hasattr(ext, 'confidence') and ext.confidence: + score += 0.2 * ext.confidence + + total_score += score + + return min(1.0, total_score / len(extraction_result.extractions)) + + def _save_progress(self, dataset: Dict[str, Any], output_file: str): + """保存处理进度""" + try: + progress_file = output_file.replace('.json', '_progress.json') + with open(progress_file, 'w', encoding='utf-8') as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + logger.info(f"进度已保存至: {progress_file}") + except Exception as e: + logger.error(f"保存进度失败: {e}") + + def _save_individual_paper_result(self, papers_dir: str, paper_id: str, paper_result: Dict[str, Any]): + """为单个论文保存提取结果到其对应的子文件夹""" + try: + # 构建论文子文件夹路径 + paper_subdir = Path(papers_dir) / paper_id + if not paper_subdir.exists(): + logger.warning(f"论文子文件夹不存在: {paper_subdir}") + return + + # 准备单个论文的数据集格式 + individual_dataset = { + "metadata": { + "creation_date": datetime.now().isoformat(), + "total_papers": 1, + "extraction_method": "langextract_source_grounded", + "api_endpoint": "http://100.82.33.121:11001/v1", + "model": "gpt-oss-20b", + "langextract_version": getattr(lx, '__version__', 'unknown'), + "paper_id": paper_id + }, + "paper": paper_result # 注意:这里是单个论文,所以用"paper"而不是"papers" + } + + # 保存JSON文件 + json_file = paper_subdir / "mimic_langextract_dataset.json" + with open(json_file, 'w', encoding='utf-8') as f: + json.dump(individual_dataset, f, ensure_ascii=False, indent=2) + + # 生成HTML报告 + html_file = paper_subdir / "mimic_langextract_dataset.html" + self._generate_individual_html_report(individual_dataset, html_file) + + logger.info(f"已保存论文 {paper_id} 的结果到: {paper_subdir}") + + except Exception as e: + logger.error(f"保存单个论文结果失败 ({paper_id}): {e}") + + def _save_dataset(self, dataset: Dict[str, Any], output_file: str): + """保存最终数据集""" + try: + # 确保输出目录存在 + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_file, 'w', encoding='utf-8') as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + logger.info(f"数据集已保存至: {output_file}") + except Exception as e: + logger.error(f"保存数据集失败: {e}") + raise + + def _generate_html_report(self, dataset: Dict[str, Any], output_file: str): + """生成LangExtract风格的交互式HTML报告""" + try: + # 合并所有提取结果用于可视化 + all_extractions = [] + for paper_id, paper_data in dataset["papers"].items(): + for module_name, module_data in paper_data.get("modules", {}).items(): + all_extractions.extend(module_data.get("extractions", [])) + + # 基础HTML模板(简化版可视化) + html_content = f""" + + +
+ +生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
+总论文数: {dataset['metadata']['total_papers']}
+提取方法: {dataset['metadata']['extraction_method']}
+总提取项: {len(all_extractions)}
+平均每篇: {len(all_extractions)/dataset['metadata']['total_papers']:.1f}
+处理成功: {len([p for p in dataset['papers'].values() if any(m.get('extraction_count', 0) > 0 for m in p.get('modules', {}).values())])}/{dataset['metadata']['total_papers']}
+提取文本: "{ext.get('extraction_text', 'N/A')}"
+属性: {ext.get('attributes', {})}
+置信度: {ext.get('confidence', 'N/A')}
+总提取项: {len(all_extractions)}
+成功模块: {successful_modules}/{total_modules}
+{module_name}: {extraction_count} 项
\n" + + html_content += """ +提取文本: "{ext.get('extraction_text', 'N/A')}"
+""" + # 添加属性信息 + attributes = ext.get('attributes', {}) + if attributes: + html_content += f"""未找到任何提取结果
+可能的原因:模型无法识别相关信息,或者文本内容不包含目标信息类型
+