Minimind/preprocessing/trex_to_sentences_simple.py

1218 lines
56 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
支持两个独立步骤:
1. 句子提取从TREx数据集提取句子并保存为JSON
2. LLM处理读取JSON文件进行LLM后处理和重要性评分
"""
import json
import os
import glob
from typing import List, Dict, Any, Union, Set
import re
import asyncio
import time
import logging
from datetime import datetime
import subprocess
import requests
from pydantic import BaseModel, Field
from agno.agent import Agent
from agno.models.ollama import Ollama
# 设置日志系统
def setup_logging():
"""设置日志系统"""
# 确保logs目录存在
os.makedirs('logs', exist_ok=True)
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/trex_processor_{timestamp}.log'
# 配置日志格式
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
# 配置root logger
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler() # 同时输出到控制台
]
)
# 获取logger
logger = logging.getLogger(__name__)
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
return logger
# 全局日志对象
logger = setup_logging()
class ProcessedSentence(BaseModel):
"""处理后的句子结构"""
corrected_sentence: str = Field(
...,
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
)
importance_score: float = Field(
...,
description="重要性评分范围0.0-10.0以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
ge=0.0,
le=10.0
)
class EnhancedTRExProcessor:
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
sentences_json: str = None, enable_llm_processing: bool = True):
self.input_dir = input_dir
# 确保output目录存在
os.makedirs('output', exist_ok=True)
# 确保所有输出文件都在output目录中
if output_file:
if not output_file.startswith('output/'):
self.output_file = os.path.join('output', output_file)
else:
self.output_file = output_file
else:
self.output_file = None
if sentences_json:
if not sentences_json.startswith('output/'):
self.sentences_json = os.path.join('output', sentences_json)
else:
self.sentences_json = sentences_json
else:
self.sentences_json = "output/extracted_sentences.json"
self.max_files = max_files
self.enable_llm_processing = enable_llm_processing
# LLM处理配置
self.llm_timeout = 60 # 增加每个请求的超时时间到60秒
self.max_concurrent = 8 # 进一步降低并发数到4
self.max_retries = 2 # 减少重试次数避免过长等待
self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
# 统计信息
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 新增:最后活动时间
# 初始化agno agent仅在需要LLM处理时
if self.enable_llm_processing:
self.setup_agent()
logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}")
# 扩展的Wikidata属性映射
self.property_mappings = {
# 基本关系
"http://www.wikidata.org/prop/direct/P31": "is a",
"http://www.wikidata.org/prop/direct/P279": "is a type of",
# 人物相关
"http://www.wikidata.org/prop/direct/P106": "works as",
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
"http://www.wikidata.org/prop/direct/P19": "was born in",
"http://www.wikidata.org/prop/direct/P20": "died in",
"http://www.wikidata.org/prop/direct/P569": "was born on",
"http://www.wikidata.org/prop/direct/P570": "died on",
"http://www.wikidata.org/prop/direct/P22": "has father",
"http://www.wikidata.org/prop/direct/P25": "has mother",
"http://www.wikidata.org/prop/direct/P26": "is married to",
# 组织相关
"http://www.wikidata.org/prop/direct/P102": "is a member of",
"http://www.wikidata.org/prop/direct/P108": "works for",
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
"http://www.wikidata.org/prop/direct/P112": "was founded by",
"http://www.wikidata.org/prop/direct/P571": "was founded in",
"http://www.wikidata.org/prop/direct/P169": "has CEO",
# 地理相关
"http://www.wikidata.org/prop/direct/P17": "is located in",
"http://www.wikidata.org/prop/direct/P131": "is located in",
"http://www.wikidata.org/prop/direct/P36": "has capital",
"http://www.wikidata.org/prop/direct/P47": "borders",
# 其他关系
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
"http://www.wikidata.org/prop/direct/P361": "is part of",
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
"http://www.wikidata.org/prop/direct/P127": "is owned by",
"http://www.wikidata.org/prop/direct/P155": "follows",
"http://www.wikidata.org/prop/direct/P156": "is followed by",
"http://www.wikidata.org/prop/direct/P138": "is named after"
}
def setup_agent(self):
"""设置agno agent"""
try:
self.agent = Agent(
model=Ollama(
id="gemma3:latest",
# 使用options设置temperature和其他参数
options={
"temperature": 0.2,
"top_p": 0.8,
"top_k": 20,
"num_ctx": 4096,
}
),
response_model=ProcessedSentence,
instructions=[
"You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.",
"",
"### Sentence Correction Rules:",
"1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses",
"2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.",
"3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage",
"4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues",
"5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent",
"6. Do not add information not present in the original text, only correct errors",
"",
"### Correction Examples:",
"- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'",
"- Corrected: 'Argument is related to philosophy and logic.'",
"",
"- Error: 'Beijing is a capital city and are.'",
"- Corrected: 'Beijing is a capital city.'",
"",
"Importance scoring criteria (0.0-10.0, in increments of 0.1):",
"",
"0.0 points - Completely incorrect or meaningless information",
"Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'",
"",
"0.5 points - Almost worthless information",
"Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'",
"",
"1.0 points - Extremely rare, non-practical knowledge",
"Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'",
"",
"1.5 points - Very niche detailed information",
"Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'",
"",
"2.0 points - Details in niche professional fields",
"Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'",
"",
"2.5 points - Technical details only professionals care about",
"Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'",
"",
"3.0 points - Professional knowledge in specific fields",
"Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'",
"",
"3.5 points - Professional information with some value",
"Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'",
"",
"4.0 points - Meaningful knowledge known by few",
"Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'",
"",
"4.5 points - Knowledge of interest to some groups",
"Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'",
"",
"5.0 points - General knowledge of moderate importance",
"Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'",
"",
"5.5 points - Fairly useful common sense",
"Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'",
"",
"6.0 points - Knowledge most educated people should know",
"Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'",
"",
"6.5 points - Important cultural or scientific common sense",
"Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'",
"",
"7.0 points - Important foundational knowledge",
"Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'",
"",
"7.5 points - Very important common sense",
"Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'",
"",
"8.0 points - Core knowledge in basic education",
"Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'",
"",
"8.5 points - Important knowledge everyone should master",
"Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'",
"",
"9.0 points - Extremely important basic concepts",
"Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'",
"",
"9.5 points - Core knowledge everyone must know",
"Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'",
"",
"10.0 points - Most basic and important common sense",
"Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'",
"",
"When scoring, please consider:",
"1. Popularity of knowledge - How many people know this knowledge",
"2. Practical value - How useful this knowledge is in daily life",
"3. Educational importance - The position of this knowledge in the education system",
"4. Cultural significance - The importance of this knowledge for understanding the world",
"",
"Please output structured results directly without showing the thinking process."
],
markdown=False
)
logger.info("LLM处理器初始化成功")
except Exception as e:
logger.error(f"LLM处理器初始化失败: {e}")
print(f"LLM处理器初始化失败: {e}")
print("将使用基础模式不使用LLM后处理")
self.enable_llm_processing = False
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
"""使用LLM处理单个句子保留用于单独调用"""
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
return response
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
return ProcessedSentence(
corrected_sentence=message['corrected_sentence'],
importance_score=message['importance_score']
)
except asyncio.TimeoutError:
logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"LLM请求最终超时使用默认处理: {sentence[:50]}...")
break
# 等待一段时间后重试
await asyncio.sleep(2 ** attempt) # 指数退避
except Exception as e:
logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,返回原句子和中等评分
logger.warning(f"使用默认处理: {sentence[:50]}...")
return ProcessedSentence(
corrected_sentence=sentence,
importance_score=5.0
)
def clean_text(self, text: str) -> str:
"""清理文本,处理特殊字符"""
if not text:
return ""
# 处理常见的Unicode字符
text = text.replace("", "-") # en dash
text = text.replace("", "-") # em dash
text = text.replace("'", "'") # right single quotation mark
text = text.replace("'", "'") # left single quotation mark
text = text.replace(""", '"') # left double quotation mark
text = text.replace(""", '"') # right double quotation mark
# 处理可能的转义序列
try:
text = text.encode('utf-8').decode('utf-8')
except:
pass
# 清理多余的空格
text = re.sub(r'\s+', ' ', text).strip()
# 移除可能的引号
text = text.strip('"\'')
return text
def parse_large_json_file(self, file_path: str) -> List[Dict]:
"""解析大型JSON文件处理可能的格式问题"""
documents = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 尝试不同的解析方法
if content.startswith('[') and content.endswith(']'):
# 标准JSON数组
documents = json.loads(content)
else:
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if '}{"' in content:
json_strings = content.split('}{')
json_strings[0] += '}' # 第一个对象
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
for i in range(1, len(json_strings) - 1):
json_strings[i] = '{' + json_strings[i] + '}'
for json_str in json_strings:
try:
doc = json.loads(json_str)
documents.append(doc)
except json.JSONDecodeError:
continue
else:
# 尝试作为单个JSON对象
try:
documents = [json.loads(content)]
except json.JSONDecodeError:
pass
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return documents
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
"""从文档中提取句子"""
sentences = []
title = self.clean_text(doc.get('title', ''))
text = self.clean_text(doc.get('text', ''))
entities = doc.get('entities', [])
triples = doc.get('triples', [])
# 处理显式三元组
for triple in triples:
sentence = self.triple_to_sentence(triple)
if sentence:
sentences.append(sentence)
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len(sentences) < 5:
# 基于标题和实体生成句子
entity_names = []
for entity in entities[:15]:
entity_name = self.clean_text(entity.get('surfaceform', ''))
if entity_name and len(entity_name) > 2:
entity_names.append(entity_name)
# 生成简单的描述句子
if entity_names:
important_entities = []
title_lower = title.lower()
for entity in entity_names:
if (entity.lower() != title_lower and
entity not in important_entities and
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
important_entities.append(entity)
if len(important_entities) >= 6:
break
if important_entities and len(sentences) < 3:
entities_str = ', '.join(important_entities[:3])
sentences.append(f"{title} is related to {entities_str}.")
return sentences
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
"""将三元组转换为自然语言句子"""
try:
subject = triple.get('subject', {})
predicate = triple.get('predicate', {})
obj = triple.get('object', {})
subject_name = self.clean_text(subject.get('surfaceform', ''))
object_name = self.clean_text(obj.get('surfaceform', ''))
predicate_uri = predicate.get('uri', '')
# 检查是否有有效的主语和宾语
if not subject_name or not object_name:
return ""
# 检查主语和宾语是否过短或无意义
if len(subject_name) <= 2 or len(object_name) <= 2:
return ""
# 获取关系文本
relation_text = self.property_mappings.get(predicate_uri, "is related to")
# 避免重复的主语宾语
if subject_name.lower() == object_name.lower():
return ""
return f"{subject_name} {relation_text} {object_name}."
except Exception as e:
print(f"Error converting triple to sentence: {e}")
return ""
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
"""使用信号量控制并发的LLM处理"""
async with semaphore:
self.total_requests += 1
self.last_activity_time = time.time() # 更新活动时间
success = False
for attempt in range(self.max_retries):
try:
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
# 使用asyncio.wait_for添加超时机制
response = await asyncio.wait_for(
self.agent.arun(prompt),
timeout=self.llm_timeout
)
# 根据agno文档response应该直接是ProcessedSentence类型
if isinstance(response, ProcessedSentence):
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": response.corrected_sentence,
"importance_score": response.importance_score
}
else:
message = response.messages[-1].content
message = message.replace("```json", "").replace("```", "")
message = json.loads(message)
result = {
"index": index,
"original_sentence": sentence,
"corrected_sentence": message['corrected_sentence'],
"importance_score": message['importance_score']
}
# 成功处理
self.successful_requests += 1
self.last_successful_time = time.time()
self.last_activity_time = time.time() # 更新活动时间
success = True
# 打印详细进度信息 - 降低频率到每50个
if index % 50 == 0:
current_time = time.time()
elapsed_time = current_time - start_time
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
remaining_sentences = total_sentences - (index + 1)
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
# 格式化时间显示
def format_time(seconds):
if seconds < 60:
return f"{seconds:.1f}"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}分钟"
else:
hours = seconds / 3600
return f"{hours:.1f}小时"
logger.info(f"已完成第 {index + 1} 个句子的处理")
logger.info(f" - 剩余句子数: {remaining_sentences}")
logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
logger.info(f" - 已用时间: {format_time(elapsed_time)}")
logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
print(f"已完成第 {index + 1} 个句子的处理")
print(f" - 剩余句子数: {remaining_sentences}")
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
print(f" - 已用时间: {format_time(elapsed_time)}")
print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
return result
except asyncio.TimeoutError:
self.timeout_requests += 1
self.last_activity_time = time.time() # 更新活动时间
logger.warning(f"{index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
if attempt == self.max_retries - 1:
logger.error(f"{index} 个句子最终超时,使用默认处理")
break
# 指数退避
await asyncio.sleep(2 ** attempt)
except Exception as e:
self.last_activity_time = time.time() # 更新活动时间
logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
if attempt == self.max_retries - 1:
break
await asyncio.sleep(1)
# 所有重试都失败,使用默认处理
if not success:
self.failed_requests += 1
logger.warning(f"{index} 个句子使用默认处理: {sentence[:50]}...")
return {
"index": index,
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
}
async def heartbeat_monitor(self, total_sentences: int):
"""心跳监控,检测是否有长时间无响应"""
consecutive_warnings = 0
while True:
await asyncio.sleep(self.heartbeat_interval)
current_time = time.time()
time_since_last_success = current_time - self.last_successful_time
time_since_last_activity = current_time - self.last_activity_time
# 检查最后成功时间
if time_since_last_success > self.heartbeat_interval:
consecutive_warnings += 1
logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
# 打印当前统计信息
if self.total_requests > 0:
success_rate = self.successful_requests / self.total_requests * 100
logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
if time_since_last_success > self.heartbeat_interval * 3:
logger.error(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f"❌ 严重警告LLM可能已卡死超过 {time_since_last_success:.1f} 秒无成功响应!")
print(f" 建议检查ollama服务状态或考虑重启程序")
# 检查ollama服务状态
if not self.check_ollama_status():
logger.critical("💀 Ollama服务异常这可能是卡死的原因")
print("💀 Ollama服务异常这可能是卡死的原因")
if consecutive_warnings >= 5:
logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
else:
if consecutive_warnings > 0:
logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
consecutive_warnings = 0
logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前")
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""批量并发处理句子每2000条保存一次检查点"""
logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent}...")
# 记录开始时间
start_time = time.time()
total_sentences = len(sentences)
# 分批处理每批1000个句子减少批次大小
batch_size = 1000
all_processed_sentences = []
# 启动心跳监控
heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences))
try:
for batch_start in range(0, total_sentences, batch_size):
batch_end = min(batch_start + batch_size, total_sentences)
batch_sentences = sentences[batch_start:batch_end]
logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
# 创建信号量限制并发数降低到8
semaphore = asyncio.Semaphore(self.max_concurrent)
# 重置批次统计
batch_start_time = time.time()
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
self.timeout_requests = 0
# 创建当前批次的任务
tasks = []
for i, sentence in enumerate(batch_sentences):
global_index = batch_start + i
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
tasks.append(task)
# 并发执行当前批次的任务
logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理当前批次的结果,过滤异常
batch_processed_sentences = []
batch_error_count = 0
for result in batch_results:
if isinstance(result, Exception):
logger.error(f"任务执行异常: {result}")
print(f"任务执行异常: {result}")
batch_error_count += 1
elif isinstance(result, dict):
batch_processed_sentences.append(result)
else:
batch_error_count += 1
# 按原始顺序排序(因为并发执行可能改变顺序)
batch_processed_sentences.sort(key=lambda x: x['index'])
# 移除index字段
for item in batch_processed_sentences:
del item['index']
# 添加到总结果中
all_processed_sentences.extend(batch_processed_sentences)
# 保存检查点
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
# 打印当前批次统计信息
elapsed_time = time.time() - start_time
batch_time = time.time() - batch_start_time
completed_sentences = len(all_processed_sentences)
logger.info(f"{batch_start//batch_size + 1} 批处理完成!")
logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
logger.info(f" - 批次用时:{batch_time/60:.1f}分钟")
logger.info(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
logger.info(f" - 检查点已保存:{checkpoint_filename}")
print(f"{batch_start//batch_size + 1} 批处理完成!")
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
print(f" - 批次用时:{batch_time/60:.1f}分钟")
print(f" - LLM统计成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
print(f" - 检查点已保存:{checkpoint_filename}")
if batch_end < total_sentences:
remaining_sentences = total_sentences - completed_sentences
avg_time_per_sentence = elapsed_time / completed_sentences
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
# 在批次之间稍作休息,避免过度压力
if batch_end < total_sentences:
logger.info("批次间休息5秒...")
await asyncio.sleep(5)
finally:
# 取消心跳监控
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# 打印最终统计信息
total_time = time.time() - start_time
logger.info(f"=== 全部处理完成!===")
logger.info(f" - 总成功:{len(all_processed_sentences)}")
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
print(f"\n=== 全部处理完成!===")
print(f" - 总成功:{len(all_processed_sentences)}")
print(f" - 总用时:{total_time/60:.1f}分钟")
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
return all_processed_sentences
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
"""保存检查点文件"""
# 生成检查点文件名确保在output目录中
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json")
# 保存检查点
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"total_processed": len(processed_sentences),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"checkpoint_number": current_count
},
"sentences": processed_sentences
}, f, ensure_ascii=False, indent=2)
return checkpoint_filename
async def process_files(self) -> List[Dict[str, Any]]:
"""处理所有文件"""
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return []
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
"""保存处理后的句子到文件"""
# 确保输出目录存在
os.makedirs('output', exist_ok=True)
# 保存为JSON格式包含完整信息
json_output_file = self.output_file.replace('.txt', '.json')
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
# 保存为简单文本格式(仅修正后的句子)
with open(self.output_file, 'w', encoding='utf-8') as f:
for item in processed_sentences:
f.write(item['corrected_sentence'] + '\n')
# 生成重要性排序文件
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
with open(importance_file, 'w', encoding='utf-8') as f:
for item in importance_sorted:
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
print(f" - JSON格式: {json_output_file}")
print(f" - 文本格式: {self.output_file}")
print(f" - 重要性排序: {importance_file}")
# 统计信息
scores = [item['importance_score'] for item in processed_sentences]
avg_score = sum(scores) / len(scores) if scores else 0
print(f" - 平均重要性评分: {avg_score:.2f}")
print(f" - 最高评分: {max(scores):.1f}")
print(f" - 最低评分: {min(scores):.1f}")
async def run(self):
"""运行处理流程"""
print("Starting enhanced TREx to sentences conversion...")
processed_sentences = await self.process_files()
self.save_sentences(processed_sentences)
print("Enhanced conversion completed!")
def find_latest_checkpoint(self) -> Union[tuple, None]:
"""查找最新的检查点文件"""
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
# 从文件名中提取数字
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
return latest_file, latest_count
else:
return None
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
"""从检查点文件加载已处理的句子"""
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'sentences' in data:
return data['sentences']
else:
# 旧格式的检查点文件
return data
except Exception as e:
print(f"加载检查点文件失败: {e}")
return []
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
"""从检查点文件中获取已处理过的句子集合"""
if not self.output_file:
return set()
processed_sentences = set()
# 查找所有检查点文件
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
print("未找到检查点文件,将从头开始处理")
return set()
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
try:
with open(latest_file, 'r', encoding='utf-8') as f:
data = json.load(f)
sentences_data = data.get('sentences', [])
for item in sentences_data:
original_sentence = item.get('original_sentence', '')
if original_sentence:
processed_sentences.add(original_sentence)
print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
except Exception as e:
print(f"读取检查点文件失败: {e}")
return set()
return processed_sentences
async def process_with_llm(self):
"""步骤2从JSON文件读取句子并进行LLM处理"""
if not self.enable_llm_processing:
print("Error: LLM processing is disabled!")
return
if not self.output_file:
print("Error: output_file is required for LLM processing!")
return
print("=== 步骤2LLM处理 ===")
# 读取句子JSON文件
if not os.path.exists(self.sentences_json):
print(f"Error: Sentences file {self.sentences_json} not found!")
print("请先运行步骤1进行句子提取")
return
print(f"正在读取句子文件: {self.sentences_json}")
try:
with open(self.sentences_json, 'r', encoding='utf-8') as f:
data = json.load(f)
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
print(f"从文件中读取了 {len(all_sentences)} 个句子")
except Exception as e:
print(f"读取句子文件失败: {e}")
return
# 获取已处理的句子
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
# 过滤出未处理的句子
unprocessed_sentences = []
for sentence in all_sentences:
if sentence not in processed_sentences_set:
unprocessed_sentences.append(sentence)
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
if not unprocessed_sentences:
print("所有句子都已处理完成!")
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
processed_data = self.load_checkpoint(checkpoint_file)
self.save_sentences(processed_data)
print("已从检查点生成最终输出文件")
return
# 处理未处理的句子
print("开始LLM处理...")
# 检查ollama服务状态
logger.info("检查Ollama服务状态...")
if not self.check_ollama_status():
logger.error("Ollama服务状态异常无法继续处理")
print("错误Ollama服务状态异常请检查服务是否正常运行")
return
new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences)
# 如果有之前的处理结果,合并它们
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
previous_processed = self.load_checkpoint(checkpoint_file)
# 合并结果
all_processed_sentences = previous_processed + new_processed_sentences
print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录")
else:
all_processed_sentences = new_processed_sentences
else:
all_processed_sentences = new_processed_sentences
# 保存最终结果
self.save_sentences(all_processed_sentences)
print("LLM处理完成")
# ==================== 新增:句子提取功能 ====================
def extract_sentences(self):
"""步骤1从TREx数据集提取句子并保存为JSON"""
if not self.input_dir:
print("Error: input_dir is required for sentence extraction!")
return
print("=== 步骤1句子提取 ===")
print("开始从TREx数据集提取句子...")
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def check_ollama_status(self) -> bool:
"""检查ollama服务是否正常运行"""
try:
# 检查ollama进程是否运行
result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True)
if result.returncode != 0:
logger.error("Ollama进程未运行")
return False
# 检查ollama API是否响应
response = requests.get('http://localhost:11434/api/tags', timeout=5)
if response.status_code == 200:
logger.info("Ollama服务状态正常")
return True
else:
logger.error(f"Ollama API响应异常状态码: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.error(f"无法连接到Ollama API: {e}")
return False
except Exception as e:
logger.error(f"检查Ollama状态时出错: {e}")
return False
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
# 选择运行模式
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程')
# 文件路径参数
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
# 处理参数
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
args = parser.parse_args()
# 根据步骤验证参数
if args.step in ['extract', 'all']:
if not os.path.exists(args.input_dir):
print(f"Error: Input directory {args.input_dir} does not exist!")
return
if args.step in ['llm', 'all']:
if args.no_llm:
print("Error: Cannot run LLM step with --no_llm flag!")
return
# 创建处理器
processor = EnhancedTRExProcessor(
input_dir=args.input_dir,
sentences_json=args.sentences_json,
output_file=args.output_file,
max_files=args.max_files,
enable_llm_processing=not args.no_llm
)
# 根据选择的步骤运行
if args.step == 'extract':
print("=== 运行模式:仅句子提取 ===")
processor.extract_sentences()
elif args.step == 'llm':
print("=== 运行模式仅LLM处理 ===")
asyncio.run(processor.process_with_llm())
elif args.step == 'all':
print("=== 运行模式:完整流程 ===")
# 步骤1提取句子
print("\n--- 开始步骤1句子提取 ---")
sentences = processor.extract_sentences()
if not sentences:
print("句子提取失败,退出")
return
if args.no_llm:
print("LLM处理已禁用流程结束")
return
# 步骤2LLM处理
print("\n--- 开始步骤2LLM处理 ---")
asyncio.run(processor.process_with_llm())
if __name__ == "__main__":
main()