From c96a9c35d587cf3a98f475e480e0ae9902513217 Mon Sep 17 00:00:00 2001 From: iomgaa Date: Mon, 26 May 2025 23:09:03 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E4=BA=86=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- preprocessing/README_trex_processor.md | 97 ++ preprocessing/trex_to_sentences_simple.py | 1016 +++++++++++++++------ train_pretrain_accelerate.py | 372 +++++++- 4 files changed, 1218 insertions(+), 271 deletions(-) create mode 100644 preprocessing/README_trex_processor.md diff --git a/.gitignore b/.gitignore index 2e522ae..055631a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ /dataset /out wandb/ -**/*.log \ No newline at end of file +**/*.log +models/sentence_transformers/ +models/sentence_transformers_cache/ \ No newline at end of file diff --git a/preprocessing/README_trex_processor.md b/preprocessing/README_trex_processor.md new file mode 100644 index 0000000..4947e30 --- /dev/null +++ b/preprocessing/README_trex_processor.md @@ -0,0 +1,97 @@ +# TREx 数据集处理工具使用说明 + +这个工具支持两步骤处理 TREx 数据集: +1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子 +2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分 + +## 安装依赖 + +```bash +pip install agno asyncio pydantic +``` + +确保已安装并启动 ollama,并下载 qwen3:4b 模型: +```bash +ollama pull qwen3:4b +``` + +## 使用方法 + +### 1. 完整流程(两步骤连续执行) + +```bash +python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2 +``` + +### 2. 分步骤执行 + +#### 步骤1:仅提取句子 +```bash +python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2 +``` + +#### 步骤2:仅LLM处理 +```bash +python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt +``` + +## 主要参数说明 + +- `--step`: 运行步骤 + - `extract`: 仅提取句子 + - `llm`: 仅LLM处理 + - `all`: 完整流程(默认) + +- `--input_dir`: TREx数据集目录(默认:`dataset/TREx`) +- `--sentences_json`: 提取的句子JSON文件(默认:`extracted_sentences.json`) +- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`) +- `--max_files`: 最大处理文件数(用于测试) +- `--no_llm`: 禁用LLM处理 + +## 输出文件 + +**注意:所有输出文件都会自动保存在 `./output/` 目录中** + +### 步骤1输出 +- `output/extracted_sentences.json`: 提取的原始句子,包含元数据 + +### 步骤2输出 +- `output/{output_file}.txt`: 修正后的句子文本文件 +- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分) +- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子 + +### 检查点文件 +- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点 + +## 检查点恢复机制 + +- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中) +- 只处理尚未处理的句子,避免重复工作 +- 如果所有句子都已处理,会直接生成最终输出文件 + +## 示例工作流 + +```bash +# 1. 先提取句子(可以快速完成) +python trex_to_sentences_simple.py --step extract --max_files 5 + +# 2. 后续进行LLM处理(耗时较长,支持断点续传) +python trex_to_sentences_simple.py --step llm + +# 如果中途中断,再次运行步骤2会自动从检查点恢复 +python trex_to_sentences_simple.py --step llm +``` + +## 性能特点 + +- **并发处理**: 最大54个并发LLM请求 +- **检查点保存**: 每2000条句子自动保存,支持断点续传 +- **进度显示**: 详细的处理进度和时间预估 +- **错误处理**: LLM请求失败时使用原句子和默认评分 + +## 注意事项 + +1. 首次运行步骤2前,必须先完成步骤1 +2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据) +3. LLM处理速度取决于模型性能和网络状况 +4. 建议先用`--max_files`参数测试小批量数据 \ No newline at end of file diff --git a/preprocessing/trex_to_sentences_simple.py b/preprocessing/trex_to_sentences_simple.py index 721b9f0..bbb415c 100644 --- a/preprocessing/trex_to_sentences_simple.py +++ b/preprocessing/trex_to_sentences_simple.py @@ -2,19 +2,57 @@ """ TREx数据集增强预处理脚本 使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分 + +支持两个独立步骤: +1. 句子提取:从TREx数据集提取句子并保存为JSON +2. LLM处理:读取JSON文件进行LLM后处理和重要性评分 """ import json import os import glob -from typing import List, Dict, Any, Union +from typing import List, Dict, Any, Union, Set import re import asyncio import time +import logging +from datetime import datetime +import subprocess +import requests from pydantic import BaseModel, Field from agno.agent import Agent from agno.models.ollama import Ollama +# 设置日志系统 +def setup_logging(): + """设置日志系统""" + # 确保logs目录存在 + os.makedirs('logs', exist_ok=True) + + # 创建日志文件名(包含时间戳) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = f'logs/trex_processor_{timestamp}.log' + + # 配置日志格式 + log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s' + + # 配置root logger + logging.basicConfig( + level=logging.INFO, + format=log_format, + handlers=[ + logging.FileHandler(log_file, encoding='utf-8'), + logging.StreamHandler() # 同时输出到控制台 + ] + ) + + # 获取logger + logger = logging.getLogger(__name__) + logger.info(f"日志系统初始化完成,日志文件: {log_file}") + return logger + +# 全局日志对象 +logger = setup_logging() class ProcessedSentence(BaseModel): """处理后的句子结构""" @@ -31,16 +69,53 @@ class ProcessedSentence(BaseModel): class EnhancedTRExProcessor: - def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True): + def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None, + sentences_json: str = None, enable_llm_processing: bool = True): self.input_dir = input_dir - self.output_file = output_file + + # 确保output目录存在 + os.makedirs('output', exist_ok=True) + + # 确保所有输出文件都在output目录中 + if output_file: + if not output_file.startswith('output/'): + self.output_file = os.path.join('output', output_file) + else: + self.output_file = output_file + else: + self.output_file = None + + if sentences_json: + if not sentences_json.startswith('output/'): + self.sentences_json = os.path.join('output', sentences_json) + else: + self.sentences_json = sentences_json + else: + self.sentences_json = "output/extracted_sentences.json" + self.max_files = max_files self.enable_llm_processing = enable_llm_processing - # 初始化agno agent + # LLM处理配置 + self.llm_timeout = 60 # 增加每个请求的超时时间到60秒 + self.max_concurrent = 8 # 进一步降低并发数到4 + self.max_retries = 2 # 减少重试次数避免过长等待 + self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒 + + # 统计信息 + self.total_requests = 0 + self.successful_requests = 0 + self.failed_requests = 0 + self.timeout_requests = 0 + self.last_successful_time = time.time() + self.last_activity_time = time.time() # 新增:最后活动时间 + + # 初始化agno agent(仅在需要LLM处理时) if self.enable_llm_processing: self.setup_agent() + logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}秒") + # 扩展的Wikidata属性映射 self.property_mappings = { # 基本关系 @@ -87,10 +162,10 @@ class EnhancedTRExProcessor: try: self.agent = Agent( model=Ollama( - id="qwen3:4b", + id="gemma3:latest", # 使用options设置temperature和其他参数 options={ - "temperature": 0.7, + "temperature": 0.2, "top_p": 0.8, "top_k": 20, "num_ctx": 4096, @@ -98,131 +173,149 @@ class EnhancedTRExProcessor: ), response_model=ProcessedSentence, instructions=[ - "你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。", + "You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.", "", - "### 句子修正规则:", - "1. 移除Wikipedia特有标记:如(disambiguation)、(film)、(band)等括号内容", - "2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is'、'or'等", - "3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确", - "4. 清理乱码和特殊字符:如â、€、™等编码问题", - "5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺", - "6. 不要添加原文没有的信息,只修正错误", + "### Sentence Correction Rules:", + "1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses", + "2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.", + "3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage", + "4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues", + "5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent", + "6. Do not add information not present in the original text, only correct errors", "", - "### 修正示例:", - "- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'", - "- 修正:'Argument is related to philosophy and logic.'", + "### Correction Examples:", + "- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'", + "- Corrected: 'Argument is related to philosophy and logic.'", "", - "- 错误:'Beijing is a capital city and are.'", - "- 修正:'Beijing is a capital city.'", + "- Error: 'Beijing is a capital city and are.'", + "- Corrected: 'Beijing is a capital city.'", "", - "重要性评分标准(0.0-10.0,以0.1递进):", + "Importance scoring criteria (0.0-10.0, in increments of 0.1):", "", - "0.0分 - 完全错误或无意义的信息", - "例:'苹果是一种金属'、'太阳从西边升起'、'1+1=3'", + "0.0 points - Completely incorrect or meaningless information", + "Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'", "", - "0.5分 - 几乎无价值的信息", - "例:'某个虚构角色的袜子颜色'、'游戏中NPC的对话第三句话'、'某人昨天早餐吃了什么'", + "0.5 points - Almost worthless information", + "Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'", "", - "1.0分 - 极其罕见、无实用价值的知识", - "例:'某小说背景角色宠物名字'、'某部电影片尾字幕第15行内容'、'某网站用户ID为123456的昵称'", + "1.0 points - Extremely rare, non-practical knowledge", + "Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'", "", - "1.5分 - 非常小众的细节信息", - "例:'某电影第37分钟路人甲服装'、'某游戏隐藏关卡的背景音乐时长'、'某漫画第200页第3个对话框内容'", + "1.5 points - Very niche detailed information", + "Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'", "", - "2.0分 - 小众专业领域的细节", - "例:'稀有矿物在特定温度下颜色变化'、'某种昆虫的第三对触角长度'、'某化学反应的副产物分子式'", + "2.0 points - Details in niche professional fields", + "Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'", "", - "2.5分 - 专业人士才关心的技术细节", - "例:'软件库特定版本发布日期'、'某算法的时间复杂度系数'、'某种材料的热膨胀系数'", + "2.5 points - Technical details only professionals care about", + "Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'", "", - "3.0分 - 特定领域的专业知识", - "例:'编程语言语法特性'、'某种病毒的基因序列'、'古代某朝代的官职制度'", + "3.0 points - Professional knowledge in specific fields", + "Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'", "", - "3.5分 - 有一定价值的专业信息", - "例:'某历史朝代特定制度'、'某种药物的作用机制'、'某技术标准的制定时间'", + "3.5 points - Professional information with some value", + "Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'", "", - "4.0分 - 较少人知道但有意义的知识", - "例:'某国家独特文化传统'、'某科学家的重要发现'、'某历史事件的详细过程'", + "4.0 points - Meaningful knowledge known by few", + "Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'", "", - "4.5分 - 部分人群感兴趣的知识", - "例:'作家创作背景'、'某艺术流派特点'、'某运动项目规则细节'", + "4.5 points - Knowledge of interest to some groups", + "Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'", "", - "5.0分 - 中等重要性的一般知识", - "例:'城市著名景点'、'某企业发展历史'、'某动物生活习性'", + "5.0 points - General knowledge of moderate importance", + "Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'", "", - "5.5分 - 比较有用的常识", - "例:'植物生长环境'、'健康饮食常识'、'基本急救知识'", + "5.5 points - Fairly useful common sense", + "Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'", "", - "6.0分 - 多数受教育人群应该知道的知识", - "例:'莎士比亚代表作品'、'基本几何定理'、'世界主要货币'", + "6.0 points - Knowledge most educated people should know", + "Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'", "", - "6.5分 - 重要的文化或科学常识", - "例:'DNA基本结构'、'牛顿三大定律'、'世界主要宗教'", + "6.5 points - Important cultural or scientific common sense", + "Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'", "", - "7.0分 - 重要的基础知识", - "例:'二次世界大战时间'、'人体主要器官功能'、'基本数学运算规则'", + "7.0 points - Important foundational knowledge", + "Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'", "", - "7.5分 - 非常重要的常识", - "例:'光速是宇宙中最快的'、'地球是圆的'、'血液循环基本原理'", + "7.5 points - Very important common sense", + "Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'", "", - "8.0分 - 基础教育中的核心知识", - "例:'地球绕太阳运行'、'四季形成原理'、'基本语法规则'", + "8.0 points - Core knowledge in basic education", + "Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'", "", - "8.5分 - 每个人都应该掌握的重要知识", - "例:'水的化学式H2O'、'基本安全常识'、'简单数学计算'", + "8.5 points - Important knowledge everyone should master", + "Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'", "", - "9.0分 - 极其重要的基础概念", - "例:'人类需要氧气生存'、'火是热的'、'基本方向概念'", + "9.0 points - Extremely important basic concepts", + "Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'", "", - "9.5分 - 人人必知的核心知识", - "例:'一天有24小时'、'一年有12个月'、'基本数字概念'", + "9.5 points - Core knowledge everyone must know", + "Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'", "", - "10.0分 - 最基础、最重要的常识", - "例:'人类需要食物和水生存'、'天空是蓝色的'、'石头比羽毛重'", + "10.0 points - Most basic and important common sense", + "Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'", "", - "评分时请考虑:", - "1. 知识的普及程度 - 有多少人知道这个知识", - "2. 实用价值 - 这个知识在日常生活中有多大用处", - "3. 教育重要性 - 这个知识在教育体系中的地位", - "4. 文化意义 - 这个知识对理解世界的重要性", + "When scoring, please consider:", + "1. Popularity of knowledge - How many people know this knowledge", + "2. Practical value - How useful this knowledge is in daily life", + "3. Educational importance - The position of this knowledge in the education system", + "4. Cultural significance - The importance of this knowledge for understanding the world", "", - "请直接输出结构化结果,不需要思考过程。" + "Please output structured results directly without showing the thinking process." ], markdown=False ) - print("LLM处理器初始化成功") + logger.info("LLM处理器初始化成功") except Exception as e: + logger.error(f"LLM处理器初始化失败: {e}") print(f"LLM处理器初始化失败: {e}") print("将使用基础模式(不使用LLM后处理)") self.enable_llm_processing = False async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence: """使用LLM处理单个句子(保留用于单独调用)""" - try: - prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}" - - # 使用agent.arun进行异步调用 - response = await self.agent.arun(prompt) - - # 根据agno文档,response应该直接是ProcessedSentence类型 - if isinstance(response, ProcessedSentence): - return response - else: - message = response.messages[-1].content - message = message.replace("```json", "").replace("```", "") - message = json.loads(message) - return ProcessedSentence( - corrected_sentence=message['corrected_sentence'], - importance_score=message['importance_score'] + for attempt in range(self.max_retries): + try: + prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}" + + # 使用asyncio.wait_for添加超时机制 + response = await asyncio.wait_for( + self.agent.arun(prompt), + timeout=self.llm_timeout ) - except Exception as e: - print(f"LLM处理句子时出错: {e}") - # 出错时返回原句子和中等评分 - return ProcessedSentence( - corrected_sentence=sentence, - importance_score=5.0 - ) + # 根据agno文档,response应该直接是ProcessedSentence类型 + if isinstance(response, ProcessedSentence): + return response + else: + message = response.messages[-1].content + message = message.replace("```json", "").replace("```", "") + message = json.loads(message) + return ProcessedSentence( + corrected_sentence=message['corrected_sentence'], + importance_score=message['importance_score'] + ) + + except asyncio.TimeoutError: + logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...") + if attempt == self.max_retries - 1: + logger.error(f"LLM请求最终超时,使用默认处理: {sentence[:50]}...") + break + # 等待一段时间后重试 + await asyncio.sleep(2 ** attempt) # 指数退避 + + except Exception as e: + logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}") + if attempt == self.max_retries - 1: + break + await asyncio.sleep(1) + + # 所有重试都失败,返回原句子和中等评分 + logger.warning(f"使用默认处理: {sentence[:50]}...") + return ProcessedSentence( + corrected_sentence=sentence, + importance_score=5.0 + ) def clean_text(self, text: str) -> str: """清理文本,处理特殊字符""" @@ -369,146 +462,277 @@ class EnhancedTRExProcessor: async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]: """使用信号量控制并发的LLM处理""" async with semaphore: - try: - prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}" - - # 使用agent.arun进行异步调用 - response = await self.agent.arun(prompt) - - # 根据agno文档,response应该直接是ProcessedSentence类型 - if isinstance(response, ProcessedSentence): - result = { - "index": index, - "original_sentence": sentence, - "corrected_sentence": response.corrected_sentence, - "importance_score": response.importance_score - } - else: - message = response.messages[-1].content - message = message.replace("```json", "").replace("```", "") - message = json.loads(message) - # print(message) - result = { - "index": index, - "original_sentence": sentence, - "corrected_sentence": message['corrected_sentence'], - "importance_score": message['importance_score'] - } - - # 打印详细进度信息 - if index % 100 == 0: - current_time = time.time() - elapsed_time = current_time - start_time - avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time - remaining_sentences = total_sentences - (index + 1) - estimated_remaining_time = avg_time_per_sentence * remaining_sentences + self.total_requests += 1 + self.last_activity_time = time.time() # 更新活动时间 + success = False + + for attempt in range(self.max_retries): + try: + prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}" - # 格式化时间显示 - def format_time(seconds): - if seconds < 60: - return f"{seconds:.1f}秒" - elif seconds < 3600: - minutes = seconds / 60 - return f"{minutes:.1f}分钟" - else: - hours = seconds / 3600 - return f"{hours:.1f}小时" + # 使用asyncio.wait_for添加超时机制 + response = await asyncio.wait_for( + self.agent.arun(prompt), + timeout=self.llm_timeout + ) - print(f"已完成第 {index + 1} 个句子的处理") - print(f" - 剩余句子数: {remaining_sentences}") - print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句") - print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}") - print(f" - 已用时间: {format_time(elapsed_time)}") + # 根据agno文档,response应该直接是ProcessedSentence类型 + if isinstance(response, ProcessedSentence): + result = { + "index": index, + "original_sentence": sentence, + "corrected_sentence": response.corrected_sentence, + "importance_score": response.importance_score + } + else: + message = response.messages[-1].content + message = message.replace("```json", "").replace("```", "") + message = json.loads(message) + result = { + "index": index, + "original_sentence": sentence, + "corrected_sentence": message['corrected_sentence'], + "importance_score": message['importance_score'] + } + + # 成功处理 + self.successful_requests += 1 + self.last_successful_time = time.time() + self.last_activity_time = time.time() # 更新活动时间 + success = True + + # 打印详细进度信息 - 降低频率到每50个 + if index % 50 == 0: + current_time = time.time() + elapsed_time = current_time - start_time + avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time + remaining_sentences = total_sentences - (index + 1) + estimated_remaining_time = avg_time_per_sentence * remaining_sentences + success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0 + + # 格式化时间显示 + def format_time(seconds): + if seconds < 60: + return f"{seconds:.1f}秒" + elif seconds < 3600: + minutes = seconds / 60 + return f"{minutes:.1f}分钟" + else: + hours = seconds / 3600 + return f"{hours:.1f}小时" + + logger.info(f"已完成第 {index + 1} 个句子的处理") + logger.info(f" - 剩余句子数: {remaining_sentences}") + logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句") + logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}") + logger.info(f" - 已用时间: {format_time(elapsed_time)}") + logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})") + + print(f"已完成第 {index + 1} 个句子的处理") + print(f" - 剩余句子数: {remaining_sentences}") + print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句") + print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}") + print(f" - 已用时间: {format_time(elapsed_time)}") + print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})") + + return result + + except asyncio.TimeoutError: + self.timeout_requests += 1 + self.last_activity_time = time.time() # 更新活动时间 + logger.warning(f"第 {index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...") + if attempt == self.max_retries - 1: + logger.error(f"第 {index} 个句子最终超时,使用默认处理") + break + # 指数退避 + await asyncio.sleep(2 ** attempt) + + except Exception as e: + self.last_activity_time = time.time() # 更新活动时间 + logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}") + if attempt == self.max_retries - 1: + break + await asyncio.sleep(1) + + # 所有重试都失败,使用默认处理 + if not success: + self.failed_requests += 1 + logger.warning(f"第 {index} 个句子使用默认处理: {sentence[:50]}...") + + return { + "index": index, + "original_sentence": sentence, + "corrected_sentence": sentence, + "importance_score": 5.0 + } + + async def heartbeat_monitor(self, total_sentences: int): + """心跳监控,检测是否有长时间无响应""" + consecutive_warnings = 0 + + while True: + await asyncio.sleep(self.heartbeat_interval) + + current_time = time.time() + time_since_last_success = current_time - self.last_successful_time + time_since_last_activity = current_time - self.last_activity_time + + # 检查最后成功时间 + if time_since_last_success > self.heartbeat_interval: + consecutive_warnings += 1 + logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应") + print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应") - return result + # 打印当前统计信息 + if self.total_requests > 0: + success_rate = self.successful_requests / self.total_requests * 100 + logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}") + print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}") + + if time_since_last_success > self.heartbeat_interval * 3: + logger.error(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!") + print(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!") + print(f" 建议:检查ollama服务状态,或考虑重启程序") - except Exception as e: - print(f"处理第 {index} 个句子时出错: {e}") - # 出错时返回原句子和中等评分 - return { - "index": index, - "original_sentence": sentence, - "corrected_sentence": sentence, - "importance_score": 5.0 - } + # 检查ollama服务状态 + if not self.check_ollama_status(): + logger.critical("💀 Ollama服务异常,这可能是卡死的原因!") + print("💀 Ollama服务异常,这可能是卡死的原因!") + + if consecutive_warnings >= 5: + logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预") + print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预") + else: + if consecutive_warnings > 0: + logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前") + print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前") + consecutive_warnings = 0 + logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前") async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]: """批量并发处理句子,每2000条保存一次检查点""" - print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:54)...") + logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...") + print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...") # 记录开始时间 start_time = time.time() total_sentences = len(sentences) - # 分批处理,每批2000个句子 - batch_size = 2000 + # 分批处理,每批1000个句子(减少批次大小) + batch_size = 1000 all_processed_sentences = [] - for batch_start in range(0, total_sentences, batch_size): - batch_end = min(batch_start + batch_size, total_sentences) - batch_sentences = sentences[batch_start:batch_end] - - print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===") - - # 创建信号量限制并发数 - semaphore = asyncio.Semaphore(54) - - # 创建当前批次的任务 - tasks = [] - for i, sentence in enumerate(batch_sentences): - global_index = batch_start + i - task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time) - tasks.append(task) - - # 并发执行当前批次的任务 - print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...") - batch_results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理当前批次的结果,过滤异常 - batch_processed_sentences = [] - batch_error_count = 0 - - for result in batch_results: - if isinstance(result, Exception): - print(f"任务执行异常: {result}") - batch_error_count += 1 - elif isinstance(result, dict): - batch_processed_sentences.append(result) - else: - batch_error_count += 1 - - # 按原始顺序排序(因为并发执行可能改变顺序) - batch_processed_sentences.sort(key=lambda x: x['index']) - - # 移除index字段 - for item in batch_processed_sentences: - del item['index'] - - # 添加到总结果中 - all_processed_sentences.extend(batch_processed_sentences) - - # 保存检查点 - checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end) - - # 打印当前批次统计信息 - elapsed_time = time.time() - start_time - completed_sentences = len(all_processed_sentences) - - print(f"第 {batch_start//batch_size + 1} 批处理完成!") - print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}") - print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)") - print(f" - 已用时间:{elapsed_time/60:.1f}分钟") - print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒") - print(f" - 检查点已保存:{checkpoint_filename}") - - if batch_end < total_sentences: - remaining_sentences = total_sentences - completed_sentences - avg_time_per_sentence = elapsed_time / completed_sentences - estimated_remaining_time = avg_time_per_sentence * remaining_sentences - print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") + # 启动心跳监控 + heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences)) + + try: + for batch_start in range(0, total_sentences, batch_size): + batch_end = min(batch_start + batch_size, total_sentences) + batch_sentences = sentences[batch_start:batch_end] + + logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===") + print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===") + + # 创建信号量限制并发数(降低到8) + semaphore = asyncio.Semaphore(self.max_concurrent) + + # 重置批次统计 + batch_start_time = time.time() + self.total_requests = 0 + self.successful_requests = 0 + self.failed_requests = 0 + self.timeout_requests = 0 + + # 创建当前批次的任务 + tasks = [] + for i, sentence in enumerate(batch_sentences): + global_index = batch_start + i + task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time) + tasks.append(task) + + # 并发执行当前批次的任务 + logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...") + print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...") + + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理当前批次的结果,过滤异常 + batch_processed_sentences = [] + batch_error_count = 0 + + for result in batch_results: + if isinstance(result, Exception): + logger.error(f"任务执行异常: {result}") + print(f"任务执行异常: {result}") + batch_error_count += 1 + elif isinstance(result, dict): + batch_processed_sentences.append(result) + else: + batch_error_count += 1 + + # 按原始顺序排序(因为并发执行可能改变顺序) + batch_processed_sentences.sort(key=lambda x: x['index']) + + # 移除index字段 + for item in batch_processed_sentences: + del item['index'] + + # 添加到总结果中 + all_processed_sentences.extend(batch_processed_sentences) + + # 保存检查点 + checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end) + + # 打印当前批次统计信息 + elapsed_time = time.time() - start_time + batch_time = time.time() - batch_start_time + completed_sentences = len(all_processed_sentences) + + logger.info(f"第 {batch_start//batch_size + 1} 批处理完成!") + logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}") + logger.info(f" - 批次用时:{batch_time/60:.1f}分钟") + logger.info(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}") + logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)") + logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟") + logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒") + logger.info(f" - 检查点已保存:{checkpoint_filename}") + + print(f"第 {batch_start//batch_size + 1} 批处理完成!") + print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}") + print(f" - 批次用时:{batch_time/60:.1f}分钟") + print(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}") + print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)") + print(f" - 已用时间:{elapsed_time/60:.1f}分钟") + print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒") + print(f" - 检查点已保存:{checkpoint_filename}") + + if batch_end < total_sentences: + remaining_sentences = total_sentences - completed_sentences + avg_time_per_sentence = elapsed_time / completed_sentences + estimated_remaining_time = avg_time_per_sentence * remaining_sentences + logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") + print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") + + # 在批次之间稍作休息,避免过度压力 + if batch_end < total_sentences: + logger.info("批次间休息5秒...") + await asyncio.sleep(5) + + finally: + # 取消心跳监控 + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass # 打印最终统计信息 total_time = time.time() - start_time + logger.info(f"=== 全部处理完成!===") + logger.info(f" - 总成功:{len(all_processed_sentences)}") + logger.info(f" - 总用时:{total_time/60:.1f}分钟") + logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒") + print(f"\n=== 全部处理完成!===") print(f" - 总成功:{len(all_processed_sentences)}") print(f" - 总用时:{total_time/60:.1f}分钟") @@ -518,9 +742,9 @@ class EnhancedTRExProcessor: def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str: """保存检查点文件""" - # 生成检查点文件名 - base_name = os.path.splitext(self.output_file)[0] - checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json" + # 生成检查点文件名,确保在output目录中 + base_name = os.path.splitext(os.path.basename(self.output_file))[0] + checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json") # 保存检查点 with open(checkpoint_filename, 'w', encoding='utf-8') as f: @@ -578,26 +802,29 @@ class EnhancedTRExProcessor: print(f"去重后剩余 {len(unique_sentences)} 个句子") - # 使用LLM处理句子 - if self.enable_llm_processing: - processed_sentences = await self.process_sentences_with_llm(unique_sentences) - else: - # 基础模式:不使用LLM - processed_sentences = [ - { - "original_sentence": sentence, - "corrected_sentence": sentence, - "importance_score": 5.0 - } - for sentence in unique_sentences - ] + # 保存原始句子到JSON文件 + sentences_data = { + "metadata": { + "total_sentences": len(unique_sentences), + "extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "source_files": len(json_files), + "max_files_limit": self.max_files + }, + "sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences] + } - return processed_sentences + with open(self.sentences_json, 'w', encoding='utf-8') as f: + json.dump(sentences_data, f, ensure_ascii=False, indent=2) + + print(f"句子提取完成!已保存到: {self.sentences_json}") + print(f"总计句子数: {len(unique_sentences)}") + + return unique_sentences def save_sentences(self, processed_sentences: List[Dict[str, Any]]): """保存处理后的句子到文件""" # 确保输出目录存在 - os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True) + os.makedirs('output', exist_ok=True) # 保存为JSON格式,包含完整信息 json_output_file = self.output_file.replace('.txt', '.json') @@ -637,8 +864,8 @@ class EnhancedTRExProcessor: def find_latest_checkpoint(self) -> Union[tuple, None]: """查找最新的检查点文件""" - base_name = os.path.splitext(self.output_file)[0] - pattern = f"./output/{base_name}_checkpoint_*.json" + base_name = os.path.splitext(os.path.basename(self.output_file))[0] + pattern = os.path.join('output', f"{base_name}_checkpoint_*.json") checkpoint_files = glob.glob(pattern) if not checkpoint_files: @@ -680,54 +907,311 @@ class EnhancedTRExProcessor: print(f"加载检查点文件失败: {e}") return [] + def get_processed_sentences_from_checkpoints(self) -> Set[str]: + """从检查点文件中获取已处理过的句子集合""" + if not self.output_file: + return set() + + processed_sentences = set() + + # 查找所有检查点文件 + base_name = os.path.splitext(os.path.basename(self.output_file))[0] + pattern = os.path.join('output', f"{base_name}_checkpoint_*.json") + checkpoint_files = glob.glob(pattern) + + if not checkpoint_files: + print("未找到检查点文件,将从头开始处理") + return set() + + # 找到最新的检查点文件 + latest_file = None + latest_count = 0 + + for file in checkpoint_files: + try: + match = re.search(r'checkpoint_(\d+)\.json$', file) + if match: + count = int(match.group(1)) + if count > latest_count: + latest_count = count + latest_file = file + except: + continue + + if latest_file: + print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)") + logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)") + try: + with open(latest_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + sentences_data = data.get('sentences', []) + for item in sentences_data: + original_sentence = item.get('original_sentence', '') + if original_sentence: + processed_sentences.add(original_sentence) + + print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子") + logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子") + + except Exception as e: + print(f"读取检查点文件失败: {e}") + return set() + + return processed_sentences + + async def process_with_llm(self): + """步骤2:从JSON文件读取句子并进行LLM处理""" + if not self.enable_llm_processing: + print("Error: LLM processing is disabled!") + return + + if not self.output_file: + print("Error: output_file is required for LLM processing!") + return + + print("=== 步骤2:LLM处理 ===") + + # 读取句子JSON文件 + if not os.path.exists(self.sentences_json): + print(f"Error: Sentences file {self.sentences_json} not found!") + print("请先运行步骤1进行句子提取") + return + + print(f"正在读取句子文件: {self.sentences_json}") + + try: + with open(self.sentences_json, 'r', encoding='utf-8') as f: + data = json.load(f) + + all_sentences = [item["sentence"] for item in data.get("sentences", [])] + print(f"从文件中读取了 {len(all_sentences)} 个句子") + + except Exception as e: + print(f"读取句子文件失败: {e}") + return + + # 获取已处理的句子 + processed_sentences_set = self.get_processed_sentences_from_checkpoints() + + # 过滤出未处理的句子 + unprocessed_sentences = [] + for sentence in all_sentences: + if sentence not in processed_sentences_set: + unprocessed_sentences.append(sentence) + + print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})") + logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})") + + if not unprocessed_sentences: + print("所有句子都已处理完成!") + + # 如果有检查点,直接从最新检查点生成最终文件 + if processed_sentences_set: + latest_checkpoint = self.find_latest_checkpoint() + if latest_checkpoint: + checkpoint_file, _ = latest_checkpoint + processed_data = self.load_checkpoint(checkpoint_file) + self.save_sentences(processed_data) + print("已从检查点生成最终输出文件") + return + + # 处理未处理的句子 + print("开始LLM处理...") + + # 检查ollama服务状态 + logger.info("检查Ollama服务状态...") + if not self.check_ollama_status(): + logger.error("Ollama服务状态异常,无法继续处理") + print("错误:Ollama服务状态异常,请检查服务是否正常运行") + return + + new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences) + + # 如果有之前的处理结果,合并它们 + if processed_sentences_set: + latest_checkpoint = self.find_latest_checkpoint() + if latest_checkpoint: + checkpoint_file, _ = latest_checkpoint + previous_processed = self.load_checkpoint(checkpoint_file) + + # 合并结果 + all_processed_sentences = previous_processed + new_processed_sentences + print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录") + else: + all_processed_sentences = new_processed_sentences + else: + all_processed_sentences = new_processed_sentences + + # 保存最终结果 + self.save_sentences(all_processed_sentences) + print("LLM处理完成!") + + # ==================== 新增:句子提取功能 ==================== + + def extract_sentences(self): + """步骤1:从TREx数据集提取句子并保存为JSON""" + if not self.input_dir: + print("Error: input_dir is required for sentence extraction!") + return + + print("=== 步骤1:句子提取 ===") + print("开始从TREx数据集提取句子...") + + json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json")) + + if not json_files: + print(f"No JSON files found in {self.input_dir}") + return + + # 排序文件以确保一致的处理顺序 + json_files.sort() + + if self.max_files: + json_files = json_files[:self.max_files] + + print(f"Found {len(json_files)} JSON files to process") + + all_sentences = [] + + for i, file_path in enumerate(json_files): + print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}") + + documents = self.parse_large_json_file(file_path) + print(f" Parsed {len(documents)} documents") + + for doc in documents: + sentences = self.extract_sentences_from_document(doc) + all_sentences.extend(sentences) + + print(f" Generated {len(all_sentences)} total raw sentences so far") + + print(f"总共提取了 {len(all_sentences)} 个原始句子") + + # 去重 + unique_sentences = [] + seen = set() + for sentence in all_sentences: + sentence = sentence.strip() + if sentence and sentence not in seen and len(sentence) > 10: + unique_sentences.append(sentence) + seen.add(sentence) + + print(f"去重后剩余 {len(unique_sentences)} 个句子") + + # 保存原始句子到JSON文件 + sentences_data = { + "metadata": { + "total_sentences": len(unique_sentences), + "extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "source_files": len(json_files), + "max_files_limit": self.max_files + }, + "sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences] + } + + with open(self.sentences_json, 'w', encoding='utf-8') as f: + json.dump(sentences_data, f, ensure_ascii=False, indent=2) + + print(f"句子提取完成!已保存到: {self.sentences_json}") + print(f"总计句子数: {len(unique_sentences)}") + + return unique_sentences + + def check_ollama_status(self) -> bool: + """检查ollama服务是否正常运行""" + try: + # 检查ollama进程是否运行 + result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True) + if result.returncode != 0: + logger.error("Ollama进程未运行") + return False + + # 检查ollama API是否响应 + response = requests.get('http://localhost:11434/api/tags', timeout=5) + if response.status_code == 200: + logger.info("Ollama服务状态正常") + return True + else: + logger.error(f"Ollama API响应异常,状态码: {response.status_code}") + return False + + except requests.exceptions.RequestException as e: + logger.error(f"无法连接到Ollama API: {e}") + return False + except Exception as e: + logger.error(f"检查Ollama状态时出错: {e}") + return False + def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing') + + # 选择运行模式 + parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm', + help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程') + + # 文件路径参数 parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files') - parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path') + parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)') + parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)') + + # 处理参数 parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)') parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)') - parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available') args = parser.parse_args() - if not os.path.exists(args.input_dir): - print(f"Error: Input directory {args.input_dir} does not exist!") - return + # 根据步骤验证参数 + if args.step in ['extract', 'all']: + if not os.path.exists(args.input_dir): + print(f"Error: Input directory {args.input_dir} does not exist!") + return + if args.step in ['llm', 'all']: + if args.no_llm: + print("Error: Cannot run LLM step with --no_llm flag!") + return + + # 创建处理器 processor = EnhancedTRExProcessor( - args.input_dir, - args.output_file, - args.max_files, + input_dir=args.input_dir, + sentences_json=args.sentences_json, + output_file=args.output_file, + max_files=args.max_files, enable_llm_processing=not args.no_llm ) - # 检查是否要从检查点恢复 - if args.resume: - checkpoint_result = processor.find_latest_checkpoint() - if checkpoint_result: - latest_checkpoint, latest_count = checkpoint_result - print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)") - confirm = input("是否从检查点恢复?(y/n): ").lower().strip() - if confirm == 'y': - processed_sentences = processor.load_checkpoint(latest_checkpoint) - if processed_sentences: - print(f"成功加载 {len(processed_sentences)} 条已处理的句子") - processor.save_sentences(processed_sentences) - print("从检查点恢复完成!") - return - else: - print("检查点文件加载失败,将重新开始处理") - else: - print("不从检查点恢复,将重新开始处理") - else: - print("未找到检查点文件,将重新开始处理") - - # 运行异步处理 - asyncio.run(processor.run()) + # 根据选择的步骤运行 + if args.step == 'extract': + print("=== 运行模式:仅句子提取 ===") + processor.extract_sentences() + + elif args.step == 'llm': + print("=== 运行模式:仅LLM处理 ===") + asyncio.run(processor.process_with_llm()) + + elif args.step == 'all': + print("=== 运行模式:完整流程 ===") + + # 步骤1:提取句子 + print("\n--- 开始步骤1:句子提取 ---") + sentences = processor.extract_sentences() + + if not sentences: + print("句子提取失败,退出") + return + + if args.no_llm: + print("LLM处理已禁用,流程结束") + return + + # 步骤2:LLM处理 + print("\n--- 开始步骤2:LLM处理 ---") + asyncio.run(processor.process_with_llm()) if __name__ == "__main__": diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 6fa19b9..dd35a19 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -3,6 +3,7 @@ import os os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun" import platform import argparse +from tqdm import tqdm import time import math import warnings @@ -18,8 +19,10 @@ from accelerate.utils import set_seed from accelerate.utils import DeepSpeedPlugin from accelerate.utils import DistributedDataParallelKwargs from transformers import AutoTokenizer, get_cosine_schedule_with_warmup +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity -from model.model import MiniMindLM +from model.model import MiniMindLM, RMSNorm from model.LMConfig import LMConfig from model.dataset import PretrainDataset @@ -41,9 +44,40 @@ def get_lr(it, num_iters, learning_rate): return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters)) # 初始化模型函数 -def init_model(lm_config, pretrained_embedding_path=None): +def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None): tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') model = MiniMindLM(lm_config) + + # 默认模型初始化 + Logger("Performing default model initialization...") + + # 初始化嵌入层权重 + nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02) + + # 初始化输出层权重(如果不共享权重的话) + if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight: + nn.init.normal_(model.output.weight, mean=0.0, std=0.02) + + # 初始化所有线性层 + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + # 使用Xavier/Glorot初始化 + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + # 嵌入层使用正态分布初始化 + nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, RMSNorm): + # RMSNorm的权重初始化为1 + if hasattr(module, 'weight'): + nn.init.ones_(module.weight) + + # 初始化位置编码相关参数 + if hasattr(model.extract_db, 'keys'): + nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02) + + Logger("Default model initialization completed") # 如果提供了预训练的嵌入权重,加载它们 if pretrained_embedding_path: @@ -51,7 +85,335 @@ def init_model(lm_config, pretrained_embedding_path=None): pretrained_embeddings = torch.load(pretrained_embedding_path) model.tok_embeddings.weight.data.copy_(pretrained_embeddings) model.output.weight.data.copy_(pretrained_embeddings) # 共享权重 + + if database_init_path: + import json + import numpy as np + from sentence_transformers import SentenceTransformer + import os + + Logger(f"Loading database initialization data from {database_init_path}") + + # 1. 加载JSON文件并转换为字典 + with open(database_init_path, 'r', encoding='utf-8') as f: + database_data = json.load(f) + + # 提取sentences列表 + sentences_data = database_data.get('sentences', []) + Logger(f"Loaded {len(sentences_data)} sentences from database") + + # 2. 按照importance_score进行排序(从高到低) + sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True) + Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})") + + # 3. 下载并初始化本地嵌入模型 + embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型 + embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2" + embedding_cache_dir = "./models/sentence_transformers/cache" + os.makedirs(embedding_cache_dir, exist_ok=True) + + Logger(f"Loading embedding model: {embedding_model_name}") + try: + embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir) + Logger("Embedding model loaded successfully") + except Exception as e: + Logger(f"Failed to load embedding model: {e}") + Logger("Falling back to random embeddings") + embedding_model = None + + # 4. 对每个corrected_sentence进行嵌入和token长度计算 + Logger("Processing sentences for embeddings and token lengths...") + + # 提取所有句子 + sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences] + + # 批量计算token长度 + Logger("Computing token lengths...") + token_lengths = [] + for sentence in sentences: + tokens = tokenizer.encode(sentence, add_special_tokens=False) + token_lengths.append(len(tokens)) + + # 批量计算嵌入 - 大幅提升速度 + Logger("Computing embeddings in batches...") + embeddings_list = [] + batch_size = 256 # 可以根据GPU内存调整 + + if embedding_model is not None: + try: + for i in range(0, len(sentences), batch_size): + batch_sentences = sentences[i:i+batch_size] + batch_embeddings = embedding_model.encode( + batch_sentences, + convert_to_tensor=False, + show_progress_bar=True if i == 0 else False, + batch_size=batch_size + ) + embeddings_list.extend(batch_embeddings) + + if (i + batch_size) % (batch_size * 10) == 0: + Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences") + + Logger("Batch embedding computation completed") + except Exception as e: + Logger(f"Error in batch encoding: {e}") + Logger("Falling back to random embeddings") + embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] + else: + # 使用随机嵌入 + embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences] + + # 创建处理后的句子列表 + processed_sentences = [] + for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)): + processed_sentences.append({ + 'sentence': sentence_data.get('corrected_sentence', ''), + 'importance_score': sentence_data.get('importance_score', 0.0), + 'token_length': token_length, + 'embedding': embedding, # Convert numpy array to list + 'original_index': i + }) + + # # Create a JSON-serializable version for saving + # json_serializable_sentences = [] + # for sentence in processed_sentences: + # json_sentence = sentence.copy() + # # Convert embedding to list if it's a numpy array + # if hasattr(json_sentence['embedding'], 'tolist'): + # json_sentence['embedding'] = json_sentence['embedding'].tolist() + # json_serializable_sentences.append(json_sentence) + + # json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8')) + + # processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8')) + + # 转换为numpy数组以便后续处理 + embeddings_array = np.array(embeddings_list) + token_lengths_array = np.array(token_lengths) + + Logger(f"Embedding processing completed:") + Logger(f" - Total sentences: {len(processed_sentences)}") + Logger(f" - Embedding shape: {embeddings_array.shape}") + Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}") + Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}") + + # 2. 聚类处理 - 优化版本 + Logger("Starting optimized clustering process...") + + # 聚类参数 + knowledge_num = args.knowledge_num + knowledge_length = args.knowledge_length + min_tokens = int(0.9 * knowledge_length) + max_tokens = knowledge_length + + # 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大) + if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算 + Logger("Pre-computing similarity matrix for faster clustering...") + embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) + similarity_matrix = cosine_similarity(embeddings_matrix) + Logger(f"Similarity matrix computed: {similarity_matrix.shape}") + else: + similarity_matrix = None + embeddings_matrix = np.array([s['embedding'] for s in processed_sentences]) + + clustered_rows = [] + remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象 + + Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens") + + # 选择聚类算法 + if args.fast_clustering and len(processed_sentences) > 5000: + Logger("Using ultra-fast approximate clustering algorithm...") + + # 超快速聚类:随机采样 + 批量处理 + import random + random.seed(42) # 确保可重现性 + + # 按重要性分层采样 + high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7] + medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7] + low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3] + + Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}") + + for cluster_idx in tqdm(range(knowledge_num)): + # 分层选择种子:优先选择高重要性句子 + if high_importance: + seed_pool = high_importance + elif medium_importance: + seed_pool = medium_importance + else: + seed_pool = low_importance if low_importance else list(range(len(processed_sentences))) + + if not seed_pool: + break + + # 随机选择种子(在同一重要性层级内) + seed_global_idx = random.choice(seed_pool) + seed_sentence = processed_sentences[seed_global_idx] + + # 从所有池中移除种子 + for pool in [high_importance, medium_importance, low_importance]: + if seed_global_idx in pool: + pool.remove(seed_global_idx) + + current_cluster_indices = [seed_global_idx] + current_tokens = seed_sentence['token_length'] + + if current_tokens < max_tokens: + # 快速选择:只从附近的句子中随机选择 + all_remaining = high_importance + medium_importance + low_importance + if all_remaining: + # 随机采样候选句子(而不是计算所有相似度) + sample_size = min(100, len(all_remaining)) + candidates = random.sample(all_remaining, sample_size) + + # 简单按token长度和重要性选择 + for candidate_idx in candidates: + candidate = processed_sentences[candidate_idx] + candidate_tokens = candidate['token_length'] + + if current_tokens + candidate_tokens + 1 <= max_tokens: + current_cluster_indices.append(candidate_idx) + current_tokens += candidate_tokens + 1 + + # 从池中移除 + for pool in [high_importance, medium_importance, low_importance]: + if candidate_idx in pool: + pool.remove(candidate_idx) + break + + if current_tokens >= min_tokens: + break + + # 生成聚类文本 + cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] + cluster_text = '\n'.join(cluster_sentences) + + # 转换为tokens + cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) + if len(cluster_tokens) > knowledge_length: + cluster_tokens = cluster_tokens[:knowledge_length] + else: + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) + + clustered_rows.append(cluster_tokens) + + if (cluster_idx + 1) % 1000 == 0: + total_remaining = len(high_importance) + len(medium_importance) + len(low_importance) + Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining") + + else: + # 原始优化算法(适用于中等规模数据集) + # 优化2: 批量处理和更高效的数据结构 + for cluster_idx in tqdm(range(knowledge_num)): + if not remaining_indices: + Logger(f"No more sentences available. Created {cluster_idx} clusters.") + break + + # 2.1 选择importance_score最高的句子作为种子 + remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices] + seed_idx_in_subset = max(range(len(remaining_sentences_subset)), + key=lambda i: remaining_sentences_subset[i]['importance_score']) + seed_global_idx = remaining_indices[seed_idx_in_subset] + seed_sentence = processed_sentences[seed_global_idx] + + # 从剩余索引中移除种子 + remaining_indices.remove(seed_global_idx) + + # 当前聚类 + current_cluster_indices = [seed_global_idx] + current_tokens = seed_sentence['token_length'] + + if current_tokens >= max_tokens: + # 如果种子句子已经超过最大token数,直接作为一个聚类 + cluster_text = seed_sentence['sentence'] + else: + # 2.2 优化的相似度计算和选择 + if remaining_indices: + if similarity_matrix is not None: + # 使用预计算的相似度矩阵 + similarities = similarity_matrix[seed_global_idx][remaining_indices] + else: + # 动态计算相似度(批量) + seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1] + remaining_embeddings = embeddings_matrix[remaining_indices] + similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0] + + # 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表 + similarity_tuples = [(similarities[i], remaining_indices[i], i) + for i in range(len(remaining_indices))] + + # 按相似度排序(降序) + similarity_tuples.sort(key=lambda x: x[0], reverse=True) + + # 优化3: 贪心选择,但限制搜索范围以提高速度 + max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子 + + selected_indices_in_remaining = [] + for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]: + candidate = processed_sentences[global_idx] + candidate_tokens = candidate['token_length'] + + if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline + current_cluster_indices.append(global_idx) + selected_indices_in_remaining.append(pos_in_remaining) + current_tokens += candidate_tokens + 1 + + if current_tokens >= min_tokens: + break + + # 批量移除选中的句子(从后往前移除以避免索引问题) + for pos in sorted(selected_indices_in_remaining, reverse=True): + remaining_indices.pop(pos) + + # 拼接句子 + cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices] + cluster_text = '\n'.join(cluster_sentences) + + # 将聚类文本转换为token + cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False) + + # 截断或填充到knowledge_length + if len(cluster_tokens) > knowledge_length: + cluster_tokens = cluster_tokens[:knowledge_length] + else: + # 用pad_token_id填充 + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens))) + + clustered_rows.append(cluster_tokens) + + # 优化4: 减少日志频率 + if (cluster_idx + 1) % 500 == 0: + Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining") + + # 如果聚类数量不足,用随机token填充 + while len(clustered_rows) < knowledge_num: + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + random_tokens = [pad_token_id] * knowledge_length + clustered_rows.append(random_tokens) + + # 转换为tensor + clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long) + + Logger(f"Clustering completed:") + Logger(f" - Created {len(clustered_rows)} clusters") + Logger(f" - Cluster shape: {clustered_tensor.shape}") + Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})") + + # 3. 初始化模型的weight_down_embed + if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'): + model.extract_db.weight_down_embed.data.copy_(clustered_tensor) + Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data") + else: + Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize") + # 存储为全局变量作为备选 + globals()['clustered_database'] = clustered_tensor + Logger(f"Database embeddings and sentences stored in model") + Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') return model, tokenizer @@ -290,7 +652,9 @@ def main(): parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目") - parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度") + parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度") + parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") + parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") args = parser.parse_args() ######################################################### @@ -379,7 +743,7 @@ def main(): ######################################################### # 初始化模型和tokenizer ######################################################### - model, tokenizer = init_model(lm_config, args.pretrained_embedding_path) + model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args) # 将accelerator传递给init_model函数中的Logger调用 Logger(f'模型初始化完成', accelerator)