2025-05-23 15:47:17 +08:00
#!/usr/bin/env python3
"""
TREx数据集增强预处理脚本
使用agno框架和ollama qwen3 : 4 b进行句子后处理和重要性评分
2025-05-26 23:09:03 +08:00
支持两个独立步骤 :
1. 句子提取 : 从TREx数据集提取句子并保存为JSON
2. LLM处理 : 读取JSON文件进行LLM后处理和重要性评分
2025-05-23 15:47:17 +08:00
"""
import json
import os
import glob
2025-05-26 23:09:03 +08:00
from typing import List , Dict , Any , Union , Set
2025-05-23 15:47:17 +08:00
import re
import asyncio
import time
2025-05-26 23:09:03 +08:00
import logging
from datetime import datetime
import subprocess
import requests
2025-05-23 15:47:17 +08:00
from pydantic import BaseModel , Field
from agno . agent import Agent
from agno . models . ollama import Ollama
2025-05-26 23:09:03 +08:00
# 设置日志系统
def setup_logging ( ) :
""" 设置日志系统 """
# 确保logs目录存在
os . makedirs ( ' logs ' , exist_ok = True )
# 创建日志文件名(包含时间戳)
timestamp = datetime . now ( ) . strftime ( ' % Y % m %d _ % H % M % S ' )
log_file = f ' logs/trex_processor_ { timestamp } .log '
# 配置日志格式
log_format = ' %(asctime)s - %(levelname)s - [ %(funcName)s : %(lineno)d ] - %(message)s '
# 配置root logger
logging . basicConfig (
level = logging . INFO ,
format = log_format ,
handlers = [
logging . FileHandler ( log_file , encoding = ' utf-8 ' ) ,
logging . StreamHandler ( ) # 同时输出到控制台
]
)
# 获取logger
logger = logging . getLogger ( __name__ )
logger . info ( f " 日志系统初始化完成,日志文件: { log_file } " )
return logger
# 全局日志对象
logger = setup_logging ( )
2025-05-23 15:47:17 +08:00
class ProcessedSentence ( BaseModel ) :
""" 处理后的句子结构 """
corrected_sentence : str = Field (
. . . ,
description = " 修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色 "
)
importance_score : float = Field (
. . . ,
description = " 重要性评分, 范围0.0-10.0, 以0.1递进。评判这个知识在现实世界中的常用程度和重要度 " ,
ge = 0.0 ,
le = 10.0
)
class EnhancedTRExProcessor :
2025-05-26 23:09:03 +08:00
def __init__ ( self , input_dir : str = None , output_file : str = None , max_files : int = None ,
sentences_json : str = None , enable_llm_processing : bool = True ) :
2025-05-23 15:47:17 +08:00
self . input_dir = input_dir
2025-05-26 23:09:03 +08:00
# 确保output目录存在
os . makedirs ( ' output ' , exist_ok = True )
# 确保所有输出文件都在output目录中
if output_file :
if not output_file . startswith ( ' output/ ' ) :
self . output_file = os . path . join ( ' output ' , output_file )
else :
self . output_file = output_file
else :
self . output_file = None
if sentences_json :
if not sentences_json . startswith ( ' output/ ' ) :
self . sentences_json = os . path . join ( ' output ' , sentences_json )
else :
self . sentences_json = sentences_json
else :
self . sentences_json = " output/extracted_sentences.json "
2025-05-23 15:47:17 +08:00
self . max_files = max_files
self . enable_llm_processing = enable_llm_processing
2025-05-26 23:09:03 +08:00
# LLM处理配置
self . llm_timeout = 60 # 增加每个请求的超时时间到60秒
self . max_concurrent = 8 # 进一步降低并发数到4
self . max_retries = 2 # 减少重试次数避免过长等待
self . heartbeat_interval = 30 # 缩短心跳检测间隔到30秒
# 统计信息
self . total_requests = 0
self . successful_requests = 0
self . failed_requests = 0
self . timeout_requests = 0
self . last_successful_time = time . time ( )
self . last_activity_time = time . time ( ) # 新增:最后活动时间
# 初始化agno agent( 仅在需要LLM处理时)
2025-05-23 15:47:17 +08:00
if self . enable_llm_processing :
self . setup_agent ( )
2025-05-26 23:09:03 +08:00
logger . info ( f " 处理器初始化完成 - 并发数: { self . max_concurrent } , 超时时间: { self . llm_timeout } 秒 " )
2025-05-23 15:47:17 +08:00
# 扩展的Wikidata属性映射
self . property_mappings = {
# 基本关系
" http://www.wikidata.org/prop/direct/P31 " : " is a " ,
" http://www.wikidata.org/prop/direct/P279 " : " is a type of " ,
# 人物相关
" http://www.wikidata.org/prop/direct/P106 " : " works as " ,
" http://www.wikidata.org/prop/direct/P27 " : " is a citizen of " ,
" http://www.wikidata.org/prop/direct/P19 " : " was born in " ,
" http://www.wikidata.org/prop/direct/P20 " : " died in " ,
" http://www.wikidata.org/prop/direct/P569 " : " was born on " ,
" http://www.wikidata.org/prop/direct/P570 " : " died on " ,
" http://www.wikidata.org/prop/direct/P22 " : " has father " ,
" http://www.wikidata.org/prop/direct/P25 " : " has mother " ,
" http://www.wikidata.org/prop/direct/P26 " : " is married to " ,
# 组织相关
" http://www.wikidata.org/prop/direct/P102 " : " is a member of " ,
" http://www.wikidata.org/prop/direct/P108 " : " works for " ,
" http://www.wikidata.org/prop/direct/P159 " : " has headquarters in " ,
" http://www.wikidata.org/prop/direct/P112 " : " was founded by " ,
" http://www.wikidata.org/prop/direct/P571 " : " was founded in " ,
" http://www.wikidata.org/prop/direct/P169 " : " has CEO " ,
# 地理相关
" http://www.wikidata.org/prop/direct/P17 " : " is located in " ,
" http://www.wikidata.org/prop/direct/P131 " : " is located in " ,
" http://www.wikidata.org/prop/direct/P36 " : " has capital " ,
" http://www.wikidata.org/prop/direct/P47 " : " borders " ,
# 其他关系
" http://www.wikidata.org/prop/direct/P1142 " : " has ideology " ,
" http://www.wikidata.org/prop/direct/P361 " : " is part of " ,
" http://www.wikidata.org/prop/direct/P737 " : " was influenced by " ,
" http://www.wikidata.org/prop/direct/P127 " : " is owned by " ,
" http://www.wikidata.org/prop/direct/P155 " : " follows " ,
" http://www.wikidata.org/prop/direct/P156 " : " is followed by " ,
" http://www.wikidata.org/prop/direct/P138 " : " is named after "
}
def setup_agent ( self ) :
""" 设置agno agent """
try :
self . agent = Agent (
model = Ollama (
2025-05-26 23:09:03 +08:00
id = " gemma3:latest " ,
2025-05-23 15:47:17 +08:00
# 使用options设置temperature和其他参数
options = {
2025-05-26 23:09:03 +08:00
" temperature " : 0.2 ,
2025-05-23 15:47:17 +08:00
" top_p " : 0.8 ,
" top_k " : 20 ,
" num_ctx " : 4096 ,
}
) ,
response_model = ProcessedSentence ,
instructions = [
2025-05-26 23:09:03 +08:00
" You are a professional text processing assistant responsible for correcting errors in sentences and evaluating the importance of knowledge. " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" ### Sentence Correction Rules: " ,
" 1. Remove Wikipedia-specific markers: such as (disambiguation), (film), (band), etc. in parentheses " ,
" 2. Ensure grammatical completeness: complete subject+predicate+object structure, avoid dangling ' and is ' , ' or ' , etc. " ,
" 3. Fix obvious grammatical errors: tense consistency, singular/plural consistency, correct preposition usage " ,
" 4. Clean up garbled text and special characters: such as â, €, ™ and other encoding issues " ,
" 5. Ensure semantic fluency: if the original sentence cannot be fixed, reorganize the language to make it coherent " ,
" 6. Do not add information not present in the original text, only correct errors " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" ### Correction Examples: " ,
" - Error: ' Argument (disambiguation) is related to philosophy, logic, and is an. ' " ,
" - Corrected: ' Argument is related to philosophy and logic. ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" - Error: ' Beijing is a capital city and are. ' " ,
" - Corrected: ' Beijing is a capital city. ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" Importance scoring criteria (0.0-10.0, in increments of 0.1): " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 0.0 points - Completely incorrect or meaningless information " ,
" Examples: ' Apple is a metal ' , ' The sun rises from the west ' , ' 1+1=3 ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 0.5 points - Almost worthless information " ,
" Examples: ' Color of a fictional character ' s socks ' , ' Third line of dialogue from a game NPC ' , ' What someone had for breakfast yesterday ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 1.0 points - Extremely rare, non-practical knowledge " ,
" Examples: ' Pet name of a minor novel character ' , ' Content of the 15th line in movie end credits ' , ' Nickname of website user ID 123456 ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 1.5 points - Very niche detailed information " ,
" Examples: ' Outfit of a passerby at minute 37 in a movie ' , ' Duration of background music in a game ' s hidden level ' , ' Content of the 3rd dialogue box on page 200 of a manga ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 2.0 points - Details in niche professional fields " ,
" Examples: ' Color change of rare minerals at specific temperatures ' , ' Length of an insect ' s third antenna ' , ' Molecular formula of chemical reaction byproducts ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 2.5 points - Technical details only professionals care about " ,
" Examples: ' Release date of specific software library version ' , ' Time complexity coefficient of an algorithm ' , ' Thermal expansion coefficient of a material ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 3.0 points - Professional knowledge in specific fields " ,
" Examples: ' Programming language syntax features ' , ' Gene sequence of a virus ' , ' Official system of ancient dynasties ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 3.5 points - Professional information with some value " ,
" Examples: ' Specific system of historical dynasty ' , ' Mechanism of action of a drug ' , ' Development time of a technical standard ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 4.0 points - Meaningful knowledge known by few " ,
" Examples: ' Unique cultural traditions of a country ' , ' Important discoveries by a scientist ' , ' Detailed process of historical events ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 4.5 points - Knowledge of interest to some groups " ,
" Examples: ' Author ' s creative background ' , ' Characteristics of an art movement ' , ' Detailed rules of a sport ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 5.0 points - General knowledge of moderate importance " ,
" Examples: ' Famous attractions in cities ' , ' Development history of a company ' , ' Living habits of animals ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 5.5 points - Fairly useful common sense " ,
" Examples: ' Plant growth environment ' , ' Healthy eating common sense ' , ' Basic first aid knowledge ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 6.0 points - Knowledge most educated people should know " ,
" Examples: ' Shakespeare ' s representative works ' , ' Basic geometric theorems ' , ' Major world currencies ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 6.5 points - Important cultural or scientific common sense " ,
" Examples: ' Basic structure of DNA ' , ' Newton ' s three laws ' , ' Major world religions ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 7.0 points - Important foundational knowledge " ,
" Examples: ' Time period of World War II ' , ' Functions of major human organs ' , ' Basic mathematical operation rules ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 7.5 points - Very important common sense " ,
" Examples: ' Light speed is the fastest in the universe ' , ' Earth is round ' , ' Basic principles of blood circulation ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 8.0 points - Core knowledge in basic education " ,
" Examples: ' Earth orbits the sun ' , ' Principle of seasonal formation ' , ' Basic grammar rules ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 8.5 points - Important knowledge everyone should master " ,
" Examples: ' Chemical formula of water H2O ' , ' Basic safety common sense ' , ' Simple mathematical calculations ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 9.0 points - Extremely important basic concepts " ,
" Examples: ' Humans need oxygen to survive ' , ' Fire is hot ' , ' Basic directional concepts ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 9.5 points - Core knowledge everyone must know " ,
" Examples: ' A day has 24 hours ' , ' A year has 12 months ' , ' Basic number concepts ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" 10.0 points - Most basic and important common sense " ,
" Examples: ' Humans need food and water to survive ' , ' The sky is blue ' , ' Stones are heavier than feathers ' " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" When scoring, please consider: " ,
" 1. Popularity of knowledge - How many people know this knowledge " ,
" 2. Practical value - How useful this knowledge is in daily life " ,
" 3. Educational importance - The position of this knowledge in the education system " ,
" 4. Cultural significance - The importance of this knowledge for understanding the world " ,
2025-05-23 15:47:17 +08:00
" " ,
2025-05-26 23:09:03 +08:00
" Please output structured results directly without showing the thinking process. "
2025-05-23 15:47:17 +08:00
] ,
markdown = False
)
2025-05-26 23:09:03 +08:00
logger . info ( " LLM处理器初始化成功 " )
2025-05-23 15:47:17 +08:00
except Exception as e :
2025-05-26 23:09:03 +08:00
logger . error ( f " LLM处理器初始化失败: { e } " )
2025-05-23 15:47:17 +08:00
print ( f " LLM处理器初始化失败: { e } " )
print ( " 将使用基础模式( 不使用LLM后处理) " )
self . enable_llm_processing = False
async def process_sentence_with_llm ( self , sentence : str ) - > ProcessedSentence :
""" 使用LLM处理单个句子( 保留用于单独调用) """
2025-05-26 23:09:03 +08:00
for attempt in range ( self . max_retries ) :
try :
prompt = f " Please correct the errors in the following sentence and evaluate its importance: { sentence } "
# 使用asyncio.wait_for添加超时机制
response = await asyncio . wait_for (
self . agent . arun ( prompt ) ,
timeout = self . llm_timeout
2025-05-23 15:47:17 +08:00
)
2025-05-26 23:09:03 +08:00
# 根据agno文档, response应该直接是ProcessedSentence类型
if isinstance ( response , ProcessedSentence ) :
return response
else :
message = response . messages [ - 1 ] . content
message = message . replace ( " ```json " , " " ) . replace ( " ``` " , " " )
message = json . loads ( message )
return ProcessedSentence (
corrected_sentence = message [ ' corrected_sentence ' ] ,
importance_score = message [ ' importance_score ' ]
)
except asyncio . TimeoutError :
logger . warning ( f " LLM请求超时 (尝试 { attempt + 1 } / { self . max_retries } ): { sentence [ : 50 ] } ... " )
if attempt == self . max_retries - 1 :
logger . error ( f " LLM请求最终超时, 使用默认处理: { sentence [ : 50 ] } ... " )
break
# 等待一段时间后重试
await asyncio . sleep ( 2 * * attempt ) # 指数退避
except Exception as e :
logger . error ( f " LLM处理句子时出错 (尝试 { attempt + 1 } / { self . max_retries } ): { e } " )
if attempt == self . max_retries - 1 :
break
await asyncio . sleep ( 1 )
# 所有重试都失败,返回原句子和中等评分
logger . warning ( f " 使用默认处理: { sentence [ : 50 ] } ... " )
return ProcessedSentence (
corrected_sentence = sentence ,
importance_score = 5.0
)
2025-05-23 15:47:17 +08:00
def clean_text ( self , text : str ) - > str :
""" 清理文本,处理特殊字符 """
if not text :
return " "
# 处理常见的Unicode字符
text = text . replace ( " – " , " - " ) # en dash
text = text . replace ( " — " , " - " ) # em dash
text = text . replace ( " ' " , " ' " ) # right single quotation mark
text = text . replace ( " ' " , " ' " ) # left single quotation mark
text = text . replace ( """ , ' " ' ) # left double quotation mark
text = text . replace ( """ , ' " ' ) # right double quotation mark
# 处理可能的转义序列
try :
text = text . encode ( ' utf-8 ' ) . decode ( ' utf-8 ' )
except :
pass
# 清理多余的空格
text = re . sub ( r ' \ s+ ' , ' ' , text ) . strip ( )
# 移除可能的引号
text = text . strip ( ' " \' ' )
return text
def parse_large_json_file ( self , file_path : str ) - > List [ Dict ] :
""" 解析大型JSON文件, 处理可能的格式问题 """
documents = [ ]
try :
with open ( file_path , ' r ' , encoding = ' utf-8 ' ) as f :
content = f . read ( ) . strip ( )
# 尝试不同的解析方法
if content . startswith ( ' [ ' ) and content . endswith ( ' ] ' ) :
# 标准JSON数组
documents = json . loads ( content )
else :
# 可能是连续的JSON对象
# 尝试在}{"之间分割
if ' } { " ' in content :
json_strings = content . split ( ' } { ' )
json_strings [ 0 ] + = ' } ' # 第一个对象
json_strings [ - 1 ] = ' { ' + json_strings [ - 1 ] # 最后一个对象
for i in range ( 1 , len ( json_strings ) - 1 ) :
json_strings [ i ] = ' { ' + json_strings [ i ] + ' } '
for json_str in json_strings :
try :
doc = json . loads ( json_str )
documents . append ( doc )
except json . JSONDecodeError :
continue
else :
# 尝试作为单个JSON对象
try :
documents = [ json . loads ( content ) ]
except json . JSONDecodeError :
pass
except Exception as e :
print ( f " Error parsing { file_path } : { e } " )
return documents
def extract_sentences_from_document ( self , doc : Dict [ str , Any ] ) - > List [ str ] :
""" 从文档中提取句子 """
sentences = [ ]
title = self . clean_text ( doc . get ( ' title ' , ' ' ) )
text = self . clean_text ( doc . get ( ' text ' , ' ' ) )
entities = doc . get ( ' entities ' , [ ] )
triples = doc . get ( ' triples ' , [ ] )
# 处理显式三元组
for triple in triples :
sentence = self . triple_to_sentence ( triple )
if sentence :
sentences . append ( sentence )
# 从实体和文本中生成基本句子(如果三元组句子不够)
if title and text and len ( sentences ) < 5 :
# 基于标题和实体生成句子
entity_names = [ ]
for entity in entities [ : 15 ] :
entity_name = self . clean_text ( entity . get ( ' surfaceform ' , ' ' ) )
if entity_name and len ( entity_name ) > 2 :
entity_names . append ( entity_name )
# 生成简单的描述句子
if entity_names :
important_entities = [ ]
title_lower = title . lower ( )
for entity in entity_names :
if ( entity . lower ( ) != title_lower and
entity not in important_entities and
not any ( t . lower ( ) in entity . lower ( ) for t in title_lower . split ( ) [ : 2 ] ) ) :
important_entities . append ( entity )
if len ( important_entities ) > = 6 :
break
if important_entities and len ( sentences ) < 3 :
entities_str = ' , ' . join ( important_entities [ : 3 ] )
sentences . append ( f " { title } is related to { entities_str } . " )
return sentences
def triple_to_sentence ( self , triple : Dict [ str , Any ] ) - > str :
""" 将三元组转换为自然语言句子 """
try :
subject = triple . get ( ' subject ' , { } )
predicate = triple . get ( ' predicate ' , { } )
obj = triple . get ( ' object ' , { } )
subject_name = self . clean_text ( subject . get ( ' surfaceform ' , ' ' ) )
object_name = self . clean_text ( obj . get ( ' surfaceform ' , ' ' ) )
predicate_uri = predicate . get ( ' uri ' , ' ' )
# 检查是否有有效的主语和宾语
if not subject_name or not object_name :
return " "
# 检查主语和宾语是否过短或无意义
if len ( subject_name ) < = 2 or len ( object_name ) < = 2 :
return " "
# 获取关系文本
relation_text = self . property_mappings . get ( predicate_uri , " is related to " )
# 避免重复的主语宾语
if subject_name . lower ( ) == object_name . lower ( ) :
return " "
return f " { subject_name } { relation_text } { object_name } . "
except Exception as e :
print ( f " Error converting triple to sentence: { e } " )
return " "
async def process_sentence_with_llm_concurrent ( self , semaphore : asyncio . Semaphore , sentence : str , index : int , total_sentences : int , start_time : float ) - > Dict [ str , Any ] :
""" 使用信号量控制并发的LLM处理 """
async with semaphore :
2025-05-26 23:09:03 +08:00
self . total_requests + = 1
self . last_activity_time = time . time ( ) # 更新活动时间
success = False
for attempt in range ( self . max_retries ) :
try :
prompt = f " Please correct the errors in the following sentence and evaluate its importance: { sentence } "
# 使用asyncio.wait_for添加超时机制
response = await asyncio . wait_for (
self . agent . arun ( prompt ) ,
timeout = self . llm_timeout
)
# 根据agno文档, response应该直接是ProcessedSentence类型
if isinstance ( response , ProcessedSentence ) :
result = {
" index " : index ,
" original_sentence " : sentence ,
" corrected_sentence " : response . corrected_sentence ,
" importance_score " : response . importance_score
}
else :
message = response . messages [ - 1 ] . content
message = message . replace ( " ```json " , " " ) . replace ( " ``` " , " " )
message = json . loads ( message )
result = {
" index " : index ,
" original_sentence " : sentence ,
" corrected_sentence " : message [ ' corrected_sentence ' ] ,
" importance_score " : message [ ' importance_score ' ]
}
# 成功处理
self . successful_requests + = 1
self . last_successful_time = time . time ( )
self . last_activity_time = time . time ( ) # 更新活动时间
success = True
# 打印详细进度信息 - 降低频率到每50个
if index % 50 == 0 :
current_time = time . time ( )
elapsed_time = current_time - start_time
avg_time_per_sentence = elapsed_time / ( index + 1 ) if index > 0 else elapsed_time
remaining_sentences = total_sentences - ( index + 1 )
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
success_rate = ( self . successful_requests / self . total_requests * 100 ) if self . total_requests > 0 else 0
# 格式化时间显示
def format_time ( seconds ) :
if seconds < 60 :
return f " { seconds : .1f } 秒 "
elif seconds < 3600 :
minutes = seconds / 60
return f " { minutes : .1f } 分钟 "
else :
hours = seconds / 3600
return f " { hours : .1f } 小时 "
logger . info ( f " 已完成第 { index + 1 } 个句子的处理 " )
logger . info ( f " - 剩余句子数: { remaining_sentences } " )
logger . info ( f " - 平均处理时间: { avg_time_per_sentence : .2f } 秒/句 " )
logger . info ( f " - 预估剩余时间: { format_time ( estimated_remaining_time ) } " )
logger . info ( f " - 已用时间: { format_time ( elapsed_time ) } " )
logger . info ( f " - 成功率: { success_rate : .1f } % ( { self . successful_requests } / { self . total_requests } ) " )
print ( f " 已完成第 { index + 1 } 个句子的处理 " )
print ( f " - 剩余句子数: { remaining_sentences } " )
print ( f " - 平均处理时间: { avg_time_per_sentence : .2f } 秒/句 " )
print ( f " - 预估剩余时间: { format_time ( estimated_remaining_time ) } " )
print ( f " - 已用时间: { format_time ( elapsed_time ) } " )
print ( f " - 成功率: { success_rate : .1f } % ( { self . successful_requests } / { self . total_requests } ) " )
return result
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
except asyncio . TimeoutError :
self . timeout_requests + = 1
self . last_activity_time = time . time ( ) # 更新活动时间
logger . warning ( f " 第 { index } 个句子处理超时 (尝试 { attempt + 1 } / { self . max_retries } ): { sentence [ : 50 ] } ... " )
if attempt == self . max_retries - 1 :
logger . error ( f " 第 { index } 个句子最终超时,使用默认处理 " )
break
# 指数退避
await asyncio . sleep ( 2 * * attempt )
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
except Exception as e :
self . last_activity_time = time . time ( ) # 更新活动时间
logger . error ( f " 处理第 { index } 个句子时出错 (尝试 { attempt + 1 } / { self . max_retries } ): { e } " )
if attempt == self . max_retries - 1 :
break
await asyncio . sleep ( 1 )
# 所有重试都失败,使用默认处理
if not success :
self . failed_requests + = 1
logger . warning ( f " 第 { index } 个句子使用默认处理: { sentence [ : 50 ] } ... " )
return {
" index " : index ,
" original_sentence " : sentence ,
" corrected_sentence " : sentence ,
" importance_score " : 5.0
}
async def heartbeat_monitor ( self , total_sentences : int ) :
""" 心跳监控,检测是否有长时间无响应 """
consecutive_warnings = 0
while True :
await asyncio . sleep ( self . heartbeat_interval )
current_time = time . time ( )
time_since_last_success = current_time - self . last_successful_time
time_since_last_activity = current_time - self . last_activity_time
# 检查最后成功时间
if time_since_last_success > self . heartbeat_interval :
consecutive_warnings + = 1
logger . warning ( f " ⚠️ 心跳检测 # { consecutive_warnings } :已有 { time_since_last_success : .1f } 秒没有成功的LLM响应 " )
print ( f " ⚠️ 心跳检测 # { consecutive_warnings } :已有 { time_since_last_success : .1f } 秒没有成功的LLM响应 " )
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
# 打印当前统计信息
if self . total_requests > 0 :
success_rate = self . successful_requests / self . total_requests * 100
logger . warning ( f " 当前统计:总请求 { self . total_requests } ,成功 { self . successful_requests } ( { success_rate : .1f } %),超时 { self . timeout_requests } " )
print ( f " 当前统计:总请求 { self . total_requests } ,成功 { self . successful_requests } ( { success_rate : .1f } %),超时 { self . timeout_requests } " )
if time_since_last_success > self . heartbeat_interval * 3 :
logger . error ( f " ❌ 严重警告: LLM可能已卡死, 超过 { time_since_last_success : .1f } 秒无成功响应! " )
print ( f " ❌ 严重警告: LLM可能已卡死, 超过 { time_since_last_success : .1f } 秒无成功响应! " )
print ( f " 建议: 检查ollama服务状态, 或考虑重启程序 " )
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
# 检查ollama服务状态
if not self . check_ollama_status ( ) :
logger . critical ( " 💀 Ollama服务异常, 这可能是卡死的原因! " )
print ( " 💀 Ollama服务异常, 这可能是卡死的原因! " )
if consecutive_warnings > = 5 :
logger . critical ( f " 💀 致命错误:连续 { consecutive_warnings } 次心跳警告,可能需要人工干预 " )
print ( f " 💀 致命错误:连续 { consecutive_warnings } 次心跳警告,可能需要人工干预 " )
else :
if consecutive_warnings > 0 :
logger . info ( f " ✅ 心跳恢复正常:最后成功时间 { time_since_last_success : .1f } 秒前 " )
print ( f " ✅ 心跳恢复正常:最后成功时间 { time_since_last_success : .1f } 秒前 " )
consecutive_warnings = 0
logger . debug ( f " 💓 心跳正常:最后成功时间 { time_since_last_success : .1f } 秒前 " )
2025-05-23 15:47:17 +08:00
async def process_sentences_with_llm ( self , sentences : List [ str ] ) - > List [ Dict [ str , Any ] ] :
""" 批量并发处理句子, 每2000条保存一次检查点 """
2025-05-26 23:09:03 +08:00
logger . info ( f " 开始使用LLM并发处理 { len ( sentences ) } 个句子(最大并发数: { self . max_concurrent } ) ..." )
print ( f " 开始使用LLM并发处理 { len ( sentences ) } 个句子(最大并发数: { self . max_concurrent } ) ..." )
2025-05-23 15:47:17 +08:00
# 记录开始时间
start_time = time . time ( )
total_sentences = len ( sentences )
2025-05-26 23:09:03 +08:00
# 分批处理, 每批1000个句子( 减少批次大小)
batch_size = 1000
2025-05-23 15:47:17 +08:00
all_processed_sentences = [ ]
2025-05-26 23:09:03 +08:00
# 启动心跳监控
heartbeat_task = asyncio . create_task ( self . heartbeat_monitor ( total_sentences ) )
try :
for batch_start in range ( 0 , total_sentences , batch_size ) :
batch_end = min ( batch_start + batch_size , total_sentences )
batch_sentences = sentences [ batch_start : batch_end ]
logger . info ( f " === 处理第 { batch_start / / batch_size + 1 } 批 ( { batch_start + 1 } - { batch_end } / { total_sentences } ) === " )
print ( f " \n === 处理第 { batch_start / / batch_size + 1 } 批 ( { batch_start + 1 } - { batch_end } / { total_sentences } ) === " )
# 创建信号量限制并发数( 降低到8)
semaphore = asyncio . Semaphore ( self . max_concurrent )
# 重置批次统计
batch_start_time = time . time ( )
self . total_requests = 0
self . successful_requests = 0
self . failed_requests = 0
self . timeout_requests = 0
# 创建当前批次的任务
tasks = [ ]
for i , sentence in enumerate ( batch_sentences ) :
global_index = batch_start + i
task = self . process_sentence_with_llm_concurrent ( semaphore , sentence , global_index , total_sentences , start_time )
tasks . append ( task )
# 并发执行当前批次的任务
logger . info ( f " 正在并发处理第 { batch_start / / batch_size + 1 } 批的 { len ( batch_sentences ) } 个句子... " )
print ( f " 正在并发处理第 { batch_start / / batch_size + 1 } 批的 { len ( batch_sentences ) } 个句子... " )
batch_results = await asyncio . gather ( * tasks , return_exceptions = True )
# 处理当前批次的结果,过滤异常
batch_processed_sentences = [ ]
batch_error_count = 0
for result in batch_results :
if isinstance ( result , Exception ) :
logger . error ( f " 任务执行异常: { result } " )
print ( f " 任务执行异常: { result } " )
batch_error_count + = 1
elif isinstance ( result , dict ) :
batch_processed_sentences . append ( result )
else :
batch_error_count + = 1
# 按原始顺序排序(因为并发执行可能改变顺序)
batch_processed_sentences . sort ( key = lambda x : x [ ' index ' ] )
# 移除index字段
for item in batch_processed_sentences :
del item [ ' index ' ]
# 添加到总结果中
all_processed_sentences . extend ( batch_processed_sentences )
# 保存检查点
checkpoint_filename = self . save_checkpoint ( all_processed_sentences , batch_end )
# 打印当前批次统计信息
elapsed_time = time . time ( ) - start_time
batch_time = time . time ( ) - batch_start_time
completed_sentences = len ( all_processed_sentences )
logger . info ( f " 第 { batch_start / / batch_size + 1 } 批处理完成! " )
logger . info ( f " - 当前批次:成功 { len ( batch_processed_sentences ) } ,失败 { batch_error_count } " )
logger . info ( f " - 批次用时: { batch_time / 60 : .1f } 分钟 " )
logger . info ( f " - LLM统计: 成功 { self . successful_requests } ,失败 { self . failed_requests } ,超时 { self . timeout_requests } " )
logger . info ( f " - 总体进度: { completed_sentences } / { total_sentences } ( { completed_sentences / total_sentences * 100 : .1f } %) " )
logger . info ( f " - 已用时间: { elapsed_time / 60 : .1f } 分钟 " )
logger . info ( f " - 平均速度: { completed_sentences / elapsed_time : .2f } 句/秒 " )
logger . info ( f " - 检查点已保存: { checkpoint_filename } " )
print ( f " 第 { batch_start / / batch_size + 1 } 批处理完成! " )
print ( f " - 当前批次:成功 { len ( batch_processed_sentences ) } ,失败 { batch_error_count } " )
print ( f " - 批次用时: { batch_time / 60 : .1f } 分钟 " )
print ( f " - LLM统计: 成功 { self . successful_requests } ,失败 { self . failed_requests } ,超时 { self . timeout_requests } " )
print ( f " - 总体进度: { completed_sentences } / { total_sentences } ( { completed_sentences / total_sentences * 100 : .1f } %) " )
print ( f " - 已用时间: { elapsed_time / 60 : .1f } 分钟 " )
print ( f " - 平均速度: { completed_sentences / elapsed_time : .2f } 句/秒 " )
print ( f " - 检查点已保存: { checkpoint_filename } " )
if batch_end < total_sentences :
remaining_sentences = total_sentences - completed_sentences
avg_time_per_sentence = elapsed_time / completed_sentences
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
logger . info ( f " - 预估剩余时间: { estimated_remaining_time / 60 : .1f } 分钟 " )
print ( f " - 预估剩余时间: { estimated_remaining_time / 60 : .1f } 分钟 " )
# 在批次之间稍作休息,避免过度压力
if batch_end < total_sentences :
logger . info ( " 批次间休息5秒... " )
await asyncio . sleep ( 5 )
finally :
# 取消心跳监控
heartbeat_task . cancel ( )
try :
await heartbeat_task
except asyncio . CancelledError :
pass
2025-05-23 15:47:17 +08:00
# 打印最终统计信息
total_time = time . time ( ) - start_time
2025-05-26 23:09:03 +08:00
logger . info ( f " === 全部处理完成!=== " )
logger . info ( f " - 总成功: { len ( all_processed_sentences ) } " )
logger . info ( f " - 总用时: { total_time / 60 : .1f } 分钟 " )
logger . info ( f " - 平均处理速度: { len ( all_processed_sentences ) / total_time : .2f } 句/秒 " )
2025-05-23 15:47:17 +08:00
print ( f " \n === 全部处理完成!=== " )
print ( f " - 总成功: { len ( all_processed_sentences ) } " )
print ( f " - 总用时: { total_time / 60 : .1f } 分钟 " )
print ( f " - 平均处理速度: { len ( all_processed_sentences ) / total_time : .2f } 句/秒 " )
return all_processed_sentences
def save_checkpoint ( self , processed_sentences : List [ Dict [ str , Any ] ] , current_count : int ) - > str :
""" 保存检查点文件 """
2025-05-26 23:09:03 +08:00
# 生成检查点文件名, 确保在output目录中
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
checkpoint_filename = os . path . join ( ' output ' , f " { base_name } _checkpoint_ { current_count } .json " )
2025-05-23 15:47:17 +08:00
# 保存检查点
with open ( checkpoint_filename , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( {
" metadata " : {
" total_processed " : len ( processed_sentences ) ,
" timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " ) ,
" checkpoint_number " : current_count
} ,
" sentences " : processed_sentences
} , f , ensure_ascii = False , indent = 2 )
return checkpoint_filename
async def process_files ( self ) - > List [ Dict [ str , Any ] ] :
""" 处理所有文件 """
json_files = glob . glob ( os . path . join ( self . input_dir , " re-nlg_*.json " ) )
if not json_files :
print ( f " No JSON files found in { self . input_dir } " )
return [ ]
# 排序文件以确保一致的处理顺序
json_files . sort ( )
if self . max_files :
json_files = json_files [ : self . max_files ]
print ( f " Found { len ( json_files ) } JSON files to process " )
all_sentences = [ ]
for i , file_path in enumerate ( json_files ) :
print ( f " Processing file { i + 1 } / { len ( json_files ) } : { os . path . basename ( file_path ) } " )
documents = self . parse_large_json_file ( file_path )
print ( f " Parsed { len ( documents ) } documents " )
for doc in documents :
sentences = self . extract_sentences_from_document ( doc )
all_sentences . extend ( sentences )
print ( f " Generated { len ( all_sentences ) } total raw sentences so far " )
print ( f " 总共提取了 { len ( all_sentences ) } 个原始句子 " )
# 去重
unique_sentences = [ ]
seen = set ( )
for sentence in all_sentences :
sentence = sentence . strip ( )
if sentence and sentence not in seen and len ( sentence ) > 10 :
unique_sentences . append ( sentence )
seen . add ( sentence )
print ( f " 去重后剩余 { len ( unique_sentences ) } 个句子 " )
2025-05-26 23:09:03 +08:00
# 保存原始句子到JSON文件
sentences_data = {
" metadata " : {
" total_sentences " : len ( unique_sentences ) ,
" extraction_timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " ) ,
" source_files " : len ( json_files ) ,
" max_files_limit " : self . max_files
} ,
" sentences " : [ { " sentence " : sentence , " processed " : False } for sentence in unique_sentences ]
}
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
with open ( self . sentences_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( sentences_data , f , ensure_ascii = False , indent = 2 )
print ( f " 句子提取完成!已保存到: { self . sentences_json } " )
print ( f " 总计句子数: { len ( unique_sentences ) } " )
return unique_sentences
2025-05-23 15:47:17 +08:00
def save_sentences ( self , processed_sentences : List [ Dict [ str , Any ] ] ) :
""" 保存处理后的句子到文件 """
# 确保输出目录存在
2025-05-26 23:09:03 +08:00
os . makedirs ( ' output ' , exist_ok = True )
2025-05-23 15:47:17 +08:00
# 保存为JSON格式, 包含完整信息
json_output_file = self . output_file . replace ( ' .txt ' , ' .json ' )
with open ( json_output_file , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( processed_sentences , f , ensure_ascii = False , indent = 2 )
# 保存为简单文本格式(仅修正后的句子)
with open ( self . output_file , ' w ' , encoding = ' utf-8 ' ) as f :
for item in processed_sentences :
f . write ( item [ ' corrected_sentence ' ] + ' \n ' )
# 生成重要性排序文件
importance_sorted = sorted ( processed_sentences , key = lambda x : x [ ' importance_score ' ] , reverse = True )
importance_file = self . output_file . replace ( ' .txt ' , ' _sorted_by_importance.txt ' )
with open ( importance_file , ' w ' , encoding = ' utf-8 ' ) as f :
for item in importance_sorted :
f . write ( f " [ { item [ ' importance_score ' ] : .1f } ] { item [ ' corrected_sentence ' ] } \n " )
print ( f " 保存了 { len ( processed_sentences ) } 个处理后的句子: " )
print ( f " - JSON格式: { json_output_file } " )
print ( f " - 文本格式: { self . output_file } " )
print ( f " - 重要性排序: { importance_file } " )
# 统计信息
scores = [ item [ ' importance_score ' ] for item in processed_sentences ]
avg_score = sum ( scores ) / len ( scores ) if scores else 0
print ( f " - 平均重要性评分: { avg_score : .2f } " )
print ( f " - 最高评分: { max ( scores ) : .1f } " )
print ( f " - 最低评分: { min ( scores ) : .1f } " )
async def run ( self ) :
""" 运行处理流程 """
print ( " Starting enhanced TREx to sentences conversion... " )
processed_sentences = await self . process_files ( )
self . save_sentences ( processed_sentences )
print ( " Enhanced conversion completed! " )
def find_latest_checkpoint ( self ) - > Union [ tuple , None ] :
""" 查找最新的检查点文件 """
2025-05-26 23:09:03 +08:00
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
pattern = os . path . join ( ' output ' , f " { base_name } _checkpoint_*.json " )
2025-05-23 15:47:17 +08:00
checkpoint_files = glob . glob ( pattern )
if not checkpoint_files :
return None
# 按检查点编号排序,获取最新的
latest_file = None
latest_count = 0
for file in checkpoint_files :
try :
# 从文件名中提取数字
match = re . search ( r ' checkpoint_( \ d+) \ .json$ ' , file )
if match :
count = int ( match . group ( 1 ) )
if count > latest_count :
latest_count = count
latest_file = file
except :
continue
if latest_file :
return latest_file , latest_count
else :
return None
def load_checkpoint ( self , checkpoint_file : str ) - > List [ Dict [ str , Any ] ] :
""" 从检查点文件加载已处理的句子 """
try :
with open ( checkpoint_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
if ' sentences ' in data :
return data [ ' sentences ' ]
else :
# 旧格式的检查点文件
return data
except Exception as e :
print ( f " 加载检查点文件失败: { e } " )
return [ ]
2025-05-26 23:09:03 +08:00
def get_processed_sentences_from_checkpoints ( self ) - > Set [ str ] :
""" 从检查点文件中获取已处理过的句子集合 """
if not self . output_file :
return set ( )
processed_sentences = set ( )
# 查找所有检查点文件
base_name = os . path . splitext ( os . path . basename ( self . output_file ) ) [ 0 ]
pattern = os . path . join ( ' output ' , f " { base_name } _checkpoint_*.json " )
checkpoint_files = glob . glob ( pattern )
if not checkpoint_files :
print ( " 未找到检查点文件,将从头开始处理 " )
return set ( )
# 找到最新的检查点文件
latest_file = None
latest_count = 0
for file in checkpoint_files :
try :
match = re . search ( r ' checkpoint_( \ d+) \ .json$ ' , file )
if match :
count = int ( match . group ( 1 ) )
if count > latest_count :
latest_count = count
latest_file = file
except :
continue
if latest_file :
print ( f " 找到最新检查点: { latest_file } (包含 { latest_count } 条记录) " )
logger . info ( f " 找到最新检查点: { latest_file } (包含 { latest_count } 条记录) " )
try :
with open ( latest_file , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
sentences_data = data . get ( ' sentences ' , [ ] )
for item in sentences_data :
original_sentence = item . get ( ' original_sentence ' , ' ' )
if original_sentence :
processed_sentences . add ( original_sentence )
print ( f " 从检查点加载了 { len ( processed_sentences ) } 个已处理的句子 " )
logger . info ( f " 从检查点加载了 { len ( processed_sentences ) } 个已处理的句子 " )
except Exception as e :
print ( f " 读取检查点文件失败: { e } " )
return set ( )
return processed_sentences
async def process_with_llm ( self ) :
""" 步骤2: 从JSON文件读取句子并进行LLM处理 """
if not self . enable_llm_processing :
print ( " Error: LLM processing is disabled! " )
return
if not self . output_file :
print ( " Error: output_file is required for LLM processing! " )
return
print ( " === 步骤2: LLM处理 === " )
# 读取句子JSON文件
if not os . path . exists ( self . sentences_json ) :
print ( f " Error: Sentences file { self . sentences_json } not found! " )
print ( " 请先运行步骤1进行句子提取 " )
return
print ( f " 正在读取句子文件: { self . sentences_json } " )
try :
with open ( self . sentences_json , ' r ' , encoding = ' utf-8 ' ) as f :
data = json . load ( f )
all_sentences = [ item [ " sentence " ] for item in data . get ( " sentences " , [ ] ) ]
print ( f " 从文件中读取了 { len ( all_sentences ) } 个句子 " )
except Exception as e :
print ( f " 读取句子文件失败: { e } " )
return
# 获取已处理的句子
processed_sentences_set = self . get_processed_sentences_from_checkpoints ( )
# 过滤出未处理的句子
unprocessed_sentences = [ ]
for sentence in all_sentences :
if sentence not in processed_sentences_set :
unprocessed_sentences . append ( sentence )
print ( f " 需要处理的句子数: { len ( unprocessed_sentences ) } (跳过已处理: { len ( processed_sentences_set ) } ) " )
logger . info ( f " 需要处理的句子数: { len ( unprocessed_sentences ) } (跳过已处理: { len ( processed_sentences_set ) } ) " )
if not unprocessed_sentences :
print ( " 所有句子都已处理完成! " )
# 如果有检查点,直接从最新检查点生成最终文件
if processed_sentences_set :
latest_checkpoint = self . find_latest_checkpoint ( )
if latest_checkpoint :
checkpoint_file , _ = latest_checkpoint
processed_data = self . load_checkpoint ( checkpoint_file )
self . save_sentences ( processed_data )
print ( " 已从检查点生成最终输出文件 " )
return
# 处理未处理的句子
print ( " 开始LLM处理... " )
# 检查ollama服务状态
logger . info ( " 检查Ollama服务状态... " )
if not self . check_ollama_status ( ) :
logger . error ( " Ollama服务状态异常, 无法继续处理 " )
print ( " 错误: Ollama服务状态异常, 请检查服务是否正常运行 " )
return
new_processed_sentences = await self . process_sentences_with_llm ( unprocessed_sentences )
# 如果有之前的处理结果,合并它们
if processed_sentences_set :
latest_checkpoint = self . find_latest_checkpoint ( )
if latest_checkpoint :
checkpoint_file , _ = latest_checkpoint
previous_processed = self . load_checkpoint ( checkpoint_file )
# 合并结果
all_processed_sentences = previous_processed + new_processed_sentences
print ( f " 合并了之前的 { len ( previous_processed ) } 条和新处理的 { len ( new_processed_sentences ) } 条记录 " )
else :
all_processed_sentences = new_processed_sentences
else :
all_processed_sentences = new_processed_sentences
# 保存最终结果
self . save_sentences ( all_processed_sentences )
print ( " LLM处理完成! " )
# ==================== 新增:句子提取功能 ====================
def extract_sentences ( self ) :
""" 步骤1: 从TREx数据集提取句子并保存为JSON """
if not self . input_dir :
print ( " Error: input_dir is required for sentence extraction! " )
return
print ( " === 步骤1: 句子提取 === " )
print ( " 开始从TREx数据集提取句子... " )
json_files = glob . glob ( os . path . join ( self . input_dir , " re-nlg_*.json " ) )
if not json_files :
print ( f " No JSON files found in { self . input_dir } " )
return
# 排序文件以确保一致的处理顺序
json_files . sort ( )
if self . max_files :
json_files = json_files [ : self . max_files ]
print ( f " Found { len ( json_files ) } JSON files to process " )
all_sentences = [ ]
for i , file_path in enumerate ( json_files ) :
print ( f " Processing file { i + 1 } / { len ( json_files ) } : { os . path . basename ( file_path ) } " )
documents = self . parse_large_json_file ( file_path )
print ( f " Parsed { len ( documents ) } documents " )
for doc in documents :
sentences = self . extract_sentences_from_document ( doc )
all_sentences . extend ( sentences )
print ( f " Generated { len ( all_sentences ) } total raw sentences so far " )
print ( f " 总共提取了 { len ( all_sentences ) } 个原始句子 " )
# 去重
unique_sentences = [ ]
seen = set ( )
for sentence in all_sentences :
sentence = sentence . strip ( )
if sentence and sentence not in seen and len ( sentence ) > 10 :
unique_sentences . append ( sentence )
seen . add ( sentence )
print ( f " 去重后剩余 { len ( unique_sentences ) } 个句子 " )
# 保存原始句子到JSON文件
sentences_data = {
" metadata " : {
" total_sentences " : len ( unique_sentences ) ,
" extraction_timestamp " : time . strftime ( " % Y- % m- %d % H: % M: % S " ) ,
" source_files " : len ( json_files ) ,
" max_files_limit " : self . max_files
} ,
" sentences " : [ { " sentence " : sentence , " processed " : False } for sentence in unique_sentences ]
}
with open ( self . sentences_json , ' w ' , encoding = ' utf-8 ' ) as f :
json . dump ( sentences_data , f , ensure_ascii = False , indent = 2 )
print ( f " 句子提取完成!已保存到: { self . sentences_json } " )
print ( f " 总计句子数: { len ( unique_sentences ) } " )
return unique_sentences
def check_ollama_status ( self ) - > bool :
""" 检查ollama服务是否正常运行 """
try :
# 检查ollama进程是否运行
result = subprocess . run ( [ ' pgrep ' , ' ollama ' ] , capture_output = True , text = True )
if result . returncode != 0 :
logger . error ( " Ollama进程未运行 " )
return False
# 检查ollama API是否响应
response = requests . get ( ' http://localhost:11434/api/tags ' , timeout = 5 )
if response . status_code == 200 :
logger . info ( " Ollama服务状态正常 " )
return True
else :
logger . error ( f " Ollama API响应异常, 状态码: { response . status_code } " )
return False
except requests . exceptions . RequestException as e :
logger . error ( f " 无法连接到Ollama API: { e } " )
return False
except Exception as e :
logger . error ( f " 检查Ollama状态时出错: { e } " )
return False
2025-05-23 15:47:17 +08:00
def main ( ) :
""" 主函数 """
import argparse
parser = argparse . ArgumentParser ( description = ' Convert TREx dataset to enhanced sentences with LLM processing ' )
2025-05-26 23:09:03 +08:00
# 选择运行模式
parser . add_argument ( ' --step ' , choices = [ ' extract ' , ' llm ' , ' all ' ] , default = ' llm ' ,
help = ' 运行步骤: extract=仅提取句子, llm=仅LLM处理, all=完整流程 ' )
# 文件路径参数
2025-05-23 15:47:17 +08:00
parser . add_argument ( ' --input_dir ' , default = ' dataset/TREx ' , help = ' Input directory containing TREx JSON files ' )
2025-05-26 23:09:03 +08:00
parser . add_argument ( ' --sentences_json ' , default = ' extracted_sentences.json ' , help = ' JSON file for extracted sentences (will be saved in output/) ' )
parser . add_argument ( ' --output_file ' , default = ' trex_sentences_enhanced.txt ' , help = ' Output file path (will be saved in output/) ' )
# 处理参数
2025-05-23 15:47:17 +08:00
parser . add_argument ( ' --max_files ' , type = int , help = ' Maximum number of files to process (for testing) ' )
parser . add_argument ( ' --no_llm ' , action = ' store_true ' , help = ' Disable LLM processing (basic mode) ' )
args = parser . parse_args ( )
2025-05-26 23:09:03 +08:00
# 根据步骤验证参数
if args . step in [ ' extract ' , ' all ' ] :
if not os . path . exists ( args . input_dir ) :
print ( f " Error: Input directory { args . input_dir } does not exist! " )
return
if args . step in [ ' llm ' , ' all ' ] :
if args . no_llm :
print ( " Error: Cannot run LLM step with --no_llm flag! " )
return
2025-05-23 15:47:17 +08:00
2025-05-26 23:09:03 +08:00
# 创建处理器
2025-05-23 15:47:17 +08:00
processor = EnhancedTRExProcessor (
2025-05-26 23:09:03 +08:00
input_dir = args . input_dir ,
sentences_json = args . sentences_json ,
output_file = args . output_file ,
max_files = args . max_files ,
2025-05-23 15:47:17 +08:00
enable_llm_processing = not args . no_llm
)
2025-05-26 23:09:03 +08:00
# 根据选择的步骤运行
if args . step == ' extract ' :
print ( " === 运行模式:仅句子提取 === " )
processor . extract_sentences ( )
elif args . step == ' llm ' :
print ( " === 运行模式: 仅LLM处理 === " )
asyncio . run ( processor . process_with_llm ( ) )
elif args . step == ' all ' :
print ( " === 运行模式:完整流程 === " )
# 步骤1: 提取句子
print ( " \n --- 开始步骤1: 句子提取 --- " )
sentences = processor . extract_sentences ( )
if not sentences :
print ( " 句子提取失败,退出 " )
return
if args . no_llm :
print ( " LLM处理已禁用, 流程结束 " )
return
# 步骤2: LLM处理
print ( " \n --- 开始步骤2: LLM处理 --- " )
asyncio . run ( processor . process_with_llm ( ) )
2025-05-23 15:47:17 +08:00
if __name__ == " __main__ " :
main ( )