Compare commits

..

No commits in common. "67c632d010198a12d1b70bde4ea3d5127dbeac5d" and "c09cd637942373fe5422cd8af7b6895b033a73f9" have entirely different histories.

5 changed files with 272 additions and 1219 deletions

2
.gitignore vendored
View File

@ -3,5 +3,3 @@
/out
wandb/
**/*.log
models/sentence_transformers/
models/sentence_transformers_cache/

View File

@ -703,7 +703,7 @@ class MiniMindLM(PreTrainedModel):
# Process query path as before
z_q = self.downsample_q_specific(shared_features)
z_k = self.extract_db.q_to_k(z_q)
# self.extract_db.updata_value(z_k, token_indices)
self.extract_db.updata_value(z_k, token_indices)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.output(self.norm(h)[:, slice_indices, :])

View File

@ -1,97 +0,0 @@
# TREx 数据集处理工具使用说明
这个工具支持两步骤处理 TREx 数据集:
1. **句子提取**:从 TREx 数据集提取三元组并转换为自然语言句子
2. **LLM 处理**:使用 ollama qwen3:4b 模型进行句子修正和重要性评分
## 安装依赖
```bash
pip install agno asyncio pydantic
```
确保已安装并启动 ollama并下载 qwen3:4b 模型:
```bash
ollama pull qwen3:4b
```
## 使用方法
### 1. 完整流程(两步骤连续执行)
```bash
python trex_to_sentences_simple.py --step all --input_dir dataset/TREx --max_files 2
```
### 2. 分步骤执行
#### 步骤1仅提取句子
```bash
python trex_to_sentences_simple.py --step extract --input_dir dataset/TREx --sentences_json my_sentences.json --max_files 2
```
#### 步骤2仅LLM处理
```bash
python trex_to_sentences_simple.py --step llm --sentences_json my_sentences.json --output_file final_output.txt
```
## 主要参数说明
- `--step`: 运行步骤
- `extract`: 仅提取句子
- `llm`: 仅LLM处理
- `all`: 完整流程(默认)
- `--input_dir`: TREx数据集目录默认`dataset/TREx`
- `--sentences_json`: 提取的句子JSON文件默认`extracted_sentences.json`
- `--output_file`: 最终输出文件(默认:`trex_sentences_enhanced.txt`
- `--max_files`: 最大处理文件数(用于测试)
- `--no_llm`: 禁用LLM处理
## 输出文件
**注意:所有输出文件都会自动保存在 `./output/` 目录中**
### 步骤1输出
- `output/extracted_sentences.json`: 提取的原始句子,包含元数据
### 步骤2输出
- `output/{output_file}.txt`: 修正后的句子文本文件
- `output/{output_file}.json`: 完整的处理结果(包含原句、修正句、评分)
- `output/{output_file}_sorted_by_importance.txt`: 按重要性评分排序的句子
### 检查点文件
- `output/{output_file}_checkpoint_{数量}.json`: 每2000条句子自动保存的检查点
## 检查点恢复机制
- 步骤2会自动检测已有的检查点文件`output/` 目录中)
- 只处理尚未处理的句子,避免重复工作
- 如果所有句子都已处理,会直接生成最终输出文件
## 示例工作流
```bash
# 1. 先提取句子(可以快速完成)
python trex_to_sentences_simple.py --step extract --max_files 5
# 2. 后续进行LLM处理耗时较长支持断点续传
python trex_to_sentences_simple.py --step llm
# 如果中途中断再次运行步骤2会自动从检查点恢复
python trex_to_sentences_simple.py --step llm
```
## 性能特点
- **并发处理**: 最大54个并发LLM请求
- **检查点保存**: 每2000条句子自动保存支持断点续传
- **进度显示**: 详细的处理进度和时间预估
- **错误处理**: LLM请求失败时使用原句子和默认评分
## 注意事项
1. 首次运行步骤2前必须先完成步骤1
2. 检查点文件会占用额外磁盘空间(每个都包含所有已处理数据)
3. LLM处理速度取决于模型性能和网络状况
4. 建议先用`--max_files`参数测试小批量数据

View File

@ -2,57 +2,19 @@
"""
TREx数据集增强预处理脚本
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
支持两个独立步骤
1. 句子提取从TREx数据集提取句子并保存为JSON
2. LLM处理读取JSON文件进行LLM后处理和重要性评分
"""
import json
import os
import glob
from typing import List, Dict, Any, Union, Set
from typing import List, Dict, Any, Union
import re
import asyncio
import time
import logging
from datetime import datetime
import subprocess
import requests
from pydantic import BaseModel, Field
from agno.agent import Agent
from agno.models.ollama import Ollama
# 设置日志系统
def setup_logging():
"""设置日志系统"""
# 确保logs目录存在
os.makedirs('logs', exist_ok=True)
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/trex_processor_{timestamp}.log'
# 配置日志格式
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
# 配置root logger
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler() # 同时输出到控制台
]
)
# 获取logger
logger = logging.getLogger(__name__)
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
return logger
# 全局日志对象
logger = setup_logging()
class ProcessedSentence(BaseModel):
"""处理后的句子结构"""
@ -69,53 +31,16 @@ class ProcessedSentence(BaseModel):
class EnhancedTRExProcessor:
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
sentences_json: str = None, enable_llm_processing: bool = True):
def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True):
self.input_dir = input_dir
# 确保output目录存在
os.makedirs('output', exist_ok=True)
# 确保所有输出文件都在output目录中
if output_file:
if not output_file.startswith('output/'):
self.output_file = os.path.join('output', output_file)
else:
self.output_file = output_file
else:
self.output_file = None
if sentences_json:
if not sentences_json.startswith('output/'):
self.sentences_json = os.path.join('output', sentences_json)
else:
self.sentences_json = sentences_json
else:
self.sentences_json = "output/extracted_sentences.json"
self.max_files = max_files
self.enable_llm_processing = enable_llm_processing
# LLM处理配置
self.llm_timeout = 60 # 增加每个请求的超时时间到60秒
self.max_concurrent = 8 # 进一步降低并发数到4
self.max_retries = 2 # 减少重试次数避免过长等待
self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
# 统计信息
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 新增:最后活动时间
# 初始化agno agent仅在需要LLM处理时
# 初始化agno agent
if self.enable_llm_processing:
self.setup_agent()
logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}")
# 扩展的Wikidata属性映射
self.property_mappings = {
# 基本关系
@ -162,10 +87,10 @@ class EnhancedTRExProcessor:
try:
self.agent = Agent(
model=Ollama(
id="gemma3:latest",
id="qwen3:4b",
# 使用options设置temperature和其他参数
options={
"temperature": 0.2,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
"num_ctx": 4096,
@ -173,116 +98,111 @@ class EnhancedTRExProcessor:
),
response_model=ProcessedSentence,
instructions=[
"You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.",
"你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。",
"",
"### Sentence Correction Rules:",
"1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses",
"2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.",
"3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage",
"4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues",
"5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent",
"6. Do not add information not present in the original text, only correct errors",
"### 句子修正规则:",
"1. 移除Wikipedia特有标记如(disambiguation)、(film)、(band)等括号内容",
"2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is''or'",
"3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确",
"4. 清理乱码和特殊字符:如â、€、™等编码问题",
"5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺",
"6. 不要添加原文没有的信息,只修正错误",
"",
"### Correction Examples:",
"- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'",
"- Corrected: 'Argument is related to philosophy and logic.'",
"### 修正示例:",
"- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'",
"- 修正:'Argument is related to philosophy and logic.'",
"",
"- Error: 'Beijing is a capital city and are.'",
"- Corrected: 'Beijing is a capital city.'",
"- 错误:'Beijing is a capital city and are.'",
"- 修正:'Beijing is a capital city.'",
"",
"Importance scoring criteria (0.0-10.0, in increments of 0.1):",
"重要性评分标准0.0-10.0以0.1递进):",
"",
"0.0 points - Completely incorrect or meaningless information",
"Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'",
"0.0分 - 完全错误或无意义的信息",
"例:'苹果是一种金属''太阳从西边升起''1+1=3'",
"",
"0.5 points - Almost worthless information",
"Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'",
"0.5分 - 几乎无价值的信息",
"例:'某个虚构角色的袜子颜色''游戏中NPC的对话第三句话''某人昨天早餐吃了什么'",
"",
"1.0 points - Extremely rare, non-practical knowledge",
"Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'",
"1.0分 - 极其罕见、无实用价值的知识",
"例:'某小说背景角色宠物名字''某部电影片尾字幕第15行内容''某网站用户ID为123456的昵称'",
"",
"1.5 points - Very niche detailed information",
"Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'",
"1.5分 - 非常小众的细节信息",
"例:'某电影第37分钟路人甲服装''某游戏隐藏关卡的背景音乐时长''某漫画第200页第3个对话框内容'",
"",
"2.0 points - Details in niche professional fields",
"Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'",
"2.0分 - 小众专业领域的细节",
"例:'稀有矿物在特定温度下颜色变化''某种昆虫的第三对触角长度''某化学反应的副产物分子式'",
"",
"2.5 points - Technical details only professionals care about",
"Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'",
"2.5分 - 专业人士才关心的技术细节",
"例:'软件库特定版本发布日期''某算法的时间复杂度系数''某种材料的热膨胀系数'",
"",
"3.0 points - Professional knowledge in specific fields",
"Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'",
"3.0分 - 特定领域的专业知识",
"例:'编程语言语法特性''某种病毒的基因序列''古代某朝代的官职制度'",
"",
"3.5 points - Professional information with some value",
"Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'",
"3.5分 - 有一定价值的专业信息",
"例:'某历史朝代特定制度''某种药物的作用机制''某技术标准的制定时间'",
"",
"4.0 points - Meaningful knowledge known by few",
"Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'",
"4.0分 - 较少人知道但有意义的知识",
"例:'某国家独特文化传统''某科学家的重要发现''某历史事件的详细过程'",
"",
"4.5 points - Knowledge of interest to some groups",
"Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'",
"4.5分 - 部分人群感兴趣的知识",
"例:'作家创作背景''某艺术流派特点''某运动项目规则细节'",
"",
"5.0 points - General knowledge of moderate importance",
"Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'",
"5.0分 - 中等重要性的一般知识",
"例:'城市著名景点''某企业发展历史''某动物生活习性'",
"",
"5.5 points - Fairly useful common sense",
"Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'",
"5.5分 - 比较有用的常识",
"例:'植物生长环境''健康饮食常识''基本急救知识'",
"",
"6.0 points - Knowledge most educated people should know",
"Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'",
"6.0分 - 多数受教育人群应该知道的知识",
"例:'莎士比亚代表作品''基本几何定理''世界主要货币'",
"",
"6.5 points - Important cultural or scientific common sense",
"Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'",
"6.5分 - 重要的文化或科学常识",
"例:'DNA基本结构''牛顿三大定律''世界主要宗教'",
"",
"7.0 points - Important foundational knowledge",
"Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'",
"7.0分 - 重要的基础知识",
"例:'二次世界大战时间''人体主要器官功能''基本数学运算规则'",
"",
"7.5 points - Very important common sense",
"Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'",
"7.5分 - 非常重要的常识",
"例:'光速是宇宙中最快的''地球是圆的''血液循环基本原理'",
"",
"8.0 points - Core knowledge in basic education",
"Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'",
"8.0分 - 基础教育中的核心知识",
"例:'地球绕太阳运行''四季形成原理''基本语法规则'",
"",
"8.5 points - Important knowledge everyone should master",
"Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'",
"8.5分 - 每个人都应该掌握的重要知识",
"例:'水的化学式H2O''基本安全常识''简单数学计算'",
"",
"9.0 points - Extremely important basic concepts",
"Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'",
"9.0分 - 极其重要的基础概念",
"例:'人类需要氧气生存''火是热的''基本方向概念'",
"",
"9.5 points - Core knowledge everyone must know",
"Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'",
"9.5分 - 人人必知的核心知识",
"例:'一天有24小时''一年有12个月''基本数字概念'",
"",
"10.0 points - Most basic and important common sense",
"Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'",
"10.0分 - 最基础、最重要的常识",
"例:'人类需要食物和水生存''天空是蓝色的''石头比羽毛重'",
"",
"When scoring, please consider:",
"1. Popularity of knowledge - How many people know this knowledge",
"2. Practical value - How useful this knowledge is in daily life",
"3. Educational importance - The position of this knowledge in the education system",
"4. Cultural significance - The importance of this knowledge for understanding the world",
"评分时请考虑:",
"1. 知识的普及程度 - 有多少人知道这个知识",
"2. 实用价值 - 这个知识在日常生活中有多大用处",
"3. 教育重要性 - 这个知识在教育体系中的地位",
"4. 文化意义 - 这个知识对理解世界的重要性",
"",
"Please output structured results directly without showing the thinking process."
"请直接输出结构化结果,不需要思考过程。"
],
markdown=False
)
logger.info("LLM处理器初始化成功")
print("LLM处理器初始化成功")
except Exception as e:
logger.error(f"LLM处理器初始化失败: {e}")
print(f"LLM处理器初始化失败: {e}")
print("将使用基础模式不使用LLM后处理")
self.enable_llm_processing = False
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
"""使用LLM处理单个句子保留用于单独调用"""
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
# 使用agent.arun进行异步调用
response = await self.agent.arun(prompt)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
@ -296,22 +216,9 @@ class EnhancedTRExProcessor:
importance_score=message['importance_score']
)
except asyncio.TimeoutError:
logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"LLM请求最终超时使用默认处理: {sentence[:50]}...")
break
# 等待一段时间后重试
await asyncio.sleep(2 ** attempt) # 指数退避
except Exception as e:
logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,返回原句子和中等评分
logger.warning(f"使用默认处理: {sentence[:50]}...")
print(f"LLM处理句子时出错: {e}")
# 出错时返回原句子和中等评分
return ProcessedSentence(
corrected_sentence=sentence,
importance_score=5.0
@ -462,19 +369,11 @@ class EnhancedTRExProcessor:
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
"""使用信号量控制并发的LLM处理"""
async with semaphore:
self.total_requests += 1
self.last_activity_time = time.time() # 更新活动时间
success = False
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
# 使用agent.arun进行异步调用
response = await self.agent.arun(prompt)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
@ -488,6 +387,7 @@ class EnhancedTRExProcessor:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
# print(message)
result = {
"index": index,
"original_sentence": sentence,
@ -495,20 +395,13 @@ class EnhancedTRExProcessor:
"importance_score": message['importance_score']
}
# 成功处理
self.successful_requests += 1
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 更新活动时间
success = True
# 打印详细进度信息 - 降低频率到每50个
if index % 50 == 0:
# 打印详细进度信息
if index % 100 == 0:
current_time = time.time()
elapsed_time = current_time - start_time
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
remaining_sentences = total_sentences - (index + 1)
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
# 格式化时间显示
def format_time(seconds):
@ -521,44 +414,17 @@ class EnhancedTRExProcessor:
hours = seconds / 3600
return f"{hours:.1f}小时"
logger.info(f"已完成第 {index + 1} 个句子的处理")
logger.info(f" - 剩余句子数: {remaining_sentences}")
logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
logger.info(f" - 已用时间: {format_time(elapsed_time)}")
logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
print(f"已完成第 {index + 1} 个句子的处理")
print(f" - 剩余句子数: {remaining_sentences}")
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
print(f" - 已用时间: {format_time(elapsed_time)}")
print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
return result
except asyncio.TimeoutError:
self.timeout_requests += 1
self.last_activity_time = time.time() # 更新活动时间
logger.warning(f"{index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"{index} 个句子最终超时,使用默认处理")
break
# 指数退避
await asyncio.sleep(2 ** attempt)
except Exception as e:
self.last_activity_time = time.time() # 更新活动时间
logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,使用默认处理
if not success:
self.failed_requests += 1
logger.warning(f"{index} 个句子使用默认处理: {sentence[:50]}...")
print(f"处理第 {index} 个句子时出错: {e}")
# 出错时返回原句子和中等评分
return {
"index": index,
"original_sentence": sentence,
@ -566,82 +432,26 @@ class EnhancedTRExProcessor:
"importance_score": 5.0
}
async def heartbeat_monitor(self, total_sentences: int):
"""心跳监控,检测是否有长时间无响应"""
consecutive_warnings = 0
while True:
await asyncio.sleep(self.heartbeat_interval)
current_time = time.time()
time_since_last_success = current_time - self.last_successful_time
time_since_last_activity = current_time - self.last_activity_time
# 检查最后成功时间
if time_since_last_success > self.heartbeat_interval:
consecutive_warnings += 1
logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
# 打印当前统计信息
if self.total_requests > 0:
success_rate = self.successful_requests / self.total_requests * 100
logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
if time_since_last_success > self.heartbeat_interval * 3:
logger.error(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f" 建议检查ollama服务状态或考虑重启程序")
# 检查ollama服务状态
if not self.check_ollama_status():
logger.critical("💀 Ollama服务异常这可能是卡死的原因")
print("💀 Ollama服务异常这可能是卡死的原因")
if consecutive_warnings >= 5:
logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
else:
if consecutive_warnings > 0:
logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
consecutive_warnings = 0
logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前")
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""批量并发处理句子每2000条保存一次检查点"""
logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
print(f"开始使用LLM并发处理 {len(sentences)} 个句子最大并发数54...")
# 记录开始时间
start_time = time.time()
total_sentences = len(sentences)
# 分批处理,每批1000个句子减少批次大小
batch_size = 1000
# 分批处理每批2000个句子
batch_size = 2000
all_processed_sentences = []
# 启动心跳监控
heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences))
try:
for batch_start in range(0, total_sentences, batch_size):
batch_end = min(batch_start + batch_size, total_sentences)
batch_sentences = sentences[batch_start:batch_end]
logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
# 创建信号量限制并发数降低到8
semaphore = asyncio.Semaphore(self.max_concurrent)
# 重置批次统计
batch_start_time = time.time()
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
# 创建信号量限制并发数
semaphore = asyncio.Semaphore(54)
# 创建当前批次的任务
tasks = []
@ -651,9 +461,7 @@ class EnhancedTRExProcessor:
tasks.append(task)
# 并发执行当前批次的任务
logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理当前批次的结果,过滤异常
@ -662,7 +470,6 @@ class EnhancedTRExProcessor:
for result in batch_results:
if isinstance(result, Exception):
logger.error(f"任务执行异常: {result}")
print(f"任务执行异常: {result}")
batch_error_count += 1
elif isinstance(result, dict):
@ -685,22 +492,10 @@ class EnhancedTRExProcessor:
# 打印当前批次统计信息
elapsed_time = time.time() - start_time
batch_time = time.time() - batch_start_time
completed_sentences = len(all_processed_sentences)
logger.info(f"{batch_start//batch_size + 1} 批处理完成!")
logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
logger.info(f" - 批次用时:{batch_time/60:.1f}分钟")
logger.info(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
logger.info(f" - 检查点已保存:{checkpoint_filename}")
print(f"{batch_start//batch_size + 1} 批处理完成!")
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
print(f" - 批次用时:{batch_time/60:.1f}分钟")
print(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
@ -710,29 +505,10 @@ class EnhancedTRExProcessor:
remaining_sentences = total_sentences - completed_sentences
avg_time_per_sentence = elapsed_time / completed_sentences
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
# 在批次之间稍作休息,避免过度压力
if batch_end < total_sentences:
logger.info("批次间休息5秒...")
await asyncio.sleep(5)
finally:
# 取消心跳监控
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# 打印最终统计信息
total_time = time.time() - start_time
logger.info(f"=== 全部处理完成!===")
logger.info(f" - 总成功:{len(all_processed_sentences)}")
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
print(f"\n=== 全部处理完成!===")
print(f" - 总成功:{len(all_processed_sentences)}")
print(f" - 总用时:{total_time/60:.1f}分钟")
@ -742,9 +518,9 @@ class EnhancedTRExProcessor:
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
"""保存检查点文件"""
# 生成检查点文件名确保在output目录中
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json")
# 生成检查点文件名
base_name = os.path.splitext(self.output_file)[0]
checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json"
# 保存检查点
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
@ -802,29 +578,26 @@ class EnhancedTRExProcessor:
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
# 使用LLM处理句子
if self.enable_llm_processing:
processed_sentences = await self.process_sentences_with_llm(unique_sentences)
else:
# 基础模式不使用LLM
processed_sentences = [
{
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
}
for sentence in unique_sentences
]
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
return processed_sentences
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
"""保存处理后的句子到文件"""
# 确保输出目录存在
os.makedirs('output', exist_ok=True)
os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True)
# 保存为JSON格式包含完整信息
json_output_file = self.output_file.replace('.txt', '.json')
@ -864,8 +637,8 @@ class EnhancedTRExProcessor:
def find_latest_checkpoint(self) -> Union[tuple, None]:
"""查找最新的检查点文件"""
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
base_name = os.path.splitext(self.output_file)[0]
pattern = f"./output/{base_name}_checkpoint_*.json"
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
@ -907,311 +680,54 @@ class EnhancedTRExProcessor:
print(f"加载检查点文件失败: {e}")
return []
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
"""从检查点文件中获取已处理过的句子集合"""
if not self.output_file:
return set()
processed_sentences = set()
# 查找所有检查点文件
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
print("未找到检查点文件,将从头开始处理")
return set()
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
try:
with open(latest_file, 'r', encoding='utf-8') as f:
data = json.load(f)
sentences_data = data.get('sentences', [])
for item in sentences_data:
original_sentence = item.get('original_sentence', '')
if original_sentence:
processed_sentences.add(original_sentence)
print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
except Exception as e:
print(f"读取检查点文件失败: {e}")
return set()
return processed_sentences
async def process_with_llm(self):
"""步骤2从JSON文件读取句子并进行LLM处理"""
if not self.enable_llm_processing:
print("Error: LLM processing is disabled!")
return
if not self.output_file:
print("Error: output_file is required for LLM processing!")
return
print("=== 步骤2LLM处理 ===")
# 读取句子JSON文件
if not os.path.exists(self.sentences_json):
print(f"Error: Sentences file {self.sentences_json} not found!")
print("请先运行步骤1进行句子提取")
return
print(f"正在读取句子文件: {self.sentences_json}")
try:
with open(self.sentences_json, 'r', encoding='utf-8') as f:
data = json.load(f)
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
print(f"从文件中读取了 {len(all_sentences)} 个句子")
except Exception as e:
print(f"读取句子文件失败: {e}")
return
# 获取已处理的句子
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
# 过滤出未处理的句子
unprocessed_sentences = []
for sentence in all_sentences:
if sentence not in processed_sentences_set:
unprocessed_sentences.append(sentence)
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
if not unprocessed_sentences:
print("所有句子都已处理完成!")
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
processed_data = self.load_checkpoint(checkpoint_file)
self.save_sentences(processed_data)
print("已从检查点生成最终输出文件")
return
# 处理未处理的句子
print("开始LLM处理...")
# 检查ollama服务状态
logger.info("检查Ollama服务状态...")
if not self.check_ollama_status():
logger.error("Ollama服务状态异常无法继续处理")
print("错误Ollama服务状态异常请检查服务是否正常运行")
return
new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences)
# 如果有之前的处理结果,合并它们
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
previous_processed = self.load_checkpoint(checkpoint_file)
# 合并结果
all_processed_sentences = previous_processed + new_processed_sentences
print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录")
else:
all_processed_sentences = new_processed_sentences
else:
all_processed_sentences = new_processed_sentences
# 保存最终结果
self.save_sentences(all_processed_sentences)
print("LLM处理完成")
# ==================== 新增:句子提取功能 ====================
def extract_sentences(self):
"""步骤1从TREx数据集提取句子并保存为JSON"""
if not self.input_dir:
print("Error: input_dir is required for sentence extraction!")
return
print("=== 步骤1句子提取 ===")
print("开始从TREx数据集提取句子...")
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def check_ollama_status(self) -> bool:
"""检查ollama服务是否正常运行"""
try:
# 检查ollama进程是否运行
result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True)
if result.returncode != 0:
logger.error("Ollama进程未运行")
return False
# 检查ollama API是否响应
response = requests.get('http://localhost:11434/api/tags', timeout=5)
if response.status_code == 200:
logger.info("Ollama服务状态正常")
return True
else:
logger.error(f"Ollama API响应异常状态码: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.error(f"无法连接到Ollama API: {e}")
return False
except Exception as e:
logger.error(f"检查Ollama状态时出错: {e}")
return False
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
# 选择运行模式
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程')
# 文件路径参数
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
# 处理参数
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path')
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available')
args = parser.parse_args()
# 根据步骤验证参数
if args.step in ['extract', 'all']:
if not os.path.exists(args.input_dir):
print(f"Error: Input directory {args.input_dir} does not exist!")
return
if args.step in ['llm', 'all']:
if args.no_llm:
print("Error: Cannot run LLM step with --no_llm flag!")
return
# 创建处理器
processor = EnhancedTRExProcessor(
input_dir=args.input_dir,
sentences_json=args.sentences_json,
output_file=args.output_file,
max_files=args.max_files,
args.input_dir,
args.output_file,
args.max_files,
enable_llm_processing=not args.no_llm
)
# 根据选择的步骤运行
if args.step == 'extract':
print("=== 运行模式:仅句子提取 ===")
processor.extract_sentences()
elif args.step == 'llm':
print("=== 运行模式仅LLM处理 ===")
asyncio.run(processor.process_with_llm())
elif args.step == 'all':
print("=== 运行模式:完整流程 ===")
# 步骤1提取句子
print("\n--- 开始步骤1句子提取 ---")
sentences = processor.extract_sentences()
if not sentences:
print("句子提取失败,退出")
# 检查是否要从检查点恢复
if args.resume:
checkpoint_result = processor.find_latest_checkpoint()
if checkpoint_result:
latest_checkpoint, latest_count = checkpoint_result
print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)")
confirm = input("是否从检查点恢复?(y/n): ").lower().strip()
if confirm == 'y':
processed_sentences = processor.load_checkpoint(latest_checkpoint)
if processed_sentences:
print(f"成功加载 {len(processed_sentences)} 条已处理的句子")
processor.save_sentences(processed_sentences)
print("从检查点恢复完成!")
return
else:
print("检查点文件加载失败,将重新开始处理")
else:
print("不从检查点恢复,将重新开始处理")
else:
print("未找到检查点文件,将重新开始处理")
if args.no_llm:
print("LLM处理已禁用流程结束")
return
# 步骤2LLM处理
print("\n--- 开始步骤2LLM处理 ---")
asyncio.run(processor.process_with_llm())
# 运行异步处理
asyncio.run(processor.run())
if __name__ == "__main__":

View File

@ -3,7 +3,6 @@ import os
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
import platform
import argparse
from tqdm import tqdm
import time
import math
import warnings
@ -19,10 +18,8 @@ from accelerate.utils import set_seed
from accelerate.utils import DeepSpeedPlugin
from accelerate.utils import DistributedDataParallelKwargs
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from model.model import MiniMindLM, RMSNorm
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import PretrainDataset
@ -44,41 +41,10 @@ def get_lr(it, num_iters, learning_rate):
return learning_rate * 0.5 * (1.0 + math.cos(math.pi * it / num_iters))
# 初始化模型函数
def init_model(lm_config, pretrained_embedding_path=None, database_init_path=None, args=None):
def init_model(lm_config, pretrained_embedding_path=None):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model = MiniMindLM(lm_config)
# 默认模型初始化
Logger("Performing default model initialization...")
# 初始化嵌入层权重
nn.init.normal_(model.tok_embeddings.weight, mean=0.0, std=0.02)
# 初始化输出层权重(如果不共享权重的话)
if not hasattr(model.tok_embeddings, 'weight') or model.output.weight is not model.tok_embeddings.weight:
nn.init.normal_(model.output.weight, mean=0.0, std=0.02)
# 初始化所有线性层
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# 使用Xavier/Glorot初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# 嵌入层使用正态分布初始化
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
# RMSNorm的权重初始化为1
if hasattr(module, 'weight'):
nn.init.ones_(module.weight)
# 初始化位置编码相关参数
if hasattr(model.extract_db, 'keys'):
nn.init.normal_(model.extract_db.keys, mean=0.0, std=0.02)
Logger("Default model initialization completed")
# 如果提供了预训练的嵌入权重,加载它们
if pretrained_embedding_path:
Logger(f"Loading pretrained token embeddings from {pretrained_embedding_path}")
@ -86,334 +52,6 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
model.tok_embeddings.weight.data.copy_(pretrained_embeddings)
model.output.weight.data.copy_(pretrained_embeddings) # 共享权重
if database_init_path:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import os
Logger(f"Loading database initialization data from {database_init_path}")
# 1. 加载JSON文件并转换为字典
with open(database_init_path, 'r', encoding='utf-8') as f:
database_data = json.load(f)
# 提取sentences列表
sentences_data = database_data.get('sentences', [])
Logger(f"Loaded {len(sentences_data)} sentences from database")
# 2. 按照importance_score进行排序从高到低
sorted_sentences = sorted(sentences_data, key=lambda x: x.get('importance_score', 0.0), reverse=True)
Logger(f"Sorted sentences by importance score (highest: {sorted_sentences[0].get('importance_score', 0.0)}, lowest: {sorted_sentences[-1].get('importance_score', 0.0)})")
# 3. 下载并初始化本地嵌入模型
embedding_model_name = "sentence-transformers/all-mpnet-base-v2" # 轻量级但效果好的模型
embedding_model_dir = "./models/sentence_transformers/models--sentence-transformers--all-mpnet-base-v2"
embedding_cache_dir = "./models/sentence_transformers/cache"
os.makedirs(embedding_cache_dir, exist_ok=True)
Logger(f"Loading embedding model: {embedding_model_name}")
try:
embedding_model = SentenceTransformer(embedding_model_dir, cache_folder=embedding_cache_dir)
Logger("Embedding model loaded successfully")
except Exception as e:
Logger(f"Failed to load embedding model: {e}")
Logger("Falling back to random embeddings")
embedding_model = None
# 4. 对每个corrected_sentence进行嵌入和token长度计算
Logger("Processing sentences for embeddings and token lengths...")
# 提取所有句子
sentences = [sentence_data.get('corrected_sentence', '') for sentence_data in sorted_sentences]
# 批量计算token长度
Logger("Computing token lengths...")
token_lengths = []
for sentence in sentences:
tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_lengths.append(len(tokens))
# 批量计算嵌入 - 大幅提升速度
Logger("Computing embeddings in batches...")
embeddings_list = []
batch_size = 256 # 可以根据GPU内存调整
if embedding_model is not None:
try:
for i in range(0, len(sentences), batch_size):
batch_sentences = sentences[i:i+batch_size]
batch_embeddings = embedding_model.encode(
batch_sentences,
convert_to_tensor=False,
show_progress_bar=True if i == 0 else False,
batch_size=batch_size
)
embeddings_list.extend(batch_embeddings)
if (i + batch_size) % (batch_size * 10) == 0:
Logger(f"Processed {min(i + batch_size, len(sentences))}/{len(sentences)} sentences")
Logger("Batch embedding computation completed")
except Exception as e:
Logger(f"Error in batch encoding: {e}")
Logger("Falling back to random embeddings")
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
else:
# 使用随机嵌入
embeddings_list = [np.random.randn(384).astype(np.float32) for _ in sentences]
# 创建处理后的句子列表
processed_sentences = []
for i, (sentence_data, embedding, token_length) in enumerate(zip(sorted_sentences, embeddings_list, token_lengths)):
processed_sentences.append({
'sentence': sentence_data.get('corrected_sentence', ''),
'importance_score': sentence_data.get('importance_score', 0.0),
'token_length': token_length,
'embedding': embedding, # Convert numpy array to list
'original_index': i
})
# # Create a JSON-serializable version for saving
# json_serializable_sentences = []
# for sentence in processed_sentences:
# json_sentence = sentence.copy()
# # Convert embedding to list if it's a numpy array
# if hasattr(json_sentence['embedding'], 'tolist'):
# json_sentence['embedding'] = json_sentence['embedding'].tolist()
# json_serializable_sentences.append(json_sentence)
# json.dump(json_serializable_sentences, open('processed_sentences.json', 'w', encoding='utf-8'))
# processed_sentences = json.load(open('processed_sentences.json', 'r', encoding='utf-8'))
# 转换为numpy数组以便后续处理
embeddings_array = np.array(embeddings_list)
token_lengths_array = np.array(token_lengths)
Logger(f"Embedding processing completed:")
Logger(f" - Total sentences: {len(processed_sentences)}")
Logger(f" - Embedding shape: {embeddings_array.shape}")
Logger(f" - Average token length: {np.mean(token_lengths_array):.2f}")
Logger(f" - Token length range: {np.min(token_lengths_array)} - {np.max(token_lengths_array)}")
# 2. 聚类处理 - 优化版本
Logger("Starting optimized clustering process...")
# 聚类参数
knowledge_num = args.knowledge_num
knowledge_length = args.knowledge_length
min_tokens = int(0.85 * knowledge_length)
max_tokens = int(0.95 * knowledge_length)
# 优化1: 预计算所有嵌入的相似度矩阵(如果数据量不太大)
if len(processed_sentences) <= 10000: # 只有在数据量不太大时才预计算
Logger("Pre-computing similarity matrix for faster clustering...")
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
similarity_matrix = cosine_similarity(embeddings_matrix)
Logger(f"Similarity matrix computed: {similarity_matrix.shape}")
else:
similarity_matrix = None
embeddings_matrix = np.array([s['embedding'] for s in processed_sentences])
clustered_rows = []
remaining_indices = list(range(len(processed_sentences))) # 使用索引而不是对象
Logger(f"Target: {knowledge_num} clusters, each with {min_tokens}-{max_tokens} tokens")
# 选择聚类算法
if args.fast_clustering and len(processed_sentences) > 5000:
Logger("Using ultra-fast approximate clustering algorithm...")
# 超快速聚类:随机采样 + 批量处理
import random
random.seed(42) # 确保可重现性
# 按重要性分层采样
high_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] > 0.7]
medium_importance = [i for i, s in enumerate(processed_sentences) if 0.3 <= s['importance_score'] <= 0.7]
low_importance = [i for i, s in enumerate(processed_sentences) if s['importance_score'] < 0.3]
Logger(f"Importance distribution: High={len(high_importance)}, Medium={len(medium_importance)}, Low={len(low_importance)}")
for cluster_idx in tqdm(range(knowledge_num)):
# 分层选择种子:优先选择高重要性句子
if high_importance:
seed_pool = high_importance
elif medium_importance:
seed_pool = medium_importance
else:
seed_pool = low_importance if low_importance else list(range(len(processed_sentences)))
if not seed_pool:
break
# 随机选择种子(在同一重要性层级内)
seed_global_idx = random.choice(seed_pool)
seed_sentence = processed_sentences[seed_global_idx]
# 从所有池中移除种子
for pool in [high_importance, medium_importance, low_importance]:
if seed_global_idx in pool:
pool.remove(seed_global_idx)
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens < max_tokens:
# 快速选择:只从附近的句子中随机选择
all_remaining = high_importance + medium_importance + low_importance
if all_remaining:
# 随机采样候选句子(而不是计算所有相似度)
sample_size = min(100, len(all_remaining))
candidates = random.sample(all_remaining, sample_size)
# 简单按token长度和重要性选择
for candidate_idx in candidates:
candidate = processed_sentences[candidate_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens:
current_cluster_indices.append(candidate_idx)
current_tokens += candidate_tokens + 1
# 从池中移除
for pool in [high_importance, medium_importance, low_importance]:
if candidate_idx in pool:
pool.remove(candidate_idx)
break
if current_tokens >= min_tokens:
break
# 生成聚类文本
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n '.join(cluster_sentences)
# 转换为tokens
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens)
if (cluster_idx + 1) % 1000 == 0:
total_remaining = len(high_importance) + len(medium_importance) + len(low_importance)
Logger(f"Fast clustering: {cluster_idx + 1}/{knowledge_num} clusters, {total_remaining} sentences remaining")
else:
# 原始优化算法(适用于中等规模数据集)
# 优化2: 批量处理和更高效的数据结构
for cluster_idx in tqdm(range(knowledge_num)):
if not remaining_indices:
Logger(f"No more sentences available. Created {cluster_idx} clusters.")
break
# 2.1 选择importance_score最高的句子作为种子
remaining_sentences_subset = [processed_sentences[i] for i in remaining_indices]
seed_idx_in_subset = max(range(len(remaining_sentences_subset)),
key=lambda i: remaining_sentences_subset[i]['importance_score'])
seed_global_idx = remaining_indices[seed_idx_in_subset]
seed_sentence = processed_sentences[seed_global_idx]
# 从剩余索引中移除种子
remaining_indices.remove(seed_global_idx)
# 当前聚类
current_cluster_indices = [seed_global_idx]
current_tokens = seed_sentence['token_length']
if current_tokens >= max_tokens:
# 如果种子句子已经超过最大token数直接作为一个聚类
cluster_text = seed_sentence['sentence']
else:
# 2.2 优化的相似度计算和选择
if remaining_indices:
if similarity_matrix is not None:
# 使用预计算的相似度矩阵
similarities = similarity_matrix[seed_global_idx][remaining_indices]
else:
# 动态计算相似度(批量)
seed_embedding = embeddings_matrix[seed_global_idx:seed_global_idx+1]
remaining_embeddings = embeddings_matrix[remaining_indices]
similarities = cosine_similarity(seed_embedding, remaining_embeddings)[0]
# 创建(相似度, 原始索引, 在remaining_indices中的位置)的元组列表
similarity_tuples = [(similarities[i], remaining_indices[i], i)
for i in range(len(remaining_indices))]
# 按相似度排序(降序)
similarity_tuples.sort(key=lambda x: x[0], reverse=True)
# 优化3: 贪心选择,但限制搜索范围以提高速度
max_candidates = min(len(similarity_tuples), 500) # 只考虑前500个最相似的句子
selected_indices_in_remaining = []
for sim_score, global_idx, pos_in_remaining in similarity_tuples[:max_candidates]:
candidate = processed_sentences[global_idx]
candidate_tokens = candidate['token_length']
if current_tokens + candidate_tokens + 1 <= max_tokens: # +1 for newline
current_cluster_indices.append(global_idx)
selected_indices_in_remaining.append(pos_in_remaining)
current_tokens += candidate_tokens + 1
if current_tokens >= min_tokens:
break
# 批量移除选中的句子(从后往前移除以避免索引问题)
for pos in sorted(selected_indices_in_remaining, reverse=True):
remaining_indices.pop(pos)
# 拼接句子
cluster_sentences = [processed_sentences[idx]['sentence'] for idx in current_cluster_indices]
cluster_text = '\n'.join(cluster_sentences)
# 将聚类文本转换为token
cluster_tokens = tokenizer.encode(cluster_text, add_special_tokens=False)
# 截断或填充到knowledge_length
if len(cluster_tokens) > knowledge_length:
cluster_tokens = cluster_tokens[:knowledge_length]
else:
# 用pad_token_id填充
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
cluster_tokens.extend([pad_token_id] * (knowledge_length - len(cluster_tokens)))
clustered_rows.append(cluster_tokens)
# 优化4: 减少日志频率
if (cluster_idx + 1) % 500 == 0:
Logger(f"Created {cluster_idx + 1}/{knowledge_num} clusters, {len(remaining_indices)} sentences remaining")
# 如果聚类数量不足用随机token填充
while len(clustered_rows) < knowledge_num:
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
random_tokens = [pad_token_id] * knowledge_length
clustered_rows.append(random_tokens)
# 转换为tensor
clustered_tensor = torch.tensor(clustered_rows, dtype=torch.long)
Logger(f"Clustering completed:")
Logger(f" - Created {len(clustered_rows)} clusters")
Logger(f" - Cluster shape: {clustered_tensor.shape}")
Logger(f" - Expected shape: ({knowledge_num}, {knowledge_length})")
# 3. 初始化模型的weight_down_embed
if hasattr(model, 'extract_db') and hasattr(model.extract_db, 'weight_down_embed'):
model.extract_db.weight_down_embed.data.copy_(clustered_tensor)
Logger("Successfully initialized model.extract_db.weight_down_embed with clustered data")
else:
Logger("Warning: Could not find model.extract_db.weight_down_embed to initialize")
# 存储为全局变量作为备选
globals()['clustered_database'] = clustered_tensor
Logger(f"Database embeddings and sentences stored in model")
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
@ -652,9 +290,7 @@ def main():
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
parser.add_argument("--knowledge_num", type=int, default=64*64,help="知识库的数据数目")
parser.add_argument("--knowledge_length", type=int, default=64,help="知识库的句子长度")
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
parser.add_argument("--knowledge_length", type=int, default=8,help="知识库的句子长度")
args = parser.parse_args()
#########################################################
@ -743,7 +379,7 @@ def main():
#########################################################
# 初始化模型和tokenizer
#########################################################
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path, args.database_init_path, args)
model, tokenizer = init_model(lm_config, args.pretrained_embedding_path)
# 将accelerator传递给init_model函数中的Logger调用
Logger(f'模型初始化完成', accelerator)