1218 lines
56 KiB
Python
1218 lines
56 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
TREx数据集增强预处理脚本
|
||
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
|
||
|
||
支持两个独立步骤:
|
||
1. 句子提取:从TREx数据集提取句子并保存为JSON
|
||
2. LLM处理:读取JSON文件进行LLM后处理和重要性评分
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import glob
|
||
from typing import List, Dict, Any, Union, Set
|
||
import re
|
||
import asyncio
|
||
import time
|
||
import logging
|
||
from datetime import datetime
|
||
import subprocess
|
||
import requests
|
||
from pydantic import BaseModel, Field
|
||
from agno.agent import Agent
|
||
from agno.models.ollama import Ollama
|
||
|
||
# 设置日志系统
|
||
def setup_logging():
|
||
"""设置日志系统"""
|
||
# 确保logs目录存在
|
||
os.makedirs('logs', exist_ok=True)
|
||
|
||
# 创建日志文件名(包含时间戳)
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
log_file = f'logs/trex_processor_{timestamp}.log'
|
||
|
||
# 配置日志格式
|
||
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
|
||
|
||
# 配置root logger
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format=log_format,
|
||
handlers=[
|
||
logging.FileHandler(log_file, encoding='utf-8'),
|
||
logging.StreamHandler() # 同时输出到控制台
|
||
]
|
||
)
|
||
|
||
# 获取logger
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
|
||
return logger
|
||
|
||
# 全局日志对象
|
||
logger = setup_logging()
|
||
|
||
class ProcessedSentence(BaseModel):
|
||
"""处理后的句子结构"""
|
||
corrected_sentence: str = Field(
|
||
...,
|
||
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
|
||
)
|
||
importance_score: float = Field(
|
||
...,
|
||
description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
|
||
ge=0.0,
|
||
le=10.0
|
||
)
|
||
|
||
|
||
class EnhancedTRExProcessor:
|
||
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
|
||
sentences_json: str = None, enable_llm_processing: bool = True):
|
||
self.input_dir = input_dir
|
||
|
||
# 确保output目录存在
|
||
os.makedirs('output', exist_ok=True)
|
||
|
||
# 确保所有输出文件都在output目录中
|
||
if output_file:
|
||
if not output_file.startswith('output/'):
|
||
self.output_file = os.path.join('output', output_file)
|
||
else:
|
||
self.output_file = output_file
|
||
else:
|
||
self.output_file = None
|
||
|
||
if sentences_json:
|
||
if not sentences_json.startswith('output/'):
|
||
self.sentences_json = os.path.join('output', sentences_json)
|
||
else:
|
||
self.sentences_json = sentences_json
|
||
else:
|
||
self.sentences_json = "output/extracted_sentences.json"
|
||
|
||
self.max_files = max_files
|
||
self.enable_llm_processing = enable_llm_processing
|
||
|
||
# LLM处理配置
|
||
self.llm_timeout = 60 # 增加每个请求的超时时间到60秒
|
||
self.max_concurrent = 8 # 进一步降低并发数到4
|
||
self.max_retries = 2 # 减少重试次数避免过长等待
|
||
self.heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
|
||
|
||
# 统计信息
|
||
self.total_requests = 0
|
||
self.successful_requests = 0
|
||
self.failed_requests = 0
|
||
self.timeout_requests = 0
|
||
self.last_successful_time = time.time()
|
||
self.last_activity_time = time.time() # 新增:最后活动时间
|
||
|
||
# 初始化agno agent(仅在需要LLM处理时)
|
||
if self.enable_llm_processing:
|
||
self.setup_agent()
|
||
|
||
logger.info(f"处理器初始化完成 - 并发数: {self.max_concurrent}, 超时时间: {self.llm_timeout}秒")
|
||
|
||
# 扩展的Wikidata属性映射
|
||
self.property_mappings = {
|
||
# 基本关系
|
||
"http://www.wikidata.org/prop/direct/P31": "is a",
|
||
"http://www.wikidata.org/prop/direct/P279": "is a type of",
|
||
|
||
# 人物相关
|
||
"http://www.wikidata.org/prop/direct/P106": "works as",
|
||
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
|
||
"http://www.wikidata.org/prop/direct/P19": "was born in",
|
||
"http://www.wikidata.org/prop/direct/P20": "died in",
|
||
"http://www.wikidata.org/prop/direct/P569": "was born on",
|
||
"http://www.wikidata.org/prop/direct/P570": "died on",
|
||
"http://www.wikidata.org/prop/direct/P22": "has father",
|
||
"http://www.wikidata.org/prop/direct/P25": "has mother",
|
||
"http://www.wikidata.org/prop/direct/P26": "is married to",
|
||
|
||
# 组织相关
|
||
"http://www.wikidata.org/prop/direct/P102": "is a member of",
|
||
"http://www.wikidata.org/prop/direct/P108": "works for",
|
||
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
|
||
"http://www.wikidata.org/prop/direct/P112": "was founded by",
|
||
"http://www.wikidata.org/prop/direct/P571": "was founded in",
|
||
"http://www.wikidata.org/prop/direct/P169": "has CEO",
|
||
|
||
# 地理相关
|
||
"http://www.wikidata.org/prop/direct/P17": "is located in",
|
||
"http://www.wikidata.org/prop/direct/P131": "is located in",
|
||
"http://www.wikidata.org/prop/direct/P36": "has capital",
|
||
"http://www.wikidata.org/prop/direct/P47": "borders",
|
||
|
||
# 其他关系
|
||
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
|
||
"http://www.wikidata.org/prop/direct/P361": "is part of",
|
||
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
|
||
"http://www.wikidata.org/prop/direct/P127": "is owned by",
|
||
"http://www.wikidata.org/prop/direct/P155": "follows",
|
||
"http://www.wikidata.org/prop/direct/P156": "is followed by",
|
||
"http://www.wikidata.org/prop/direct/P138": "is named after"
|
||
}
|
||
|
||
def setup_agent(self):
|
||
"""设置agno agent"""
|
||
try:
|
||
self.agent = Agent(
|
||
model=Ollama(
|
||
id="gemma3:latest",
|
||
# 使用options设置temperature和其他参数
|
||
options={
|
||
"temperature": 0.2,
|
||
"top_p": 0.8,
|
||
"top_k": 20,
|
||
"num_ctx": 4096,
|
||
}
|
||
),
|
||
response_model=ProcessedSentence,
|
||
instructions=[
|
||
"You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.",
|
||
"",
|
||
"### Sentence Correction Rules:",
|
||
"1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses",
|
||
"2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.",
|
||
"3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage",
|
||
"4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues",
|
||
"5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent",
|
||
"6. Do not add information not present in the original text, only correct errors",
|
||
"",
|
||
"### Correction Examples:",
|
||
"- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'",
|
||
"- Corrected: 'Argument is related to philosophy and logic.'",
|
||
"",
|
||
"- Error: 'Beijing is a capital city and are.'",
|
||
"- Corrected: 'Beijing is a capital city.'",
|
||
"",
|
||
"Importance scoring criteria (0.0-10.0, in increments of 0.1):",
|
||
"",
|
||
"0.0 points - Completely incorrect or meaningless information",
|
||
"Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'",
|
||
"",
|
||
"0.5 points - Almost worthless information",
|
||
"Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'",
|
||
"",
|
||
"1.0 points - Extremely rare, non-practical knowledge",
|
||
"Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'",
|
||
"",
|
||
"1.5 points - Very niche detailed information",
|
||
"Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'",
|
||
"",
|
||
"2.0 points - Details in niche professional fields",
|
||
"Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'",
|
||
"",
|
||
"2.5 points - Technical details only professionals care about",
|
||
"Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'",
|
||
"",
|
||
"3.0 points - Professional knowledge in specific fields",
|
||
"Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'",
|
||
"",
|
||
"3.5 points - Professional information with some value",
|
||
"Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'",
|
||
"",
|
||
"4.0 points - Meaningful knowledge known by few",
|
||
"Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'",
|
||
"",
|
||
"4.5 points - Knowledge of interest to some groups",
|
||
"Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'",
|
||
"",
|
||
"5.0 points - General knowledge of moderate importance",
|
||
"Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'",
|
||
"",
|
||
"5.5 points - Fairly useful common sense",
|
||
"Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'",
|
||
"",
|
||
"6.0 points - Knowledge most educated people should know",
|
||
"Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'",
|
||
"",
|
||
"6.5 points - Important cultural or scientific common sense",
|
||
"Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'",
|
||
"",
|
||
"7.0 points - Important foundational knowledge",
|
||
"Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'",
|
||
"",
|
||
"7.5 points - Very important common sense",
|
||
"Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'",
|
||
"",
|
||
"8.0 points - Core knowledge in basic education",
|
||
"Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'",
|
||
"",
|
||
"8.5 points - Important knowledge everyone should master",
|
||
"Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'",
|
||
"",
|
||
"9.0 points - Extremely important basic concepts",
|
||
"Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'",
|
||
"",
|
||
"9.5 points - Core knowledge everyone must know",
|
||
"Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'",
|
||
"",
|
||
"10.0 points - Most basic and important common sense",
|
||
"Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'",
|
||
"",
|
||
"When scoring, please consider:",
|
||
"1. Popularity of knowledge - How many people know this knowledge",
|
||
"2. Practical value - How useful this knowledge is in daily life",
|
||
"3. Educational importance - The position of this knowledge in the education system",
|
||
"4. Cultural significance - The importance of this knowledge for understanding the world",
|
||
"",
|
||
"Please output structured results directly without showing the thinking process."
|
||
],
|
||
markdown=False
|
||
)
|
||
logger.info("LLM处理器初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"LLM处理器初始化失败: {e}")
|
||
print(f"LLM处理器初始化失败: {e}")
|
||
print("将使用基础模式(不使用LLM后处理)")
|
||
self.enable_llm_processing = False
|
||
|
||
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
|
||
"""使用LLM处理单个句子(保留用于单独调用)"""
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
|
||
|
||
# 使用asyncio.wait_for添加超时机制
|
||
response = await asyncio.wait_for(
|
||
self.agent.arun(prompt),
|
||
timeout=self.llm_timeout
|
||
)
|
||
|
||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||
if isinstance(response, ProcessedSentence):
|
||
return response
|
||
else:
|
||
message = response.messages[-1].content
|
||
message = message.replace("```json", "").replace("```", "")
|
||
message = json.loads(message)
|
||
return ProcessedSentence(
|
||
corrected_sentence=message['corrected_sentence'],
|
||
importance_score=message['importance_score']
|
||
)
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"LLM请求超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"LLM请求最终超时,使用默认处理: {sentence[:50]}...")
|
||
break
|
||
# 等待一段时间后重试
|
||
await asyncio.sleep(2 ** attempt) # 指数退避
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM处理句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||
if attempt == self.max_retries - 1:
|
||
break
|
||
await asyncio.sleep(1)
|
||
|
||
# 所有重试都失败,返回原句子和中等评分
|
||
logger.warning(f"使用默认处理: {sentence[:50]}...")
|
||
return ProcessedSentence(
|
||
corrected_sentence=sentence,
|
||
importance_score=5.0
|
||
)
|
||
|
||
def clean_text(self, text: str) -> str:
|
||
"""清理文本,处理特殊字符"""
|
||
if not text:
|
||
return ""
|
||
|
||
# 处理常见的Unicode字符
|
||
text = text.replace("–", "-") # en dash
|
||
text = text.replace("—", "-") # em dash
|
||
text = text.replace("'", "'") # right single quotation mark
|
||
text = text.replace("'", "'") # left single quotation mark
|
||
text = text.replace(""", '"') # left double quotation mark
|
||
text = text.replace(""", '"') # right double quotation mark
|
||
|
||
# 处理可能的转义序列
|
||
try:
|
||
text = text.encode('utf-8').decode('utf-8')
|
||
except:
|
||
pass
|
||
|
||
# 清理多余的空格
|
||
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
||
# 移除可能的引号
|
||
text = text.strip('"\'')
|
||
|
||
return text
|
||
|
||
def parse_large_json_file(self, file_path: str) -> List[Dict]:
|
||
"""解析大型JSON文件,处理可能的格式问题"""
|
||
documents = []
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read().strip()
|
||
|
||
# 尝试不同的解析方法
|
||
if content.startswith('[') and content.endswith(']'):
|
||
# 标准JSON数组
|
||
documents = json.loads(content)
|
||
else:
|
||
# 可能是连续的JSON对象
|
||
# 尝试在}{"之间分割
|
||
if '}{"' in content:
|
||
json_strings = content.split('}{')
|
||
json_strings[0] += '}' # 第一个对象
|
||
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
|
||
|
||
for i in range(1, len(json_strings) - 1):
|
||
json_strings[i] = '{' + json_strings[i] + '}'
|
||
|
||
for json_str in json_strings:
|
||
try:
|
||
doc = json.loads(json_str)
|
||
documents.append(doc)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
else:
|
||
# 尝试作为单个JSON对象
|
||
try:
|
||
documents = [json.loads(content)]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
except Exception as e:
|
||
print(f"Error parsing {file_path}: {e}")
|
||
|
||
return documents
|
||
|
||
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
|
||
"""从文档中提取句子"""
|
||
sentences = []
|
||
|
||
title = self.clean_text(doc.get('title', ''))
|
||
text = self.clean_text(doc.get('text', ''))
|
||
entities = doc.get('entities', [])
|
||
triples = doc.get('triples', [])
|
||
|
||
# 处理显式三元组
|
||
for triple in triples:
|
||
sentence = self.triple_to_sentence(triple)
|
||
if sentence:
|
||
sentences.append(sentence)
|
||
|
||
# 从实体和文本中生成基本句子(如果三元组句子不够)
|
||
if title and text and len(sentences) < 5:
|
||
# 基于标题和实体生成句子
|
||
entity_names = []
|
||
for entity in entities[:15]:
|
||
entity_name = self.clean_text(entity.get('surfaceform', ''))
|
||
if entity_name and len(entity_name) > 2:
|
||
entity_names.append(entity_name)
|
||
|
||
# 生成简单的描述句子
|
||
if entity_names:
|
||
important_entities = []
|
||
title_lower = title.lower()
|
||
for entity in entity_names:
|
||
if (entity.lower() != title_lower and
|
||
entity not in important_entities and
|
||
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
|
||
important_entities.append(entity)
|
||
if len(important_entities) >= 6:
|
||
break
|
||
|
||
if important_entities and len(sentences) < 3:
|
||
entities_str = ', '.join(important_entities[:3])
|
||
sentences.append(f"{title} is related to {entities_str}.")
|
||
|
||
return sentences
|
||
|
||
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
|
||
"""将三元组转换为自然语言句子"""
|
||
try:
|
||
subject = triple.get('subject', {})
|
||
predicate = triple.get('predicate', {})
|
||
obj = triple.get('object', {})
|
||
|
||
subject_name = self.clean_text(subject.get('surfaceform', ''))
|
||
object_name = self.clean_text(obj.get('surfaceform', ''))
|
||
predicate_uri = predicate.get('uri', '')
|
||
|
||
# 检查是否有有效的主语和宾语
|
||
if not subject_name or not object_name:
|
||
return ""
|
||
|
||
# 检查主语和宾语是否过短或无意义
|
||
if len(subject_name) <= 2 or len(object_name) <= 2:
|
||
return ""
|
||
|
||
# 获取关系文本
|
||
relation_text = self.property_mappings.get(predicate_uri, "is related to")
|
||
|
||
# 避免重复的主语宾语
|
||
if subject_name.lower() == object_name.lower():
|
||
return ""
|
||
|
||
return f"{subject_name} {relation_text} {object_name}."
|
||
|
||
except Exception as e:
|
||
print(f"Error converting triple to sentence: {e}")
|
||
return ""
|
||
|
||
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
|
||
"""使用信号量控制并发的LLM处理"""
|
||
async with semaphore:
|
||
self.total_requests += 1
|
||
self.last_activity_time = time.time() # 更新活动时间
|
||
success = False
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
prompt = f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
|
||
|
||
# 使用asyncio.wait_for添加超时机制
|
||
response = await asyncio.wait_for(
|
||
self.agent.arun(prompt),
|
||
timeout=self.llm_timeout
|
||
)
|
||
|
||
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||
if isinstance(response, ProcessedSentence):
|
||
result = {
|
||
"index": index,
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": response.corrected_sentence,
|
||
"importance_score": response.importance_score
|
||
}
|
||
else:
|
||
message = response.messages[-1].content
|
||
message = message.replace("```json", "").replace("```", "")
|
||
message = json.loads(message)
|
||
result = {
|
||
"index": index,
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": message['corrected_sentence'],
|
||
"importance_score": message['importance_score']
|
||
}
|
||
|
||
# 成功处理
|
||
self.successful_requests += 1
|
||
self.last_successful_time = time.time()
|
||
self.last_activity_time = time.time() # 更新活动时间
|
||
success = True
|
||
|
||
# 打印详细进度信息 - 降低频率到每50个
|
||
if index % 50 == 0:
|
||
current_time = time.time()
|
||
elapsed_time = current_time - start_time
|
||
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
|
||
remaining_sentences = total_sentences - (index + 1)
|
||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||
success_rate = (self.successful_requests / self.total_requests * 100) if self.total_requests > 0 else 0
|
||
|
||
# 格式化时间显示
|
||
def format_time(seconds):
|
||
if seconds < 60:
|
||
return f"{seconds:.1f}秒"
|
||
elif seconds < 3600:
|
||
minutes = seconds / 60
|
||
return f"{minutes:.1f}分钟"
|
||
else:
|
||
hours = seconds / 3600
|
||
return f"{hours:.1f}小时"
|
||
|
||
logger.info(f"已完成第 {index + 1} 个句子的处理")
|
||
logger.info(f" - 剩余句子数: {remaining_sentences}")
|
||
logger.info(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||
logger.info(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||
logger.info(f" - 已用时间: {format_time(elapsed_time)}")
|
||
logger.info(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
|
||
|
||
print(f"已完成第 {index + 1} 个句子的处理")
|
||
print(f" - 剩余句子数: {remaining_sentences}")
|
||
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||
print(f" - 已用时间: {format_time(elapsed_time)}")
|
||
print(f" - 成功率: {success_rate:.1f}% ({self.successful_requests}/{self.total_requests})")
|
||
|
||
return result
|
||
|
||
except asyncio.TimeoutError:
|
||
self.timeout_requests += 1
|
||
self.last_activity_time = time.time() # 更新活动时间
|
||
logger.warning(f"第 {index} 个句子处理超时 (尝试 {attempt + 1}/{self.max_retries}): {sentence[:50]}...")
|
||
if attempt == self.max_retries - 1:
|
||
logger.error(f"第 {index} 个句子最终超时,使用默认处理")
|
||
break
|
||
# 指数退避
|
||
await asyncio.sleep(2 ** attempt)
|
||
|
||
except Exception as e:
|
||
self.last_activity_time = time.time() # 更新活动时间
|
||
logger.error(f"处理第 {index} 个句子时出错 (尝试 {attempt + 1}/{self.max_retries}): {e}")
|
||
if attempt == self.max_retries - 1:
|
||
break
|
||
await asyncio.sleep(1)
|
||
|
||
# 所有重试都失败,使用默认处理
|
||
if not success:
|
||
self.failed_requests += 1
|
||
logger.warning(f"第 {index} 个句子使用默认处理: {sentence[:50]}...")
|
||
|
||
return {
|
||
"index": index,
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": sentence,
|
||
"importance_score": 5.0
|
||
}
|
||
|
||
async def heartbeat_monitor(self, total_sentences: int):
|
||
"""心跳监控,检测是否有长时间无响应"""
|
||
consecutive_warnings = 0
|
||
|
||
while True:
|
||
await asyncio.sleep(self.heartbeat_interval)
|
||
|
||
current_time = time.time()
|
||
time_since_last_success = current_time - self.last_successful_time
|
||
time_since_last_activity = current_time - self.last_activity_time
|
||
|
||
# 检查最后成功时间
|
||
if time_since_last_success > self.heartbeat_interval:
|
||
consecutive_warnings += 1
|
||
logger.warning(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
|
||
print(f"⚠️ 心跳检测 #{consecutive_warnings}:已有 {time_since_last_success:.1f} 秒没有成功的LLM响应")
|
||
|
||
# 打印当前统计信息
|
||
if self.total_requests > 0:
|
||
success_rate = self.successful_requests / self.total_requests * 100
|
||
logger.warning(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
|
||
print(f" 当前统计:总请求 {self.total_requests},成功 {self.successful_requests} ({success_rate:.1f}%),超时 {self.timeout_requests}")
|
||
|
||
if time_since_last_success > self.heartbeat_interval * 3:
|
||
logger.error(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!")
|
||
print(f"❌ 严重警告:LLM可能已卡死,超过 {time_since_last_success:.1f} 秒无成功响应!")
|
||
print(f" 建议:检查ollama服务状态,或考虑重启程序")
|
||
|
||
# 检查ollama服务状态
|
||
if not self.check_ollama_status():
|
||
logger.critical("💀 Ollama服务异常,这可能是卡死的原因!")
|
||
print("💀 Ollama服务异常,这可能是卡死的原因!")
|
||
|
||
if consecutive_warnings >= 5:
|
||
logger.critical(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
|
||
print(f"💀 致命错误:连续 {consecutive_warnings} 次心跳警告,可能需要人工干预")
|
||
else:
|
||
if consecutive_warnings > 0:
|
||
logger.info(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||
print(f"✅ 心跳恢复正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||
consecutive_warnings = 0
|
||
logger.debug(f"💓 心跳正常:最后成功时间 {time_since_last_success:.1f} 秒前")
|
||
|
||
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||
"""批量并发处理句子,每2000条保存一次检查点"""
|
||
logger.info(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...")
|
||
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:{self.max_concurrent})...")
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
total_sentences = len(sentences)
|
||
|
||
# 分批处理,每批1000个句子(减少批次大小)
|
||
batch_size = 1000
|
||
all_processed_sentences = []
|
||
|
||
# 启动心跳监控
|
||
heartbeat_task = asyncio.create_task(self.heartbeat_monitor(total_sentences))
|
||
|
||
try:
|
||
for batch_start in range(0, total_sentences, batch_size):
|
||
batch_end = min(batch_start + batch_size, total_sentences)
|
||
batch_sentences = sentences[batch_start:batch_end]
|
||
|
||
logger.info(f"=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||
|
||
# 创建信号量限制并发数(降低到8)
|
||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||
|
||
# 重置批次统计
|
||
batch_start_time = time.time()
|
||
self.total_requests = 0
|
||
self.successful_requests = 0
|
||
self.failed_requests = 0
|
||
self.timeout_requests = 0
|
||
|
||
# 创建当前批次的任务
|
||
tasks = []
|
||
for i, sentence in enumerate(batch_sentences):
|
||
global_index = batch_start + i
|
||
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
|
||
tasks.append(task)
|
||
|
||
# 并发执行当前批次的任务
|
||
logger.info(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||
|
||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理当前批次的结果,过滤异常
|
||
batch_processed_sentences = []
|
||
batch_error_count = 0
|
||
|
||
for result in batch_results:
|
||
if isinstance(result, Exception):
|
||
logger.error(f"任务执行异常: {result}")
|
||
print(f"任务执行异常: {result}")
|
||
batch_error_count += 1
|
||
elif isinstance(result, dict):
|
||
batch_processed_sentences.append(result)
|
||
else:
|
||
batch_error_count += 1
|
||
|
||
# 按原始顺序排序(因为并发执行可能改变顺序)
|
||
batch_processed_sentences.sort(key=lambda x: x['index'])
|
||
|
||
# 移除index字段
|
||
for item in batch_processed_sentences:
|
||
del item['index']
|
||
|
||
# 添加到总结果中
|
||
all_processed_sentences.extend(batch_processed_sentences)
|
||
|
||
# 保存检查点
|
||
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
|
||
|
||
# 打印当前批次统计信息
|
||
elapsed_time = time.time() - start_time
|
||
batch_time = time.time() - batch_start_time
|
||
completed_sentences = len(all_processed_sentences)
|
||
|
||
logger.info(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||
logger.info(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||
logger.info(f" - 批次用时:{batch_time/60:.1f}分钟")
|
||
logger.info(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
|
||
logger.info(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||
logger.info(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||
logger.info(f" - 检查点已保存:{checkpoint_filename}")
|
||
|
||
print(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||
print(f" - 批次用时:{batch_time/60:.1f}分钟")
|
||
print(f" - LLM统计:成功 {self.successful_requests},失败 {self.failed_requests},超时 {self.timeout_requests}")
|
||
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||
print(f" - 检查点已保存:{checkpoint_filename}")
|
||
|
||
if batch_end < total_sentences:
|
||
remaining_sentences = total_sentences - completed_sentences
|
||
avg_time_per_sentence = elapsed_time / completed_sentences
|
||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||
|
||
# 在批次之间稍作休息,避免过度压力
|
||
if batch_end < total_sentences:
|
||
logger.info("批次间休息5秒...")
|
||
await asyncio.sleep(5)
|
||
|
||
finally:
|
||
# 取消心跳监控
|
||
heartbeat_task.cancel()
|
||
try:
|
||
await heartbeat_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 打印最终统计信息
|
||
total_time = time.time() - start_time
|
||
logger.info(f"=== 全部处理完成!===")
|
||
logger.info(f" - 总成功:{len(all_processed_sentences)}")
|
||
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
|
||
logger.info(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
|
||
|
||
print(f"\n=== 全部处理完成!===")
|
||
print(f" - 总成功:{len(all_processed_sentences)}")
|
||
print(f" - 总用时:{total_time/60:.1f}分钟")
|
||
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
|
||
|
||
return all_processed_sentences
|
||
|
||
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
|
||
"""保存检查点文件"""
|
||
# 生成检查点文件名,确保在output目录中
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
checkpoint_filename = os.path.join('output', f"{base_name}_checkpoint_{current_count}.json")
|
||
|
||
# 保存检查点
|
||
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
||
json.dump({
|
||
"metadata": {
|
||
"total_processed": len(processed_sentences),
|
||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"checkpoint_number": current_count
|
||
},
|
||
"sentences": processed_sentences
|
||
}, f, ensure_ascii=False, indent=2)
|
||
|
||
return checkpoint_filename
|
||
|
||
async def process_files(self) -> List[Dict[str, Any]]:
|
||
"""处理所有文件"""
|
||
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||
|
||
if not json_files:
|
||
print(f"No JSON files found in {self.input_dir}")
|
||
return []
|
||
|
||
# 排序文件以确保一致的处理顺序
|
||
json_files.sort()
|
||
|
||
if self.max_files:
|
||
json_files = json_files[:self.max_files]
|
||
|
||
print(f"Found {len(json_files)} JSON files to process")
|
||
|
||
all_sentences = []
|
||
|
||
for i, file_path in enumerate(json_files):
|
||
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||
|
||
documents = self.parse_large_json_file(file_path)
|
||
print(f" Parsed {len(documents)} documents")
|
||
|
||
for doc in documents:
|
||
sentences = self.extract_sentences_from_document(doc)
|
||
all_sentences.extend(sentences)
|
||
|
||
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||
|
||
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||
|
||
# 去重
|
||
unique_sentences = []
|
||
seen = set()
|
||
for sentence in all_sentences:
|
||
sentence = sentence.strip()
|
||
if sentence and sentence not in seen and len(sentence) > 10:
|
||
unique_sentences.append(sentence)
|
||
seen.add(sentence)
|
||
|
||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||
|
||
# 保存原始句子到JSON文件
|
||
sentences_data = {
|
||
"metadata": {
|
||
"total_sentences": len(unique_sentences),
|
||
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"source_files": len(json_files),
|
||
"max_files_limit": self.max_files
|
||
},
|
||
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||
}
|
||
|
||
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||
print(f"总计句子数: {len(unique_sentences)}")
|
||
|
||
return unique_sentences
|
||
|
||
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
||
"""保存处理后的句子到文件"""
|
||
# 确保输出目录存在
|
||
os.makedirs('output', exist_ok=True)
|
||
|
||
# 保存为JSON格式,包含完整信息
|
||
json_output_file = self.output_file.replace('.txt', '.json')
|
||
with open(json_output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
|
||
|
||
# 保存为简单文本格式(仅修正后的句子)
|
||
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||
for item in processed_sentences:
|
||
f.write(item['corrected_sentence'] + '\n')
|
||
|
||
# 生成重要性排序文件
|
||
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
|
||
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
|
||
with open(importance_file, 'w', encoding='utf-8') as f:
|
||
for item in importance_sorted:
|
||
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
|
||
|
||
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
|
||
print(f" - JSON格式: {json_output_file}")
|
||
print(f" - 文本格式: {self.output_file}")
|
||
print(f" - 重要性排序: {importance_file}")
|
||
|
||
# 统计信息
|
||
scores = [item['importance_score'] for item in processed_sentences]
|
||
avg_score = sum(scores) / len(scores) if scores else 0
|
||
print(f" - 平均重要性评分: {avg_score:.2f}")
|
||
print(f" - 最高评分: {max(scores):.1f}")
|
||
print(f" - 最低评分: {min(scores):.1f}")
|
||
|
||
async def run(self):
|
||
"""运行处理流程"""
|
||
print("Starting enhanced TREx to sentences conversion...")
|
||
processed_sentences = await self.process_files()
|
||
self.save_sentences(processed_sentences)
|
||
print("Enhanced conversion completed!")
|
||
|
||
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
||
"""查找最新的检查点文件"""
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||
checkpoint_files = glob.glob(pattern)
|
||
|
||
if not checkpoint_files:
|
||
return None
|
||
|
||
# 按检查点编号排序,获取最新的
|
||
latest_file = None
|
||
latest_count = 0
|
||
|
||
for file in checkpoint_files:
|
||
try:
|
||
# 从文件名中提取数字
|
||
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||
if match:
|
||
count = int(match.group(1))
|
||
if count > latest_count:
|
||
latest_count = count
|
||
latest_file = file
|
||
except:
|
||
continue
|
||
|
||
if latest_file:
|
||
return latest_file, latest_count
|
||
else:
|
||
return None
|
||
|
||
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
|
||
"""从检查点文件加载已处理的句子"""
|
||
try:
|
||
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
if 'sentences' in data:
|
||
return data['sentences']
|
||
else:
|
||
# 旧格式的检查点文件
|
||
return data
|
||
except Exception as e:
|
||
print(f"加载检查点文件失败: {e}")
|
||
return []
|
||
|
||
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
|
||
"""从检查点文件中获取已处理过的句子集合"""
|
||
if not self.output_file:
|
||
return set()
|
||
|
||
processed_sentences = set()
|
||
|
||
# 查找所有检查点文件
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||
checkpoint_files = glob.glob(pattern)
|
||
|
||
if not checkpoint_files:
|
||
print("未找到检查点文件,将从头开始处理")
|
||
return set()
|
||
|
||
# 找到最新的检查点文件
|
||
latest_file = None
|
||
latest_count = 0
|
||
|
||
for file in checkpoint_files:
|
||
try:
|
||
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||
if match:
|
||
count = int(match.group(1))
|
||
if count > latest_count:
|
||
latest_count = count
|
||
latest_file = file
|
||
except:
|
||
continue
|
||
|
||
if latest_file:
|
||
print(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||
logger.info(f"找到最新检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||
try:
|
||
with open(latest_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
sentences_data = data.get('sentences', [])
|
||
for item in sentences_data:
|
||
original_sentence = item.get('original_sentence', '')
|
||
if original_sentence:
|
||
processed_sentences.add(original_sentence)
|
||
|
||
print(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||
logger.info(f"从检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取检查点文件失败: {e}")
|
||
return set()
|
||
|
||
return processed_sentences
|
||
|
||
async def process_with_llm(self):
|
||
"""步骤2:从JSON文件读取句子并进行LLM处理"""
|
||
if not self.enable_llm_processing:
|
||
print("Error: LLM processing is disabled!")
|
||
return
|
||
|
||
if not self.output_file:
|
||
print("Error: output_file is required for LLM processing!")
|
||
return
|
||
|
||
print("=== 步骤2:LLM处理 ===")
|
||
|
||
# 读取句子JSON文件
|
||
if not os.path.exists(self.sentences_json):
|
||
print(f"Error: Sentences file {self.sentences_json} not found!")
|
||
print("请先运行步骤1进行句子提取")
|
||
return
|
||
|
||
print(f"正在读取句子文件: {self.sentences_json}")
|
||
|
||
try:
|
||
with open(self.sentences_json, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
|
||
print(f"从文件中读取了 {len(all_sentences)} 个句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取句子文件失败: {e}")
|
||
return
|
||
|
||
# 获取已处理的句子
|
||
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
|
||
|
||
# 过滤出未处理的句子
|
||
unprocessed_sentences = []
|
||
for sentence in all_sentences:
|
||
if sentence not in processed_sentences_set:
|
||
unprocessed_sentences.append(sentence)
|
||
|
||
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||
|
||
if not unprocessed_sentences:
|
||
print("所有句子都已处理完成!")
|
||
|
||
# 如果有检查点,直接从最新检查点生成最终文件
|
||
if processed_sentences_set:
|
||
latest_checkpoint = self.find_latest_checkpoint()
|
||
if latest_checkpoint:
|
||
checkpoint_file, _ = latest_checkpoint
|
||
processed_data = self.load_checkpoint(checkpoint_file)
|
||
self.save_sentences(processed_data)
|
||
print("已从检查点生成最终输出文件")
|
||
return
|
||
|
||
# 处理未处理的句子
|
||
print("开始LLM处理...")
|
||
|
||
# 检查ollama服务状态
|
||
logger.info("检查Ollama服务状态...")
|
||
if not self.check_ollama_status():
|
||
logger.error("Ollama服务状态异常,无法继续处理")
|
||
print("错误:Ollama服务状态异常,请检查服务是否正常运行")
|
||
return
|
||
|
||
new_processed_sentences = await self.process_sentences_with_llm(unprocessed_sentences)
|
||
|
||
# 如果有之前的处理结果,合并它们
|
||
if processed_sentences_set:
|
||
latest_checkpoint = self.find_latest_checkpoint()
|
||
if latest_checkpoint:
|
||
checkpoint_file, _ = latest_checkpoint
|
||
previous_processed = self.load_checkpoint(checkpoint_file)
|
||
|
||
# 合并结果
|
||
all_processed_sentences = previous_processed + new_processed_sentences
|
||
print(f"合并了之前的 {len(previous_processed)} 条和新处理的 {len(new_processed_sentences)} 条记录")
|
||
else:
|
||
all_processed_sentences = new_processed_sentences
|
||
else:
|
||
all_processed_sentences = new_processed_sentences
|
||
|
||
# 保存最终结果
|
||
self.save_sentences(all_processed_sentences)
|
||
print("LLM处理完成!")
|
||
|
||
# ==================== 新增:句子提取功能 ====================
|
||
|
||
def extract_sentences(self):
|
||
"""步骤1:从TREx数据集提取句子并保存为JSON"""
|
||
if not self.input_dir:
|
||
print("Error: input_dir is required for sentence extraction!")
|
||
return
|
||
|
||
print("=== 步骤1:句子提取 ===")
|
||
print("开始从TREx数据集提取句子...")
|
||
|
||
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||
|
||
if not json_files:
|
||
print(f"No JSON files found in {self.input_dir}")
|
||
return
|
||
|
||
# 排序文件以确保一致的处理顺序
|
||
json_files.sort()
|
||
|
||
if self.max_files:
|
||
json_files = json_files[:self.max_files]
|
||
|
||
print(f"Found {len(json_files)} JSON files to process")
|
||
|
||
all_sentences = []
|
||
|
||
for i, file_path in enumerate(json_files):
|
||
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||
|
||
documents = self.parse_large_json_file(file_path)
|
||
print(f" Parsed {len(documents)} documents")
|
||
|
||
for doc in documents:
|
||
sentences = self.extract_sentences_from_document(doc)
|
||
all_sentences.extend(sentences)
|
||
|
||
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||
|
||
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||
|
||
# 去重
|
||
unique_sentences = []
|
||
seen = set()
|
||
for sentence in all_sentences:
|
||
sentence = sentence.strip()
|
||
if sentence and sentence not in seen and len(sentence) > 10:
|
||
unique_sentences.append(sentence)
|
||
seen.add(sentence)
|
||
|
||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||
|
||
# 保存原始句子到JSON文件
|
||
sentences_data = {
|
||
"metadata": {
|
||
"total_sentences": len(unique_sentences),
|
||
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"source_files": len(json_files),
|
||
"max_files_limit": self.max_files
|
||
},
|
||
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||
}
|
||
|
||
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||
print(f"总计句子数: {len(unique_sentences)}")
|
||
|
||
return unique_sentences
|
||
|
||
def check_ollama_status(self) -> bool:
|
||
"""检查ollama服务是否正常运行"""
|
||
try:
|
||
# 检查ollama进程是否运行
|
||
result = subprocess.run(['pgrep', 'ollama'], capture_output=True, text=True)
|
||
if result.returncode != 0:
|
||
logger.error("Ollama进程未运行")
|
||
return False
|
||
|
||
# 检查ollama API是否响应
|
||
response = requests.get('http://localhost:11434/api/tags', timeout=5)
|
||
if response.status_code == 200:
|
||
logger.info("Ollama服务状态正常")
|
||
return True
|
||
else:
|
||
logger.error(f"Ollama API响应异常,状态码: {response.status_code}")
|
||
return False
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"无法连接到Ollama API: {e}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"检查Ollama状态时出错: {e}")
|
||
return False
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
|
||
|
||
# 选择运行模式
|
||
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
|
||
help='运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程')
|
||
|
||
# 文件路径参数
|
||
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
||
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
|
||
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
|
||
|
||
# 处理参数
|
||
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
||
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 根据步骤验证参数
|
||
if args.step in ['extract', 'all']:
|
||
if not os.path.exists(args.input_dir):
|
||
print(f"Error: Input directory {args.input_dir} does not exist!")
|
||
return
|
||
|
||
if args.step in ['llm', 'all']:
|
||
if args.no_llm:
|
||
print("Error: Cannot run LLM step with --no_llm flag!")
|
||
return
|
||
|
||
# 创建处理器
|
||
processor = EnhancedTRExProcessor(
|
||
input_dir=args.input_dir,
|
||
sentences_json=args.sentences_json,
|
||
output_file=args.output_file,
|
||
max_files=args.max_files,
|
||
enable_llm_processing=not args.no_llm
|
||
)
|
||
|
||
# 根据选择的步骤运行
|
||
if args.step == 'extract':
|
||
print("=== 运行模式:仅句子提取 ===")
|
||
processor.extract_sentences()
|
||
|
||
elif args.step == 'llm':
|
||
print("=== 运行模式:仅LLM处理 ===")
|
||
asyncio.run(processor.process_with_llm())
|
||
|
||
elif args.step == 'all':
|
||
print("=== 运行模式:完整流程 ===")
|
||
|
||
# 步骤1:提取句子
|
||
print("\n--- 开始步骤1:句子提取 ---")
|
||
sentences = processor.extract_sentences()
|
||
|
||
if not sentences:
|
||
print("句子提取失败,退出")
|
||
return
|
||
|
||
if args.no_llm:
|
||
print("LLM处理已禁用,流程结束")
|
||
return
|
||
|
||
# 步骤2:LLM处理
|
||
print("\n--- 开始步骤2:LLM处理 ---")
|
||
asyncio.run(processor.process_with_llm())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |