Minimind/preprocessing/trex_to_sentences_simple.py

1238 lines
55 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
使用vLLM OpenAI兼容API进行句子后处理和重要性评分
支持两个独立步骤:
1. 句子提取从TREx数据集提取句子并保存为JSON
2. LLM处理读取JSON文件进行LLM后处理和重要性评分
"""
import json
import os
import glob
from typing import List, Dict, Any, Union, Set
import re
import asyncio
import time
import logging
from datetime import datetime
import requests
from pydantic import BaseModel, Field
import aiohttp
import concurrent.futures
# 设置日志系统
def setup_logging():
"""设置日志系统"""
# 确保logs目录存在
os.makedirs('logs', exist_ok=True)
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/trex_processor_{timestamp}.log'
# 配置日志格式
log_format = '%(asctime)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s'
# 配置root logger
logging.basicConfig(
level=logging.INFO,
format=log_format,
handlers=[
logging.FileHandler(log_file, encoding='utf-8'),
logging.StreamHandler() # 同时输出到控制台
]
)
# 获取logger
logger = logging.getLogger(__name__)
logger.info(f"日志系统初始化完成,日志文件: {log_file}")
return logger
# 全局日志对象
logger = setup_logging()
class ProcessedSentence(BaseModel):
"""处理后的句子结构"""
corrected_sentence: str = Field(
...,
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
)
importance_score: float = Field(
...,
description="重要性评分范围0.0-10.0以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
ge=0.0,
le=10.0
)
class EnhancedTRExProcessor:
def __init__(self, input_dir: str = None, output_file: str = None, max_files: int = None,
sentences_json: str = None, enable_llm_processing: bool = True):
self.input_dir = input_dir
# 确保output目录存在
os.makedirs('output', exist_ok=True)
# 确保所有输出文件都在output目录中
if output_file:
if not output_file.startswith('output/'):
self.output_file = os.path.join('output', output_file)
else:
self.output_file = output_file
else:
self.output_file = None
if sentences_json:
if not sentences_json.startswith('output/'):
self.sentences_json = os.path.join('output', sentences_json)
else:
self.sentences_json = sentences_json
else:
self.sentences_json = "output/extracted_sentences.json"
self.max_files = max_files
self.enable_llm_processing = enable_llm_processing
# Ollama API配置
self.model_name = "gemma3:latest" # Ollama模型名称
self.ollama_base_url = "http://localhost:11434" # Ollama服务器地址
self.batch_size_per_request = 8 # 每个API请求处理的句子数量Ollama建议较小批次
self.max_concurrent_requests = 2 # 最大并发请求数Ollama建议较低并发
self.request_timeout = 180 # 请求超时时间(秒)
self.retry_attempts = 3 # 重试次数
# 统计信息
self.total_requests = 0
self.successful_requests = 0
self.failed_requests = 0
logger.info(f"处理器初始化完成 - 模型: {self.model_name}, 批次大小: {self.batch_size_per_request}, 并发数: {self.max_concurrent_requests}")
# 扩展的Wikidata属性映射
self.property_mappings = {
# 基本关系
"http://www.wikidata.org/prop/direct/P31": "is a",
"http://www.wikidata.org/prop/direct/P279": "is a type of",
# 人物相关
"http://www.wikidata.org/prop/direct/P106": "works as",
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
"http://www.wikidata.org/prop/direct/P19": "was born in",
"http://www.wikidata.org/prop/direct/P20": "died in",
"http://www.wikidata.org/prop/direct/P569": "was born on",
"http://www.wikidata.org/prop/direct/P570": "died on",
"http://www.wikidata.org/prop/direct/P22": "has father",
"http://www.wikidata.org/prop/direct/P25": "has mother",
"http://www.wikidata.org/prop/direct/P26": "is married to",
# 组织相关
"http://www.wikidata.org/prop/direct/P102": "is a member of",
"http://www.wikidata.org/prop/direct/P108": "works for",
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
"http://www.wikidata.org/prop/direct/P112": "was founded by",
"http://www.wikidata.org/prop/direct/P571": "was founded in",
"http://www.wikidata.org/prop/direct/P169": "has CEO",
# 地理相关
"http://www.wikidata.org/prop/direct/P17": "is located in",
"http://www.wikidata.org/prop/direct/P131": "is located in",
"http://www.wikidata.org/prop/direct/P36": "has capital",
"http://www.wikidata.org/prop/direct/P47": "borders",
# 其他关系
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
"http://www.wikidata.org/prop/direct/P361": "is part of",
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
"http://www.wikidata.org/prop/direct/P127": "is owned by",
"http://www.wikidata.org/prop/direct/P155": "follows",
"http://www.wikidata.org/prop/direct/P156": "is followed by",
"http://www.wikidata.org/prop/direct/P138": "is named after"
}
def get_system_prompt(self) -> str:
"""获取系统提示"""
return """You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.
### Sentence Correction Rules:
1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses
2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling 'and is', 'or', etc.
3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage
4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues
5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent
6. Do not add information not present in the original text, only correct errors
### Correction Examples:
- Error: 'Argument (disambiguation) is related to philosophy, logic, and is an.'
- Corrected: 'Argument is related to philosophy and logic.'
- Error: 'Beijing is a capital city and are.'
- Corrected: 'Beijing is a capital city.'
Importance scoring criteria (0.0-10.0, in increments of 0.1):
0.0 points - Completely incorrect or meaningless information
Examples: 'Apple is a metal', 'The sun rises from the west', '1+1=3'
0.5 points - Almost worthless information
Examples: 'Color of a fictional character's socks', 'Third line of dialogue from a game NPC', 'What someone had for breakfast yesterday'
1.0 points - Extremely rare, non-practical knowledge
Examples: 'Pet name of a minor novel character', 'Content of the 15th line in movie end credits', 'Nickname of website user ID 123456'
1.5 points - Very niche detailed information
Examples: 'Outfit of a passerby at minute 37 in a movie', 'Duration of background music in a game's hidden level', 'Content of the 3rd dialogue box on page 200 of a manga'
2.0 points - Details in niche professional fields
Examples: 'Color change of rare minerals at specific temperatures', 'Length of an insect's third antenna', 'Molecular formula of chemical reaction byproducts'
2.5 points - Technical details only professionals care about
Examples: 'Release date of specific software library version', 'Time complexity coefficient of an algorithm', 'Thermal expansion coefficient of a material'
3.0 points - Professional knowledge in specific fields
Examples: 'Programming language syntax features', 'Gene sequence of a virus', 'Official system of ancient dynasties'
3.5 points - Professional information with some value
Examples: 'Specific system of historical dynasty', 'Mechanism of action of a drug', 'Development time of a technical standard'
4.0 points - Meaningful knowledge known by few
Examples: 'Unique cultural traditions of a country', 'Important discoveries by a scientist', 'Detailed process of historical events'
4.5 points - Knowledge of interest to some groups
Examples: 'Author's creative background', 'Characteristics of an art movement', 'Detailed rules of a sport'
5.0 points - General knowledge of moderate importance
Examples: 'Famous attractions in cities', 'Development history of a company', 'Living habits of animals'
5.5 points - Fairly useful common sense
Examples: 'Plant growth environment', 'Healthy eating common sense', 'Basic first aid knowledge'
6.0 points - Knowledge most educated people should know
Examples: 'Shakespeare's representative works', 'Basic geometric theorems', 'Major world currencies'
6.5 points - Important cultural or scientific common sense
Examples: 'Basic structure of DNA', 'Newton's three laws', 'Major world religions'
7.0 points - Important foundational knowledge
Examples: 'Time period of World War II', 'Functions of major human organs', 'Basic mathematical operation rules'
7.5 points - Very important common sense
Examples: 'Light speed is the fastest in the universe', 'Earth is round', 'Basic principles of blood circulation'
8.0 points - Core knowledge in basic education
Examples: 'Earth orbits the sun', 'Principle of seasonal formation', 'Basic grammar rules'
8.5 points - Important knowledge everyone should master
Examples: 'Chemical formula of water H2O', 'Basic safety common sense', 'Simple mathematical calculations'
9.0 points - Extremely important basic concepts
Examples: 'Humans need oxygen to survive', 'Fire is hot', 'Basic directional concepts'
9.5 points - Core knowledge everyone must know
Examples: 'A day has 24 hours', 'A year has 12 months', 'Basic number concepts'
10.0 points - Most basic and important common sense
Examples: 'Humans need food and water to survive', 'The sky is blue', 'Stones are heavier than feathers'
When scoring, please consider:
1. Popularity of knowledge - How many people know this knowledge
2. Practical value - How useful this knowledge is in daily life
3. Educational importance - The position of this knowledge in the education system
4. Cultural significance - The importance of this knowledge for understanding world
Please respond with valid JSON in the following format:
{
"corrected_sentence": "corrected sentence here",
"importance_score": evaluation score
}"""
async def process_batch_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""使用vLLM OpenAI兼容API处理一批句子"""
processed_sentences = []
async with aiohttp.ClientSession() as session:
# 创建并发任务
semaphore = asyncio.Semaphore(self.max_concurrent_requests)
tasks = []
# 将句子分成小批次
for i in range(0, len(sentences), self.batch_size_per_request):
batch_sentences = sentences[i:i + self.batch_size_per_request]
task = self.process_single_batch_request(session, semaphore, batch_sentences, i)
tasks.append(task)
# 等待所有任务完成
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# 收集结果
for result in batch_results:
if isinstance(result, Exception):
logger.error(f"批次处理出错: {result}")
continue
if result:
processed_sentences.extend(result)
return processed_sentences
async def process_single_batch_request(self, session: aiohttp.ClientSession, semaphore: asyncio.Semaphore,
sentences: List[str], batch_index: int) -> List[Dict[str, Any]]:
"""处理单个批次的API请求"""
async with semaphore:
for attempt in range(self.retry_attempts):
try:
# 为每个句子创建单独的消息
messages = []
for sentence in sentences:
messages.append({
"role": "user",
"content": f"Please correct the errors in the following sentence and evaluate its importance: {sentence}"
})
# 构建Ollama请求数据
request_data = {
"model": self.model_name,
"messages": [
{"role": "system", "content": self.get_system_prompt()}
] + messages,
"stream": False,
"options": {
"temperature": 0.2,
"num_predict": 500 * len(sentences) # 为每个句子分配足够的token
},
"format": "json" # Ollama的JSON格式参数
}
# 发送请求到Ollama
async with session.post(
f'{self.ollama_base_url}/api/chat',
json=request_data,
timeout=aiohttp.ClientTimeout(total=self.request_timeout)
) as response:
if response.status == 200:
result = await response.json()
return self.parse_ollama_response(result, sentences, batch_index)
else:
error_text = await response.text()
logger.warning(f"API请求失败 (批次 {batch_index}, 尝试 {attempt + 1}/{self.retry_attempts}): {response.status} - {error_text}")
if attempt == self.retry_attempts - 1: # 最后一次尝试
logger.error(f"批次 {batch_index}{self.retry_attempts} 次尝试后仍然失败")
self.failed_requests += len(sentences)
return self.create_default_responses(sentences)
else:
# 等待后重试
await asyncio.sleep(2 ** attempt) # 指数退避
continue
except asyncio.TimeoutError:
logger.warning(f"批次 {batch_index} 请求超时 (尝试 {attempt + 1}/{self.retry_attempts})")
if attempt == self.retry_attempts - 1:
logger.error(f"批次 {batch_index}{self.retry_attempts} 次尝试后仍然超时")
self.failed_requests += len(sentences)
return self.create_default_responses(sentences)
else:
await asyncio.sleep(2 ** attempt)
continue
except Exception as e:
logger.warning(f"处理批次 {batch_index} 时出错 (尝试 {attempt + 1}/{self.retry_attempts}): {e}")
if attempt == self.retry_attempts - 1:
logger.error(f"批次 {batch_index}{self.retry_attempts} 次尝试后仍然失败")
self.failed_requests += len(sentences)
return self.create_default_responses(sentences)
else:
await asyncio.sleep(2 ** attempt)
continue
# 如果所有重试都失败了
return self.create_default_responses(sentences)
def parse_ollama_response(self, response: Dict[str, Any], original_sentences: List[str], batch_index: int) -> List[Dict[str, Any]]:
"""解析Ollama响应"""
processed_sentences = []
try:
# Ollama的响应格式
message = response.get('message', {})
content = message.get('content', '')
if not content:
logger.warning(f"批次 {batch_index} 没有返回内容")
return self.create_default_responses(original_sentences)
# 尝试解析JSON响应
try:
# 如果返回的是单个JSON对象
if content.strip().startswith('{') and content.strip().endswith('}'):
response_data = json.loads(content)
processed_sentence = ProcessedSentence(
corrected_sentence=response_data.get('corrected_sentence', original_sentences[0] if original_sentences else ""),
importance_score=float(response_data.get('importance_score', 5.0))
)
processed_sentences.append({
"original_sentence": original_sentences[0] if original_sentences else "",
"corrected_sentence": processed_sentence.corrected_sentence,
"importance_score": processed_sentence.importance_score
})
self.successful_requests += 1
# 如果有多个句子但只返回一个结果,为其他句子创建默认响应
for i in range(1, len(original_sentences)):
processed_sentences.append({
"original_sentence": original_sentences[i],
"corrected_sentence": original_sentences[i],
"importance_score": 5.0
})
self.failed_requests += 1
else:
# 尝试解析多个JSON对象
json_objects = []
for line in content.split('\n'):
line = line.strip()
if line.startswith('{') and line.endswith('}'):
try:
json_objects.append(json.loads(line))
except:
continue
if json_objects:
for i, (sentence, json_obj) in enumerate(zip(original_sentences, json_objects)):
try:
processed_sentence = ProcessedSentence(
corrected_sentence=json_obj.get('corrected_sentence', sentence),
importance_score=float(json_obj.get('importance_score', 5.0))
)
processed_sentences.append({
"original_sentence": sentence,
"corrected_sentence": processed_sentence.corrected_sentence,
"importance_score": processed_sentence.importance_score
})
self.successful_requests += 1
except Exception as e:
logger.warning(f"解析JSON对象失败: {e}")
processed_sentences.append({
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
})
self.failed_requests += 1
# 为剩余句子创建默认响应
for i in range(len(json_objects), len(original_sentences)):
processed_sentences.append({
"original_sentence": original_sentences[i],
"corrected_sentence": original_sentences[i],
"importance_score": 5.0
})
self.failed_requests += 1
else:
logger.warning(f"批次 {batch_index} 无法解析JSON响应: {content}")
return self.create_default_responses(original_sentences)
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"批次 {batch_index} 解析响应JSON失败: {e}")
logger.warning(f"原始内容: {content}")
return self.create_default_responses(original_sentences)
except Exception as e:
logger.error(f"解析批次 {batch_index} 响应时出错: {e}")
return self.create_default_responses(original_sentences)
return processed_sentences
def create_default_responses(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""为失败的请求创建默认响应"""
default_responses = []
for sentence in sentences:
default_responses.append({
"original_sentence": sentence,
"corrected_sentence": sentence,
"importance_score": 5.0
})
return default_responses
async def process_sentences_with_vllm_api(self, sentences: List[str]) -> List[Dict[str, Any]]:
"""使用Ollama API处理句子"""
logger.info(f"开始使用Ollama API处理 {len(sentences)} 个句子...")
print(f"开始使用Ollama API处理 {len(sentences)} 个句子...")
start_time = time.time()
total_sentences = len(sentences)
total_processed_count = 0
# 检查Ollama服务状态
if not self.check_ollama_status():
logger.error("Ollama服务状态异常无法继续处理")
print("错误Ollama服务状态异常请检查服务是否正常运行")
return []
# 分大批次处理(用于检查点保存)
large_batch_size = 1000 # 每1000个句子保存一次检查点
all_processed_sentences = []
for large_batch_start in range(0, total_sentences, large_batch_size):
large_batch_end = min(large_batch_start + large_batch_size, total_sentences)
large_batch_sentences = sentences[large_batch_start:large_batch_end]
large_batch_number = large_batch_start // large_batch_size + 1
logger.info(f"=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===")
print(f"\n=== 处理大批次 {large_batch_number} ({large_batch_start + 1}-{large_batch_end}/{total_sentences}) ===")
large_batch_start_time = time.time()
# 处理当前大批次
batch_processed = await self.process_batch_with_vllm_api(large_batch_sentences)
all_processed_sentences.extend(batch_processed)
total_processed_count += len(batch_processed)
# 保存当前大批次的检查点
checkpoint_filename = self.save_batch_checkpoint(batch_processed, large_batch_number, total_processed_count)
# 打印进度
large_batch_time = time.time() - large_batch_start_time
elapsed_time = time.time() - start_time
logger.info(f"大批次 {large_batch_number} 处理完成!")
logger.info(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟")
logger.info(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)")
logger.info(f" - 已用时间:{elapsed_time/60:.1f}分钟")
logger.info(f" - 批次检查点已保存:{checkpoint_filename}")
print(f"大批次 {large_batch_number} 处理完成!")
print(f" - 当前批次:成功 {len(batch_processed)},用时 {large_batch_time/60:.1f}分钟")
print(f" - 总体进度:{total_processed_count}/{total_sentences} ({total_processed_count/total_sentences*100:.1f}%)")
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
print(f" - 批次检查点已保存:{checkpoint_filename}")
if large_batch_end < total_sentences:
remaining_sentences = total_sentences - total_processed_count
avg_time_per_sentence = elapsed_time / total_processed_count
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger.info(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
# 打印最终统计
total_time = time.time() - start_time
logger.info(f"=== 全部处理完成!===")
logger.info(f" - 总成功:{self.successful_requests}")
logger.info(f" - 总失败:{self.failed_requests}")
logger.info(f" - 总用时:{total_time/60:.1f}分钟")
logger.info(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒")
print(f"\n=== 全部处理完成!===")
print(f" - 总成功:{self.successful_requests}")
print(f" - 总失败:{self.failed_requests}")
print(f" - 总用时:{total_time/60:.1f}分钟")
print(f" - 平均处理速度:{total_processed_count/total_time:.2f}句/秒")
return all_processed_sentences
def check_ollama_status(self) -> bool:
"""检查Ollama服务是否正常运行"""
try:
# 检查Ollama API是否响应
response = requests.get(f'{self.ollama_base_url}/api/tags', timeout=10)
if response.status_code == 200:
models = response.json()
model_names = [model.get('name', 'unknown') for model in models.get('models', [])]
logger.info(f"Ollama服务状态正常可用模型: {model_names}")
# 检查目标模型是否可用
if self.model_name in model_names:
logger.info(f"目标模型 {self.model_name} 可用")
return True
else:
logger.warning(f"目标模型 {self.model_name} 不在可用模型列表中: {model_names}")
logger.info("尝试拉取模型...")
# 尝试拉取模型
try:
pull_response = requests.post(
f'{self.ollama_base_url}/api/pull',
json={"name": self.model_name},
timeout=300 # 5分钟超时
)
if pull_response.status_code == 200:
logger.info(f"成功拉取模型 {self.model_name}")
return True
else:
logger.error(f"拉取模型失败: {pull_response.status_code}")
return False
except Exception as e:
logger.error(f"拉取模型时出错: {e}")
return False
else:
logger.error(f"Ollama API响应异常状态码: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.error(f"无法连接到Ollama API: {e}")
return False
except Exception as e:
logger.error(f"检查Ollama状态时出错: {e}")
return False
def clean_text(self, text: str) -> str:
"""清理文本,处理特殊字符"""
if not text:
return ""
# 处理常见的Unicode字符
text = text.replace("", "-") # en dash
text = text.replace("", "-") # em dash
text = text.replace("'", "'") # right single quotation mark
text = text.replace("'", "'") # left single quotation mark
text = text.replace(""", '"') # left double quotation mark
text = text.replace(""", '"') # right double quotation mark
# 处理可能的转义序列
try:
text = text.encode('utf-8').decode('utf-8')
except:
pass
# 清理多余的空格
text = re.sub(r'\s+', ' ', text).strip()
# 移除可能的引号
text = text.strip('"\'')
return text
def parse_large_json_file(self, file_path: str) -> List[Dict]:
"""解析大型JSON文件处理可能的格式问题"""
documents = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 尝试不同的解析方法
if content.startswith('[') and content.endswith(']'):
# 标准JSON数组
documents = json.loads(content)
else:
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if '}{"' in content:
json_strings = content.split('}{')
json_strings[0] += '}' # 第一个对象
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
for i in range(1, len(json_strings) - 1):
json_strings[i] = '{' + json_strings[i] + '}'
for json_str in json_strings:
try:
doc = json.loads(json_str)
documents.append(doc)
except json.JSONDecodeError:
continue
else:
# 尝试作为单个JSON对象
try:
documents = [json.loads(content)]
except json.JSONDecodeError:
pass
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return documents
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
"""从文档中提取句子"""
sentences = []
title = self.clean_text(doc.get('title', ''))
text = self.clean_text(doc.get('text', ''))
entities = doc.get('entities', [])
triples = doc.get('triples', [])
# 处理显式三元组
for triple in triples:
sentence = self.triple_to_sentence(triple)
if sentence:
sentences.append(sentence)
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len(sentences) < 5:
# 基于标题和实体生成句子
entity_names = []
for entity in entities[:15]:
entity_name = self.clean_text(entity.get('surfaceform', ''))
if entity_name and len(entity_name) > 2:
entity_names.append(entity_name)
# 生成简单的描述句子
if entity_names:
important_entities = []
title_lower = title.lower()
for entity in entity_names:
if (entity.lower() != title_lower and
entity not in important_entities and
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
important_entities.append(entity)
if len(important_entities) >= 6:
break
if important_entities and len(sentences) < 3:
entities_str = ', '.join(important_entities[:3])
sentences.append(f"{title} is related to {entities_str}.")
return sentences
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
"""将三元组转换为自然语言句子"""
try:
subject = triple.get('subject', {})
predicate = triple.get('predicate', {})
obj = triple.get('object', {})
subject_name = self.clean_text(subject.get('surfaceform', ''))
object_name = self.clean_text(obj.get('surfaceform', ''))
predicate_uri = predicate.get('uri', '')
# 检查是否有有效的主语和宾语
if not subject_name or not object_name:
return ""
# 检查主语和宾语是否过短或无意义
if len(subject_name) <= 2 or len(object_name) <= 2:
return ""
# 获取关系文本
relation_text = self.property_mappings.get(predicate_uri, "is related to")
# 避免重复的主语宾语
if subject_name.lower() == object_name.lower():
return ""
return f"{subject_name} {relation_text} {object_name}."
except Exception as e:
print(f"Error converting triple to sentence: {e}")
return ""
def save_batch_checkpoint(self, processed_sentences: List[Dict[str, Any]], batch_number: int, total_processed_count: int) -> str:
"""保存当前批次的检查点文件"""
# 生成检查点文件名确保在output目录中
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
checkpoint_filename = os.path.join('output', f"{base_name}_batch_{batch_number}.json")
# 保存检查点
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
json.dump({
"metadata": {
"batch_number": batch_number,
"batch_size": len(processed_sentences),
"total_processed_count": total_processed_count,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
},
"sentences": processed_sentences
}, f, ensure_ascii=False, indent=2)
return checkpoint_filename
async def process_files(self) -> List[Dict[str, Any]]:
"""处理所有文件"""
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return []
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
"""保存处理后的句子到文件"""
# 确保输出目录存在
os.makedirs('output', exist_ok=True)
# 保存为JSON格式包含完整信息
json_output_file = self.output_file.replace('.txt', '.json')
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
# 保存为简单文本格式(仅修正后的句子)
with open(self.output_file, 'w', encoding='utf-8') as f:
for item in processed_sentences:
f.write(item['corrected_sentence'] + '\n')
# 生成重要性排序文件
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
with open(importance_file, 'w', encoding='utf-8') as f:
for item in importance_sorted:
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
print(f" - JSON格式: {json_output_file}")
print(f" - 文本格式: {self.output_file}")
print(f" - 重要性排序: {importance_file}")
# 统计信息
scores = [item['importance_score'] for item in processed_sentences]
avg_score = sum(scores) / len(scores) if scores else 0
print(f" - 平均重要性评分: {avg_score:.2f}")
print(f" - 最高评分: {max(scores):.1f}")
print(f" - 最低评分: {min(scores):.1f}")
async def run(self):
"""运行处理流程"""
print("Starting enhanced TREx to sentences conversion...")
processed_sentences = await self.process_files()
self.save_sentences(processed_sentences)
print("Enhanced conversion completed!")
def find_latest_checkpoint(self) -> Union[tuple, None]:
"""查找最新的检查点文件"""
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(pattern)
if not checkpoint_files:
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
# 从文件名中提取数字
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
return latest_file, latest_count
else:
return None
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
"""从检查点文件加载已处理的句子"""
try:
with open(checkpoint_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'sentences' in data:
return data['sentences']
else:
# 旧格式的检查点文件
return data
except Exception as e:
print(f"加载检查点文件失败: {e}")
return []
def get_processed_sentences_from_checkpoints(self) -> Set[str]:
"""从检查点文件中获取已处理过的句子集合"""
if not self.output_file:
return set()
processed_sentences = set()
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
# 首先查找新格式的批次文件
batch_pattern = os.path.join('output', f"{base_name}_batch_*.json")
batch_files = glob.glob(batch_pattern)
if batch_files:
print(f"找到 {len(batch_files)} 个批次检查点文件")
batch_files.sort() # 确保按顺序处理
for batch_file in batch_files:
try:
with open(batch_file, 'r', encoding='utf-8') as f:
data = json.load(f)
sentences_data = data.get('sentences', [])
for item in sentences_data:
original_sentence = item.get('original_sentence', '')
if original_sentence:
processed_sentences.add(original_sentence)
batch_number = data.get('metadata', {}).get('batch_number', 'unknown')
print(f" - 从批次 {batch_number} 加载了 {len(sentences_data)} 个句子")
except Exception as e:
print(f"读取批次文件 {batch_file} 失败: {e}")
continue
print(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子")
logger.info(f"从批次文件总计加载了 {len(processed_sentences)} 个已处理的句子")
return processed_sentences
# 如果没有批次文件,尝试查找旧格式的检查点文件
old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
checkpoint_files = glob.glob(old_pattern)
if not checkpoint_files:
print("未找到检查点文件,将从头开始处理")
return set()
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files:
try:
match = re.search(r'checkpoint_(\d+)\.json$', file)
if match:
count = int(match.group(1))
if count > latest_count:
latest_count = count
latest_file = file
except:
continue
if latest_file:
print(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)")
logger.info(f"找到旧格式检查点: {latest_file} (包含 {latest_count} 条记录)")
try:
with open(latest_file, 'r', encoding='utf-8') as f:
data = json.load(f)
sentences_data = data.get('sentences', [])
for item in sentences_data:
original_sentence = item.get('original_sentence', '')
if original_sentence:
processed_sentences.add(original_sentence)
print(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子")
logger.info(f"从旧格式检查点加载了 {len(processed_sentences)} 个已处理的句子")
except Exception as e:
print(f"读取检查点文件失败: {e}")
return set()
return processed_sentences
async def process_with_llm(self):
"""步骤2从JSON文件读取句子并进行vLLM处理保持兼容性"""
await self.process_with_vllm_api()
async def process_with_vllm_api(self):
"""步骤2从JSON文件读取句子并进行vLLM处理"""
if not self.enable_llm_processing:
print("Error: LLM processing is disabled!")
return
if not self.output_file:
print("Error: output_file is required for LLM processing!")
return
print("=== 步骤2vLLM处理 ===")
# 读取句子JSON文件
if not os.path.exists(self.sentences_json):
print(f"Error: Sentences file {self.sentences_json} not found!")
print("请先运行步骤1进行句子提取")
return
print(f"正在读取句子文件: {self.sentences_json}")
try:
with open(self.sentences_json, 'r', encoding='utf-8') as f:
data = json.load(f)
all_sentences = [item["sentence"] for item in data.get("sentences", [])]
print(f"从文件中读取了 {len(all_sentences)} 个句子")
except Exception as e:
print(f"读取句子文件失败: {e}")
return
# 获取已处理的句子
processed_sentences_set = self.get_processed_sentences_from_checkpoints()
# 过滤出未处理的句子
unprocessed_sentences = []
for sentence in all_sentences:
if sentence not in processed_sentences_set:
unprocessed_sentences.append(sentence)
print(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
logger.info(f"需要处理的句子数: {len(unprocessed_sentences)} (跳过已处理: {len(processed_sentences_set)})")
if not unprocessed_sentences:
print("所有句子都已处理完成!")
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set:
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
processed_data = self.load_checkpoint(checkpoint_file)
self.save_sentences(processed_data)
print("已从检查点生成最终输出文件")
return
# 处理未处理的句子
print("开始vLLM处理...")
# 处理新句子(现在返回空列表,数据保存在批次检查点中)
await self.process_sentences_with_vllm_api(unprocessed_sentences)
# 处理完成后,合并所有批次检查点生成最终文件
print("合并所有批次检查点生成最终文件...")
all_processed_sentences = self.merge_all_batch_checkpoints()
if all_processed_sentences:
# 保存最终结果
self.save_sentences(all_processed_sentences)
print("vLLM处理完成")
else:
print("警告:没有找到任何处理结果")
def merge_all_batch_checkpoints(self) -> List[Dict[str, Any]]:
"""合并所有批次检查点文件"""
if not self.output_file:
return []
base_name = os.path.splitext(os.path.basename(self.output_file))[0]
# 查找所有批次检查点文件
batch_pattern = os.path.join('output', f"{base_name}_batch_*.json")
batch_files = glob.glob(batch_pattern)
if not batch_files:
# 如果没有批次文件,尝试查找旧格式的检查点文件
old_pattern = os.path.join('output', f"{base_name}_checkpoint_*.json")
old_files = glob.glob(old_pattern)
if old_files:
print("找到旧格式检查点文件,尝试读取最新的...")
latest_checkpoint = self.find_latest_checkpoint()
if latest_checkpoint:
checkpoint_file, _ = latest_checkpoint
return self.load_checkpoint(checkpoint_file)
return []
print(f"找到 {len(batch_files)} 个批次检查点文件")
all_sentences = []
batch_files.sort() # 确保按顺序处理
for batch_file in batch_files:
try:
with open(batch_file, 'r', encoding='utf-8') as f:
data = json.load(f)
batch_sentences = data.get('sentences', [])
all_sentences.extend(batch_sentences)
batch_number = data.get('metadata', {}).get('batch_number', 'unknown')
batch_size = len(batch_sentences)
print(f" - 批次 {batch_number}: {batch_size} 个句子")
except Exception as e:
print(f"读取批次文件 {batch_file} 失败: {e}")
continue
print(f"总计合并了 {len(all_sentences)} 个句子")
return all_sentences
def extract_sentences(self):
"""步骤1从TREx数据集提取句子并保存为JSON"""
if not self.input_dir:
print("Error: input_dir is required for sentence extraction!")
return
print("=== 步骤1句子提取 ===")
print("开始从TREx数据集提取句子...")
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
if not json_files:
print(f"No JSON files found in {self.input_dir}")
return
# 排序文件以确保一致的处理顺序
json_files.sort()
if self.max_files:
json_files = json_files[:self.max_files]
print(f"Found {len(json_files)} JSON files to process")
all_sentences = []
for i, file_path in enumerate(json_files):
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
documents = self.parse_large_json_file(file_path)
print(f" Parsed {len(documents)} documents")
for doc in documents:
sentences = self.extract_sentences_from_document(doc)
all_sentences.extend(sentences)
print(f" Generated {len(all_sentences)} total raw sentences so far")
print(f"总共提取了 {len(all_sentences)} 个原始句子")
# 去重
unique_sentences = []
seen = set()
for sentence in all_sentences:
sentence = sentence.strip()
if sentence and sentence not in seen and len(sentence) > 10:
unique_sentences.append(sentence)
seen.add(sentence)
print(f"去重后剩余 {len(unique_sentences)} 个句子")
# 保存原始句子到JSON文件
sentences_data = {
"metadata": {
"total_sentences": len(unique_sentences),
"extraction_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"source_files": len(json_files),
"max_files_limit": self.max_files
},
"sentences": [{"sentence": sentence, "processed": False} for sentence in unique_sentences]
}
with open(self.sentences_json, 'w', encoding='utf-8') as f:
json.dump(sentences_data, f, ensure_ascii=False, indent=2)
print(f"句子提取完成!已保存到: {self.sentences_json}")
print(f"总计句子数: {len(unique_sentences)}")
return unique_sentences
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with vLLM processing')
# 选择运行模式
parser.add_argument('--step', choices=['extract', 'llm', 'all'], default='llm',
help='运行步骤: extract=仅提取句子, llm=仅vLLM处理, all=完整流程')
# 文件路径参数
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
parser.add_argument('--sentences_json', default='extracted_sentences.json', help='JSON file for extracted sentences (will be saved in output/)')
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path (will be saved in output/)')
# 处理参数
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
parser.add_argument('--no_llm', action='store_true', help='Disable vLLM processing (basic mode)')
args = parser.parse_args()
# 根据步骤验证参数
if args.step in ['extract', 'all']:
if not os.path.exists(args.input_dir):
print(f"Error: Input directory {args.input_dir} does not exist!")
return
if args.step in ['llm', 'all']:
if args.no_llm:
print("Error: Cannot run vLLM step with --no_llm flag!")
return
# 创建处理器
processor = EnhancedTRExProcessor(
input_dir=args.input_dir,
sentences_json=args.sentences_json,
output_file=args.output_file,
max_files=args.max_files,
enable_llm_processing=not args.no_llm
)
# 根据选择的步骤运行
if args.step == 'extract':
print("=== 运行模式:仅句子提取 ===")
processor.extract_sentences()
elif args.step == 'llm':
print("=== 运行模式仅vLLM处理 ===")
asyncio.run(processor.process_with_vllm_api())
elif args.step == 'all':
print("=== 运行模式:完整流程 ===")
# 步骤1提取句子
print("\n--- 开始步骤1句子提取 ---")
sentences = processor.extract_sentences()
if not sentences:
print("句子提取失败,退出")
return
if args.no_llm:
print("vLLM处理已禁用流程结束")
return
# 步骤2vLLM处理
print("\n--- 开始步骤2vLLM处理 ---")
asyncio.run(processor.process_with_vllm_api())
if __name__ == "__main__":
main()