1238 lines
55 KiB
Python
1238 lines
55 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
TREx数据集增强预处理脚本
|
||
使用vLLM OpenAI兼容API进行句子后处理和重要性评分
|
||
|
||
支持两个独立步骤:
|
||
1. 句子提取:从TREx数据集提取句子并保存为JSON
|
||
2. LLM处理:读取JSON文件进行LLM后处理和重要性评分
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import glob
|
||
from typing import List, Dict, Any, Union, Set
|
||
import re
|
||
import asyncio
|
||
import time
|
||
import logging
|
||
from datetime import datetime
|
||
import requests
|
||
from pydantic import BaseModel, Field
|
||
import aiohttp
|
||
import concurrent.futures
|
||
|
||
# 设置日志系统
|
||
def setup_logging():
|
||
"""设置日志系统"""
|
||
# 确保logs目录存在
|
||
os.makedirs('logs', exist_ok=True)
|
||
|
||
# 创建日志文件名(包含时间戳)
|
||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
log_file = f'logs/trex_processor_{timestamp}.log'
|
||
|
||
# 配置日志格式
|
||
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
|
||
|
||
# 配置root logger
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format=log_format,
|
||
handlers=[
|
||
logging.FileHandler(log_file, encoding='utf-8'),
|
||
logging.StreamHandler() # 同时输出到控制台
|
||
]
|
||
)
|
||
|
||
# 获取logger
|
||
logger = logging.getLogger(__name__)
|
||
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
|
||
return logger
|
||
|
||
# 全局日志对象
|
||
logger = setup_logging()
|
||
|
||
class ProcessedSentence(BaseModel):
|
||
"""处理后的句子结构"""
|
||
corrected_sentence: str = Field(
|
||
...,
|
||
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
|
||
)
|
||
importance_score: float = Field(
|
||
...,
|
||
description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
|
||
ge=0.0,
|
||
le=10.0
|
||
)
|
||
|
||
|
||
class EnhancedTRExProcessor:
|
||
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
|
||
sentences_json: str = None, enable_llm_processing: bool = True):
|
||
self.input_dir = input_dir
|
||
|
||
# 确保output目录存在
|
||
os.makedirs('output', exist_ok=True)
|
||
|
||
# 确保所有输出文件都在output目录中
|
||
if output_file:
|
||
if not output_file.startswith('output/'):
|
||
self.output_file = os.path.join('output', output_file)
|
||
else:
|
||
self.output_file = output_file
|
||
else:
|
||
self.output_file = None
|
||
|
||
if sentences_json:
|
||
if not sentences_json.startswith('output/'):
|
||
self.sentences_json = os.path.join('output', sentences_json)
|
||
else:
|
||
self.sentences_json = sentences_json
|
||
else:
|
||
self.sentences_json = "output/extracted_sentences.json"
|
||
|
||
self.max_files = max_files
|
||
self.enable_llm_processing = enable_llm_processing
|
||
|
||
# Ollama API配置
|
||
self.model_name = "gemma3:latest" # Ollama模型名称
|
||
self.ollama_base_url = "http://localhost:11434" # Ollama服务器地址
|
||
self.batch_size_per_request = 8 # 每个API请求处理的句子数量(Ollama建议较小批次)
|
||
self.max_concurrent_requests = 2 # 最大并发请求数(Ollama建议较低并发)
|
||
self.request_timeout = 180 # 请求超时时间(秒)
|
||
self.retry_attempts = 3 # 重试次数
|
||
|
||
# 统计信息
|
||
self.total_requests = 0
|
||
self.successful_requests = 0
|
||
self.failed_requests = 0
|
||
|
||
logger.info(f"处理器初始化完成 - 模型: {self.model_name}, 批次大小: {self.batch_size_per_request}, 并发数: {self.max_concurrent_requests}")
|
||
|
||
# 扩展的Wikidata属性映射
|
||
self.property_mappings = {
|
||
# 基本关系
|
||
"http://www.wikidata.org/prop/direct/P31": "is a",
|
||
"http://www.wikidata.org/prop/direct/P279": "is a type of",
|
||
|
||
# 人物相关
|
||
"http://www.wikidata.org/prop/direct/P106": "works as",
|
||
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
|
||
"http://www.wikidata.org/prop/direct/P19": "was born in",
|
||
"http://www.wikidata.org/prop/direct/P20": "died in",
|
||
"http://www.wikidata.org/prop/direct/P569": "was born on",
|
||
"http://www.wikidata.org/prop/direct/P570": "died on",
|
||
"http://www.wikidata.org/prop/direct/P22": "has father",
|
||
"http://www.wikidata.org/prop/direct/P25": "has mother",
|
||
"http://www.wikidata.org/prop/direct/P26": "is married to",
|
||
|
||
# 组织相关
|
||
"http://www.wikidata.org/prop/direct/P102": "is a member of",
|
||
"http://www.wikidata.org/prop/direct/P108": "works for",
|
||
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
|
||
"http://www.wikidata.org/prop/direct/P112": "was founded by",
|
||
"http://www.wikidata.org/prop/direct/P571": "was founded in",
|
||
"http://www.wikidata.org/prop/direct/P169": "has CEO",
|
||
|
||
# 地理相关
|
||
"http://www.wikidata.org/prop/direct/P17": "is located in",
|
||
"http://www.wikidata.org/prop/direct/P131": "is located in",
|
||
"http://www.wikidata.org/prop/direct/P36": "has capital",
|
||
"http://www.wikidata.org/prop/direct/P47": "borders",
|
||
|
||
# 其他关系
|
||
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
|
||
"http://www.wikidata.org/prop/direct/P361": "is part of",
|
||
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
|
||
"http://www.wikidata.org/prop/direct/P127": "is owned by",
|
||
"http://www.wikidata.org/prop/direct/P155": "follows",
|
||
"http://www.wikidata.org/prop/direct/P156": "is followed by",
|
||
"http://www.wikidata.org/prop/direct/P138": "is named after"
|
||
}
|
||
|
||
def get_system_prompt(self) -> str:
|
||
"""获取系统提示"""
|
||
return """You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.
|
||
|
||
### Sentence Correction Rules:
|
||
1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses
|
||
2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.
|
||
3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage
|
||
4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues
|
||
5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent
|
||
6. Do not add information not present in the original text, only correct errors
|
||
|
||
### Correction Examples:
|
||
- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'
|
||
- Corrected: 'Argument is related to philosophy and logic.'
|
||
|
||
- Error: 'Beijing is a capital city and are.'
|
||
- Corrected: 'Beijing is a capital city.'
|
||
|
||
Importance scoring criteria (0.0-10.0, in increments of 0.1):
|
||
|
||
0.0 points - Completely incorrect or meaningless information
|
||
Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'
|
||
|
||
0.5 points - Almost worthless information
|
||
Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'
|
||
|
||
1.0 points - Extremely rare, non-practical knowledge
|
||
Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'
|
||
|
||
1.5 points - Very niche detailed information
|
||
Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'
|
||
|
||
2.0 points - Details in niche professional fields
|
||
Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'
|
||
|
||
2.5 points - Technical details only professionals care about
|
||
Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'
|
||
|
||
3.0 points - Professional knowledge in specific fields
|
||
Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'
|
||
|
||
3.5 points - Professional information with some value
|
||
Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'
|
||
|
||
4.0 points - Meaningful knowledge known by few
|
||
Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'
|
||
|
||
4.5 points - Knowledge of interest to some groups
|
||
Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'
|
||
|
||
5.0 points - General knowledge of moderate importance
|
||
Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'
|
||
|
||
5.5 points - Fairly useful common sense
|
||
Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'
|
||
|
||
6.0 points - Knowledge most educated people should know
|
||
Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'
|
||
|
||
6.5 points - Important cultural or scientific common sense
|
||
Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'
|
||
|
||
7.0 points - Important foundational knowledge
|
||
Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'
|
||
|
||
7.5 points - Very important common sense
|
||
Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'
|
||
|
||
8.0 points - Core knowledge in basic education
|
||
Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'
|
||
|
||
8.5 points - Important knowledge everyone should master
|
||
Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'
|
||
|
||
9.0 points - Extremely important basic concepts
|
||
Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'
|
||
|
||
9.5 points - Core knowledge everyone must know
|
||
Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'
|
||
|
||
10.0 points - Most basic and important common sense
|
||
Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'
|
||
|
||
When scoring, please consider:
|
||
1. Popularity of knowledge - How many people know this knowledge
|
||
2. Practical value - How useful this knowledge is in daily life
|
||
3. Educational importance - The position of this knowledge in the education system
|
||
4. Cultural significance - The importance of this knowledge for understanding world
|
||
|
||
Please respond with valid JSON in the following format:
|
||
{
|
||
"corrected_sentence": "corrected sentence here",
|
||
"importance_score": evaluation score
|
||
}"""
|
||
|
||
async def process_batch_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||
"""使用vLLM OpenAI兼容API处理一批句子"""
|
||
processed_sentences = []
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
# 创建并发任务
|
||
semaphore = asyncio.Semaphore(self.max_concurrent_requests)
|
||
tasks = []
|
||
|
||
# 将句子分成小批次
|
||
for i in range(0, len(sentences), self.batch_size_per_request):
|
||
batch_sentences = sentences[i:i + self.batch_size_per_request]
|
||
task = self.process_single_batch_request(session, semaphore, batch_sentences, i)
|
||
tasks.append(task)
|
||
|
||
# 等待所有任务完成
|
||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 收集结果
|
||
for result in batch_results:
|
||
if isinstance(result, Exception):
|
||
logger.error(f"批次处理出错: {result}")
|
||
continue
|
||
if result:
|
||
processed_sentences.extend(result)
|
||
|
||
return processed_sentences
|
||
|
||
async def process_single_batch_request(self, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore,
|
||
sentences: List[str], batch_index: int) -> List[Dict[str, Any]]:
|
||
"""处理单个批次的API请求"""
|
||
async with semaphore:
|
||
for attempt in range(self.retry_attempts):
|
||
try:
|
||
# 为每个句子创建单独的消息
|
||
messages = []
|
||
for sentence in sentences:
|
||
messages.append({
|
||
"role": "user",
|
||
"content": f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
|
||
})
|
||
|
||
# 构建Ollama请求数据
|
||
request_data = {
|
||
"model": self.model_name,
|
||
"messages": [
|
||
{"role": "system", "content": self.get_system_prompt()}
|
||
] + messages,
|
||
"stream": False,
|
||
"options": {
|
||
"temperature": 0.2,
|
||
"num_predict": 500 * len(sentences) # 为每个句子分配足够的token
|
||
},
|
||
"format": "json" # Ollama的JSON格式参数
|
||
}
|
||
|
||
# 发送请求到Ollama
|
||
async with session.post(
|
||
f'{self.ollama_base_url}/api/chat',
|
||
json=request_data,
|
||
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
|
||
) as response:
|
||
|
||
if response.status == 200:
|
||
result = await response.json()
|
||
return self.parse_ollama_response(result, sentences, batch_index)
|
||
else:
|
||
error_text = await response.text()
|
||
logger.warning(f"API请求失败 (批次 {batch_index}, 尝试 {attempt + 1}/{self.retry_attempts}): {response.status} - {error_text}")
|
||
|
||
if attempt == self.retry_attempts - 1: # 最后一次尝试
|
||
logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然失败")
|
||
self.failed_requests += len(sentences)
|
||
return self.create_default_responses(sentences)
|
||
else:
|
||
# 等待后重试
|
||
await asyncio.sleep(2 ** attempt) # 指数退避
|
||
continue
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.warning(f"批次 {batch_index} 请求超时 (尝试 {attempt + 1}/{self.retry_attempts})")
|
||
if attempt == self.retry_attempts - 1:
|
||
logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然超时")
|
||
self.failed_requests += len(sentences)
|
||
return self.create_default_responses(sentences)
|
||
else:
|
||
await asyncio.sleep(2 ** attempt)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.warning(f"处理批次 {batch_index} 时出错 (尝试 {attempt + 1}/{self.retry_attempts}): {e}")
|
||
if attempt == self.retry_attempts - 1:
|
||
logger.error(f"批次 {batch_index} 在 {self.retry_attempts} 次尝试后仍然失败")
|
||
self.failed_requests += len(sentences)
|
||
return self.create_default_responses(sentences)
|
||
else:
|
||
await asyncio.sleep(2 ** attempt)
|
||
continue
|
||
|
||
# 如果所有重试都失败了
|
||
return self.create_default_responses(sentences)
|
||
|
||
def parse_ollama_response(self, response: Dict[str, Any], original_sentences: List[str], batch_index: int) -> List[Dict[str, Any]]:
|
||
"""解析Ollama响应"""
|
||
processed_sentences = []
|
||
|
||
try:
|
||
# Ollama的响应格式
|
||
message = response.get('message', {})
|
||
content = message.get('content', '')
|
||
|
||
if not content:
|
||
logger.warning(f"批次 {batch_index} 没有返回内容")
|
||
return self.create_default_responses(original_sentences)
|
||
|
||
# 尝试解析JSON响应
|
||
try:
|
||
# 如果返回的是单个JSON对象
|
||
if content.strip().startswith('{') and content.strip().endswith('}'):
|
||
response_data = json.loads(content)
|
||
processed_sentence = ProcessedSentence(
|
||
corrected_sentence=response_data.get('corrected_sentence', original_sentences[0] if original_sentences else ""),
|
||
importance_score=float(response_data.get('importance_score', 5.0))
|
||
)
|
||
|
||
processed_sentences.append({
|
||
"original_sentence": original_sentences[0] if original_sentences else "",
|
||
"corrected_sentence": processed_sentence.corrected_sentence,
|
||
"importance_score": processed_sentence.importance_score
|
||
})
|
||
self.successful_requests += 1
|
||
|
||
# 如果有多个句子但只返回一个结果,为其他句子创建默认响应
|
||
for i in range(1, len(original_sentences)):
|
||
processed_sentences.append({
|
||
"original_sentence": original_sentences[i],
|
||
"corrected_sentence": original_sentences[i],
|
||
"importance_score": 5.0
|
||
})
|
||
self.failed_requests += 1
|
||
|
||
else:
|
||
# 尝试解析多个JSON对象
|
||
json_objects = []
|
||
for line in content.split('\n'):
|
||
line = line.strip()
|
||
if line.startswith('{') and line.endswith('}'):
|
||
try:
|
||
json_objects.append(json.loads(line))
|
||
except:
|
||
continue
|
||
|
||
if json_objects:
|
||
for i, (sentence, json_obj) in enumerate(zip(original_sentences, json_objects)):
|
||
try:
|
||
processed_sentence = ProcessedSentence(
|
||
corrected_sentence=json_obj.get('corrected_sentence', sentence),
|
||
importance_score=float(json_obj.get('importance_score', 5.0))
|
||
)
|
||
|
||
processed_sentences.append({
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": processed_sentence.corrected_sentence,
|
||
"importance_score": processed_sentence.importance_score
|
||
})
|
||
self.successful_requests += 1
|
||
except Exception as e:
|
||
logger.warning(f"解析JSON对象失败: {e}")
|
||
processed_sentences.append({
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": sentence,
|
||
"importance_score": 5.0
|
||
})
|
||
self.failed_requests += 1
|
||
|
||
# 为剩余句子创建默认响应
|
||
for i in range(len(json_objects), len(original_sentences)):
|
||
processed_sentences.append({
|
||
"original_sentence": original_sentences[i],
|
||
"corrected_sentence": original_sentences[i],
|
||
"importance_score": 5.0
|
||
})
|
||
self.failed_requests += 1
|
||
else:
|
||
logger.warning(f"批次 {batch_index} 无法解析JSON响应: {content}")
|
||
return self.create_default_responses(original_sentences)
|
||
|
||
except (json.JSONDecodeError, ValueError) as e:
|
||
logger.warning(f"批次 {batch_index} 解析响应JSON失败: {e}")
|
||
logger.warning(f"原始内容: {content}")
|
||
return self.create_default_responses(original_sentences)
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析批次 {batch_index} 响应时出错: {e}")
|
||
return self.create_default_responses(original_sentences)
|
||
|
||
return processed_sentences
|
||
|
||
def create_default_responses(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||
"""为失败的请求创建默认响应"""
|
||
default_responses = []
|
||
for sentence in sentences:
|
||
default_responses.append({
|
||
"original_sentence": sentence,
|
||
"corrected_sentence": sentence,
|
||
"importance_score": 5.0
|
||
})
|
||
return default_responses
|
||
|
||
async def process_sentences_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||
"""使用Ollama API处理句子"""
|
||
logger.info(f"开始使用Ollama API处理 {len(sentences)} 个句子...")
|
||
print(f"开始使用Ollama API处理 {len(sentences)} 个句子...")
|
||
|
||
start_time = time.time()
|
||
total_sentences = len(sentences)
|
||
total_processed_count = 0
|
||
|
||
# 检查Ollama服务状态
|
||
if not self.check_ollama_status():
|
||
logger.error("Ollama服务状态异常,无法继续处理")
|
||
print("错误:Ollama服务状态异常,请检查服务是否正常运行")
|
||
return []
|
||
|
||
# 分大批次处理(用于检查点保存)
|
||
large_batch_size = 1000 # 每1000个句子保存一次检查点
|
||
all_processed_sentences = []
|
||
|
||
for large_batch_start in range(0, total_sentences, large_batch_size):
|
||
large_batch_end = min(large_batch_start + large_batch_size, total_sentences)
|
||
large_batch_sentences = sentences[large_batch_start:large_batch_end]
|
||
large_batch_number = large_batch_start // large_batch_size + 1
|
||
|
||
logger.info(f"=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===")
|
||
print(f"\n=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===")
|
||
|
||
large_batch_start_time = time.time()
|
||
|
||
# 处理当前大批次
|
||
batch_processed = await self.process_batch_with_vllm_api(large_batch_sentences)
|
||
all_processed_sentences.extend(batch_processed)
|
||
total_processed_count += len(batch_processed)
|
||
|
||
# 保存当前大批次的检查点
|
||
checkpoint_filename = self.save_batch_checkpoint(batch_processed, large_batch_number, total_processed_count)
|
||
|
||
# 打印进度
|
||
large_batch_time = time.time() - large_batch_start_time
|
||
elapsed_time = time.time() - start_time
|
||
|
||
logger.info(f"大批次 {large_batch_number} 处理完成!")
|
||
logger.info(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟")
|
||
logger.info(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)")
|
||
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||
logger.info(f" - 批次检查点已保存:{checkpoint_filename}")
|
||
|
||
print(f"大批次 {large_batch_number} 处理完成!")
|
||
print(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟")
|
||
print(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)")
|
||
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||
print(f" - 批次检查点已保存:{checkpoint_filename}")
|
||
|
||
if large_batch_end < total_sentences:
|
||
remaining_sentences = total_sentences - total_processed_count
|
||
avg_time_per_sentence = elapsed_time / total_processed_count
|
||
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||
|
||
# 打印最终统计
|
||
total_time = time.time() - start_time
|
||
logger.info(f"=== 全部处理完成!===")
|
||
logger.info(f" - 总成功:{self.successful_requests}")
|
||
logger.info(f" - 总失败:{self.failed_requests}")
|
||
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
|
||
logger.info(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒")
|
||
|
||
print(f"\n=== 全部处理完成!===")
|
||
print(f" - 总成功:{self.successful_requests}")
|
||
print(f" - 总失败:{self.failed_requests}")
|
||
print(f" - 总用时:{total_time/60:.1f}分钟")
|
||
print(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒")
|
||
|
||
return all_processed_sentences
|
||
|
||
def check_ollama_status(self) -> bool:
|
||
"""检查Ollama服务是否正常运行"""
|
||
try:
|
||
# 检查Ollama API是否响应
|
||
response = requests.get(f'{self.ollama_base_url}/api/tags', timeout=10)
|
||
|
||
if response.status_code == 200:
|
||
models = response.json()
|
||
model_names = [model.get('name', 'unknown') for model in models.get('models', [])]
|
||
logger.info(f"Ollama服务状态正常,可用模型: {model_names}")
|
||
|
||
# 检查目标模型是否可用
|
||
if self.model_name in model_names:
|
||
logger.info(f"目标模型 {self.model_name} 可用")
|
||
return True
|
||
else:
|
||
logger.warning(f"目标模型 {self.model_name} 不在可用模型列表中: {model_names}")
|
||
logger.info("尝试拉取模型...")
|
||
# 尝试拉取模型
|
||
try:
|
||
pull_response = requests.post(
|
||
f'{self.ollama_base_url}/api/pull',
|
||
json={"name": self.model_name},
|
||
timeout=300 # 5分钟超时
|
||
)
|
||
if pull_response.status_code == 200:
|
||
logger.info(f"成功拉取模型 {self.model_name}")
|
||
return True
|
||
else:
|
||
logger.error(f"拉取模型失败: {pull_response.status_code}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"拉取模型时出错: {e}")
|
||
return False
|
||
else:
|
||
logger.error(f"Ollama API响应异常,状态码: {response.status_code}")
|
||
return False
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
logger.error(f"无法连接到Ollama API: {e}")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"检查Ollama状态时出错: {e}")
|
||
return False
|
||
|
||
def clean_text(self, text: str) -> str:
|
||
"""清理文本,处理特殊字符"""
|
||
if not text:
|
||
return ""
|
||
|
||
# 处理常见的Unicode字符
|
||
text = text.replace("–", "-") # en dash
|
||
text = text.replace("—", "-") # em dash
|
||
text = text.replace("'", "'") # right single quotation mark
|
||
text = text.replace("'", "'") # left single quotation mark
|
||
text = text.replace(""", '"') # left double quotation mark
|
||
text = text.replace(""", '"') # right double quotation mark
|
||
|
||
# 处理可能的转义序列
|
||
try:
|
||
text = text.encode('utf-8').decode('utf-8')
|
||
except:
|
||
pass
|
||
|
||
# 清理多余的空格
|
||
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
||
# 移除可能的引号
|
||
text = text.strip('"\'')
|
||
|
||
return text
|
||
|
||
def parse_large_json_file(self, file_path: str) -> List[Dict]:
|
||
"""解析大型JSON文件,处理可能的格式问题"""
|
||
documents = []
|
||
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
content = f.read().strip()
|
||
|
||
# 尝试不同的解析方法
|
||
if content.startswith('[') and content.endswith(']'):
|
||
# 标准JSON数组
|
||
documents = json.loads(content)
|
||
else:
|
||
# 可能是连续的JSON对象
|
||
# 尝试在}{"之间分割
|
||
if '}{"' in content:
|
||
json_strings = content.split('}{')
|
||
json_strings[0] += '}' # 第一个对象
|
||
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
|
||
|
||
for i in range(1, len(json_strings) - 1):
|
||
json_strings[i] = '{' + json_strings[i] + '}'
|
||
|
||
for json_str in json_strings:
|
||
try:
|
||
doc = json.loads(json_str)
|
||
documents.append(doc)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
else:
|
||
# 尝试作为单个JSON对象
|
||
try:
|
||
documents = [json.loads(content)]
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
except Exception as e:
|
||
print(f"Error parsing {file_path}: {e}")
|
||
|
||
return documents
|
||
|
||
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
|
||
"""从文档中提取句子"""
|
||
sentences = []
|
||
|
||
title = self.clean_text(doc.get('title', ''))
|
||
text = self.clean_text(doc.get('text', ''))
|
||
entities = doc.get('entities', [])
|
||
triples = doc.get('triples', [])
|
||
|
||
# 处理显式三元组
|
||
for triple in triples:
|
||
sentence = self.triple_to_sentence(triple)
|
||
if sentence:
|
||
sentences.append(sentence)
|
||
|
||
# 从实体和文本中生成基本句子(如果三元组句子不够)
|
||
if title and text and len(sentences) < 5:
|
||
# 基于标题和实体生成句子
|
||
entity_names = []
|
||
for entity in entities[:15]:
|
||
entity_name = self.clean_text(entity.get('surfaceform', ''))
|
||
if entity_name and len(entity_name) > 2:
|
||
entity_names.append(entity_name)
|
||
|
||
# 生成简单的描述句子
|
||
if entity_names:
|
||
important_entities = []
|
||
title_lower = title.lower()
|
||
for entity in entity_names:
|
||
if (entity.lower() != title_lower and
|
||
entity not in important_entities and
|
||
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
|
||
important_entities.append(entity)
|
||
if len(important_entities) >= 6:
|
||
break
|
||
|
||
if important_entities and len(sentences) < 3:
|
||
entities_str = ', '.join(important_entities[:3])
|
||
sentences.append(f"{title} is related to {entities_str}.")
|
||
|
||
return sentences
|
||
|
||
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
|
||
"""将三元组转换为自然语言句子"""
|
||
try:
|
||
subject = triple.get('subject', {})
|
||
predicate = triple.get('predicate', {})
|
||
obj = triple.get('object', {})
|
||
|
||
subject_name = self.clean_text(subject.get('surfaceform', ''))
|
||
object_name = self.clean_text(obj.get('surfaceform', ''))
|
||
predicate_uri = predicate.get('uri', '')
|
||
|
||
# 检查是否有有效的主语和宾语
|
||
if not subject_name or not object_name:
|
||
return ""
|
||
|
||
# 检查主语和宾语是否过短或无意义
|
||
if len(subject_name) <= 2 or len(object_name) <= 2:
|
||
return ""
|
||
|
||
# 获取关系文本
|
||
relation_text = self.property_mappings.get(predicate_uri, "is related to")
|
||
|
||
# 避免重复的主语宾语
|
||
if subject_name.lower() == object_name.lower():
|
||
return ""
|
||
|
||
return f"{subject_name} {relation_text} {object_name}."
|
||
|
||
except Exception as e:
|
||
print(f"Error converting triple to sentence: {e}")
|
||
return ""
|
||
|
||
def save_batch_checkpoint(self, processed_sentences: List[Dict[str, Any]], batch_number: int, total_processed_count: int) -> str:
|
||
"""保存当前批次的检查点文件"""
|
||
# 生成检查点文件名,确保在output目录中
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
checkpoint_filename = os.path.join('output', f"{base_name}_batch_{batch_number}.json")
|
||
|
||
# 保存检查点
|
||
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
||
json.dump({
|
||
"metadata": {
|
||
"batch_number": batch_number,
|
||
"batch_size": len(processed_sentences),
|
||
"total_processed_count": total_processed_count,
|
||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||
},
|
||
"sentences": processed_sentences
|
||
}, f, ensure_ascii=False, indent=2)
|
||
|
||
return checkpoint_filename
|
||
|
||
async def process_files(self) -> List[Dict[str, Any]]:
|
||
"""处理所有文件"""
|
||
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||
|
||
if not json_files:
|
||
print(f"No JSON files found in {self.input_dir}")
|
||
return []
|
||
|
||
# 排序文件以确保一致的处理顺序
|
||
json_files.sort()
|
||
|
||
if self.max_files:
|
||
json_files = json_files[:self.max_files]
|
||
|
||
print(f"Found {len(json_files)} JSON files to process")
|
||
|
||
all_sentences = []
|
||
|
||
for i, file_path in enumerate(json_files):
|
||
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||
|
||
documents = self.parse_large_json_file(file_path)
|
||
print(f" Parsed {len(documents)} documents")
|
||
|
||
for doc in documents:
|
||
sentences = self.extract_sentences_from_document(doc)
|
||
all_sentences.extend(sentences)
|
||
|
||
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||
|
||
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||
|
||
# 去重
|
||
unique_sentences = []
|
||
seen = set()
|
||
for sentence in all_sentences:
|
||
sentence = sentence.strip()
|
||
if sentence and sentence not in seen and len(sentence) > 10:
|
||
unique_sentences.append(sentence)
|
||
seen.add(sentence)
|
||
|
||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||
|
||
# 保存原始句子到JSON文件
|
||
sentences_data = {
|
||
"metadata": {
|
||
"total_sentences": len(unique_sentences),
|
||
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"source_files": len(json_files),
|
||
"max_files_limit": self.max_files
|
||
},
|
||
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||
}
|
||
|
||
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||
print(f"总计句子数: {len(unique_sentences)}")
|
||
|
||
return unique_sentences
|
||
|
||
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
||
"""保存处理后的句子到文件"""
|
||
# 确保输出目录存在
|
||
os.makedirs('output', exist_ok=True)
|
||
|
||
# 保存为JSON格式,包含完整信息
|
||
json_output_file = self.output_file.replace('.txt', '.json')
|
||
with open(json_output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
|
||
|
||
# 保存为简单文本格式(仅修正后的句子)
|
||
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||
for item in processed_sentences:
|
||
f.write(item['corrected_sentence'] + '\n')
|
||
|
||
# 生成重要性排序文件
|
||
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
|
||
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
|
||
with open(importance_file, 'w', encoding='utf-8') as f:
|
||
for item in importance_sorted:
|
||
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
|
||
|
||
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
|
||
print(f" - JSON格式: {json_output_file}")
|
||
print(f" - 文本格式: {self.output_file}")
|
||
print(f" - 重要性排序: {importance_file}")
|
||
|
||
# 统计信息
|
||
scores = [item['importance_score'] for item in processed_sentences]
|
||
avg_score = sum(scores) / len(scores) if scores else 0
|
||
print(f" - 平均重要性评分: {avg_score:.2f}")
|
||
print(f" - 最高评分: {max(scores):.1f}")
|
||
print(f" - 最低评分: {min(scores):.1f}")
|
||
|
||
async def run(self):
|
||
"""运行处理流程"""
|
||
print("Starting enhanced TREx to sentences conversion...")
|
||
processed_sentences = await self.process_files()
|
||
self.save_sentences(processed_sentences)
|
||
print("Enhanced conversion completed!")
|
||
|
||
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
||
"""查找最新的检查点文件"""
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||
checkpoint_files = glob.glob(pattern)
|
||
|
||
if not checkpoint_files:
|
||
return None
|
||
|
||
# 按检查点编号排序,获取最新的
|
||
latest_file = None
|
||
latest_count = 0
|
||
|
||
for file in checkpoint_files:
|
||
try:
|
||
# 从文件名中提取数字
|
||
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||
if match:
|
||
count = int(match.group(1))
|
||
if count > latest_count:
|
||
latest_count = count
|
||
latest_file = file
|
||
except:
|
||
continue
|
||
|
||
if latest_file:
|
||
return latest_file, latest_count
|
||
else:
|
||
return None
|
||
|
||
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
|
||
"""从检查点文件加载已处理的句子"""
|
||
try:
|
||
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
if 'sentences' in data:
|
||
return data['sentences']
|
||
else:
|
||
# 旧格式的检查点文件
|
||
return data
|
||
except Exception as e:
|
||
print(f"加载检查点文件失败: {e}")
|
||
return []
|
||
|
||
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
|
||
"""从检查点文件中获取已处理过的句子集合"""
|
||
if not self.output_file:
|
||
return set()
|
||
|
||
processed_sentences = set()
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
|
||
# 首先查找新格式的批次文件
|
||
batch_pattern = os.path.join('output', f"{base_name}_batch_*.json")
|
||
batch_files = glob.glob(batch_pattern)
|
||
|
||
if batch_files:
|
||
print(f"找到 {len(batch_files)} 个批次检查点文件")
|
||
batch_files.sort() # 确保按顺序处理
|
||
|
||
for batch_file in batch_files:
|
||
try:
|
||
with open(batch_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
sentences_data = data.get('sentences', [])
|
||
for item in sentences_data:
|
||
original_sentence = item.get('original_sentence', '')
|
||
if original_sentence:
|
||
processed_sentences.add(original_sentence)
|
||
|
||
batch_number = data.get('metadata', {}).get('batch_number', 'unknown')
|
||
print(f" - 从批次 {batch_number} 加载了 {len(sentences_data)} 个句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取批次文件 {batch_file} 失败: {e}")
|
||
continue
|
||
|
||
print(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子")
|
||
logger.info(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子")
|
||
return processed_sentences
|
||
|
||
# 如果没有批次文件,尝试查找旧格式的检查点文件
|
||
old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||
checkpoint_files = glob.glob(old_pattern)
|
||
|
||
if not checkpoint_files:
|
||
print("未找到检查点文件,将从头开始处理")
|
||
return set()
|
||
|
||
# 找到最新的检查点文件
|
||
latest_file = None
|
||
latest_count = 0
|
||
|
||
for file in checkpoint_files:
|
||
try:
|
||
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||
if match:
|
||
count = int(match.group(1))
|
||
if count > latest_count:
|
||
latest_count = count
|
||
latest_file = file
|
||
except:
|
||
continue
|
||
|
||
if latest_file:
|
||
print(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||
logger.info(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)")
|
||
try:
|
||
with open(latest_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
sentences_data = data.get('sentences', [])
|
||
for item in sentences_data:
|
||
original_sentence = item.get('original_sentence', '')
|
||
if original_sentence:
|
||
processed_sentences.add(original_sentence)
|
||
|
||
print(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||
logger.info(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取检查点文件失败: {e}")
|
||
return set()
|
||
|
||
return processed_sentences
|
||
|
||
async def process_with_llm(self):
|
||
"""步骤2:从JSON文件读取句子并进行vLLM处理(保持兼容性)"""
|
||
await self.process_with_vllm_api()
|
||
|
||
async def process_with_vllm_api(self):
|
||
"""步骤2:从JSON文件读取句子并进行vLLM处理"""
|
||
if not self.enable_llm_processing:
|
||
print("Error: LLM processing is disabled!")
|
||
return
|
||
|
||
if not self.output_file:
|
||
print("Error: output_file is required for LLM processing!")
|
||
return
|
||
|
||
print("=== 步骤2:vLLM处理 ===")
|
||
|
||
# 读取句子JSON文件
|
||
if not os.path.exists(self.sentences_json):
|
||
print(f"Error: Sentences file {self.sentences_json} not found!")
|
||
print("请先运行步骤1进行句子提取")
|
||
return
|
||
|
||
print(f"正在读取句子文件: {self.sentences_json}")
|
||
|
||
try:
|
||
with open(self.sentences_json, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
|
||
print(f"从文件中读取了 {len(all_sentences)} 个句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取句子文件失败: {e}")
|
||
return
|
||
|
||
# 获取已处理的句子
|
||
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
|
||
|
||
# 过滤出未处理的句子
|
||
unprocessed_sentences = []
|
||
for sentence in all_sentences:
|
||
if sentence not in processed_sentences_set:
|
||
unprocessed_sentences.append(sentence)
|
||
|
||
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
|
||
|
||
if not unprocessed_sentences:
|
||
print("所有句子都已处理完成!")
|
||
|
||
# 如果有检查点,直接从最新检查点生成最终文件
|
||
if processed_sentences_set:
|
||
latest_checkpoint = self.find_latest_checkpoint()
|
||
if latest_checkpoint:
|
||
checkpoint_file, _ = latest_checkpoint
|
||
processed_data = self.load_checkpoint(checkpoint_file)
|
||
self.save_sentences(processed_data)
|
||
print("已从检查点生成最终输出文件")
|
||
return
|
||
|
||
# 处理未处理的句子
|
||
print("开始vLLM处理...")
|
||
|
||
# 处理新句子(现在返回空列表,数据保存在批次检查点中)
|
||
await self.process_sentences_with_vllm_api(unprocessed_sentences)
|
||
|
||
# 处理完成后,合并所有批次检查点生成最终文件
|
||
print("合并所有批次检查点生成最终文件...")
|
||
all_processed_sentences = self.merge_all_batch_checkpoints()
|
||
|
||
if all_processed_sentences:
|
||
# 保存最终结果
|
||
self.save_sentences(all_processed_sentences)
|
||
print("vLLM处理完成!")
|
||
else:
|
||
print("警告:没有找到任何处理结果")
|
||
|
||
def merge_all_batch_checkpoints(self) -> List[Dict[str, Any]]:
|
||
"""合并所有批次检查点文件"""
|
||
if not self.output_file:
|
||
return []
|
||
|
||
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
|
||
|
||
# 查找所有批次检查点文件
|
||
batch_pattern = os.path.join('output', f"{base_name}_batch_*.json")
|
||
batch_files = glob.glob(batch_pattern)
|
||
|
||
if not batch_files:
|
||
# 如果没有批次文件,尝试查找旧格式的检查点文件
|
||
old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
|
||
old_files = glob.glob(old_pattern)
|
||
if old_files:
|
||
print("找到旧格式检查点文件,尝试读取最新的...")
|
||
latest_checkpoint = self.find_latest_checkpoint()
|
||
if latest_checkpoint:
|
||
checkpoint_file, _ = latest_checkpoint
|
||
return self.load_checkpoint(checkpoint_file)
|
||
return []
|
||
|
||
print(f"找到 {len(batch_files)} 个批次检查点文件")
|
||
|
||
all_sentences = []
|
||
batch_files.sort() # 确保按顺序处理
|
||
|
||
for batch_file in batch_files:
|
||
try:
|
||
with open(batch_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
batch_sentences = data.get('sentences', [])
|
||
all_sentences.extend(batch_sentences)
|
||
|
||
batch_number = data.get('metadata', {}).get('batch_number', 'unknown')
|
||
batch_size = len(batch_sentences)
|
||
print(f" - 批次 {batch_number}: {batch_size} 个句子")
|
||
|
||
except Exception as e:
|
||
print(f"读取批次文件 {batch_file} 失败: {e}")
|
||
continue
|
||
|
||
print(f"总计合并了 {len(all_sentences)} 个句子")
|
||
return all_sentences
|
||
|
||
def extract_sentences(self):
|
||
"""步骤1:从TREx数据集提取句子并保存为JSON"""
|
||
if not self.input_dir:
|
||
print("Error: input_dir is required for sentence extraction!")
|
||
return
|
||
|
||
print("=== 步骤1:句子提取 ===")
|
||
print("开始从TREx数据集提取句子...")
|
||
|
||
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||
|
||
if not json_files:
|
||
print(f"No JSON files found in {self.input_dir}")
|
||
return
|
||
|
||
# 排序文件以确保一致的处理顺序
|
||
json_files.sort()
|
||
|
||
if self.max_files:
|
||
json_files = json_files[:self.max_files]
|
||
|
||
print(f"Found {len(json_files)} JSON files to process")
|
||
|
||
all_sentences = []
|
||
|
||
for i, file_path in enumerate(json_files):
|
||
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||
|
||
documents = self.parse_large_json_file(file_path)
|
||
print(f" Parsed {len(documents)} documents")
|
||
|
||
for doc in documents:
|
||
sentences = self.extract_sentences_from_document(doc)
|
||
all_sentences.extend(sentences)
|
||
|
||
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||
|
||
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||
|
||
# 去重
|
||
unique_sentences = []
|
||
seen = set()
|
||
for sentence in all_sentences:
|
||
sentence = sentence.strip()
|
||
if sentence and sentence not in seen and len(sentence) > 10:
|
||
unique_sentences.append(sentence)
|
||
seen.add(sentence)
|
||
|
||
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||
|
||
# 保存原始句子到JSON文件
|
||
sentences_data = {
|
||
"metadata": {
|
||
"total_sentences": len(unique_sentences),
|
||
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
"source_files": len(json_files),
|
||
"max_files_limit": self.max_files
|
||
},
|
||
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
|
||
}
|
||
|
||
with open(self.sentences_json, 'w', encoding='utf-8') as f:
|
||
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"句子提取完成!已保存到: {self.sentences_json}")
|
||
print(f"总计句子数: {len(unique_sentences)}")
|
||
|
||
return unique_sentences
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with vLLM processing')
|
||
|
||
# 选择运行模式
|
||
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
|
||
help='运行步骤: extract=仅提取句子, llm=仅vLLM处理, all=完整流程')
|
||
|
||
# 文件路径参数
|
||
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
||
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
|
||
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
|
||
|
||
# 处理参数
|
||
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
||
parser.add_argument('--no_llm', action='store_true', help='Disable vLLM processing (basic mode)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 根据步骤验证参数
|
||
if args.step in ['extract', 'all']:
|
||
if not os.path.exists(args.input_dir):
|
||
print(f"Error: Input directory {args.input_dir} does not exist!")
|
||
return
|
||
|
||
if args.step in ['llm', 'all']:
|
||
if args.no_llm:
|
||
print("Error: Cannot run vLLM step with --no_llm flag!")
|
||
return
|
||
|
||
# 创建处理器
|
||
processor = EnhancedTRExProcessor(
|
||
input_dir=args.input_dir,
|
||
sentences_json=args.sentences_json,
|
||
output_file=args.output_file,
|
||
max_files=args.max_files,
|
||
enable_llm_processing=not args.no_llm
|
||
)
|
||
|
||
# 根据选择的步骤运行
|
||
if args.step == 'extract':
|
||
print("=== 运行模式:仅句子提取 ===")
|
||
processor.extract_sentences()
|
||
|
||
elif args.step == 'llm':
|
||
print("=== 运行模式:仅vLLM处理 ===")
|
||
asyncio.run(processor.process_with_vllm_api())
|
||
|
||
elif args.step == 'all':
|
||
print("=== 运行模式:完整流程 ===")
|
||
|
||
# 步骤1:提取句子
|
||
print("\n--- 开始步骤1:句子提取 ---")
|
||
sentences = processor.extract_sentences()
|
||
|
||
if not sentences:
|
||
print("句子提取失败,退出")
|
||
return
|
||
|
||
if args.no_llm:
|
||
print("vLLM处理已禁用,流程结束")
|
||
return
|
||
|
||
# 步骤2:vLLM处理
|
||
print("\n--- 开始步骤2:vLLM处理 ---")
|
||
asyncio.run(processor.process_with_vllm_api())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |