Minimind/preprocessing/trex_to_sentences_simple.py

1218 lines
56 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
2025-05-26 23:09:03 +08:00
支持两个独立步骤
1. 句子提取从TREx数据集提取句子并保存为JSON
2. LLM处理读取JSON文件进行LLM后处理和重要性评分
"""
import json
import os
import glob
2025-05-26 23:09:03 +08:00
from typing import List, Dict, Any, Union, Set
import re
import asyncio
import time
2025-05-26 23:09:03 +08:00
import logging
from datetime import datetime
import subprocess
import requests
from pydantic import BaseModel, Field
from agno.agent import Agent
from agno.models.ollama import Ollama
2025-05-26 23:09:03 +08:00
# 设置日志系统
def setup_logging():
"""设置日志系统"""
# 确保logs目录存在
os.makedirs('logs', exist_ok=True)
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/trex_processor_{timestamp}.log'
# 配置日志格式
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
# 配置root logger
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler() # 同时输出到控制台
]
)
# 获取logger
logger = logging.getLogger(__name__)
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
return logger
# 全局日志对象
logger = setup_logging()
class ProcessedSentence(BaseModel):
"""处理后的句子结构"""
corrected_sentence: str = Field(
...,
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
)
importance_score: float = Field(
...,
description="重要性评分范围0.0-10.0以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
ge=0.0,
le=10.0
)
class EnhancedTRExProcessor:
2025-05-26 23:09:03 +08:00
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
sentences_json: str = None, enable_llm_processing: bool = True):
self.input_dir = input_dir
2025-05-26 23:09:03 +08:00
# 确保output目录存在
os.makedirs('output', exist_ok=True)
# 确保所有输出文件都在output目录中
if output_file:
if not output_file.startswith('output/'):
self.output_file = os.path.join('output', output_file)
else:
self.output_file = output_file
else:
self.output_file = None
if sentences_json:
if not sentences_json.startswith('output/'):
self.sentences_json = os.path.join('output', sentences_json)
else:
self.sentences_json = sentences_json
else:
self.sentences_json = "output/extracted_sentences.json"
self.max_files = max_files
self.enable_llm_processing = enable_llm_processing
2025-05-26 23:09:03 +08:00
# LLM处理配置
self.llm_timeout = 60 # 增加每个请求的超时时间到60秒
self.max_concurrent = 8 # 进一步降低并发数到4
self.max_retries = 2 # 减少重试次数避免过长等待
self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
# 统计信息
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 新增:最后活动时间
# 初始化agno agent仅在需要LLM处理时
if self.enable_llm_processing:
self.setup_agent()
2025-05-26 23:09:03 +08:00
logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}")
# 扩展的Wikidata属性映射
self.property_mappings = {
# 基本关系
"http://www.wikidata.org/prop/direct/P31": "is a",
"http://www.wikidata.org/prop/direct/P279": "is a type of",
# 人物相关
"http://www.wikidata.org/prop/direct/P106": "works as",
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
"http://www.wikidata.org/prop/direct/P19": "was born in",
"http://www.wikidata.org/prop/direct/P20": "died in",
"http://www.wikidata.org/prop/direct/P569": "was born on",
"http://www.wikidata.org/prop/direct/P570": "died on",
"http://www.wikidata.org/prop/direct/P22": "has father",
"http://www.wikidata.org/prop/direct/P25": "has mother",
"http://www.wikidata.org/prop/direct/P26": "is married to",
# 组织相关
"http://www.wikidata.org/prop/direct/P102": "is a member of",
"http://www.wikidata.org/prop/direct/P108": "works for",
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
"http://www.wikidata.org/prop/direct/P112": "was founded by",
"http://www.wikidata.org/prop/direct/P571": "was founded in",
"http://www.wikidata.org/prop/direct/P169": "has CEO",
# 地理相关
"http://www.wikidata.org/prop/direct/P17": "is located in",
"http://www.wikidata.org/prop/direct/P131": "is located in",
"http://www.wikidata.org/prop/direct/P36": "has capital",
"http://www.wikidata.org/prop/direct/P47": "borders",
# 其他关系
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
"http://www.wikidata.org/prop/direct/P361": "is part of",
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
"http://www.wikidata.org/prop/direct/P127": "is owned by",
"http://www.wikidata.org/prop/direct/P155": "follows",
"http://www.wikidata.org/prop/direct/P156": "is followed by",
"http://www.wikidata.org/prop/direct/P138": "is named after"
}
def setup_agent(self):
"""设置agno agent"""
try:
self.agent = Agent(
model=Ollama(
2025-05-26 23:09:03 +08:00
id="gemma3:latest",
# 使用options设置temperature和其他参数
options={
2025-05-26 23:09:03 +08:00
"temperature": 0.2,
"top_p": 0.8,
"top_k": 20,
"num_ctx": 4096,
}
),
response_model=ProcessedSentence,
instructions=[
2025-05-26 23:09:03 +08:00
"You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.",
"",
2025-05-26 23:09:03 +08:00
"### Sentence Correction Rules:",
"1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses",
"2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.",
"3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage",
"4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues",
"5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent",
"6. Do not add information not present in the original text, only correct errors",
"",
2025-05-26 23:09:03 +08:00
"### Correction Examples:",
"- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'",
"- Corrected: 'Argument is related to philosophy and logic.'",
"",
2025-05-26 23:09:03 +08:00
"- Error: 'Beijing is a capital city and are.'",
"- Corrected: 'Beijing is a capital city.'",
"",
2025-05-26 23:09:03 +08:00
"Importance scoring criteria (0.0-10.0, in increments of 0.1):",
"",
2025-05-26 23:09:03 +08:00
"0.0 points - Completely incorrect or meaningless information",
"Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'",
"",
2025-05-26 23:09:03 +08:00
"0.5 points - Almost worthless information",
"Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'",
"",
2025-05-26 23:09:03 +08:00
"1.0 points - Extremely rare, non-practical knowledge",
"Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'",
"",
2025-05-26 23:09:03 +08:00
"1.5 points - Very niche detailed information",
"Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'",
"",
2025-05-26 23:09:03 +08:00
"2.0 points - Details in niche professional fields",
"Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'",
"",
2025-05-26 23:09:03 +08:00
"2.5 points - Technical details only professionals care about",
"Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'",
"",
2025-05-26 23:09:03 +08:00
"3.0 points - Professional knowledge in specific fields",
"Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'",
"",
2025-05-26 23:09:03 +08:00
"3.5 points - Professional information with some value",
"Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'",
"",
2025-05-26 23:09:03 +08:00
"4.0 points - Meaningful knowledge known by few",
"Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'",
"",
2025-05-26 23:09:03 +08:00
"4.5 points - Knowledge of interest to some groups",
"Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'",
"",
2025-05-26 23:09:03 +08:00
"5.0 points - General knowledge of moderate importance",
"Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'",
"",
2025-05-26 23:09:03 +08:00
"5.5 points - Fairly useful common sense",
"Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'",
"",
2025-05-26 23:09:03 +08:00
"6.0 points - Knowledge most educated people should know",
"Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'",
"",
2025-05-26 23:09:03 +08:00
"6.5 points - Important cultural or scientific common sense",
"Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'",
"",
2025-05-26 23:09:03 +08:00
"7.0 points - Important foundational knowledge",
"Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'",
"",
2025-05-26 23:09:03 +08:00
"7.5 points - Very important common sense",
"Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'",
"",
2025-05-26 23:09:03 +08:00
"8.0 points - Core knowledge in basic education",
"Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'",
"",
2025-05-26 23:09:03 +08:00
"8.5 points - Important knowledge everyone should master",
"Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'",
"",
2025-05-26 23:09:03 +08:00
"9.0 points - Extremely important basic concepts",
"Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'",
"",
2025-05-26 23:09:03 +08:00
"9.5 points - Core knowledge everyone must know",
"Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'",
"",
2025-05-26 23:09:03 +08:00
"10.0 points - Most basic and important common sense",
"Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'",
"",
2025-05-26 23:09:03 +08:00
"When scoring, please consider:",
"1. Popularity of knowledge - How many people know this knowledge",
"2. Practical value - How useful this knowledge is in daily life",
"3. Educational importance - The position of this knowledge in the education system",
"4. Cultural significance - The importance of this knowledge for understanding the world",
"",
2025-05-26 23:09:03 +08:00
"Please output structured results directly without showing the thinking process."
],
markdown=False
)
2025-05-26 23:09:03 +08:00
logger.info("LLM处理器初始化成功")
except Exception as e:
2025-05-26 23:09:03 +08:00
logger.error(f"LLM处理器初始化失败: {e}")
print(f"LLM处理器初始化失败: {e}")
print("将使用基础模式不使用LLM后处理")
self.enable_llm_processing = False
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
"""使用LLM处理单个句子保留用于单独调用"""
2025-05-26 23:09:03 +08:00
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
2025-05-26 23:09:03 +08:00
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
return response
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
return ProcessedSentence(
corrected_sentence=message['corrected_sentence'],
importance_score=message['importance_score']
)
except asyncio.TimeoutError:
logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"LLM请求最终超时使用默认处理: {sentence[:50]}...")
break
# 等待一段时间后重试
await asyncio.sleep(2 ** attempt) # 指数退避
except Exception as e:
logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,返回原句子和中等评分
logger.warning(f"使用默认处理: {sentence[:50]}...")
return ProcessedSentence(
corrected_sentence=sentence,
importance_score=5.0
)
def clean_text(self, text: str) -> str:
"""清理文本,处理特殊字符"""
if not text:
return ""
# 处理常见的Unicode字符
text = text.replace("", "-") # en dash
text = text.replace("", "-") # em dash
text = text.replace("'", "'") # right single quotation mark
text = text.replace("'", "'") # left single quotation mark
text = text.replace(""", '"') # left double quotation mark
text = text.replace(""", '"') # right double quotation mark
# 处理可能的转义序列
try:
text = text.encode('utf-8').decode('utf-8')
except:
pass
# 清理多余的空格
text = re.sub(r'\s+', ' ', text).strip()
# 移除可能的引号
text = text.strip('"\'')
return text
def parse_large_json_file(self, file_path: str) -> List[Dict]:
"""解析大型JSON文件处理可能的格式问题"""
documents = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 尝试不同的解析方法
if content.startswith('[') and content.endswith(']'):
# 标准JSON数组
documents = json.loads(content)
else:
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if '}{"' in content:
json_strings = content.split('}{')
json_strings[0] += '}' # 第一个对象
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
for i in range(1, len(json_strings) - 1):
json_strings[i] = '{' + json_strings[i] + '}'
for json_str in json_strings:
try:
doc = json.loads(json_str)
documents.append(doc)
except json.JSONDecodeError:
continue
else:
# 尝试作为单个JSON对象
try:
documents = [json.loads(content)]
except json.JSONDecodeError:
pass
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return documents
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
"""从文档中提取句子"""
sentences = []
title = self.clean_text(doc.get('title', ''))
text = self.clean_text(doc.get('text', ''))
entities = doc.get('entities', [])
triples = doc.get('triples', [])
# 处理显式三元组
for triple in triples:
sentence = self.triple_to_sentence(triple)
if sentence:
sentences.append(sentence)
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len(sentences) < 5:
# 基于标题和实体生成句子
entity_names = []
for entity in entities[:15]:
entity_name = self.clean_text(entity.get('surfaceform', ''))
if entity_name and len(entity_name) > 2:
entity_names.append(entity_name)
# 生成简单的描述句子
if entity_names:
important_entities = []
title_lower = title.lower()
for entity in entity_names:
if (entity.lower() != title_lower and
entity not in important_entities and
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
important_entities.append(entity)
if len(important_entities) >= 6:
break
if important_entities and len(sentences) < 3:
entities_str = ', '.join(important_entities[:3])
sentences.append(f"{title} is related to {entities_str}.")
return sentences
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
"""将三元组转换为自然语言句子"""
try:
subject = triple.get('subject', {})
predicate = triple.get('predicate', {})
obj = triple.get('object', {})
subject_name = self.clean_text(subject.get('surfaceform', ''))
object_name = self.clean_text(obj.get('surfaceform', ''))
predicate_uri = predicate.get('uri', '')
# 检查是否有有效的主语和宾语
if not subject_name or not object_name:
return ""
# 检查主语和宾语是否过短或无意义
if len(subject_name) <= 2 or len(object_name) <= 2:
return ""
# 获取关系文本
relation_text = self.property_mappings.get(predicate_uri, "is related to")
# 避免重复的主语宾语
if subject_name.lower() == object_name.lower():
return ""
return f"{subject_name} {relation_text} {object_name}."
except Exception as e:
print(f"Error converting triple to sentence: {e}")
return ""
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
"""使用信号量控制并发的LLM处理"""
async with semaphore:
2025-05-26 23:09:03 +08:00
self.total_requests += 1
self.last_activity_time = time.time() # 更新活动时间
success = False
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": response.corrected_sentence,
"importance_score": response.importance_score
}
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": message['corrected_sentence'],
"importance_score": message['importance_score']
}
# 成功处理
self.successful_requests += 1
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 更新活动时间
success = True
# 打印详细进度信息 - 降低频率到每50个
if index % 50 == 0:
current_time = time.time()
elapsed_time = current_time - start_time
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
remaining_sentences = total_sentences - (index + 1)
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
# 格式化时间显示
def format_time(seconds):
if seconds < 60:
return f"{seconds:.1f}"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}分钟"
else:
hours = seconds / 3600
return f"{hours:.1f}小时"
logger.info(f"已完成第 {index + 1} 个句子的处理")
logger.info(f" - 剩余句子数: {remaining_sentences}")
logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
logger.info(f" - 已用时间: {format_time(elapsed_time)}")
logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
print(f"已完成第 {index + 1} 个句子的处理")
print(f" - 剩余句子数: {remaining_sentences}")
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
print(f" - 已用时间: {format_time(elapsed_time)}")
print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
return result
2025-05-26 23:09:03 +08:00
except asyncio.TimeoutError:
self.timeout_requests += 1
self.last_activity_time = time.time() # 更新活动时间
logger.warning(f"{index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"{index} 个句子最终超时,使用默认处理")
break
# 指数退避
await asyncio.sleep(2 ** attempt)
2025-05-26 23:09:03 +08:00
except Exception as e:
self.last_activity_time = time.time() # 更新活动时间
logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,使用默认处理
if not success:
self.failed_requests += 1
logger.warning(f"{index} 个句子使用默认处理: {sentence[:50]}...")
return {
"index": index,
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
}
async def heartbeat_monitor(self, total_sentences: int):
"""心跳监控,检测是否有长时间无响应"""
consecutive_warnings = 0
while True:
await asyncio.sleep(self.heartbeat_interval)
current_time = time.time()
time_since_last_success = current_time - self.last_successful_time
time_since_last_activity = current_time - self.last_activity_time
# 检查最后成功时间
if time_since_last_success > self.heartbeat_interval:
consecutive_warnings += 1
logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
2025-05-26 23:09:03 +08:00
# 打印当前统计信息
if self.total_requests > 0:
success_rate = self.successful_requests / self.total_requests * 100
logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
if time_since_last_success > self.heartbeat_interval * 3:
logger.error(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f" 建议检查ollama服务状态或考虑重启程序")
2025-05-26 23:09:03 +08:00
# 检查ollama服务状态
if not self.check_ollama_status():
logger.critical("💀 Ollama服务异常这可能是卡死的原因")
print("💀 Ollama服务异常这可能是卡死的原因")
if consecutive_warnings >= 5:
logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
else:
if consecutive_warnings > 0:
logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
consecutive_warnings = 0
logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前")
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""批量并发处理句子每2000条保存一次检查点"""
2025-05-26 23:09:03 +08:00
logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
# 记录开始时间
start_time = time.time()
total_sentences = len(sentences)
2025-05-26 23:09:03 +08:00
# 分批处理每批1000个句子减少批次大小
batch_size = 1000
all_processed_sentences = []
2025-05-26 23:09:03 +08:00
# 启动心跳监控
heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences))
try:
for batch_start in range(0, total_sentences, batch_size):
batch_end = min(batch_start + batch_size, total_sentences)
batch_sentences = sentences[batch_start:batch_end]
logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
# 创建信号量限制并发数降低到8
semaphore = asyncio.Semaphore(self.max_concurrent)
# 重置批次统计
batch_start_time = time.time()
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
# 创建当前批次的任务
tasks = []
for i, sentence in enumerate(batch_sentences):
global_index = batch_start + i
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
tasks.append(task)
# 并发执行当前批次的任务
logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理当前批次的结果,过滤异常
batch_processed_sentences = []
batch_error_count = 0
for result in batch_results:
if isinstance(result, Exception):
logger.error(f"任务执行异常: {result}")
print(f"任务执行异常: {result}")
batch_error_count += 1
elif isinstance(result, dict):
batch_processed_sentences.append(result)
else:
batch_error_count += 1
# 按原始顺序排序(因为并发执行可能改变顺序)
batch_processed_sentences.sort(key=lambda x: x['index'])
# 移除index字段
for item in batch_processed_sentences:
del item['index']
# 添加到总结果中
all_processed_sentences.extend(batch_processed_sentences)
# 保存检查点
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
# 打印当前批次统计信息
elapsed_time = time.time() - start_time
batch_time = time.time() - batch_start_time
completed_sentences = len(all_processed_sentences)
logger.info(f"{batch_start//batch_size + 1} 批处理完成!")
logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
logger.info(f" - 批次用时:{batch_time/60:.1f}分钟")
logger.info(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
logger.info(f" - 检查点已保存:{checkpoint_filename}")
print(f"{batch_start//batch_size + 1} 批处理完成!")
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
print(f" - 批次用时:{batch_time/60:.1f}分钟")
print(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
print(f" - 检查点已保存:{checkpoint_filename}")
if batch_end < total_sentences:
remaining_sentences = total_sentences - completed_sentences
avg_time_per_sentence = elapsed_time / completed_sentences
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
# 在批次之间稍作休息,避免过度压力
if batch_end < total_sentences:
logger.info("批次间休息5秒...")
await asyncio.sleep(5)
finally:
# 取消心跳监控
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# 打印最终统计信息
total_time = time.time() - start_time
2025-05-26 23:09:03 +08:00
logger.info(f"=== 全部处理完成!===")
logger.info(f" - 总成功:{len(all_processed_sentences)}")
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
print(f"\n=== 全部处理完成!===")
print(f" - 总成功:{len(all_processed_sentences)}")
print(f" - 总用时:{total_time/60:.1f}分钟")
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
return all_processed_sentences
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
"""保存检查点文件"""
2025-05-26 23:09:03 +08:00
# 生成检查点文件名确保在output目录中
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json")
# 保存检查点
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"total_processed": len(processed_sentences),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"checkpoint_number": current_count
},
"sentences": processed_sentences
}, f, ensure_ascii=False, indent=2)
return checkpoint_filename
async def process_files(self) -> List[Dict[str, Any]]:
"""处理所有文件"""
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return []
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
2025-05-26 23:09:03 +08:00
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
2025-05-26 23:09:03 +08:00
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
"""保存处理后的句子到文件"""
# 确保输出目录存在
2025-05-26 23:09:03 +08:00
os.makedirs('output', exist_ok=True)
# 保存为JSON格式包含完整信息
json_output_file = self.output_file.replace('.txt', '.json')
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
# 保存为简单文本格式(仅修正后的句子)
with open(self.output_file, 'w', encoding='utf-8') as f:
for item in processed_sentences:
f.write(item['corrected_sentence'] + '\n')
# 生成重要性排序文件
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
with open(importance_file, 'w', encoding='utf-8') as f:
for item in importance_sorted:
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
print(f" - JSON格式: {json_output_file}")
print(f" - 文本格式: {self.output_file}")
print(f" - 重要性排序: {importance_file}")
# 统计信息
scores = [item['importance_score'] for item in processed_sentences]
avg_score = sum(scores) / len(scores) if scores else 0
print(f" - 平均重要性评分: {avg_score:.2f}")
print(f" - 最高评分: {max(scores):.1f}")
print(f" - 最低评分: {min(scores):.1f}")
async def run(self):
"""运行处理流程"""
print("Starting enhanced TREx to sentences conversion...")
processed_sentences = await self.process_files()
self.save_sentences(processed_sentences)
print("Enhanced conversion completed!")
def find_latest_checkpoint(self) -> Union[tuple, None]:
"""查找最新的检查点文件"""
2025-05-26 23:09:03 +08:00
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
# 从文件名中提取数字
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
return latest_file, latest_count
else:
return None
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
"""从检查点文件加载已处理的句子"""
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'sentences' in data:
return data['sentences']
else:
# 旧格式的检查点文件
return data
except Exception as e:
print(f"加载检查点文件失败: {e}")
return []
2025-05-26 23:09:03 +08:00
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
"""从检查点文件中获取已处理过的句子集合"""
if not self.output_file:
return set()
processed_sentences = set()
# 查找所有检查点文件
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
print("未找到检查点文件,将从头开始处理")
return set()
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
try:
with open(latest_file, 'r', encoding='utf-8') as f:
data = json.load(f)
sentences_data = data.get('sentences', [])
for item in sentences_data:
original_sentence = item.get('original_sentence', '')
if original_sentence:
processed_sentences.add(original_sentence)
print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
except Exception as e:
print(f"读取检查点文件失败: {e}")
return set()
return processed_sentences
async def process_with_llm(self):
"""步骤2从JSON文件读取句子并进行LLM处理"""
if not self.enable_llm_processing:
print("Error: LLM processing is disabled!")
return
if not self.output_file:
print("Error: output_file is required for LLM processing!")
return
print("=== 步骤2LLM处理 ===")
# 读取句子JSON文件
if not os.path.exists(self.sentences_json):
print(f"Error: Sentences file {self.sentences_json} not found!")
print("请先运行步骤1进行句子提取")
return
print(f"正在读取句子文件: {self.sentences_json}")
try:
with open(self.sentences_json, 'r', encoding='utf-8') as f:
data = json.load(f)
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
print(f"从文件中读取了 {len(all_sentences)} 个句子")
except Exception as e:
print(f"读取句子文件失败: {e}")
return
# 获取已处理的句子
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
# 过滤出未处理的句子
unprocessed_sentences = []
for sentence in all_sentences:
if sentence not in processed_sentences_set:
unprocessed_sentences.append(sentence)
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
if not unprocessed_sentences:
print("所有句子都已处理完成!")
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
processed_data = self.load_checkpoint(checkpoint_file)
self.save_sentences(processed_data)
print("已从检查点生成最终输出文件")
return
# 处理未处理的句子
print("开始LLM处理...")
# 检查ollama服务状态
logger.info("检查Ollama服务状态...")
if not self.check_ollama_status():
logger.error("Ollama服务状态异常无法继续处理")
print("错误Ollama服务状态异常请检查服务是否正常运行")
return
new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences)
# 如果有之前的处理结果,合并它们
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
previous_processed = self.load_checkpoint(checkpoint_file)
# 合并结果
all_processed_sentences = previous_processed + new_processed_sentences
print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录")
else:
all_processed_sentences = new_processed_sentences
else:
all_processed_sentences = new_processed_sentences
# 保存最终结果
self.save_sentences(all_processed_sentences)
print("LLM处理完成")
# ==================== 新增:句子提取功能 ====================
def extract_sentences(self):
"""步骤1从TREx数据集提取句子并保存为JSON"""
if not self.input_dir:
print("Error: input_dir is required for sentence extraction!")
return
print("=== 步骤1句子提取 ===")
print("开始从TREx数据集提取句子...")
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def check_ollama_status(self) -> bool:
"""检查ollama服务是否正常运行"""
try:
# 检查ollama进程是否运行
result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True)
if result.returncode != 0:
logger.error("Ollama进程未运行")
return False
# 检查ollama API是否响应
response = requests.get('http://localhost:11434/api/tags', timeout=5)
if response.status_code == 200:
logger.info("Ollama服务状态正常")
return True
else:
logger.error(f"Ollama API响应异常状态码: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.error(f"无法连接到Ollama API: {e}")
return False
except Exception as e:
logger.error(f"检查Ollama状态时出错: {e}")
return False
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
2025-05-26 23:09:03 +08:00
# 选择运行模式
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程')
# 文件路径参数
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
2025-05-26 23:09:03 +08:00
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
# 处理参数
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
args = parser.parse_args()
2025-05-26 23:09:03 +08:00
# 根据步骤验证参数
if args.step in ['extract', 'all']:
if not os.path.exists(args.input_dir):
print(f"Error: Input directory {args.input_dir} does not exist!")
return
if args.step in ['llm', 'all']:
if args.no_llm:
print("Error: Cannot run LLM step with --no_llm flag!")
return
2025-05-26 23:09:03 +08:00
# 创建处理器
processor = EnhancedTRExProcessor(
2025-05-26 23:09:03 +08:00
input_dir=args.input_dir,
sentences_json=args.sentences_json,
output_file=args.output_file,
max_files=args.max_files,
enable_llm_processing=not args.no_llm
)
2025-05-26 23:09:03 +08:00
# 根据选择的步骤运行
if args.step == 'extract':
print("=== 运行模式:仅句子提取 ===")
processor.extract_sentences()
elif args.step == 'llm':
print("=== 运行模式仅LLM处理 ===")
asyncio.run(processor.process_with_llm())
elif args.step == 'all':
print("=== 运行模式:完整流程 ===")
# 步骤1提取句子
print("\n--- 开始步骤1句子提取 ---")
sentences = processor.extract_sentences()
if not sentences:
print("句子提取失败,退出")
return
if args.no_llm:
print("LLM处理已禁用流程结束")
return
# 步骤2LLM处理
print("\n--- 开始步骤2LLM处理 ---")
asyncio.run(processor.process_with_llm())
if __name__ == "__main__":
main()