对数据库进行了初始化
This commit is contained in:
parent
c09cd63794
commit
c96a9c35d5
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,3 +3,5 @@
|
|||||||
/out
|
/out
|
||||||
wandb/
|
wandb/
|
||||||
**/*.log
|
**/*.log
|
||||||
|
models/sentence_transformers/
|
||||||
|
models/sentence_transformers_cache/
|
97
preprocessing/README_trex_processor.md
Normal file
97
preprocessing/README_trex_processor.md
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# TREx 数据集处理工具使用说明
|
||||||
|
|
||||||
|
这个工具支持两步骤处理 TREx 数据集:
|
||||||
|
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
|
||||||
|
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install agno asyncio pydantic
|
||||||
|
```
|
||||||
|
|
||||||
|
确保已安装并启动 ollama,并下载 qwen3:4b 模型:
|
||||||
|
```bash
|
||||||
|
ollama pull qwen3:4b
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 完整流程(两步骤连续执行)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 分步骤执行
|
||||||
|
|
||||||
|
#### 步骤1:仅提取句子
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 步骤2:仅LLM处理
|
||||||
|
```bash
|
||||||
|
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 主要参数说明
|
||||||
|
|
||||||
|
- `--step`: 运行步骤
|
||||||
|
- `extract`: 仅提取句子
|
||||||
|
- `llm`: 仅LLM处理
|
||||||
|
- `all`: 完整流程(默认)
|
||||||
|
|
||||||
|
- `--input_dir`: TREx数据集目录(默认:`dataset/TREx`)
|
||||||
|
- `--sentences_json`: 提取的句子JSON文件(默认:`extracted_sentences.json`)
|
||||||
|
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`)
|
||||||
|
- `--max_files`: 最大处理文件数(用于测试)
|
||||||
|
- `--no_llm`: 禁用LLM处理
|
||||||
|
|
||||||
|
## 输出文件
|
||||||
|
|
||||||
|
**注意:所有输出文件都会自动保存在 `./output/` 目录中**
|
||||||
|
|
||||||
|
### 步骤1输出
|
||||||
|
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
|
||||||
|
|
||||||
|
### 步骤2输出
|
||||||
|
- `output/{output_file}.txt`: 修正后的句子文本文件
|
||||||
|
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
|
||||||
|
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
|
||||||
|
|
||||||
|
### 检查点文件
|
||||||
|
- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点
|
||||||
|
|
||||||
|
## 检查点恢复机制
|
||||||
|
|
||||||
|
- 步骤2会自动检测已有的检查点文件(在 `output/` 目录中)
|
||||||
|
- 只处理尚未处理的句子,避免重复工作
|
||||||
|
- 如果所有句子都已处理,会直接生成最终输出文件
|
||||||
|
|
||||||
|
## 示例工作流
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. 先提取句子(可以快速完成)
|
||||||
|
python trex_to_sentences_simple.py --step extract --max_files 5
|
||||||
|
|
||||||
|
# 2. 后续进行LLM处理(耗时较长,支持断点续传)
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
|
||||||
|
# 如果中途中断,再次运行步骤2会自动从检查点恢复
|
||||||
|
python trex_to_sentences_simple.py --step llm
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能特点
|
||||||
|
|
||||||
|
- **并发处理**: 最大54个并发LLM请求
|
||||||
|
- **检查点保存**: 每2000条句子自动保存,支持断点续传
|
||||||
|
- **进度显示**: 详细的处理进度和时间预估
|
||||||
|
- **错误处理**: LLM请求失败时使用原句子和默认评分
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 首次运行步骤2前,必须先完成步骤1
|
||||||
|
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
|
||||||
|
3. LLM处理速度取决于模型性能和网络状况
|
||||||
|
4. 建议先用`--max_files`参数测试小批量数据
|
@ -2,19 +2,57 @@
|
|||||||
"""
|
"""
|
||||||
TREx数据集增强预处理脚本
|
TREx数据集增强预处理脚本
|
||||||
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
|
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
|
||||||
|
|
||||||
|
支持两个独立步骤:
|
||||||
|
1. 句子提取:从TREx数据集提取句子并保存为JSON
|
||||||
|
2. LLM处理:读取JSON文件进行LLM后处理和重要性评分
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from typing import List, Dict, Any, Union
|
from typing import List, Dict, Any, Union, Set
|
||||||
import re
|
import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
import subprocess
|
||||||
|
import requests
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from agno.agent import Agent
|
from agno.agent import Agent
|
||||||
from agno.models.ollama import Ollama
|
from agno.models.ollama import Ollama
|
||||||
|
|
||||||
|
# 设置日志系统
|
||||||
|
def setup_logging():
|
||||||
|
"""设置日志系统"""
|
||||||
|
# 确保logs目录存在
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
|
||||||
|
# 创建日志文件名(包含时间戳)
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
log_file = f'logs/trex_processor_{timestamp}.log'
|
||||||
|
|
||||||
|
# 配置日志格式
|
||||||
|
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
|
||||||
|
|
||||||
|
# 配置root logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format=log_format,
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_file, encoding='utf-8'),
|
||||||
|
logging.StreamHandler() # 同时输出到控制台
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
|
||||||
|
return logger
|
||||||
|
|
||||||
|
# 全局日志对象
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
class ProcessedSentence(BaseModel):
|
class ProcessedSentence(BaseModel):
|
||||||
"""处理后的句子结构"""
|
"""处理后的句子结构"""
|
||||||
@ -31,16 +69,53 @@ class ProcessedSentence(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class EnhancedTRExProcessor:
|
class EnhancedTRExProcessor:
|
||||||
def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True):
|
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
|
||||||
|
sentences_json: str = None, enable_llm_processing: bool = True):
|
||||||
self.input_dir = input_dir
|
self.input_dir = input_dir
|
||||||
|
|
||||||
|
# 确保output目录存在
|
||||||
|
os.makedirs('output', exist_ok=True)
|
||||||
|
|
||||||
|
# 确保所有输出文件都在output目录中
|
||||||
|
if output_file:
|
||||||
|
if not output_file.startswith('output/'):
|
||||||
|
self.output_file = os.path.join('output', output_file)
|
||||||
|
else:
|
||||||
self.output_file = output_file
|
self.output_file = output_file
|
||||||
|
else:
|
||||||
|
self.output_file = None
|
||||||
|
|
||||||
|
if sentences_json:
|
||||||
|
if not sentences_json.startswith('output/'):
|
||||||
|
self.sentences_json = os.path.join('output', sentences_json)
|
||||||
|
else:
|
||||||
|
self.sentences_json = sentences_json
|
||||||
|
else:
|
||||||
|
self.sentences_json = "output/extracted_sentences.json"
|
||||||
|
|
||||||
self.max_files = max_files
|
self.max_files = max_files
|
||||||
self.enable_llm_processing = enable_llm_processing
|
self.enable_llm_processing = enable_llm_processing
|
||||||
|
|
||||||
# 初始化agno agent
|
# LLM处理配置
|
||||||
|
self.llm_timeout = 60 # 增加每个请求的超时时间到60秒
|
||||||
|
self.max_concurrent = 8 # 进一步降低并发数到4
|
||||||
|
self.max_retries = 2 # 减少重试次数避免过长等待
|
||||||
|
self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.total_requests = 0
|
||||||
|
self.successful_requests = 0
|
||||||
|
self.failed_requests = 0
|
||||||
|
self.timeout_requests = 0
|
||||||
|
self.last_successful_time = time.time()
|
||||||
|
self.last_activity_time = time.time() # 新增:最后活动时间
|
||||||
|
|
||||||
|
# 初始化agno agent(仅在需要LLM处理时)
|
||||||
if self.enable_llm_processing:
|
if self.enable_llm_processing:
|
||||||
self.setup_agent()
|
self.setup_agent()
|
||||||
|
|
||||||
|
logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}秒")
|
||||||
|
|
||||||
# 扩展的Wikidata属性映射
|
# 扩展的Wikidata属性映射
|
||||||
self.property_mappings = {
|
self.property_mappings = {
|
||||||
# 基本关系
|
# 基本关系
|
||||||
@ -87,10 +162,10 @@ class EnhancedTRExProcessor:
|
|||||||
try:
|
try:
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
model=Ollama(
|
model=Ollama(
|
||||||
id="qwen3:4b",
|
id="gemma3:latest",
|
||||||
# 使用options设置temperature和其他参数
|
# 使用options设置temperature和其他参数
|
||||||
options={
|
options={
|
||||||
"temperature": 0.7,
|
"temperature": 0.2,
|
||||||
"top_p": 0.8,
|
"top_p": 0.8,
|
||||||
"top_k": 20,
|
"top_k": 20,
|
||||||
"num_ctx": 4096,
|
"num_ctx": 4096,
|
||||||
@ -98,111 +173,116 @@ class EnhancedTRExProcessor:
|
|||||||
),
|
),
|
||||||
response_model=ProcessedSentence,
|
response_model=ProcessedSentence,
|
||||||
instructions=[
|
instructions=[
|
||||||
"你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。",
|
"You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.",
|
||||||
"",
|
"",
|
||||||
"### 句子修正规则:",
|
"### Sentence Correction Rules:",
|
||||||
"1. 移除Wikipedia特有标记:如(disambiguation)、(film)、(band)等括号内容",
|
"1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses",
|
||||||
"2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is'、'or'等",
|
"2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.",
|
||||||
"3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确",
|
"3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage",
|
||||||
"4. 清理乱码和特殊字符:如â、€、™等编码问题",
|
"4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues",
|
||||||
"5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺",
|
"5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent",
|
||||||
"6. 不要添加原文没有的信息,只修正错误",
|
"6. Do not add information not present in the original text, only correct errors",
|
||||||
"",
|
"",
|
||||||
"### 修正示例:",
|
"### Correction Examples:",
|
||||||
"- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'",
|
"- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'",
|
||||||
"- 修正:'Argument is related to philosophy and logic.'",
|
"- Corrected: 'Argument is related to philosophy and logic.'",
|
||||||
"",
|
"",
|
||||||
"- 错误:'Beijing is a capital city and are.'",
|
"- Error: 'Beijing is a capital city and are.'",
|
||||||
"- 修正:'Beijing is a capital city.'",
|
"- Corrected: 'Beijing is a capital city.'",
|
||||||
"",
|
"",
|
||||||
"重要性评分标准(0.0-10.0,以0.1递进):",
|
"Importance scoring criteria (0.0-10.0, in increments of 0.1):",
|
||||||
"",
|
"",
|
||||||
"0.0分 - 完全错误或无意义的信息",
|
"0.0 points - Completely incorrect or meaningless information",
|
||||||
"例:'苹果是一种金属'、'太阳从西边升起'、'1+1=3'",
|
"Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'",
|
||||||
"",
|
"",
|
||||||
"0.5分 - 几乎无价值的信息",
|
"0.5 points - Almost worthless information",
|
||||||
"例:'某个虚构角色的袜子颜色'、'游戏中NPC的对话第三句话'、'某人昨天早餐吃了什么'",
|
"Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'",
|
||||||
"",
|
"",
|
||||||
"1.0分 - 极其罕见、无实用价值的知识",
|
"1.0 points - Extremely rare, non-practical knowledge",
|
||||||
"例:'某小说背景角色宠物名字'、'某部电影片尾字幕第15行内容'、'某网站用户ID为123456的昵称'",
|
"Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'",
|
||||||
"",
|
"",
|
||||||
"1.5分 - 非常小众的细节信息",
|
"1.5 points - Very niche detailed information",
|
||||||
"例:'某电影第37分钟路人甲服装'、'某游戏隐藏关卡的背景音乐时长'、'某漫画第200页第3个对话框内容'",
|
"Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'",
|
||||||
"",
|
"",
|
||||||
"2.0分 - 小众专业领域的细节",
|
"2.0 points - Details in niche professional fields",
|
||||||
"例:'稀有矿物在特定温度下颜色变化'、'某种昆虫的第三对触角长度'、'某化学反应的副产物分子式'",
|
"Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'",
|
||||||
"",
|
"",
|
||||||
"2.5分 - 专业人士才关心的技术细节",
|
"2.5 points - Technical details only professionals care about",
|
||||||
"例:'软件库特定版本发布日期'、'某算法的时间复杂度系数'、'某种材料的热膨胀系数'",
|
"Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'",
|
||||||
"",
|
"",
|
||||||
"3.0分 - 特定领域的专业知识",
|
"3.0 points - Professional knowledge in specific fields",
|
||||||
"例:'编程语言语法特性'、'某种病毒的基因序列'、'古代某朝代的官职制度'",
|
"Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'",
|
||||||
"",
|
"",
|
||||||
"3.5分 - 有一定价值的专业信息",
|
"3.5 points - Professional information with some value",
|
||||||
"例:'某历史朝代特定制度'、'某种药物的作用机制'、'某技术标准的制定时间'",
|
"Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'",
|
||||||
"",
|
"",
|
||||||
"4.0分 - 较少人知道但有意义的知识",
|
"4.0 points - Meaningful knowledge known by few",
|
||||||
"例:'某国家独特文化传统'、'某科学家的重要发现'、'某历史事件的详细过程'",
|
"Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'",
|
||||||
"",
|
"",
|
||||||
"4.5分 - 部分人群感兴趣的知识",
|
"4.5 points - Knowledge of interest to some groups",
|
||||||
"例:'作家创作背景'、'某艺术流派特点'、'某运动项目规则细节'",
|
"Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'",
|
||||||
"",
|
"",
|
||||||
"5.0分 - 中等重要性的一般知识",
|
"5.0 points - General knowledge of moderate importance",
|
||||||
"例:'城市著名景点'、'某企业发展历史'、'某动物生活习性'",
|
"Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'",
|
||||||
"",
|
"",
|
||||||
"5.5分 - 比较有用的常识",
|
"5.5 points - Fairly useful common sense",
|
||||||
"例:'植物生长环境'、'健康饮食常识'、'基本急救知识'",
|
"Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'",
|
||||||
"",
|
"",
|
||||||
"6.0分 - 多数受教育人群应该知道的知识",
|
"6.0 points - Knowledge most educated people should know",
|
||||||
"例:'莎士比亚代表作品'、'基本几何定理'、'世界主要货币'",
|
"Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'",
|
||||||
"",
|
"",
|
||||||
"6.5分 - 重要的文化或科学常识",
|
"6.5 points - Important cultural or scientific common sense",
|
||||||
"例:'DNA基本结构'、'牛顿三大定律'、'世界主要宗教'",
|
"Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'",
|
||||||
"",
|
"",
|
||||||
"7.0分 - 重要的基础知识",
|
"7.0 points - Important foundational knowledge",
|
||||||
"例:'二次世界大战时间'、'人体主要器官功能'、'基本数学运算规则'",
|
"Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'",
|
||||||
"",
|
"",
|
||||||
"7.5分 - 非常重要的常识",
|
"7.5 points - Very important common sense",
|
||||||
"例:'光速是宇宙中最快的'、'地球是圆的'、'血液循环基本原理'",
|
"Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'",
|
||||||
"",
|
"",
|
||||||
"8.0分 - 基础教育中的核心知识",
|
"8.0 points - Core knowledge in basic education",
|
||||||
"例:'地球绕太阳运行'、'四季形成原理'、'基本语法规则'",
|
"Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'",
|
||||||
"",
|
"",
|
||||||
"8.5分 - 每个人都应该掌握的重要知识",
|
"8.5 points - Important knowledge everyone should master",
|
||||||
"例:'水的化学式H2O'、'基本安全常识'、'简单数学计算'",
|
"Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'",
|
||||||
"",
|
"",
|
||||||
"9.0分 - 极其重要的基础概念",
|
"9.0 points - Extremely important basic concepts",
|
||||||
"例:'人类需要氧气生存'、'火是热的'、'基本方向概念'",
|
"Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'",
|
||||||
"",
|
"",
|
||||||
"9.5分 - 人人必知的核心知识",
|
"9.5 points - Core knowledge everyone must know",
|
||||||
"例:'一天有24小时'、'一年有12个月'、'基本数字概念'",
|
"Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'",
|
||||||
"",
|
"",
|
||||||
"10.0分 - 最基础、最重要的常识",
|
"10.0 points - Most basic and important common sense",
|
||||||
"例:'人类需要食物和水生存'、'天空是蓝色的'、'石头比羽毛重'",
|
"Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'",
|
||||||
"",
|
"",
|
||||||
"评分时请考虑:",
|
"When scoring, please consider:",
|
||||||
"1. 知识的普及程度 - 有多少人知道这个知识",
|
"1. Popularity of knowledge - How many people know this knowledge",
|
||||||
"2. 实用价值 - 这个知识在日常生活中有多大用处",
|
"2. Practical value - How useful this knowledge is in daily life",
|
||||||
"3. 教育重要性 - 这个知识在教育体系中的地位",
|
"3. Educational importance - The position of this knowledge in the education system",
|
||||||
"4. 文化意义 - 这个知识对理解世界的重要性",
|
"4. Cultural significance - The importance of this knowledge for understanding the world",
|
||||||
"",
|
"",
|
||||||
"请直接输出结构化结果,不需要思考过程。"
|
"Please output structured results directly without showing the thinking process."
|
||||||
],
|
],
|
||||||
markdown=False
|
markdown=False
|
||||||
)
|
)
|
||||||
print("LLM处理器初始化成功")
|
logger.info("LLM处理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"LLM处理器初始化失败: {e}")
|
||||||
print(f"LLM处理器初始化失败: {e}")
|
print(f"LLM处理器初始化失败: {e}")
|
||||||
print("将使用基础模式(不使用LLM后处理)")
|
print("将使用基础模式(不使用LLM后处理)")
|
||||||
self.enable_llm_processing = False
|
self.enable_llm_processing = False
|
||||||
|
|
||||||
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
|
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
|
||||||
"""使用LLM处理单个句子(保留用于单独调用)"""
|
"""使用LLM处理单个句子(保留用于单独调用)"""
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
try:
|
try:
|
||||||
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
|
||||||
|
|
||||||
# 使用agent.arun进行异步调用
|
# 使用asyncio.wait_for添加超时机制
|
||||||
response = await self.agent.arun(prompt)
|
response = await asyncio.wait_for(
|
||||||
|
self.agent.arun(prompt),
|
||||||
|
timeout=self.llm_timeout
|
||||||
|
)
|
||||||
|
|
||||||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||||
if isinstance(response, ProcessedSentence):
|
if isinstance(response, ProcessedSentence):
|
||||||
@ -216,9 +296,22 @@ class EnhancedTRExProcessor:
|
|||||||
importance_score=message['importance_score']
|
importance_score=message['importance_score']
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
logger.error(f"LLM请求最终超时,使用默认处理: {sentence[:50]}...")
|
||||||
|
break
|
||||||
|
# 等待一段时间后重试
|
||||||
|
await asyncio.sleep(2 ** attempt) # 指数退避
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM处理句子时出错: {e}")
|
logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||||||
# 出错时返回原句子和中等评分
|
if attempt == self.max_retries - 1:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# 所有重试都失败,返回原句子和中等评分
|
||||||
|
logger.warning(f"使用默认处理: {sentence[:50]}...")
|
||||||
return ProcessedSentence(
|
return ProcessedSentence(
|
||||||
corrected_sentence=sentence,
|
corrected_sentence=sentence,
|
||||||
importance_score=5.0
|
importance_score=5.0
|
||||||
@ -369,11 +462,19 @@ class EnhancedTRExProcessor:
|
|||||||
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
|
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
|
||||||
"""使用信号量控制并发的LLM处理"""
|
"""使用信号量控制并发的LLM处理"""
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
try:
|
self.total_requests += 1
|
||||||
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
self.last_activity_time = time.time() # 更新活动时间
|
||||||
|
success = False
|
||||||
|
|
||||||
# 使用agent.arun进行异步调用
|
for attempt in range(self.max_retries):
|
||||||
response = await self.agent.arun(prompt)
|
try:
|
||||||
|
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
|
||||||
|
|
||||||
|
# 使用asyncio.wait_for添加超时机制
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self.agent.arun(prompt),
|
||||||
|
timeout=self.llm_timeout
|
||||||
|
)
|
||||||
|
|
||||||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||||
if isinstance(response, ProcessedSentence):
|
if isinstance(response, ProcessedSentence):
|
||||||
@ -387,7 +488,6 @@ class EnhancedTRExProcessor:
|
|||||||
message = response.messages[-1].content
|
message = response.messages[-1].content
|
||||||
message = message.replace("```json", "").replace("```", "")
|
message = message.replace("```json", "").replace("```", "")
|
||||||
message = json.loads(message)
|
message = json.loads(message)
|
||||||
# print(message)
|
|
||||||
result = {
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"original_sentence": sentence,
|
"original_sentence": sentence,
|
||||||
@ -395,13 +495,20 @@ class EnhancedTRExProcessor:
|
|||||||
"importance_score": message['importance_score']
|
"importance_score": message['importance_score']
|
||||||
}
|
}
|
||||||
|
|
||||||
# 打印详细进度信息
|
# 成功处理
|
||||||
if index % 100 == 0:
|
self.successful_requests += 1
|
||||||
|
self.last_successful_time = time.time()
|
||||||
|
self.last_activity_time = time.time() # 更新活动时间
|
||||||
|
success = True
|
||||||
|
|
||||||
|
# 打印详细进度信息 - 降低频率到每50个
|
||||||
|
if index % 50 == 0:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
elapsed_time = current_time - start_time
|
elapsed_time = current_time - start_time
|
||||||
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
|
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
|
||||||
remaining_sentences = total_sentences - (index + 1)
|
remaining_sentences = total_sentences - (index + 1)
|
||||||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||||
|
success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
|
||||||
|
|
||||||
# 格式化时间显示
|
# 格式化时间显示
|
||||||
def format_time(seconds):
|
def format_time(seconds):
|
||||||
@ -414,17 +521,44 @@ class EnhancedTRExProcessor:
|
|||||||
hours = seconds / 3600
|
hours = seconds / 3600
|
||||||
return f"{hours:.1f}小时"
|
return f"{hours:.1f}小时"
|
||||||
|
|
||||||
|
logger.info(f"已完成第 {index + 1} 个句子的处理")
|
||||||
|
logger.info(f" - 剩余句子数: {remaining_sentences}")
|
||||||
|
logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||||||
|
logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||||||
|
logger.info(f" - 已用时间: {format_time(elapsed_time)}")
|
||||||
|
logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
|
||||||
|
|
||||||
print(f"已完成第 {index + 1} 个句子的处理")
|
print(f"已完成第 {index + 1} 个句子的处理")
|
||||||
print(f" - 剩余句子数: {remaining_sentences}")
|
print(f" - 剩余句子数: {remaining_sentences}")
|
||||||
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||||||
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||||||
print(f" - 已用时间: {format_time(elapsed_time)}")
|
print(f" - 已用时间: {format_time(elapsed_time)}")
|
||||||
|
print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.timeout_requests += 1
|
||||||
|
self.last_activity_time = time.time() # 更新活动时间
|
||||||
|
logger.warning(f"第 {index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
logger.error(f"第 {index} 个句子最终超时,使用默认处理")
|
||||||
|
break
|
||||||
|
# 指数退避
|
||||||
|
await asyncio.sleep(2 ** attempt)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理第 {index} 个句子时出错: {e}")
|
self.last_activity_time = time.time() # 更新活动时间
|
||||||
# 出错时返回原句子和中等评分
|
logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# 所有重试都失败,使用默认处理
|
||||||
|
if not success:
|
||||||
|
self.failed_requests += 1
|
||||||
|
logger.warning(f"第 {index} 个句子使用默认处理: {sentence[:50]}...")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"index": index,
|
"index": index,
|
||||||
"original_sentence": sentence,
|
"original_sentence": sentence,
|
||||||
@ -432,26 +566,82 @@ class EnhancedTRExProcessor:
|
|||||||
"importance_score": 5.0
|
"importance_score": 5.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def heartbeat_monitor(self, total_sentences: int):
|
||||||
|
"""心跳监控,检测是否有长时间无响应"""
|
||||||
|
consecutive_warnings = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_last_success = current_time - self.last_successful_time
|
||||||
|
time_since_last_activity = current_time - self.last_activity_time
|
||||||
|
|
||||||
|
# 检查最后成功时间
|
||||||
|
if time_since_last_success > self.heartbeat_interval:
|
||||||
|
consecutive_warnings += 1
|
||||||
|
logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
|
||||||
|
print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
|
||||||
|
|
||||||
|
# 打印当前统计信息
|
||||||
|
if self.total_requests > 0:
|
||||||
|
success_rate = self.successful_requests / self.total_requests * 100
|
||||||
|
logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
|
||||||
|
print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
|
||||||
|
|
||||||
|
if time_since_last_success > self.heartbeat_interval * 3:
|
||||||
|
logger.error(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!")
|
||||||
|
print(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!")
|
||||||
|
print(f" 建议:检查ollama服务状态,或考虑重启程序")
|
||||||
|
|
||||||
|
# 检查ollama服务状态
|
||||||
|
if not self.check_ollama_status():
|
||||||
|
logger.critical("💀 Ollama服务异常,这可能是卡死的原因!")
|
||||||
|
print("💀 Ollama服务异常,这可能是卡死的原因!")
|
||||||
|
|
||||||
|
if consecutive_warnings >= 5:
|
||||||
|
logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
|
||||||
|
print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
|
||||||
|
else:
|
||||||
|
if consecutive_warnings > 0:
|
||||||
|
logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||||||
|
print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||||||
|
consecutive_warnings = 0
|
||||||
|
logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||||||
|
|
||||||
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||||||
"""批量并发处理句子,每2000条保存一次检查点"""
|
"""批量并发处理句子,每2000条保存一次检查点"""
|
||||||
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:54)...")
|
logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...")
|
||||||
|
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...")
|
||||||
|
|
||||||
# 记录开始时间
|
# 记录开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
total_sentences = len(sentences)
|
total_sentences = len(sentences)
|
||||||
|
|
||||||
# 分批处理,每批2000个句子
|
# 分批处理,每批1000个句子(减少批次大小)
|
||||||
batch_size = 2000
|
batch_size = 1000
|
||||||
all_processed_sentences = []
|
all_processed_sentences = []
|
||||||
|
|
||||||
|
# 启动心跳监控
|
||||||
|
heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences))
|
||||||
|
|
||||||
|
try:
|
||||||
for batch_start in range(0, total_sentences, batch_size):
|
for batch_start in range(0, total_sentences, batch_size):
|
||||||
batch_end = min(batch_start + batch_size, total_sentences)
|
batch_end = min(batch_start + batch_size, total_sentences)
|
||||||
batch_sentences = sentences[batch_start:batch_end]
|
batch_sentences = sentences[batch_start:batch_end]
|
||||||
|
|
||||||
|
logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||||||
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||||||
|
|
||||||
# 创建信号量限制并发数
|
# 创建信号量限制并发数(降低到8)
|
||||||
semaphore = asyncio.Semaphore(54)
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
|
||||||
|
# 重置批次统计
|
||||||
|
batch_start_time = time.time()
|
||||||
|
self.total_requests = 0
|
||||||
|
self.successful_requests = 0
|
||||||
|
self.failed_requests = 0
|
||||||
|
self.timeout_requests = 0
|
||||||
|
|
||||||
# 创建当前批次的任务
|
# 创建当前批次的任务
|
||||||
tasks = []
|
tasks = []
|
||||||
@ -461,7 +651,9 @@ class EnhancedTRExProcessor:
|
|||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
# 并发执行当前批次的任务
|
# 并发执行当前批次的任务
|
||||||
|
logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||||||
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||||||
|
|
||||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 处理当前批次的结果,过滤异常
|
# 处理当前批次的结果,过滤异常
|
||||||
@ -470,6 +662,7 @@ class EnhancedTRExProcessor:
|
|||||||
|
|
||||||
for result in batch_results:
|
for result in batch_results:
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"任务执行异常: {result}")
|
||||||
print(f"任务执行异常: {result}")
|
print(f"任务执行异常: {result}")
|
||||||
batch_error_count += 1
|
batch_error_count += 1
|
||||||
elif isinstance(result, dict):
|
elif isinstance(result, dict):
|
||||||
@ -492,10 +685,22 @@ class EnhancedTRExProcessor:
|
|||||||
|
|
||||||
# 打印当前批次统计信息
|
# 打印当前批次统计信息
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
batch_time = time.time() - batch_start_time
|
||||||
completed_sentences = len(all_processed_sentences)
|
completed_sentences = len(all_processed_sentences)
|
||||||
|
|
||||||
|
logger.info(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||||||
|
logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||||||
|
logger.info(f" - 批次用时:{batch_time/60:.1f}分钟")
|
||||||
|
logger.info(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
|
||||||
|
logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||||||
|
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||||||
|
logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||||||
|
logger.info(f" - 检查点已保存:{checkpoint_filename}")
|
||||||
|
|
||||||
print(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
print(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||||||
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||||||
|
print(f" - 批次用时:{batch_time/60:.1f}分钟")
|
||||||
|
print(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
|
||||||
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||||||
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||||||
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||||||
@ -505,10 +710,29 @@ class EnhancedTRExProcessor:
|
|||||||
remaining_sentences = total_sentences - completed_sentences
|
remaining_sentences = total_sentences - completed_sentences
|
||||||
avg_time_per_sentence = elapsed_time / completed_sentences
|
avg_time_per_sentence = elapsed_time / completed_sentences
|
||||||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||||
|
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||||||
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||||||
|
|
||||||
|
# 在批次之间稍作休息,避免过度压力
|
||||||
|
if batch_end < total_sentences:
|
||||||
|
logger.info("批次间休息5秒...")
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 取消心跳监控
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
try:
|
||||||
|
await heartbeat_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
# 打印最终统计信息
|
# 打印最终统计信息
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
|
logger.info(f"=== 全部处理完成!===")
|
||||||
|
logger.info(f" - 总成功:{len(all_processed_sentences)}")
|
||||||
|
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
|
||||||
|
logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
|
||||||
|
|
||||||
print(f"\n=== 全部处理完成!===")
|
print(f"\n=== 全部处理完成!===")
|
||||||
print(f" - 总成功:{len(all_processed_sentences)}")
|
print(f" - 总成功:{len(all_processed_sentences)}")
|
||||||
print(f" - 总用时:{total_time/60:.1f}分钟")
|
print(f" - 总用时:{total_time/60:.1f}分钟")
|
||||||
@ -518,9 +742,9 @@ class EnhancedTRExProcessor:
|
|||||||
|
|
||||||
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
|
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
|
||||||
"""保存检查点文件"""
|
"""保存检查点文件"""
|
||||||
# 生成检查点文件名
|
# 生成检查点文件名,确保在output目录中
|
||||||
base_name = os.path.splitext(self.output_file)[0]
|
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||||||
checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json"
|
checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json")
|
||||||
|
|
||||||
# 保存检查点
|
# 保存检查点
|
||||||
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
||||||
@ -578,26 +802,29 @@ class EnhancedTRExProcessor:
|
|||||||
|
|
||||||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||||||
|
|
||||||
# 使用LLM处理句子
|
# 保存原始句子到JSON文件
|
||||||
if self.enable_llm_processing:
|
sentences_data = {
|
||||||
processed_sentences = await self.process_sentences_with_llm(unique_sentences)
|
"metadata": {
|
||||||
else:
|
"total_sentences": len(unique_sentences),
|
||||||
# 基础模式:不使用LLM
|
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
processed_sentences = [
|
"source_files": len(json_files),
|
||||||
{
|
"max_files_limit": self.max_files
|
||||||
"original_sentence": sentence,
|
},
|
||||||
"corrected_sentence": sentence,
|
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||||||
"importance_score": 5.0
|
|
||||||
}
|
}
|
||||||
for sentence in unique_sentences
|
|
||||||
]
|
|
||||||
|
|
||||||
return processed_sentences
|
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||||||
|
print(f"总计句子数: {len(unique_sentences)}")
|
||||||
|
|
||||||
|
return unique_sentences
|
||||||
|
|
||||||
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
||||||
"""保存处理后的句子到文件"""
|
"""保存处理后的句子到文件"""
|
||||||
# 确保输出目录存在
|
# 确保输出目录存在
|
||||||
os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True)
|
os.makedirs('output', exist_ok=True)
|
||||||
|
|
||||||
# 保存为JSON格式,包含完整信息
|
# 保存为JSON格式,包含完整信息
|
||||||
json_output_file = self.output_file.replace('.txt', '.json')
|
json_output_file = self.output_file.replace('.txt', '.json')
|
||||||
@ -637,8 +864,8 @@ class EnhancedTRExProcessor:
|
|||||||
|
|
||||||
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
||||||
"""查找最新的检查点文件"""
|
"""查找最新的检查点文件"""
|
||||||
base_name = os.path.splitext(self.output_file)[0]
|
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||||||
pattern = f"./output/{base_name}_checkpoint_*.json"
|
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||||||
checkpoint_files = glob.glob(pattern)
|
checkpoint_files = glob.glob(pattern)
|
||||||
|
|
||||||
if not checkpoint_files:
|
if not checkpoint_files:
|
||||||
@ -680,54 +907,311 @@ class EnhancedTRExProcessor:
|
|||||||
print(f"加载检查点文件失败: {e}")
|
print(f"加载检查点文件失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
|
||||||
|
"""从检查点文件中获取已处理过的句子集合"""
|
||||||
|
if not self.output_file:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
processed_sentences = set()
|
||||||
|
|
||||||
|
# 查找所有检查点文件
|
||||||
|
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||||||
|
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||||||
|
checkpoint_files = glob.glob(pattern)
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
print("未找到检查点文件,将从头开始处理")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# 找到最新的检查点文件
|
||||||
|
latest_file = None
|
||||||
|
latest_count = 0
|
||||||
|
|
||||||
|
for file in checkpoint_files:
|
||||||
|
try:
|
||||||
|
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||||||
|
if match:
|
||||||
|
count = int(match.group(1))
|
||||||
|
if count > latest_count:
|
||||||
|
latest_count = count
|
||||||
|
latest_file = file
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if latest_file:
|
||||||
|
print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||||||
|
logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||||||
|
try:
|
||||||
|
with open(latest_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
sentences_data = data.get('sentences', [])
|
||||||
|
for item in sentences_data:
|
||||||
|
original_sentence = item.get('original_sentence', '')
|
||||||
|
if original_sentence:
|
||||||
|
processed_sentences.add(original_sentence)
|
||||||
|
|
||||||
|
print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||||||
|
logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取检查点文件失败: {e}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
return processed_sentences
|
||||||
|
|
||||||
|
async def process_with_llm(self):
|
||||||
|
"""步骤2:从JSON文件读取句子并进行LLM处理"""
|
||||||
|
if not self.enable_llm_processing:
|
||||||
|
print("Error: LLM processing is disabled!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.output_file:
|
||||||
|
print("Error: output_file is required for LLM processing!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("=== 步骤2:LLM处理 ===")
|
||||||
|
|
||||||
|
# 读取句子JSON文件
|
||||||
|
if not os.path.exists(self.sentences_json):
|
||||||
|
print(f"Error: Sentences file {self.sentences_json} not found!")
|
||||||
|
print("请先运行步骤1进行句子提取")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"正在读取句子文件: {self.sentences_json}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.sentences_json, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
|
||||||
|
print(f"从文件中读取了 {len(all_sentences)} 个句子")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取句子文件失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取已处理的句子
|
||||||
|
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
|
||||||
|
|
||||||
|
# 过滤出未处理的句子
|
||||||
|
unprocessed_sentences = []
|
||||||
|
for sentence in all_sentences:
|
||||||
|
if sentence not in processed_sentences_set:
|
||||||
|
unprocessed_sentences.append(sentence)
|
||||||
|
|
||||||
|
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||||||
|
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||||||
|
|
||||||
|
if not unprocessed_sentences:
|
||||||
|
print("所有句子都已处理完成!")
|
||||||
|
|
||||||
|
# 如果有检查点,直接从最新检查点生成最终文件
|
||||||
|
if processed_sentences_set:
|
||||||
|
latest_checkpoint = self.find_latest_checkpoint()
|
||||||
|
if latest_checkpoint:
|
||||||
|
checkpoint_file, _ = latest_checkpoint
|
||||||
|
processed_data = self.load_checkpoint(checkpoint_file)
|
||||||
|
self.save_sentences(processed_data)
|
||||||
|
print("已从检查点生成最终输出文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 处理未处理的句子
|
||||||
|
print("开始LLM处理...")
|
||||||
|
|
||||||
|
# 检查ollama服务状态
|
||||||
|
logger.info("检查Ollama服务状态...")
|
||||||
|
if not self.check_ollama_status():
|
||||||
|
logger.error("Ollama服务状态异常,无法继续处理")
|
||||||
|
print("错误:Ollama服务状态异常,请检查服务是否正常运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences)
|
||||||
|
|
||||||
|
# 如果有之前的处理结果,合并它们
|
||||||
|
if processed_sentences_set:
|
||||||
|
latest_checkpoint = self.find_latest_checkpoint()
|
||||||
|
if latest_checkpoint:
|
||||||
|
checkpoint_file, _ = latest_checkpoint
|
||||||
|
previous_processed = self.load_checkpoint(checkpoint_file)
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
all_processed_sentences = previous_processed + new_processed_sentences
|
||||||
|
print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录")
|
||||||
|
else:
|
||||||
|
all_processed_sentences = new_processed_sentences
|
||||||
|
else:
|
||||||
|
all_processed_sentences = new_processed_sentences
|
||||||
|
|
||||||
|
# 保存最终结果
|
||||||
|
self.save_sentences(all_processed_sentences)
|
||||||
|
print("LLM处理完成!")
|
||||||
|
|
||||||
|
# ==================== 新增:句子提取功能 ====================
|
||||||
|
|
||||||
|
def extract_sentences(self):
|
||||||
|
"""步骤1:从TREx数据集提取句子并保存为JSON"""
|
||||||
|
if not self.input_dir:
|
||||||
|
print("Error: input_dir is required for sentence extraction!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("=== 步骤1:句子提取 ===")
|
||||||
|
print("开始从TREx数据集提取句子...")
|
||||||
|
|
||||||
|
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||||||
|
|
||||||
|
if not json_files:
|
||||||
|
print(f"No JSON files found in {self.input_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 排序文件以确保一致的处理顺序
|
||||||
|
json_files.sort()
|
||||||
|
|
||||||
|
if self.max_files:
|
||||||
|
json_files = json_files[:self.max_files]
|
||||||
|
|
||||||
|
print(f"Found {len(json_files)} JSON files to process")
|
||||||
|
|
||||||
|
all_sentences = []
|
||||||
|
|
||||||
|
for i, file_path in enumerate(json_files):
|
||||||
|
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||||||
|
|
||||||
|
documents = self.parse_large_json_file(file_path)
|
||||||
|
print(f" Parsed {len(documents)} documents")
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
sentences = self.extract_sentences_from_document(doc)
|
||||||
|
all_sentences.extend(sentences)
|
||||||
|
|
||||||
|
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||||||
|
|
||||||
|
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_sentences = []
|
||||||
|
seen = set()
|
||||||
|
for sentence in all_sentences:
|
||||||
|
sentence = sentence.strip()
|
||||||
|
if sentence and sentence not in seen and len(sentence) > 10:
|
||||||
|
unique_sentences.append(sentence)
|
||||||
|
seen.add(sentence)
|
||||||
|
|
||||||
|
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||||||
|
|
||||||
|
# 保存原始句子到JSON文件
|
||||||
|
sentences_data = {
|
||||||
|
"metadata": {
|
||||||
|
"total_sentences": len(unique_sentences),
|
||||||
|
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"source_files": len(json_files),
|
||||||
|
"max_files_limit": self.max_files
|
||||||
|
},
|
||||||
|
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||||||
|
print(f"总计句子数: {len(unique_sentences)}")
|
||||||
|
|
||||||
|
return unique_sentences
|
||||||
|
|
||||||
|
def check_ollama_status(self) -> bool:
|
||||||
|
"""检查ollama服务是否正常运行"""
|
||||||
|
try:
|
||||||
|
# 检查ollama进程是否运行
|
||||||
|
result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
logger.error("Ollama进程未运行")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查ollama API是否响应
|
||||||
|
response = requests.get('http://localhost:11434/api/tags', timeout=5)
|
||||||
|
if response.status_code == 200:
|
||||||
|
logger.info("Ollama服务状态正常")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error(f"Ollama API响应异常,状态码: {response.status_code}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"无法连接到Ollama API: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查Ollama状态时出错: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
|
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
|
||||||
|
|
||||||
|
# 选择运行模式
|
||||||
|
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
|
||||||
|
help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程')
|
||||||
|
|
||||||
|
# 文件路径参数
|
||||||
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
||||||
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path')
|
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
|
||||||
|
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
|
||||||
|
|
||||||
|
# 处理参数
|
||||||
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
||||||
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
|
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
|
||||||
parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 根据步骤验证参数
|
||||||
|
if args.step in ['extract', 'all']:
|
||||||
if not os.path.exists(args.input_dir):
|
if not os.path.exists(args.input_dir):
|
||||||
print(f"Error: Input directory {args.input_dir} does not exist!")
|
print(f"Error: Input directory {args.input_dir} does not exist!")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if args.step in ['llm', 'all']:
|
||||||
|
if args.no_llm:
|
||||||
|
print("Error: Cannot run LLM step with --no_llm flag!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建处理器
|
||||||
processor = EnhancedTRExProcessor(
|
processor = EnhancedTRExProcessor(
|
||||||
args.input_dir,
|
input_dir=args.input_dir,
|
||||||
args.output_file,
|
sentences_json=args.sentences_json,
|
||||||
args.max_files,
|
output_file=args.output_file,
|
||||||
|
max_files=args.max_files,
|
||||||
enable_llm_processing=not args.no_llm
|
enable_llm_processing=not args.no_llm
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否要从检查点恢复
|
# 根据选择的步骤运行
|
||||||
if args.resume:
|
if args.step == 'extract':
|
||||||
checkpoint_result = processor.find_latest_checkpoint()
|
print("=== 运行模式:仅句子提取 ===")
|
||||||
if checkpoint_result:
|
processor.extract_sentences()
|
||||||
latest_checkpoint, latest_count = checkpoint_result
|
|
||||||
print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)")
|
|
||||||
confirm = input("是否从检查点恢复?(y/n): ").lower().strip()
|
|
||||||
if confirm == 'y':
|
|
||||||
processed_sentences = processor.load_checkpoint(latest_checkpoint)
|
|
||||||
if processed_sentences:
|
|
||||||
print(f"成功加载 {len(processed_sentences)} 条已处理的句子")
|
|
||||||
processor.save_sentences(processed_sentences)
|
|
||||||
print("从检查点恢复完成!")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
print("检查点文件加载失败,将重新开始处理")
|
|
||||||
else:
|
|
||||||
print("不从检查点恢复,将重新开始处理")
|
|
||||||
else:
|
|
||||||
print("未找到检查点文件,将重新开始处理")
|
|
||||||
|
|
||||||
# 运行异步处理
|
elif args.step == 'llm':
|
||||||
asyncio.run(processor.run())
|
print("=== 运行模式:仅LLM处理 ===")
|
||||||
|
asyncio.run(processor.process_with_llm())
|
||||||
|
|
||||||
|
elif args.step == 'all':
|
||||||
|
print("=== 运行模式:完整流程 ===")
|
||||||
|
|
||||||
|
# 步骤1:提取句子
|
||||||
|
print("\n--- 开始步骤1:句子提取 ---")
|
||||||
|
sentences = processor.extract_sentences()
|
||||||
|
|
||||||
|
if not sentences:
|
||||||
|
print("句子提取失败,退出")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.no_llm:
|
||||||
|
print("LLM处理已禁用,流程结束")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 步骤2:LLM处理
|
||||||
|
print("\n--- 开始步骤2:LLM处理 ---")
|
||||||
|
asyncio.run(processor.process_with_llm())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||||
import platform
|
import platform
|
||||||
import argparse
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
@ -18,8 +19,10 @@ from accelerate.utils import set_seed
|
|||||||
from accelerate.utils import DeepSpeedPlugin
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
from accelerate.utils import DistributedDataParallelKwargs
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
from model.model import MiniMindLM
|
from model.model import MiniMindLM, RMSNorm
|
||||||
from model.LMConfig import LMConfig
|
from model.LMConfig import LMConfig
|
||||||
from model.dataset import PretrainDataset
|
from model.dataset import PretrainDataset
|
||||||
|
|
||||||
@ -41,10 +44,41 @@ def get_lr(it, num_iters, learning_rate):
|
|||||||
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
|
||||||
|
|
||||||
# 初始化模型函数
|
# 初始化模型函数
|
||||||
def init_model(lm_config, pretrained_embedding_path=None):
|
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
|
||||||
model = MiniMindLM(lm_config)
|
model = MiniMindLM(lm_config)
|
||||||
|
|
||||||
|
# 默认模型初始化
|
||||||
|
Logger("Performing default model initialization...")
|
||||||
|
|
||||||
|
# 初始化嵌入层权重
|
||||||
|
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化输出层权重(如果不共享权重的话)
|
||||||
|
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
|
||||||
|
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
# 初始化所有线性层
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# 使用Xavier/Glorot初始化
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
# 嵌入层使用正态分布初始化
|
||||||
|
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
elif isinstance(module, RMSNorm):
|
||||||
|
# RMSNorm的权重初始化为1
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
|
||||||
|
# 初始化位置编码相关参数
|
||||||
|
if hasattr(model.extract_db, 'keys'):
|
||||||
|
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02)
|
||||||
|
|
||||||
|
Logger("Default model initialization completed")
|
||||||
|
|
||||||
# 如果提供了预训练的嵌入权重,加载它们
|
# 如果提供了预训练的嵌入权重,加载它们
|
||||||
if pretrained_embedding_path:
|
if pretrained_embedding_path:
|
||||||
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
|
||||||
@ -52,6 +86,334 @@ def init_model(lm_config, pretrained_embedding_path=None):
|
|||||||
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
|
||||||
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
|
||||||
|
|
||||||
|
if database_init_path:
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import os
|
||||||
|
|
||||||
|
Logger(f"Loading database initialization data from {database_init_path}")
|
||||||
|
|
||||||
|
# 1. 加载JSON文件并转换为字典
|
||||||
|
with open(database_init_path, 'r', encoding='utf-8') as f:
|
||||||
|
database_data = json.load(f)
|
||||||
|
|
||||||
|
# 提取sentences列表
|
||||||
|
sentences_data = database_data.get('sentences', [])
|
||||||
|
Logger(f"Loaded {len(sentences_data)} sentences from database")
|
||||||
|
|
||||||
|
# 2. 按照importance_score进行排序(从高到低)
|
||||||
|
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
|
||||||
|
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
|
||||||
|
|
||||||
|
# 3. 下载并初始化本地嵌入模型
|
||||||
|
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
|
||||||
|
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
|
||||||
|
embedding_cache_dir = "./models/sentence_transformers/cache"
|
||||||
|
os.makedirs(embedding_cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
Logger(f"Loading embedding model: {embedding_model_name}")
|
||||||
|
try:
|
||||||
|
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
|
||||||
|
Logger("Embedding model loaded successfully")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Failed to load embedding model: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embedding_model = None
|
||||||
|
|
||||||
|
# 4. 对每个corrected_sentence进行嵌入和token长度计算
|
||||||
|
Logger("Processing sentences for embeddings and token lengths...")
|
||||||
|
|
||||||
|
# 提取所有句子
|
||||||
|
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
|
||||||
|
|
||||||
|
# 批量计算token长度
|
||||||
|
Logger("Computing token lengths...")
|
||||||
|
token_lengths = []
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.encode(sentence, add_special_tokens=False)
|
||||||
|
token_lengths.append(len(tokens))
|
||||||
|
|
||||||
|
# 批量计算嵌入 - 大幅提升速度
|
||||||
|
Logger("Computing embeddings in batches...")
|
||||||
|
embeddings_list = []
|
||||||
|
batch_size = 256 # 可以根据GPU内存调整
|
||||||
|
|
||||||
|
if embedding_model is not None:
|
||||||
|
try:
|
||||||
|
for i in range(0, len(sentences), batch_size):
|
||||||
|
batch_sentences = sentences[i:i+batch_size]
|
||||||
|
batch_embeddings = embedding_model.encode(
|
||||||
|
batch_sentences,
|
||||||
|
convert_to_tensor=False,
|
||||||
|
show_progress_bar=True if i == 0 else False,
|
||||||
|
batch_size=batch_size
|
||||||
|
)
|
||||||
|
embeddings_list.extend(batch_embeddings)
|
||||||
|
|
||||||
|
if (i + batch_size) % (batch_size * 10) == 0:
|
||||||
|
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
|
||||||
|
|
||||||
|
Logger("Batch embedding computation completed")
|
||||||
|
except Exception as e:
|
||||||
|
Logger(f"Error in batch encoding: {e}")
|
||||||
|
Logger("Falling back to random embeddings")
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
else:
|
||||||
|
# 使用随机嵌入
|
||||||
|
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
|
||||||
|
|
||||||
|
# 创建处理后的句子列表
|
||||||
|
processed_sentences = []
|
||||||
|
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
|
||||||
|
processed_sentences.append({
|
||||||
|
'sentence': sentence_data.get('corrected_sentence', ''),
|
||||||
|
'importance_score': sentence_data.get('importance_score', 0.0),
|
||||||
|
'token_length': token_length,
|
||||||
|
'embedding': embedding, # Convert numpy array to list
|
||||||
|
'original_index': i
|
||||||
|
})
|
||||||
|
|
||||||
|
# # Create a JSON-serializable version for saving
|
||||||
|
# json_serializable_sentences = []
|
||||||
|
# for sentence in processed_sentences:
|
||||||
|
# json_sentence = sentence.copy()
|
||||||
|
# # Convert embedding to list if it's a numpy array
|
||||||
|
# if hasattr(json_sentence['embedding'], 'tolist'):
|
||||||
|
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
|
||||||
|
# json_serializable_sentences.append(json_sentence)
|
||||||
|
|
||||||
|
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
|
||||||
|
|
||||||
|
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
|
||||||
|
|
||||||
|
# 转换为numpy数组以便后续处理
|
||||||
|
embeddings_array = np.array(embeddings_list)
|
||||||
|
token_lengths_array = np.array(token_lengths)
|
||||||
|
|
||||||
|
Logger(f"Embedding processing completed:")
|
||||||
|
Logger(f" - Total sentences: {len(processed_sentences)}")
|
||||||
|
Logger(f" - Embedding shape: {embeddings_array.shape}")
|
||||||
|
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
|
||||||
|
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
|
||||||
|
|
||||||
|
# 2. 聚类处理 - 优化版本
|
||||||
|
Logger("Starting optimized clustering process...")
|
||||||
|
|
||||||
|
# 聚类参数
|
||||||
|
knowledge_num = args.knowledge_num
|
||||||
|
knowledge_length = args.knowledge_length
|
||||||
|
min_tokens = int(0.9 * knowledge_length)
|
||||||
|
max_tokens = knowledge_length
|
||||||
|
|
||||||
|
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
|
||||||
|
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
|
||||||
|
Logger("Pre-computing similarity matrix for faster clustering...")
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
similarity_matrix = cosine_similarity(embeddings_matrix)
|
||||||
|
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
|
||||||
|
else:
|
||||||
|
similarity_matrix = None
|
||||||
|
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
|
||||||
|
|
||||||
|
clustered_rows = []
|
||||||
|
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
|
||||||
|
|
||||||
|
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
|
||||||
|
|
||||||
|
# 选择聚类算法
|
||||||
|
if args.fast_clustering and len(processed_sentences) > 5000:
|
||||||
|
Logger("Using ultra-fast approximate clustering algorithm...")
|
||||||
|
|
||||||
|
# 超快速聚类:随机采样 + 批量处理
|
||||||
|
import random
|
||||||
|
random.seed(42) # 确保可重现性
|
||||||
|
|
||||||
|
# 按重要性分层采样
|
||||||
|
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
|
||||||
|
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
|
||||||
|
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
|
||||||
|
|
||||||
|
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
|
||||||
|
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
# 分层选择种子:优先选择高重要性句子
|
||||||
|
if high_importance:
|
||||||
|
seed_pool = high_importance
|
||||||
|
elif medium_importance:
|
||||||
|
seed_pool = medium_importance
|
||||||
|
else:
|
||||||
|
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
|
||||||
|
|
||||||
|
if not seed_pool:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 随机选择种子(在同一重要性层级内)
|
||||||
|
seed_global_idx = random.choice(seed_pool)
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从所有池中移除种子
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if seed_global_idx in pool:
|
||||||
|
pool.remove(seed_global_idx)
|
||||||
|
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens < max_tokens:
|
||||||
|
# 快速选择:只从附近的句子中随机选择
|
||||||
|
all_remaining = high_importance + medium_importance + low_importance
|
||||||
|
if all_remaining:
|
||||||
|
# 随机采样候选句子(而不是计算所有相似度)
|
||||||
|
sample_size = min(100, len(all_remaining))
|
||||||
|
candidates = random.sample(all_remaining, sample_size)
|
||||||
|
|
||||||
|
# 简单按token长度和重要性选择
|
||||||
|
for candidate_idx in candidates:
|
||||||
|
candidate = processed_sentences[candidate_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens:
|
||||||
|
current_cluster_indices.append(candidate_idx)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
# 从池中移除
|
||||||
|
for pool in [high_importance, medium_importance, low_importance]:
|
||||||
|
if candidate_idx in pool:
|
||||||
|
pool.remove(candidate_idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 生成聚类文本
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 转换为tokens
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
if (cluster_idx + 1) % 1000 == 0:
|
||||||
|
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
|
||||||
|
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 原始优化算法(适用于中等规模数据集)
|
||||||
|
# 优化2: 批量处理和更高效的数据结构
|
||||||
|
for cluster_idx in tqdm(range(knowledge_num)):
|
||||||
|
if not remaining_indices:
|
||||||
|
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2.1 选择importance_score最高的句子作为种子
|
||||||
|
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
|
||||||
|
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
|
||||||
|
key=lambda i: remaining_sentences_subset[i]['importance_score'])
|
||||||
|
seed_global_idx = remaining_indices[seed_idx_in_subset]
|
||||||
|
seed_sentence = processed_sentences[seed_global_idx]
|
||||||
|
|
||||||
|
# 从剩余索引中移除种子
|
||||||
|
remaining_indices.remove(seed_global_idx)
|
||||||
|
|
||||||
|
# 当前聚类
|
||||||
|
current_cluster_indices = [seed_global_idx]
|
||||||
|
current_tokens = seed_sentence['token_length']
|
||||||
|
|
||||||
|
if current_tokens >= max_tokens:
|
||||||
|
# 如果种子句子已经超过最大token数,直接作为一个聚类
|
||||||
|
cluster_text = seed_sentence['sentence']
|
||||||
|
else:
|
||||||
|
# 2.2 优化的相似度计算和选择
|
||||||
|
if remaining_indices:
|
||||||
|
if similarity_matrix is not None:
|
||||||
|
# 使用预计算的相似度矩阵
|
||||||
|
similarities = similarity_matrix[seed_global_idx][remaining_indices]
|
||||||
|
else:
|
||||||
|
# 动态计算相似度(批量)
|
||||||
|
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
|
||||||
|
remaining_embeddings = embeddings_matrix[remaining_indices]
|
||||||
|
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
|
||||||
|
|
||||||
|
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
|
||||||
|
similarity_tuples = [(similarities[i], remaining_indices[i], i)
|
||||||
|
for i in range(len(remaining_indices))]
|
||||||
|
|
||||||
|
# 按相似度排序(降序)
|
||||||
|
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# 优化3: 贪心选择,但限制搜索范围以提高速度
|
||||||
|
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
|
||||||
|
|
||||||
|
selected_indices_in_remaining = []
|
||||||
|
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
|
||||||
|
candidate = processed_sentences[global_idx]
|
||||||
|
candidate_tokens = candidate['token_length']
|
||||||
|
|
||||||
|
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
|
||||||
|
current_cluster_indices.append(global_idx)
|
||||||
|
selected_indices_in_remaining.append(pos_in_remaining)
|
||||||
|
current_tokens += candidate_tokens + 1
|
||||||
|
|
||||||
|
if current_tokens >= min_tokens:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量移除选中的句子(从后往前移除以避免索引问题)
|
||||||
|
for pos in sorted(selected_indices_in_remaining, reverse=True):
|
||||||
|
remaining_indices.pop(pos)
|
||||||
|
|
||||||
|
# 拼接句子
|
||||||
|
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
|
||||||
|
cluster_text = '\n'.join(cluster_sentences)
|
||||||
|
|
||||||
|
# 将聚类文本转换为token
|
||||||
|
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# 截断或填充到knowledge_length
|
||||||
|
if len(cluster_tokens) > knowledge_length:
|
||||||
|
cluster_tokens = cluster_tokens[:knowledge_length]
|
||||||
|
else:
|
||||||
|
# 用pad_token_id填充
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
|
||||||
|
|
||||||
|
clustered_rows.append(cluster_tokens)
|
||||||
|
|
||||||
|
# 优化4: 减少日志频率
|
||||||
|
if (cluster_idx + 1) % 500 == 0:
|
||||||
|
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
|
||||||
|
|
||||||
|
# 如果聚类数量不足,用随机token填充
|
||||||
|
while len(clustered_rows) < knowledge_num:
|
||||||
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
|
random_tokens = [pad_token_id] * knowledge_length
|
||||||
|
clustered_rows.append(random_tokens)
|
||||||
|
|
||||||
|
# 转换为tensor
|
||||||
|
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
|
||||||
|
|
||||||
|
Logger(f"Clustering completed:")
|
||||||
|
Logger(f" - Created {len(clustered_rows)} clusters")
|
||||||
|
Logger(f" - Cluster shape: {clustered_tensor.shape}")
|
||||||
|
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
|
||||||
|
|
||||||
|
# 3. 初始化模型的weight_down_embed
|
||||||
|
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
|
||||||
|
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
|
||||||
|
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
|
||||||
|
else:
|
||||||
|
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize")
|
||||||
|
# 存储为全局变量作为备选
|
||||||
|
globals()['clustered_database'] = clustered_tensor
|
||||||
|
|
||||||
|
Logger(f"Database embeddings and sentences stored in model")
|
||||||
|
|
||||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
@ -290,7 +652,9 @@ def main():
|
|||||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||||
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
|
||||||
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
|
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
|
||||||
|
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||||
|
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@ -379,7 +743,7 @@ def main():
|
|||||||
#########################################################
|
#########################################################
|
||||||
# 初始化模型和tokenizer
|
# 初始化模型和tokenizer
|
||||||
#########################################################
|
#########################################################
|
||||||
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
|
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
|
||||||
# 将accelerator传递给init_model函数中的Logger调用
|
# 将accelerator传递给init_model函数中的Logger调用
|
||||||
Logger(f'模型初始化完成', accelerator)
|
Logger(f'模型初始化完成', accelerator)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user