2025-05-23 15:47:17 +08:00
#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
2025-05-29 19:30:19 +08:00
使用vLLM OpenAI兼容API进行句子后处理和重要性评分
2025-05-26 23:09:03 +08:00
支持两个独立步骤 :
1. 句子提取 : 从TREx数据集提取句子并保存为JSON
2. LLM处理 : 读取JSON文件进行LLM后处理和重要性评分
2025-05-23 15:47:17 +08:00
"""
import json
import os
import glob
2025-05-26 23:09:03 +08:00
from typing import List , Dict , Any , Union , Set
2025-05-23 15:47:17 +08:00
import re
import asyncio
import time
2025-05-26 23:09:03 +08:00
import logging
from datetime import datetime
import requests
2025-05-23 15:47:17 +08:00
from pydantic import BaseModel , Field
2025-05-29 19:30:19 +08:00
import aiohttp
import concurrent . futures
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
# 设置日志系统
def setup_logging ( ) :
""" 设置日志系统 """
# 确保logs目录存在
os . makedirs ( ' logs ' , exist_ok = True )
# 创建日志文件名(包含时间戳)
timestamp = datetime . now ( ) . strftime ( ' % Y % m %d _ % H % M % S ' )
log_file = f ' logs/trex_processor_ { timestamp } .log '
# 配置日志格式
log_format = ' %(asctime)s - %(levelname)s - [ %(funcName)s : %(lineno)d ] - %(message)s '
# 配置root logger
logging . basicConfig (
level = logging . INFO ,
format = log_format ,
handlers = [
logging . FileHandler ( log_file , encoding = ' utf-8 ' ) ,
logging . StreamHandler ( ) # 同时输出到控制台
]
)
# 获取logger
logger = logging . getLogger ( __name__ )
logger . info ( f " 日志系统初始化完成,日志文件: { log_file } " )
return logger
# 全局日志对象
logger = setup_logging ( )
2025-05-23 15:47:17 +08:00
class ProcessedSentence ( BaseModel ) :
""" 处理后的句子结构 """
corrected_sentence : str = Field (
. . . ,
description = " 修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色 "
)
importance_score : float = Field (
. . . ,
description = " 重要性评分, 范围0.0-10.0, 以0.1递进。评判这个知识在现实世界中的常用程度和重要度 " ,
ge = 0.0 ,
le = 10.0
)
class EnhancedTRExProcessor :
2025-05-26 23:09:03 +08:00
def __init__ ( self , input_dir : str = None , output_file : str = None , max_files : int = None ,
sentences_json : str = None , enable_llm_processing : bool = True ) :
2025-05-23 15:47:17 +08:00
self . input_dir = input_dir
2025-05-26 23:09:03 +08:00
# 确保output目录存在
os . makedirs ( ' output ' , exist_ok = True )
# 确保所有输出文件都在output目录中
if output_file :
if not output_file . startswith ( ' output/ ' ) :
self . output_file = os . path . join ( ' output ' , output_file )
else :
self . output_file = output_file
else :
self . output_file = None
if sentences_json :
if not sentences_json . startswith ( ' output/ ' ) :
self . sentences_json = os . path . join ( ' output ' , sentences_json )
else :
self . sentences_json = sentences_json
else :
self . sentences_json = " output/extracted_sentences.json "
2025-05-23 15:47:17 +08:00
self . max_files = max_files
self . enable_llm_processing = enable_llm_processing
2025-05-29 19:30:19 +08:00
# Ollama API配置
self . model_name = " gemma3:latest " # Ollama模型名称
self . ollama_base_url = " http://localhost:11434 " # Ollama服务器地址
self . batch_size_per_request = 8 # 每个API请求处理的句子数量( Ollama建议较小批次)
self . max_concurrent_requests = 2 # 最大并发请求数( Ollama建议较低并发)
self . request_timeout = 180 # 请求超时时间(秒)
self . retry_attempts = 3 # 重试次数
2025-05-26 23:09:03 +08:00
# 统计信息
self . total_requests = 0
self . successful_requests = 0
self . failed_requests = 0
2025-05-29 19:30:19 +08:00
logger . info ( f " 处理器初始化完成 - 模型: { self . model_name } , 批次大小: { self . batch_size_per_request } , 并发数: { self . max_concurrent_requests } " )
2025-05-26 23:09:03 +08:00
2025-05-23 15:47:17 +08:00
# 扩展的Wikidata属性映射
self . property_mappings = {
# 基本关系
" http://www.wikidata.org/prop/direct/P31 " : " is a " ,
" http://www.wikidata.org/prop/direct/P279 " : " is a type of " ,
# 人物相关
" http://www.wikidata.org/prop/direct/P106 " : " works as " ,
" http://www.wikidata.org/prop/direct/P27 " : " is a citizen of " ,
" http://www.wikidata.org/prop/direct/P19 " : " was born in " ,
" http://www.wikidata.org/prop/direct/P20 " : " died in " ,
" http://www.wikidata.org/prop/direct/P569 " : " was born on " ,
" http://www.wikidata.org/prop/direct/P570 " : " died on " ,
" http://www.wikidata.org/prop/direct/P22 " : " has father " ,
" http://www.wikidata.org/prop/direct/P25 " : " has mother " ,
" http://www.wikidata.org/prop/direct/P26 " : " is married to " ,
# 组织相关
" http://www.wikidata.org/prop/direct/P102 " : " is a member of " ,
" http://www.wikidata.org/prop/direct/P108 " : " works for " ,
" http://www.wikidata.org/prop/direct/P159 " : " has headquarters in " ,
" http://www.wikidata.org/prop/direct/P112 " : " was founded by " ,
" http://www.wikidata.org/prop/direct/P571 " : " was founded in " ,
" http://www.wikidata.org/prop/direct/P169 " : " has CEO " ,
# 地理相关
" http://www.wikidata.org/prop/direct/P17 " : " is located in " ,
" http://www.wikidata.org/prop/direct/P131 " : " is located in " ,
" http://www.wikidata.org/prop/direct/P36 " : " has capital " ,
" http://www.wikidata.org/prop/direct/P47 " : " borders " ,
# 其他关系
" http://www.wikidata.org/prop/direct/P1142 " : " has ideology " ,
" http://www.wikidata.org/prop/direct/P361 " : " is part of " ,
" http://www.wikidata.org/prop/direct/P737 " : " was influenced by " ,
" http://www.wikidata.org/prop/direct/P127 " : " is owned by " ,
" http://www.wikidata.org/prop/direct/P155 " : " follows " ,
" http://www.wikidata.org/prop/direct/P156 " : " is followed by " ,
" http://www.wikidata.org/prop/direct/P138 " : " is named after "
}
2025-05-29 19:30:19 +08:00
def get_system_prompt ( self ) - > str :
""" 获取系统提示 """
return """ You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge.
### Sentence Correction Rules:
1. Remove Wikipedia - specific markers : such as ( disambiguation ) , ( film ) , ( band ) , etc . in parentheses
2. Ensure grammatical completeness : complete subject + predicate + object structure , avoid dangling ' and is ' , ' or ' , etc .
3. Fix obvious grammatical errors : tense consistency , singular / plural consistency , correct preposition usage
4. Clean up garbled text and special characters : such as â , € , ™ and other encoding issues
5. Ensure semantic fluency : if the original sentence cannot be fixed , reorganize the language to make it coherent
6. Do not add information not present in the original text , only correct errors
### Correction Examples:
- Error : ' Argument (disambiguation) is related to philosophy, logic, and is an. '
- Corrected : ' Argument is related to philosophy and logic. '
- Error : ' Beijing is a capital city and are. '
- Corrected : ' Beijing is a capital city. '
Importance scoring criteria ( 0.0 - 10.0 , in increments of 0.1 ) :
0.0 points - Completely incorrect or meaningless information
Examples : ' Apple is a metal ' , ' The sun rises from the west ' , ' 1+1=3 '
0.5 points - Almost worthless information
Examples : ' Color of a fictional character ' s socks ' , ' Third line of dialogue from a game NPC ' , ' What someone had for breakfast yesterday '
1.0 points - Extremely rare , non - practical knowledge
Examples : ' Pet name of a minor novel character ' , ' Content of the 15th line in movie end credits ' , ' Nickname of website user ID 123456 '
1.5 points - Very niche detailed information
Examples : ' Outfit of a passerby at minute 37 in a movie ' , ' Duration of background music in a game ' s hidden level ' , ' Content of the 3 rd dialogue box on page 200 of a manga '
2.0 points - Details in niche professional fields
Examples : ' Color change of rare minerals at specific temperatures ' , ' Length of an insect ' s third antenna ' , ' Molecular formula of chemical reaction byproducts '
2.5 points - Technical details only professionals care about
Examples : ' Release date of specific software library version ' , ' Time complexity coefficient of an algorithm ' , ' Thermal expansion coefficient of a material '
3.0 points - Professional knowledge in specific fields
Examples : ' Programming language syntax features ' , ' Gene sequence of a virus ' , ' Official system of ancient dynasties '
3.5 points - Professional information with some value
Examples : ' Specific system of historical dynasty ' , ' Mechanism of action of a drug ' , ' Development time of a technical standard '
4.0 points - Meaningful knowledge known by few
Examples : ' Unique cultural traditions of a country ' , ' Important discoveries by a scientist ' , ' Detailed process of historical events '
4.5 points - Knowledge of interest to some groups
Examples : ' Author ' s creative background ' , ' Characteristics of an art movement ' , ' Detailed rules of a sport '
5.0 points - General knowledge of moderate importance
Examples : ' Famous attractions in cities ' , ' Development history of a company ' , ' Living habits of animals '
5.5 points - Fairly useful common sense
Examples : ' Plant growth environment ' , ' Healthy eating common sense ' , ' Basic first aid knowledge '
6.0 points - Knowledge most educated people should know
Examples : ' Shakespeare ' s representative works ' , ' Basic geometric theorems ' , ' Major world currencies '
6.5 points - Important cultural or scientific common sense
Examples : ' Basic structure of DNA ' , ' Newton ' s three laws ' , ' Major world religions '
7.0 points - Important foundational knowledge
Examples : ' Time period of World War II ' , ' Functions of major human organs ' , ' Basic mathematical operation rules '
7.5 points - Very important common sense
Examples : ' Light speed is the fastest in the universe ' , ' Earth is round ' , ' Basic principles of blood circulation '
8.0 points - Core knowledge in basic education
Examples : ' Earth orbits the sun ' , ' Principle of seasonal formation ' , ' Basic grammar rules '
8.5 points - Important knowledge everyone should master
Examples : ' Chemical formula of water H2O ' , ' Basic safety common sense ' , ' Simple mathematical calculations '
9.0 points - Extremely important basic concepts
Examples : ' Humans need oxygen to survive ' , ' Fire is hot ' , ' Basic directional concepts '
9.5 points - Core knowledge everyone must know
Examples : ' A day has 24 hours ' , ' A year has 12 months ' , ' Basic number concepts '
10.0 points - Most basic and important common sense
Examples : ' Humans need food and water to survive ' , ' The sky is blue ' , ' Stones are heavier than feathers '
When scoring , please consider :
1. Popularity of knowledge - How many people know this knowledge
2. Practical value - How useful this knowledge is in daily life
3. Educational importance - The position of this knowledge in the education system
4. Cultural significance - The importance of this knowledge for understanding world
Please respond with valid JSON in the following format :
{
" corrected_sentence " : " corrected sentence here " ,
" importance_score " : evaluation score
} """
async def process_batch_with_vllm_api ( self , sentences : List [ str ] ) - > List [ Dict [ str , Any ] ] :
""" 使用vLLM OpenAI兼容API处理一批句子 """
processed_sentences = [ ]
async with aiohttp . ClientSession ( ) as session :
# 创建并发任务
semaphore = asyncio . Semaphore ( self . max_concurrent_requests )
tasks = [ ]
# 将句子分成小批次
for i in range ( 0 , len ( sentences ) , self . batch_size_per_request ) :
batch_sentences = sentences [ i : i + self . batch_size_per_request ]
task = self . process_single_batch_request ( session , semaphore , batch_sentences , i )
tasks . append ( task )
# 等待所有任务完成
batch_results = await asyncio . gather ( * tasks , return_exceptions = True )
# 收集结果
for result in batch_results :
if isinstance ( result , Exception ) :
logger . error ( f " 批次处理出错: { result } " )
continue
if result :
processed_sentences . extend ( result )
return processed_sentences
async def process_single_batch_request ( self , session : aiohttp . ClientSession , semaphore : asyncio . Semaphore ,
sentences : List [ str ] , batch_index : int ) - > List [ Dict [ str , Any ] ] :
""" 处理单个批次的API请求 """
async with semaphore :
for attempt in range ( self . retry_attempts ) :
try :
# 为每个句子创建单独的消息
messages = [ ]
for sentence in sentences :
messages . append ( {
" role " : " user " ,
" content " : f " Please correct the errors in the following sentence and evaluate its importance: { sentence } "
} )
# 构建Ollama请求数据
request_data = {
" model " : self . model_name ,
" messages " : [
{ " role " : " system " , " content " : self . get_system_prompt ( ) }
] + messages ,
" stream " : False ,
" options " : {
" temperature " : 0.2 ,
" num_predict " : 500 * len ( sentences ) # 为每个句子分配足够的token
} ,
" format " : " json " # Ollama的JSON格式参数
2025-05-23 15:47:17 +08:00
}
2025-05-29 19:30:19 +08:00
# 发送请求到Ollama
async with session . post (
f ' { self . ollama_base_url } /api/chat ' ,
json = request_data ,
timeout = aiohttp . ClientTimeout ( total = self . request_timeout )
) as response :
if response . status == 200 :
result = await response . json ( )
return self . parse_ollama_response ( result , sentences , batch_index )
else :
error_text = await response . text ( )
logger . warning ( f " API请求失败 (批次 { batch_index } , 尝试 { attempt + 1 } / { self . retry_attempts } ): { response . status } - { error_text } " )
if attempt == self . retry_attempts - 1 : # 最后一次尝试
logger . error ( f " 批次 { batch_index } 在 { self . retry_attempts } 次尝试后仍然失败 " )
self . failed_requests + = len ( sentences )
return self . create_default_responses ( sentences )
else :
# 等待后重试
await asyncio . sleep ( 2 * * attempt ) # 指数退避
continue
except asyncio . TimeoutError :
logger . warning ( f " 批次 { batch_index } 请求超时 (尝试 { attempt + 1 } / { self . retry_attempts } ) " )
if attempt == self . retry_attempts - 1 :
logger . error ( f " 批次 { batch_index } 在 { self . retry_attempts } 次尝试后仍然超时 " )
self . failed_requests + = len ( sentences )
return self . create_default_responses ( sentences )
else :
await asyncio . sleep ( 2 * * attempt )
continue
except Exception as e :
logger . warning ( f " 处理批次 { batch_index } 时出错 (尝试 { attempt + 1 } / { self . retry_attempts } ): { e } " )
if attempt == self . retry_attempts - 1 :
logger . error ( f " 批次 { batch_index } 在 { self . retry_attempts } 次尝试后仍然失败 " )
self . failed_requests + = len ( sentences )
return self . create_default_responses ( sentences )
else :
await asyncio . sleep ( 2 * * attempt )
continue
# 如果所有重试都失败了
return self . create_default_responses ( sentences )
2025-05-23 15:47:17 +08:00
2025-05-29 19:30:19 +08:00
def parse_ollama_response ( self , response : Dict [ str , Any ] , original_sentences : List [ str ] , batch_index : int ) - > List [ Dict [ str , Any ] ] :
""" 解析Ollama响应 """
processed_sentences = [ ]
try :
# Ollama的响应格式
message = response . get ( ' message ' , { } )
content = message . get ( ' content ' , ' ' )
if not content :
logger . warning ( f " 批次 { batch_index } 没有返回内容 " )
return self . create_default_responses ( original_sentences )
# 尝试解析JSON响应
2025-05-26 23:09:03 +08:00
try :
2025-05-29 19:30:19 +08:00
# 如果返回的是单个JSON对象
if content . strip ( ) . startswith ( ' { ' ) and content . strip ( ) . endswith ( ' } ' ) :
response_data = json . loads ( content )
processed_sentence = ProcessedSentence (
corrected_sentence = response_data . get ( ' corrected_sentence ' , original_sentences [ 0 ] if original_sentences else " " ) ,
importance_score = float ( response_data . get ( ' importance_score ' , 5.0 ) )
)
processed_sentences . append ( {
" original_sentence " : original_sentences [ 0 ] if original_sentences else " " ,
" corrected_sentence " : processed_sentence . corrected_sentence ,
" importance_score " : processed_sentence . importance_score
} )
self . successful_requests + = 1
# 如果有多个句子但只返回一个结果,为其他句子创建默认响应
for i in range ( 1 , len ( original_sentences ) ) :
processed_sentences . append ( {
" original_sentence " : original_sentences [ i ] ,
" corrected_sentence " : original_sentences [ i ] ,
" importance_score " : 5.0
} )
self . failed_requests + = 1
else :
# 尝试解析多个JSON对象
json_objects = [ ]
for line in content . split ( ' \n ' ) :
line = line . strip ( )
if line . startswith ( ' { ' ) and line . endswith ( ' } ' ) :
try :
json_objects . append ( json . loads ( line ) )
except :
continue
if json_objects :
for i , ( sentence , json_obj ) in enumerate ( zip ( original_sentences , json_objects ) ) :
try :
processed_sentence = ProcessedSentence (
corrected_sentence = json_obj . get ( ' corrected_sentence ' , sentence ) ,
importance_score = float ( json_obj . get ( ' importance_score ' , 5.0 ) )
)
processed_sentences . append ( {
" original_sentence " : sentence ,
" corrected_sentence " : processed_sentence . corrected_sentence ,
" importance_score " : processed_sentence . importance_score
} )
self . successful_requests + = 1
except Exception as e :
logger . warning ( f " 解析JSON对象失败: { e } " )
processed_sentences . append ( {
" original_sentence " : sentence ,
" corrected_sentence " : sentence ,
" importance_score " : 5.0
} )
self . failed_requests + = 1
# 为剩余句子创建默认响应
for i in range ( len ( json_objects ) , len ( original_sentences ) ) :
processed_sentences . append ( {
" original_sentence " : original_sentences [ i ] ,
" corrected_sentence " : original_sentences [ i ] ,
" importance_score " : 5.0
} )
self . failed_requests + = 1
else :
logger . warning ( f " 批次 { batch_index } 无法解析JSON响应: { content } " )
return self . create_default_responses ( original_sentences )
except ( json . JSONDecodeError , ValueError ) as e :
logger . warning ( f " 批次 { batch_index } 解析响应JSON失败: { e } " )
logger . warning ( f " 原始内容: { content } " )
return self . create_default_responses ( original_sentences )
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
except Exception as e :
logger . error ( f " 解析批次 { batch_index } 响应时出错: { e } " )
return self . create_default_responses ( original_sentences )
return processed_sentences
def create_default_responses ( self , sentences : List [ str ] ) - > List [ Dict [ str , Any ] ] :
""" 为失败的请求创建默认响应 """
default_responses = [ ]
for sentence in sentences :
default_responses . append ( {
" original_sentence " : sentence ,
" corrected_sentence " : sentence ,
" importance_score " : 5.0
} )
return default_responses
async def process_sentences_with_vllm_api ( self , sentences : List [ str ] ) - > List [ Dict [ str , Any ] ] :
""" 使用Ollama API处理句子 """
logger . info ( f " 开始使用Ollama API处理 { len ( sentences ) } 个句子... " )
print ( f " 开始使用Ollama API处理 { len ( sentences ) } 个句子... " )
start_time = time . time ( )
total_sentences = len ( sentences )
total_processed_count = 0
# 检查Ollama服务状态
if not self . check_ollama_status ( ) :
logger . error ( " Ollama服务状态异常, 无法继续处理 " )
print ( " 错误: Ollama服务状态异常, 请检查服务是否正常运行 " )
return [ ]
# 分大批次处理(用于检查点保存)
large_batch_size = 1000 # 每1000个句子保存一次检查点
all_processed_sentences = [ ]
for large_batch_start in range ( 0 , total_sentences , large_batch_size ) :
large_batch_end = min ( large_batch_start + large_batch_size , total_sentences )
large_batch_sentences = sentences [ large_batch_start : large_batch_end ]
large_batch_number = large_batch_start / / large_batch_size + 1
logger . info ( f " === 处理大批次 { large_batch_number } ( { large_batch_start + 1 } - { large_batch_end } / { total_sentences } ) === " )
print ( f " \n === 处理大批次 { large_batch_number } ( { large_batch_start + 1 } - { large_batch_end } / { total_sentences } ) === " )
large_batch_start_time = time . time ( )
# 处理当前大批次
batch_processed = await self . process_batch_with_vllm_api ( large_batch_sentences )
all_processed_sentences . extend ( batch_processed )
total_processed_count + = len ( batch_processed )
# 保存当前大批次的检查点
checkpoint_filename = self . save_batch_checkpoint ( batch_processed , large_batch_number , total_processed_count )
# 打印进度
large_batch_time = time . time ( ) - large_batch_start_time
elapsed_time = time . time ( ) - start_time
logger . info ( f " 大批次 { large_batch_number } 处理完成! " )
logger . info ( f " - 当前批次:成功 { len ( batch_processed ) } ,用时 { large_batch_time / 60 : .1f } 分钟 " )
logger . info ( f " - 总体进度: { total_processed_count } / { total_sentences } ( { total_processed_count / total_sentences * 100 : .1f } %) " )
logger . info ( f " - 已用时间: { elapsed_time / 60 : .1f } 分钟 " )
logger . info ( f " - 批次检查点已保存: { checkpoint_filename } " )
print ( f " 大批次 { large_batch_number } 处理完成! " )
print ( f " - 当前批次:成功 { len ( batch_processed ) } ,用时 { large_batch_time / 60 : .1f } 分钟 " )
print ( f " - 总体进度: { total_processed_count } / { total_sentences } ( { total_processed_count / total_sentences * 100 : .1f } %) " )
print ( f " - 已用时间: { elapsed_time / 60 : .1f } 分钟 " )
print ( f " - 批次检查点已保存: { checkpoint_filename } " )
if large_batch_end < total_sentences :
remaining_sentences = total_sentences - total_processed_count
avg_time_per_sentence = elapsed_time / total_processed_count
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger . info ( f " - 预估剩余时间: { estimated_remaining_time / 60 : .1f } 分钟 " )
print ( f " - 预估剩余时间: { estimated_remaining_time / 60 : .1f } 分钟 " )
# 打印最终统计
total_time = time . time ( ) - start_time
logger . info ( f " === 全部处理完成!=== " )
logger . info ( f " - 总成功: { self . successful_requests } " )
logger . info ( f " - 总失败: { self . failed_requests } " )
logger . info ( f " - 总用时: { total_time / 60 : .1f } 分钟 " )
logger . info ( f " - 平均处理速度: { total_processed_count / total_time : .2f } 句/秒 " )
print ( f " \n === 全部处理完成!=== " )
print ( f " - 总成功: { self . successful_requests } " )
print ( f " - 总失败: { self . failed_requests } " )
print ( f " - 总用时: { total_time / 60 : .1f } 分钟 " )
print ( f " - 平均处理速度: { total_processed_count / total_time : .2f } 句/秒 " )
return all_processed_sentences
def check_ollama_status ( self ) - > bool :
""" 检查Ollama服务是否正常运行 """
try :
# 检查Ollama API是否响应
response = requests . get ( f ' { self . ollama_base_url } /api/tags ' , timeout = 10 )
if response . status_code == 200 :
models = response . json ( )
model_names = [ model . get ( ' name ' , ' unknown ' ) for model in models . get ( ' models ' , [ ] ) ]
logger . info ( f " Ollama服务状态正常, 可用模型: { model_names } " )
2025-05-23 15:47:17 +08:00
2025-05-29 19:30:19 +08:00
# 检查目标模型是否可用
if self . model_name in model_names :
logger . info ( f " 目标模型 { self . model_name } 可用 " )
return True
2025-05-26 23:09:03 +08:00
else :
2025-05-29 19:30:19 +08:00
logger . warning ( f " 目标模型 { self . model_name } 不在可用模型列表中: { model_names } " )
logger . info ( " 尝试拉取模型... " )
# 尝试拉取模型
try :
pull_response = requests . post (
f ' { self . ollama_base_url } /api/pull ' ,
json = { " name " : self . model_name } ,
timeout = 300 # 5分钟超时
)
if pull_response . status_code == 200 :
logger . info ( f " 成功拉取模型 { self . model_name } " )
return True
else :
logger . error ( f " 拉取模型失败: { pull_response . status_code } " )
return False
except Exception as e :
logger . error ( f " 拉取模型时出错: { e } " )
return False
else :
logger . error ( f " Ollama API响应异常, 状态码: { response . status_code } " )
return False
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
except requests . exceptions . RequestException as e :
logger . error ( f " 无法连接到Ollama API: { e } " )
return False
except Exception as e :
logger . error ( f " 检查Ollama状态时出错: { e } " )
return False
2025-05-23 15:47:17 +08:00
def clean_text ( self , text : str ) - > str :
""" 清理文本,处理特殊字符 """
if not text :
return " "
# 处理常见的Unicode字符
text = text . replace ( " – " , " - " ) # en dash
text = text . replace ( " — " , " - " ) # em dash
text = text . replace ( " ' " , " ' " ) # right single quotation mark
text = text . replace ( " ' " , " ' " ) # left single quotation mark
text = text . replace ( """ , ' " ' ) # left double quotation mark
text = text . replace ( """ , ' " ' ) # right double quotation mark
# 处理可能的转义序列
try :
text = text . encode ( ' utf-8 ' ) . decode ( ' utf-8 ' )
except :
pass
# 清理多余的空格
text = re . sub ( r ' \ s+ ' , ' ' , text ) . strip ( )
# 移除可能的引号
text = text . strip ( ' " \' ' )
return text
def parse_large_json_file ( self , file_path : str ) - > List [ Dict ] :
""" 解析大型JSON文件, 处理可能的格式问题 """
documents = [ ]
try :
with open ( file_path , ' r ' , encoding = ' utf-8 ' ) as f :
content = f . read ( ) . strip ( )
# 尝试不同的解析方法
if content . startswith ( ' [ ' ) and content . endswith ( ' ] ' ) :
# 标准JSON数组
documents = json . loads ( content )
else :
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if ' } { " ' in content :
json_strings = content . split ( ' } { ' )
json_strings [ 0 ] + = ' } ' # 第一个对象
json_strings [ - 1 ] = ' { ' + json_strings [ - 1 ] # 最后一个对象
for i in range ( 1 , len ( json_strings ) - 1 ) :
json_strings [ i ] = ' { ' + json_strings [ i ] + ' } '
for json_str in json_strings :
try :
doc = json . loads ( json_str )
documents . append ( doc )
except json . JSONDecodeError :
continue
else :
# 尝试作为单个JSON对象
try :
documents = [ json . loads ( content ) ]
except json . JSONDecodeError :
pass
except Exception as e :
print ( f " Error parsing { file_path } : { e } " )
return documents
def extract_sentences_from_document ( self , doc : Dict [ str , Any ] ) - > List [ str ] :
""" 从文档中提取句子 """
sentences = [ ]
title = self . clean_text ( doc . get ( ' title ' , ' ' ) )
text = self . clean_text ( doc . get ( ' text ' , ' ' ) )
entities = doc . get ( ' entities ' , [ ] )
triples = doc . get ( ' triples ' , [ ] )
# 处理显式三元组
for triple in triples :
sentence = self . triple_to_sentence ( triple )
if sentence :
sentences . append ( sentence )
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len ( sentences ) < 5 :
# 基于标题和实体生成句子
entity_names = [ ]
for entity in entities [ : 15 ] :
entity_name = self . clean_text ( entity . get ( ' surfaceform ' , ' ' ) )
if entity_name and len ( entity_name ) > 2 :
entity_names . append ( entity_name )
# 生成简单的描述句子
if entity_names :
important_entities = [ ]
title_lower = title . lower ( )
for entity in entity_names :
if ( entity . lower ( ) != title_lower and
entity not in important_entities and
not any ( t . lower ( ) in entity . lower ( ) for t in title_lower . split ( ) [ : 2 ] ) ) :
important_entities . append ( entity )
if len ( important_entities ) > = 6 :
break
if important_entities and len ( sentences ) < 3 :
entities_str = ' , ' . join ( important_entities [ : 3 ] )
sentences . append ( f " { title } is related to { entities_str } . " )
return sentences
def triple_to_sentence ( self , triple : Dict [ str , Any ] ) - > str :
""" 将三元组转换为自然语言句子 """
try :
subject = triple . get ( ' subject ' , { } )
predicate = triple . get ( ' predicate ' , { } )
obj = triple . get ( ' object ' , { } )
subject_name = self . clean_text ( subject . get ( ' surfaceform ' , ' ' ) )
object_name = self . clean_text ( obj . get ( ' surfaceform ' , ' ' ) )
predicate_uri = predicate . get ( ' uri ' , ' ' )
# 检查是否有有效的主语和宾语
if not subject_name or not object_name :
return " "
# 检查主语和宾语是否过短或无意义
if len ( subject_name ) < = 2 or len ( object_name ) < = 2 :
return " "
# 获取关系文本
relation_text = self . property_mappings . get ( predicate_uri , " is related to " )
# 避免重复的主语宾语
if subject_name . lower ( ) == object_name . lower ( ) :
return " "
return f " { subject_name } { relation_text } { object_name } . "
except Exception as e :
print ( f " Error converting triple to sentence: { e } " )
return " "
2025-05-29 19:30:19 +08:00
def save_batch_checkpoint ( self , processed_sentences : List [ Dict [ str , Any ] ] , batch_number : int , total_processed_count : int ) - > str :
""" 保存当前批次的检查点文件 """
2025-05-26 23:09:03 +08:00
# 生成检查点文件名, 确保在output目录中
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
2025-05-29 19:30:19 +08:00
checkpoint_filename = os . path . join ( ' output ' , f " { base_name } _batch_ { batch_number } .json " )
2025-05-23 15:47:17 +08:00
# 保存检查点
with open ( checkpoint_filename , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( {
" metadata " : {
2025-05-29 19:30:19 +08:00
" batch_number " : batch_number ,
" batch_size " : len ( processed_sentences ) ,
" total_processed_count " : total_processed_count ,
" timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " )
2025-05-23 15:47:17 +08:00
} ,
" sentences " : processed_sentences
} , f , ensure_ascii = False , indent = 2 )
return checkpoint_filename
async def process_files ( self ) - > List [ Dict [ str , Any ] ] :
""" 处理所有文件 """
json_files = glob . glob ( os . path . join ( self . input_dir , " re-nlg_*.json " ) )
if not json_files :
print ( f " No JSON files found in { self . input_dir } " )
return [ ]
# 排序文件以确保一致的处理顺序
json_files . sort ( )
if self . max_files :
json_files = json_files [ : self . max_files ]
print ( f " Found { len ( json_files ) } JSON files to process " )
all_sentences = [ ]
for i , file_path in enumerate ( json_files ) :
print ( f " Processing file { i + 1 } / { len ( json_files ) } : { os . path . basename ( file_path ) } " )
documents = self . parse_large_json_file ( file_path )
print ( f " Parsed { len ( documents ) } documents " )
for doc in documents :
sentences = self . extract_sentences_from_document ( doc )
all_sentences . extend ( sentences )
print ( f " Generated { len ( all_sentences ) } total raw sentences so far " )
print ( f " 总共提取了 { len ( all_sentences ) } 个原始句子 " )
# 去重
unique_sentences = [ ]
seen = set ( )
for sentence in all_sentences :
sentence = sentence . strip ( )
if sentence and sentence not in seen and len ( sentence ) > 10 :
unique_sentences . append ( sentence )
seen . add ( sentence )
print ( f " 去重后剩余 { len ( unique_sentences ) } 个句子 " )
2025-05-26 23:09:03 +08:00
# 保存原始句子到JSON文件
sentences_data = {
" metadata " : {
" total_sentences " : len ( unique_sentences ) ,
" extraction_timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " ) ,
" source_files " : len ( json_files ) ,
" max_files_limit " : self . max_files
} ,
" sentences " : [ { " sentence " : sentence , " processed " : False } for sentence in unique_sentences ]
}
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
with open ( self . sentences_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( sentences_data , f , ensure_ascii = False , indent = 2 )
print ( f " 句子提取完成!已保存到: { self . sentences_json } " )
print ( f " 总计句子数: { len ( unique_sentences ) } " )
return unique_sentences
2025-05-23 15:47:17 +08:00
def save_sentences ( self , processed_sentences : List [ Dict [ str , Any ] ] ) :
""" 保存处理后的句子到文件 """
# 确保输出目录存在
2025-05-26 23:09:03 +08:00
os . makedirs ( ' output ' , exist_ok = True )
2025-05-23 15:47:17 +08:00
# 保存为JSON格式, 包含完整信息
json_output_file = self . output_file . replace ( ' .txt ' , ' .json ' )
with open ( json_output_file , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( processed_sentences , f , ensure_ascii = False , indent = 2 )
# 保存为简单文本格式(仅修正后的句子)
with open ( self . output_file , ' w ' , encoding = ' utf-8 ' ) as f :
for item in processed_sentences :
f . write ( item [ ' corrected_sentence ' ] + ' \n ' )
# 生成重要性排序文件
importance_sorted = sorted ( processed_sentences , key = lambda x : x [ ' importance_score ' ] , reverse = True )
importance_file = self . output_file . replace ( ' .txt ' , ' _sorted_by_importance.txt ' )
with open ( importance_file , ' w ' , encoding = ' utf-8 ' ) as f :
for item in importance_sorted :
f . write ( f " [ { item [ ' importance_score ' ] : .1f } ] { item [ ' corrected_sentence ' ] } \n " )
print ( f " 保存了 { len ( processed_sentences ) } 个处理后的句子: " )
print ( f " - JSON格式: { json_output_file } " )
print ( f " - 文本格式: { self . output_file } " )
print ( f " - 重要性排序: { importance_file } " )
# 统计信息
scores = [ item [ ' importance_score ' ] for item in processed_sentences ]
avg_score = sum ( scores ) / len ( scores ) if scores else 0
print ( f " - 平均重要性评分: { avg_score : .2f } " )
print ( f " - 最高评分: { max ( scores ) : .1f } " )
print ( f " - 最低评分: { min ( scores ) : .1f } " )
async def run ( self ) :
""" 运行处理流程 """
print ( " Starting enhanced TREx to sentences conversion... " )
processed_sentences = await self . process_files ( )
self . save_sentences ( processed_sentences )
print ( " Enhanced conversion completed! " )
def find_latest_checkpoint ( self ) - > Union [ tuple , None ] :
""" 查找最新的检查点文件 """
2025-05-26 23:09:03 +08:00
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
pattern = os . path . join ( ' output ' , f " { base_name } _checkpoint_*.json " )
2025-05-23 15:47:17 +08:00
checkpoint_files = glob . glob ( pattern )
if not checkpoint_files :
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files :
try :
# 从文件名中提取数字
match = re . search ( r ' checkpoint_( \ d+) \ .json$ ' , file )
if match :
count = int ( match . group ( 1 ) )
if count > latest_count :
latest_count = count
latest_file = file
except :
continue
if latest_file :
return latest_file , latest_count
else :
return None
def load_checkpoint ( self , checkpoint_file : str ) - > List [ Dict [ str , Any ] ] :
""" 从检查点文件加载已处理的句子 """
try :
with open ( checkpoint_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
if ' sentences ' in data :
return data [ ' sentences ' ]
else :
# 旧格式的检查点文件
return data
except Exception as e :
print ( f " 加载检查点文件失败: { e } " )
return [ ]
2025-05-26 23:09:03 +08:00
def get_processed_sentences_from_checkpoints ( self ) - > Set [ str ] :
""" 从检查点文件中获取已处理过的句子集合 """
if not self . output_file :
return set ( )
processed_sentences = set ( )
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
2025-05-29 19:30:19 +08:00
# 首先查找新格式的批次文件
batch_pattern = os . path . join ( ' output ' , f " { base_name } _batch_*.json " )
batch_files = glob . glob ( batch_pattern )
if batch_files :
print ( f " 找到 { len ( batch_files ) } 个批次检查点文件 " )
batch_files . sort ( ) # 确保按顺序处理
for batch_file in batch_files :
try :
with open ( batch_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
sentences_data = data . get ( ' sentences ' , [ ] )
for item in sentences_data :
original_sentence = item . get ( ' original_sentence ' , ' ' )
if original_sentence :
processed_sentences . add ( original_sentence )
batch_number = data . get ( ' metadata ' , { } ) . get ( ' batch_number ' , ' unknown ' )
print ( f " - 从批次 { batch_number } 加载了 { len ( sentences_data ) } 个句子 " )
except Exception as e :
print ( f " 读取批次文件 { batch_file } 失败: { e } " )
continue
print ( f " 从批次文件总计加载了 { len ( processed_sentences ) } 个已处理的句子 " )
logger . info ( f " 从批次文件总计加载了 { len ( processed_sentences ) } 个已处理的句子 " )
return processed_sentences
# 如果没有批次文件,尝试查找旧格式的检查点文件
old_pattern = os . path . join ( ' output ' , f " { base_name } _checkpoint_*.json " )
checkpoint_files = glob . glob ( old_pattern )
2025-05-26 23:09:03 +08:00
if not checkpoint_files :
print ( " 未找到检查点文件,将从头开始处理 " )
return set ( )
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files :
try :
match = re . search ( r ' checkpoint_( \ d+) \ .json$ ' , file )
if match :
count = int ( match . group ( 1 ) )
if count > latest_count :
latest_count = count
latest_file = file
except :
continue
if latest_file :
2025-05-29 19:30:19 +08:00
print ( f " 找到旧格式检查点: { latest_file } (包含 { latest_count } 条记录) " )
logger . info ( f " 找到旧格式检查点: { latest_file } (包含 { latest_count } 条记录) " )
2025-05-26 23:09:03 +08:00
try :
with open ( latest_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
sentences_data = data . get ( ' sentences ' , [ ] )
for item in sentences_data :
original_sentence = item . get ( ' original_sentence ' , ' ' )
if original_sentence :
processed_sentences . add ( original_sentence )
2025-05-29 19:30:19 +08:00
print ( f " 从旧格式检查点加载了 { len ( processed_sentences ) } 个已处理的句子 " )
logger . info ( f " 从旧格式检查点加载了 { len ( processed_sentences ) } 个已处理的句子 " )
2025-05-26 23:09:03 +08:00
except Exception as e :
print ( f " 读取检查点文件失败: { e } " )
return set ( )
return processed_sentences
async def process_with_llm ( self ) :
2025-05-29 19:30:19 +08:00
""" 步骤2: 从JSON文件读取句子并进行vLLM处理( 保持兼容性) """
await self . process_with_vllm_api ( )
async def process_with_vllm_api ( self ) :
""" 步骤2: 从JSON文件读取句子并进行vLLM处理 """
2025-05-26 23:09:03 +08:00
if not self . enable_llm_processing :
print ( " Error: LLM processing is disabled! " )
return
if not self . output_file :
print ( " Error: output_file is required for LLM processing! " )
return
2025-05-29 19:30:19 +08:00
print ( " === 步骤2: vLLM处理 === " )
2025-05-26 23:09:03 +08:00
# 读取句子JSON文件
if not os . path . exists ( self . sentences_json ) :
print ( f " Error: Sentences file { self . sentences_json } not found! " )
print ( " 请先运行步骤1进行句子提取 " )
return
print ( f " 正在读取句子文件: { self . sentences_json } " )
try :
with open ( self . sentences_json , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
all_sentences = [ item [ " sentence " ] for item in data . get ( " sentences " , [ ] ) ]
print ( f " 从文件中读取了 { len ( all_sentences ) } 个句子 " )
except Exception as e :
print ( f " 读取句子文件失败: { e } " )
return
# 获取已处理的句子
processed_sentences_set = self . get_processed_sentences_from_checkpoints ( )
# 过滤出未处理的句子
unprocessed_sentences = [ ]
for sentence in all_sentences :
if sentence not in processed_sentences_set :
unprocessed_sentences . append ( sentence )
print ( f " 需要处理的句子数: { len ( unprocessed_sentences ) } (跳过已处理: { len ( processed_sentences_set ) } ) " )
logger . info ( f " 需要处理的句子数: { len ( unprocessed_sentences ) } (跳过已处理: { len ( processed_sentences_set ) } ) " )
if not unprocessed_sentences :
print ( " 所有句子都已处理完成! " )
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set :
latest_checkpoint = self . find_latest_checkpoint ( )
if latest_checkpoint :
checkpoint_file , _ = latest_checkpoint
processed_data = self . load_checkpoint ( checkpoint_file )
self . save_sentences ( processed_data )
print ( " 已从检查点生成最终输出文件 " )
return
# 处理未处理的句子
2025-05-29 19:30:19 +08:00
print ( " 开始vLLM处理... " )
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
# 处理新句子(现在返回空列表,数据保存在批次检查点中)
await self . process_sentences_with_vllm_api ( unprocessed_sentences )
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
# 处理完成后,合并所有批次检查点生成最终文件
print ( " 合并所有批次检查点生成最终文件... " )
all_processed_sentences = self . merge_all_batch_checkpoints ( )
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
if all_processed_sentences :
# 保存最终结果
self . save_sentences ( all_processed_sentences )
print ( " vLLM处理完成! " )
2025-05-26 23:09:03 +08:00
else :
2025-05-29 19:30:19 +08:00
print ( " 警告:没有找到任何处理结果 " )
2025-05-26 23:09:03 +08:00
2025-05-29 19:30:19 +08:00
def merge_all_batch_checkpoints ( self ) - > List [ Dict [ str , Any ] ] :
""" 合并所有批次检查点文件 """
if not self . output_file :
return [ ]
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
# 查找所有批次检查点文件
batch_pattern = os . path . join ( ' output ' , f " { base_name } _batch_*.json " )
batch_files = glob . glob ( batch_pattern )
if not batch_files :
# 如果没有批次文件,尝试查找旧格式的检查点文件
old_pattern = os . path . join ( ' output ' , f " { base_name } _checkpoint_*.json " )
old_files = glob . glob ( old_pattern )
if old_files :
print ( " 找到旧格式检查点文件,尝试读取最新的... " )
latest_checkpoint = self . find_latest_checkpoint ( )
if latest_checkpoint :
checkpoint_file , _ = latest_checkpoint
return self . load_checkpoint ( checkpoint_file )
return [ ]
print ( f " 找到 { len ( batch_files ) } 个批次检查点文件 " )
all_sentences = [ ]
batch_files . sort ( ) # 确保按顺序处理
for batch_file in batch_files :
try :
with open ( batch_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
batch_sentences = data . get ( ' sentences ' , [ ] )
all_sentences . extend ( batch_sentences )
batch_number = data . get ( ' metadata ' , { } ) . get ( ' batch_number ' , ' unknown ' )
batch_size = len ( batch_sentences )
print ( f " - 批次 { batch_number } : { batch_size } 个句子 " )
except Exception as e :
print ( f " 读取批次文件 { batch_file } 失败: { e } " )
continue
print ( f " 总计合并了 { len ( all_sentences ) } 个句子 " )
return all_sentences
2025-05-26 23:09:03 +08:00
def extract_sentences ( self ) :
""" 步骤1: 从TREx数据集提取句子并保存为JSON """
if not self . input_dir :
print ( " Error: input_dir is required for sentence extraction! " )
return
print ( " === 步骤1: 句子提取 === " )
print ( " 开始从TREx数据集提取句子... " )
json_files = glob . glob ( os . path . join ( self . input_dir , " re-nlg_*.json " ) )
if not json_files :
print ( f " No JSON files found in { self . input_dir } " )
return
# 排序文件以确保一致的处理顺序
json_files . sort ( )
if self . max_files :
json_files = json_files [ : self . max_files ]
print ( f " Found { len ( json_files ) } JSON files to process " )
all_sentences = [ ]
for i , file_path in enumerate ( json_files ) :
print ( f " Processing file { i + 1 } / { len ( json_files ) } : { os . path . basename ( file_path ) } " )
documents = self . parse_large_json_file ( file_path )
print ( f " Parsed { len ( documents ) } documents " )
for doc in documents :
sentences = self . extract_sentences_from_document ( doc )
all_sentences . extend ( sentences )
print ( f " Generated { len ( all_sentences ) } total raw sentences so far " )
print ( f " 总共提取了 { len ( all_sentences ) } 个原始句子 " )
# 去重
unique_sentences = [ ]
seen = set ( )
for sentence in all_sentences :
sentence = sentence . strip ( )
if sentence and sentence not in seen and len ( sentence ) > 10 :
unique_sentences . append ( sentence )
seen . add ( sentence )
print ( f " 去重后剩余 { len ( unique_sentences ) } 个句子 " )
# 保存原始句子到JSON文件
sentences_data = {
" metadata " : {
" total_sentences " : len ( unique_sentences ) ,
" extraction_timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " ) ,
" source_files " : len ( json_files ) ,
" max_files_limit " : self . max_files
} ,
" sentences " : [ { " sentence " : sentence , " processed " : False } for sentence in unique_sentences ]
}
with open ( self . sentences_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( sentences_data , f , ensure_ascii = False , indent = 2 )
print ( f " 句子提取完成!已保存到: { self . sentences_json } " )
print ( f " 总计句子数: { len ( unique_sentences ) } " )
return unique_sentences
2025-05-23 15:47:17 +08:00
def main ( ) :
""" 主函数 """
import argparse
2025-05-29 19:30:19 +08:00
parser = argparse . ArgumentParser ( description = ' Convert TREx dataset to enhanced sentences with vLLM processing ' )
2025-05-26 23:09:03 +08:00
# 选择运行模式
parser . add_argument ( ' --step ' , choices = [ ' extract ' , ' llm ' , ' all ' ] , default = ' llm ' ,
2025-05-29 19:30:19 +08:00
help = ' 运行步骤: extract=仅提取句子, llm=仅vLLM处理, all=完整流程 ' )
2025-05-26 23:09:03 +08:00
# 文件路径参数
2025-05-23 15:47:17 +08:00
parser . add_argument ( ' --input_dir ' , default = ' dataset/TREx ' , help = ' Input directory containing TREx JSON files ' )
2025-05-26 23:09:03 +08:00
parser . add_argument ( ' --sentences_json ' , default = ' extracted_sentences.json ' , help = ' JSON file for extracted sentences (will be saved in output/) ' )
parser . add_argument ( ' --output_file ' , default = ' trex_sentences_enhanced.txt ' , help = ' Output file path (will be saved in output/) ' )
# 处理参数
2025-05-23 15:47:17 +08:00
parser . add_argument ( ' --max_files ' , type = int , help = ' Maximum number of files to process (for testing) ' )
2025-05-29 19:30:19 +08:00
parser . add_argument ( ' --no_llm ' , action = ' store_true ' , help = ' Disable vLLM processing (basic mode) ' )
2025-05-23 15:47:17 +08:00
args = parser . parse_args ( )
2025-05-26 23:09:03 +08:00
# 根据步骤验证参数
if args . step in [ ' extract ' , ' all ' ] :
if not os . path . exists ( args . input_dir ) :
print ( f " Error: Input directory { args . input_dir } does not exist! " )
return
if args . step in [ ' llm ' , ' all ' ] :
if args . no_llm :
2025-05-29 19:30:19 +08:00
print ( " Error: Cannot run vLLM step with --no_llm flag! " )
2025-05-26 23:09:03 +08:00
return
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
# 创建处理器
2025-05-23 15:47:17 +08:00
processor = EnhancedTRExProcessor (
2025-05-26 23:09:03 +08:00
input_dir = args . input_dir ,
sentences_json = args . sentences_json ,
output_file = args . output_file ,
max_files = args . max_files ,
2025-05-23 15:47:17 +08:00
enable_llm_processing = not args . no_llm
)
2025-05-26 23:09:03 +08:00
# 根据选择的步骤运行
if args . step == ' extract ' :
print ( " === 运行模式:仅句子提取 === " )
processor . extract_sentences ( )
elif args . step == ' llm ' :
2025-05-29 19:30:19 +08:00
print ( " === 运行模式: 仅vLLM处理 === " )
asyncio . run ( processor . process_with_vllm_api ( ) )
2025-05-26 23:09:03 +08:00
elif args . step == ' all ' :
print ( " === 运行模式:完整流程 === " )
# 步骤1: 提取句子
print ( " \n --- 开始步骤1: 句子提取 --- " )
sentences = processor . extract_sentences ( )
if not sentences :
print ( " 句子提取失败,退出 " )
return
if args . no_llm :
2025-05-29 19:30:19 +08:00
print ( " vLLM处理已禁用, 流程结束 " )
2025-05-26 23:09:03 +08:00
return
2025-05-29 19:30:19 +08:00
# 步骤2: vLLM处理
print ( " \n --- 开始步骤2: vLLM处理 --- " )
asyncio . run ( processor . process_with_vllm_api ( ) )
2025-05-23 15:47:17 +08:00
if __name__ == " __main__ " :
main ( )