基于trex数据集构建知识库单初始值预处理
This commit is contained in:
parent
45da3b383b
commit
c09cd63794
734
preprocessing/trex_to_sentences_simple.py
Normal file
734
preprocessing/trex_to_sentences_simple.py
Normal file
@ -0,0 +1,734 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TREx数据集增强预处理脚本
|
||||||
|
使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
from typing import List, Dict, Any, Union
|
||||||
|
import re
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from agno.agent import Agent
|
||||||
|
from agno.models.ollama import Ollama
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessedSentence(BaseModel):
|
||||||
|
"""处理后的句子结构"""
|
||||||
|
corrected_sentence: str = Field(
|
||||||
|
...,
|
||||||
|
description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色"
|
||||||
|
)
|
||||||
|
importance_score: float = Field(
|
||||||
|
...,
|
||||||
|
description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度",
|
||||||
|
ge=0.0,
|
||||||
|
le=10.0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedTRExProcessor:
|
||||||
|
def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True):
|
||||||
|
self.input_dir = input_dir
|
||||||
|
self.output_file = output_file
|
||||||
|
self.max_files = max_files
|
||||||
|
self.enable_llm_processing = enable_llm_processing
|
||||||
|
|
||||||
|
# 初始化agno agent
|
||||||
|
if self.enable_llm_processing:
|
||||||
|
self.setup_agent()
|
||||||
|
|
||||||
|
# 扩展的Wikidata属性映射
|
||||||
|
self.property_mappings = {
|
||||||
|
# 基本关系
|
||||||
|
"http://www.wikidata.org/prop/direct/P31": "is a",
|
||||||
|
"http://www.wikidata.org/prop/direct/P279": "is a type of",
|
||||||
|
|
||||||
|
# 人物相关
|
||||||
|
"http://www.wikidata.org/prop/direct/P106": "works as",
|
||||||
|
"http://www.wikidata.org/prop/direct/P27": "is a citizen of",
|
||||||
|
"http://www.wikidata.org/prop/direct/P19": "was born in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P20": "died in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P569": "was born on",
|
||||||
|
"http://www.wikidata.org/prop/direct/P570": "died on",
|
||||||
|
"http://www.wikidata.org/prop/direct/P22": "has father",
|
||||||
|
"http://www.wikidata.org/prop/direct/P25": "has mother",
|
||||||
|
"http://www.wikidata.org/prop/direct/P26": "is married to",
|
||||||
|
|
||||||
|
# 组织相关
|
||||||
|
"http://www.wikidata.org/prop/direct/P102": "is a member of",
|
||||||
|
"http://www.wikidata.org/prop/direct/P108": "works for",
|
||||||
|
"http://www.wikidata.org/prop/direct/P159": "has headquarters in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P112": "was founded by",
|
||||||
|
"http://www.wikidata.org/prop/direct/P571": "was founded in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P169": "has CEO",
|
||||||
|
|
||||||
|
# 地理相关
|
||||||
|
"http://www.wikidata.org/prop/direct/P17": "is located in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P131": "is located in",
|
||||||
|
"http://www.wikidata.org/prop/direct/P36": "has capital",
|
||||||
|
"http://www.wikidata.org/prop/direct/P47": "borders",
|
||||||
|
|
||||||
|
# 其他关系
|
||||||
|
"http://www.wikidata.org/prop/direct/P1142": "has ideology",
|
||||||
|
"http://www.wikidata.org/prop/direct/P361": "is part of",
|
||||||
|
"http://www.wikidata.org/prop/direct/P737": "was influenced by",
|
||||||
|
"http://www.wikidata.org/prop/direct/P127": "is owned by",
|
||||||
|
"http://www.wikidata.org/prop/direct/P155": "follows",
|
||||||
|
"http://www.wikidata.org/prop/direct/P156": "is followed by",
|
||||||
|
"http://www.wikidata.org/prop/direct/P138": "is named after"
|
||||||
|
}
|
||||||
|
|
||||||
|
def setup_agent(self):
|
||||||
|
"""设置agno agent"""
|
||||||
|
try:
|
||||||
|
self.agent = Agent(
|
||||||
|
model=Ollama(
|
||||||
|
id="qwen3:4b",
|
||||||
|
# 使用options设置temperature和其他参数
|
||||||
|
options={
|
||||||
|
"temperature": 0.7,
|
||||||
|
"top_p": 0.8,
|
||||||
|
"top_k": 20,
|
||||||
|
"num_ctx": 4096,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
response_model=ProcessedSentence,
|
||||||
|
instructions=[
|
||||||
|
"你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。",
|
||||||
|
"",
|
||||||
|
"### 句子修正规则:",
|
||||||
|
"1. 移除Wikipedia特有标记:如(disambiguation)、(film)、(band)等括号内容",
|
||||||
|
"2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is'、'or'等",
|
||||||
|
"3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确",
|
||||||
|
"4. 清理乱码和特殊字符:如â、€、™等编码问题",
|
||||||
|
"5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺",
|
||||||
|
"6. 不要添加原文没有的信息,只修正错误",
|
||||||
|
"",
|
||||||
|
"### 修正示例:",
|
||||||
|
"- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'",
|
||||||
|
"- 修正:'Argument is related to philosophy and logic.'",
|
||||||
|
"",
|
||||||
|
"- 错误:'Beijing is a capital city and are.'",
|
||||||
|
"- 修正:'Beijing is a capital city.'",
|
||||||
|
"",
|
||||||
|
"重要性评分标准(0.0-10.0,以0.1递进):",
|
||||||
|
"",
|
||||||
|
"0.0分 - 完全错误或无意义的信息",
|
||||||
|
"例:'苹果是一种金属'、'太阳从西边升起'、'1+1=3'",
|
||||||
|
"",
|
||||||
|
"0.5分 - 几乎无价值的信息",
|
||||||
|
"例:'某个虚构角色的袜子颜色'、'游戏中NPC的对话第三句话'、'某人昨天早餐吃了什么'",
|
||||||
|
"",
|
||||||
|
"1.0分 - 极其罕见、无实用价值的知识",
|
||||||
|
"例:'某小说背景角色宠物名字'、'某部电影片尾字幕第15行内容'、'某网站用户ID为123456的昵称'",
|
||||||
|
"",
|
||||||
|
"1.5分 - 非常小众的细节信息",
|
||||||
|
"例:'某电影第37分钟路人甲服装'、'某游戏隐藏关卡的背景音乐时长'、'某漫画第200页第3个对话框内容'",
|
||||||
|
"",
|
||||||
|
"2.0分 - 小众专业领域的细节",
|
||||||
|
"例:'稀有矿物在特定温度下颜色变化'、'某种昆虫的第三对触角长度'、'某化学反应的副产物分子式'",
|
||||||
|
"",
|
||||||
|
"2.5分 - 专业人士才关心的技术细节",
|
||||||
|
"例:'软件库特定版本发布日期'、'某算法的时间复杂度系数'、'某种材料的热膨胀系数'",
|
||||||
|
"",
|
||||||
|
"3.0分 - 特定领域的专业知识",
|
||||||
|
"例:'编程语言语法特性'、'某种病毒的基因序列'、'古代某朝代的官职制度'",
|
||||||
|
"",
|
||||||
|
"3.5分 - 有一定价值的专业信息",
|
||||||
|
"例:'某历史朝代特定制度'、'某种药物的作用机制'、'某技术标准的制定时间'",
|
||||||
|
"",
|
||||||
|
"4.0分 - 较少人知道但有意义的知识",
|
||||||
|
"例:'某国家独特文化传统'、'某科学家的重要发现'、'某历史事件的详细过程'",
|
||||||
|
"",
|
||||||
|
"4.5分 - 部分人群感兴趣的知识",
|
||||||
|
"例:'作家创作背景'、'某艺术流派特点'、'某运动项目规则细节'",
|
||||||
|
"",
|
||||||
|
"5.0分 - 中等重要性的一般知识",
|
||||||
|
"例:'城市著名景点'、'某企业发展历史'、'某动物生活习性'",
|
||||||
|
"",
|
||||||
|
"5.5分 - 比较有用的常识",
|
||||||
|
"例:'植物生长环境'、'健康饮食常识'、'基本急救知识'",
|
||||||
|
"",
|
||||||
|
"6.0分 - 多数受教育人群应该知道的知识",
|
||||||
|
"例:'莎士比亚代表作品'、'基本几何定理'、'世界主要货币'",
|
||||||
|
"",
|
||||||
|
"6.5分 - 重要的文化或科学常识",
|
||||||
|
"例:'DNA基本结构'、'牛顿三大定律'、'世界主要宗教'",
|
||||||
|
"",
|
||||||
|
"7.0分 - 重要的基础知识",
|
||||||
|
"例:'二次世界大战时间'、'人体主要器官功能'、'基本数学运算规则'",
|
||||||
|
"",
|
||||||
|
"7.5分 - 非常重要的常识",
|
||||||
|
"例:'光速是宇宙中最快的'、'地球是圆的'、'血液循环基本原理'",
|
||||||
|
"",
|
||||||
|
"8.0分 - 基础教育中的核心知识",
|
||||||
|
"例:'地球绕太阳运行'、'四季形成原理'、'基本语法规则'",
|
||||||
|
"",
|
||||||
|
"8.5分 - 每个人都应该掌握的重要知识",
|
||||||
|
"例:'水的化学式H2O'、'基本安全常识'、'简单数学计算'",
|
||||||
|
"",
|
||||||
|
"9.0分 - 极其重要的基础概念",
|
||||||
|
"例:'人类需要氧气生存'、'火是热的'、'基本方向概念'",
|
||||||
|
"",
|
||||||
|
"9.5分 - 人人必知的核心知识",
|
||||||
|
"例:'一天有24小时'、'一年有12个月'、'基本数字概念'",
|
||||||
|
"",
|
||||||
|
"10.0分 - 最基础、最重要的常识",
|
||||||
|
"例:'人类需要食物和水生存'、'天空是蓝色的'、'石头比羽毛重'",
|
||||||
|
"",
|
||||||
|
"评分时请考虑:",
|
||||||
|
"1. 知识的普及程度 - 有多少人知道这个知识",
|
||||||
|
"2. 实用价值 - 这个知识在日常生活中有多大用处",
|
||||||
|
"3. 教育重要性 - 这个知识在教育体系中的地位",
|
||||||
|
"4. 文化意义 - 这个知识对理解世界的重要性",
|
||||||
|
"",
|
||||||
|
"请直接输出结构化结果,不需要思考过程。"
|
||||||
|
],
|
||||||
|
markdown=False
|
||||||
|
)
|
||||||
|
print("LLM处理器初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LLM处理器初始化失败: {e}")
|
||||||
|
print("将使用基础模式(不使用LLM后处理)")
|
||||||
|
self.enable_llm_processing = False
|
||||||
|
|
||||||
|
async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence:
|
||||||
|
"""使用LLM处理单个句子(保留用于单独调用)"""
|
||||||
|
try:
|
||||||
|
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
||||||
|
|
||||||
|
# 使用agent.arun进行异步调用
|
||||||
|
response = await self.agent.arun(prompt)
|
||||||
|
|
||||||
|
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||||
|
if isinstance(response, ProcessedSentence):
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
message = response.messages[-1].content
|
||||||
|
message = message.replace("```json", "").replace("```", "")
|
||||||
|
message = json.loads(message)
|
||||||
|
return ProcessedSentence(
|
||||||
|
corrected_sentence=message['corrected_sentence'],
|
||||||
|
importance_score=message['importance_score']
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"LLM处理句子时出错: {e}")
|
||||||
|
# 出错时返回原句子和中等评分
|
||||||
|
return ProcessedSentence(
|
||||||
|
corrected_sentence=sentence,
|
||||||
|
importance_score=5.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def clean_text(self, text: str) -> str:
|
||||||
|
"""清理文本,处理特殊字符"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 处理常见的Unicode字符
|
||||||
|
text = text.replace("–", "-") # en dash
|
||||||
|
text = text.replace("—", "-") # em dash
|
||||||
|
text = text.replace("'", "'") # right single quotation mark
|
||||||
|
text = text.replace("'", "'") # left single quotation mark
|
||||||
|
text = text.replace(""", '"') # left double quotation mark
|
||||||
|
text = text.replace(""", '"') # right double quotation mark
|
||||||
|
|
||||||
|
# 处理可能的转义序列
|
||||||
|
try:
|
||||||
|
text = text.encode('utf-8').decode('utf-8')
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 清理多余的空格
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
# 移除可能的引号
|
||||||
|
text = text.strip('"\'')
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def parse_large_json_file(self, file_path: str) -> List[Dict]:
|
||||||
|
"""解析大型JSON文件,处理可能的格式问题"""
|
||||||
|
documents = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
|
||||||
|
# 尝试不同的解析方法
|
||||||
|
if content.startswith('[') and content.endswith(']'):
|
||||||
|
# 标准JSON数组
|
||||||
|
documents = json.loads(content)
|
||||||
|
else:
|
||||||
|
# 可能是连续的JSON对象
|
||||||
|
# 尝试在}{"之间分割
|
||||||
|
if '}{"' in content:
|
||||||
|
json_strings = content.split('}{')
|
||||||
|
json_strings[0] += '}' # 第一个对象
|
||||||
|
json_strings[-1] = '{' + json_strings[-1] # 最后一个对象
|
||||||
|
|
||||||
|
for i in range(1, len(json_strings) - 1):
|
||||||
|
json_strings[i] = '{' + json_strings[i] + '}'
|
||||||
|
|
||||||
|
for json_str in json_strings:
|
||||||
|
try:
|
||||||
|
doc = json.loads(json_str)
|
||||||
|
documents.append(doc)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 尝试作为单个JSON对象
|
||||||
|
try:
|
||||||
|
documents = [json.loads(content)]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing {file_path}: {e}")
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]:
|
||||||
|
"""从文档中提取句子"""
|
||||||
|
sentences = []
|
||||||
|
|
||||||
|
title = self.clean_text(doc.get('title', ''))
|
||||||
|
text = self.clean_text(doc.get('text', ''))
|
||||||
|
entities = doc.get('entities', [])
|
||||||
|
triples = doc.get('triples', [])
|
||||||
|
|
||||||
|
# 处理显式三元组
|
||||||
|
for triple in triples:
|
||||||
|
sentence = self.triple_to_sentence(triple)
|
||||||
|
if sentence:
|
||||||
|
sentences.append(sentence)
|
||||||
|
|
||||||
|
# 从实体和文本中生成基本句子(如果三元组句子不够)
|
||||||
|
if title and text and len(sentences) < 5:
|
||||||
|
# 基于标题和实体生成句子
|
||||||
|
entity_names = []
|
||||||
|
for entity in entities[:15]:
|
||||||
|
entity_name = self.clean_text(entity.get('surfaceform', ''))
|
||||||
|
if entity_name and len(entity_name) > 2:
|
||||||
|
entity_names.append(entity_name)
|
||||||
|
|
||||||
|
# 生成简单的描述句子
|
||||||
|
if entity_names:
|
||||||
|
important_entities = []
|
||||||
|
title_lower = title.lower()
|
||||||
|
for entity in entity_names:
|
||||||
|
if (entity.lower() != title_lower and
|
||||||
|
entity not in important_entities and
|
||||||
|
not any(t.lower() in entity.lower() for t in title_lower.split()[:2])):
|
||||||
|
important_entities.append(entity)
|
||||||
|
if len(important_entities) >= 6:
|
||||||
|
break
|
||||||
|
|
||||||
|
if important_entities and len(sentences) < 3:
|
||||||
|
entities_str = ', '.join(important_entities[:3])
|
||||||
|
sentences.append(f"{title} is related to {entities_str}.")
|
||||||
|
|
||||||
|
return sentences
|
||||||
|
|
||||||
|
def triple_to_sentence(self, triple: Dict[str, Any]) -> str:
|
||||||
|
"""将三元组转换为自然语言句子"""
|
||||||
|
try:
|
||||||
|
subject = triple.get('subject', {})
|
||||||
|
predicate = triple.get('predicate', {})
|
||||||
|
obj = triple.get('object', {})
|
||||||
|
|
||||||
|
subject_name = self.clean_text(subject.get('surfaceform', ''))
|
||||||
|
object_name = self.clean_text(obj.get('surfaceform', ''))
|
||||||
|
predicate_uri = predicate.get('uri', '')
|
||||||
|
|
||||||
|
# 检查是否有有效的主语和宾语
|
||||||
|
if not subject_name or not object_name:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 检查主语和宾语是否过短或无意义
|
||||||
|
if len(subject_name) <= 2 or len(object_name) <= 2:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取关系文本
|
||||||
|
relation_text = self.property_mappings.get(predicate_uri, "is related to")
|
||||||
|
|
||||||
|
# 避免重复的主语宾语
|
||||||
|
if subject_name.lower() == object_name.lower():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return f"{subject_name} {relation_text} {object_name}."
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error converting triple to sentence: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]:
|
||||||
|
"""使用信号量控制并发的LLM处理"""
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}"
|
||||||
|
|
||||||
|
# 使用agent.arun进行异步调用
|
||||||
|
response = await self.agent.arun(prompt)
|
||||||
|
|
||||||
|
# 根据agno文档,response应该直接是ProcessedSentence类型
|
||||||
|
if isinstance(response, ProcessedSentence):
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"original_sentence": sentence,
|
||||||
|
"corrected_sentence": response.corrected_sentence,
|
||||||
|
"importance_score": response.importance_score
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
message = response.messages[-1].content
|
||||||
|
message = message.replace("```json", "").replace("```", "")
|
||||||
|
message = json.loads(message)
|
||||||
|
# print(message)
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"original_sentence": sentence,
|
||||||
|
"corrected_sentence": message['corrected_sentence'],
|
||||||
|
"importance_score": message['importance_score']
|
||||||
|
}
|
||||||
|
|
||||||
|
# 打印详细进度信息
|
||||||
|
if index % 100 == 0:
|
||||||
|
current_time = time.time()
|
||||||
|
elapsed_time = current_time - start_time
|
||||||
|
avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time
|
||||||
|
remaining_sentences = total_sentences - (index + 1)
|
||||||
|
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||||
|
|
||||||
|
# 格式化时间显示
|
||||||
|
def format_time(seconds):
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds:.1f}秒"
|
||||||
|
elif seconds < 3600:
|
||||||
|
minutes = seconds / 60
|
||||||
|
return f"{minutes:.1f}分钟"
|
||||||
|
else:
|
||||||
|
hours = seconds / 3600
|
||||||
|
return f"{hours:.1f}小时"
|
||||||
|
|
||||||
|
print(f"已完成第 {index + 1} 个句子的处理")
|
||||||
|
print(f" - 剩余句子数: {remaining_sentences}")
|
||||||
|
print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句")
|
||||||
|
print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}")
|
||||||
|
print(f" - 已用时间: {format_time(elapsed_time)}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理第 {index} 个句子时出错: {e}")
|
||||||
|
# 出错时返回原句子和中等评分
|
||||||
|
return {
|
||||||
|
"index": index,
|
||||||
|
"original_sentence": sentence,
|
||||||
|
"corrected_sentence": sentence,
|
||||||
|
"importance_score": 5.0
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]:
|
||||||
|
"""批量并发处理句子,每2000条保存一次检查点"""
|
||||||
|
print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:54)...")
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.time()
|
||||||
|
total_sentences = len(sentences)
|
||||||
|
|
||||||
|
# 分批处理,每批2000个句子
|
||||||
|
batch_size = 2000
|
||||||
|
all_processed_sentences = []
|
||||||
|
|
||||||
|
for batch_start in range(0, total_sentences, batch_size):
|
||||||
|
batch_end = min(batch_start + batch_size, total_sentences)
|
||||||
|
batch_sentences = sentences[batch_start:batch_end]
|
||||||
|
|
||||||
|
print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===")
|
||||||
|
|
||||||
|
# 创建信号量限制并发数
|
||||||
|
semaphore = asyncio.Semaphore(54)
|
||||||
|
|
||||||
|
# 创建当前批次的任务
|
||||||
|
tasks = []
|
||||||
|
for i, sentence in enumerate(batch_sentences):
|
||||||
|
global_index = batch_start + i
|
||||||
|
task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# 并发执行当前批次的任务
|
||||||
|
print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...")
|
||||||
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理当前批次的结果,过滤异常
|
||||||
|
batch_processed_sentences = []
|
||||||
|
batch_error_count = 0
|
||||||
|
|
||||||
|
for result in batch_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
print(f"任务执行异常: {result}")
|
||||||
|
batch_error_count += 1
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
batch_processed_sentences.append(result)
|
||||||
|
else:
|
||||||
|
batch_error_count += 1
|
||||||
|
|
||||||
|
# 按原始顺序排序(因为并发执行可能改变顺序)
|
||||||
|
batch_processed_sentences.sort(key=lambda x: x['index'])
|
||||||
|
|
||||||
|
# 移除index字段
|
||||||
|
for item in batch_processed_sentences:
|
||||||
|
del item['index']
|
||||||
|
|
||||||
|
# 添加到总结果中
|
||||||
|
all_processed_sentences.extend(batch_processed_sentences)
|
||||||
|
|
||||||
|
# 保存检查点
|
||||||
|
checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end)
|
||||||
|
|
||||||
|
# 打印当前批次统计信息
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
completed_sentences = len(all_processed_sentences)
|
||||||
|
|
||||||
|
print(f"第 {batch_start//batch_size + 1} 批处理完成!")
|
||||||
|
print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}")
|
||||||
|
print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)")
|
||||||
|
print(f" - 已用时间:{elapsed_time/60:.1f}分钟")
|
||||||
|
print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒")
|
||||||
|
print(f" - 检查点已保存:{checkpoint_filename}")
|
||||||
|
|
||||||
|
if batch_end < total_sentences:
|
||||||
|
remaining_sentences = total_sentences - completed_sentences
|
||||||
|
avg_time_per_sentence = elapsed_time / completed_sentences
|
||||||
|
estimated_remaining_time = avg_time_per_sentence * remaining_sentences
|
||||||
|
print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟")
|
||||||
|
|
||||||
|
# 打印最终统计信息
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f"\n=== 全部处理完成!===")
|
||||||
|
print(f" - 总成功:{len(all_processed_sentences)}")
|
||||||
|
print(f" - 总用时:{total_time/60:.1f}分钟")
|
||||||
|
print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒")
|
||||||
|
|
||||||
|
return all_processed_sentences
|
||||||
|
|
||||||
|
def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str:
|
||||||
|
"""保存检查点文件"""
|
||||||
|
# 生成检查点文件名
|
||||||
|
base_name = os.path.splitext(self.output_file)[0]
|
||||||
|
checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json"
|
||||||
|
|
||||||
|
# 保存检查点
|
||||||
|
with open(checkpoint_filename, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump({
|
||||||
|
"metadata": {
|
||||||
|
"total_processed": len(processed_sentences),
|
||||||
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
"checkpoint_number": current_count
|
||||||
|
},
|
||||||
|
"sentences": processed_sentences
|
||||||
|
}, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return checkpoint_filename
|
||||||
|
|
||||||
|
async def process_files(self) -> List[Dict[str, Any]]:
|
||||||
|
"""处理所有文件"""
|
||||||
|
json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json"))
|
||||||
|
|
||||||
|
if not json_files:
|
||||||
|
print(f"No JSON files found in {self.input_dir}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 排序文件以确保一致的处理顺序
|
||||||
|
json_files.sort()
|
||||||
|
|
||||||
|
if self.max_files:
|
||||||
|
json_files = json_files[:self.max_files]
|
||||||
|
|
||||||
|
print(f"Found {len(json_files)} JSON files to process")
|
||||||
|
|
||||||
|
all_sentences = []
|
||||||
|
|
||||||
|
for i, file_path in enumerate(json_files):
|
||||||
|
print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}")
|
||||||
|
|
||||||
|
documents = self.parse_large_json_file(file_path)
|
||||||
|
print(f" Parsed {len(documents)} documents")
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
sentences = self.extract_sentences_from_document(doc)
|
||||||
|
all_sentences.extend(sentences)
|
||||||
|
|
||||||
|
print(f" Generated {len(all_sentences)} total raw sentences so far")
|
||||||
|
|
||||||
|
print(f"总共提取了 {len(all_sentences)} 个原始句子")
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
unique_sentences = []
|
||||||
|
seen = set()
|
||||||
|
for sentence in all_sentences:
|
||||||
|
sentence = sentence.strip()
|
||||||
|
if sentence and sentence not in seen and len(sentence) > 10:
|
||||||
|
unique_sentences.append(sentence)
|
||||||
|
seen.add(sentence)
|
||||||
|
|
||||||
|
print(f"去重后剩余 {len(unique_sentences)} 个句子")
|
||||||
|
|
||||||
|
# 使用LLM处理句子
|
||||||
|
if self.enable_llm_processing:
|
||||||
|
processed_sentences = await self.process_sentences_with_llm(unique_sentences)
|
||||||
|
else:
|
||||||
|
# 基础模式:不使用LLM
|
||||||
|
processed_sentences = [
|
||||||
|
{
|
||||||
|
"original_sentence": sentence,
|
||||||
|
"corrected_sentence": sentence,
|
||||||
|
"importance_score": 5.0
|
||||||
|
}
|
||||||
|
for sentence in unique_sentences
|
||||||
|
]
|
||||||
|
|
||||||
|
return processed_sentences
|
||||||
|
|
||||||
|
def save_sentences(self, processed_sentences: List[Dict[str, Any]]):
|
||||||
|
"""保存处理后的句子到文件"""
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True)
|
||||||
|
|
||||||
|
# 保存为JSON格式,包含完整信息
|
||||||
|
json_output_file = self.output_file.replace('.txt', '.json')
|
||||||
|
with open(json_output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(processed_sentences, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# 保存为简单文本格式(仅修正后的句子)
|
||||||
|
with open(self.output_file, 'w', encoding='utf-8') as f:
|
||||||
|
for item in processed_sentences:
|
||||||
|
f.write(item['corrected_sentence'] + '\n')
|
||||||
|
|
||||||
|
# 生成重要性排序文件
|
||||||
|
importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True)
|
||||||
|
importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt')
|
||||||
|
with open(importance_file, 'w', encoding='utf-8') as f:
|
||||||
|
for item in importance_sorted:
|
||||||
|
f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n")
|
||||||
|
|
||||||
|
print(f"保存了 {len(processed_sentences)} 个处理后的句子:")
|
||||||
|
print(f" - JSON格式: {json_output_file}")
|
||||||
|
print(f" - 文本格式: {self.output_file}")
|
||||||
|
print(f" - 重要性排序: {importance_file}")
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
scores = [item['importance_score'] for item in processed_sentences]
|
||||||
|
avg_score = sum(scores) / len(scores) if scores else 0
|
||||||
|
print(f" - 平均重要性评分: {avg_score:.2f}")
|
||||||
|
print(f" - 最高评分: {max(scores):.1f}")
|
||||||
|
print(f" - 最低评分: {min(scores):.1f}")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""运行处理流程"""
|
||||||
|
print("Starting enhanced TREx to sentences conversion...")
|
||||||
|
processed_sentences = await self.process_files()
|
||||||
|
self.save_sentences(processed_sentences)
|
||||||
|
print("Enhanced conversion completed!")
|
||||||
|
|
||||||
|
def find_latest_checkpoint(self) -> Union[tuple, None]:
|
||||||
|
"""查找最新的检查点文件"""
|
||||||
|
base_name = os.path.splitext(self.output_file)[0]
|
||||||
|
pattern = f"./output/{base_name}_checkpoint_*.json"
|
||||||
|
checkpoint_files = glob.glob(pattern)
|
||||||
|
|
||||||
|
if not checkpoint_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 按检查点编号排序,获取最新的
|
||||||
|
latest_file = None
|
||||||
|
latest_count = 0
|
||||||
|
|
||||||
|
for file in checkpoint_files:
|
||||||
|
try:
|
||||||
|
# 从文件名中提取数字
|
||||||
|
match = re.search(r'checkpoint_(\d+)\.json$', file)
|
||||||
|
if match:
|
||||||
|
count = int(match.group(1))
|
||||||
|
if count > latest_count:
|
||||||
|
latest_count = count
|
||||||
|
latest_file = file
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if latest_file:
|
||||||
|
return latest_file, latest_count
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]:
|
||||||
|
"""从检查点文件加载已处理的句子"""
|
||||||
|
try:
|
||||||
|
with open(checkpoint_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if 'sentences' in data:
|
||||||
|
return data['sentences']
|
||||||
|
else:
|
||||||
|
# 旧格式的检查点文件
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载检查点文件失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing')
|
||||||
|
parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files')
|
||||||
|
parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path')
|
||||||
|
parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)')
|
||||||
|
parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)')
|
||||||
|
parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.input_dir):
|
||||||
|
print(f"Error: Input directory {args.input_dir} does not exist!")
|
||||||
|
return
|
||||||
|
|
||||||
|
processor = EnhancedTRExProcessor(
|
||||||
|
args.input_dir,
|
||||||
|
args.output_file,
|
||||||
|
args.max_files,
|
||||||
|
enable_llm_processing=not args.no_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否要从检查点恢复
|
||||||
|
if args.resume:
|
||||||
|
checkpoint_result = processor.find_latest_checkpoint()
|
||||||
|
if checkpoint_result:
|
||||||
|
latest_checkpoint, latest_count = checkpoint_result
|
||||||
|
print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)")
|
||||||
|
confirm = input("是否从检查点恢复?(y/n): ").lower().strip()
|
||||||
|
if confirm == 'y':
|
||||||
|
processed_sentences = processor.load_checkpoint(latest_checkpoint)
|
||||||
|
if processed_sentences:
|
||||||
|
print(f"成功加载 {len(processed_sentences)} 条已处理的句子")
|
||||||
|
processor.save_sentences(processed_sentences)
|
||||||
|
print("从检查点恢复完成!")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("检查点文件加载失败,将重新开始处理")
|
||||||
|
else:
|
||||||
|
print("不从检查点恢复,将重新开始处理")
|
||||||
|
else:
|
||||||
|
print("未找到检查点文件,将重新开始处理")
|
||||||
|
|
||||||
|
# 运行异步处理
|
||||||
|
asyncio.run(processor.run())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user