diff --git a/preprocessing/trex_to_sentences_simple.py b/preprocessing/trex_to_sentences_simple.py new file mode 100644 index 0000000..721b9f0 --- /dev/null +++ b/preprocessing/trex_to_sentences_simple.py @@ -0,0 +1,734 @@ +#!/usr/bin/env python3 +""" +TREx数据集增强预处理脚本 +使用agno框架和ollama qwen3:4b进行句子后处理和重要性评分 +""" + +import json +import os +import glob +from typing import List, Dict, Any, Union +import re +import asyncio +import time +from pydantic import BaseModel, Field +from agno.agent import Agent +from agno.models.ollama import Ollama + + +class ProcessedSentence(BaseModel): + """处理后的句子结构""" + corrected_sentence: str = Field( + ..., + description="修正后的句子,只修正语法错误、乱码和不通顺的地方,不进行额外润色" + ) + importance_score: float = Field( + ..., + description="重要性评分,范围0.0-10.0,以0.1递进。评判这个知识在现实世界中的常用程度和重要度", + ge=0.0, + le=10.0 + ) + + +class EnhancedTRExProcessor: + def __init__(self, input_dir: str, output_file: str, max_files: int = None, enable_llm_processing: bool = True): + self.input_dir = input_dir + self.output_file = output_file + self.max_files = max_files + self.enable_llm_processing = enable_llm_processing + + # 初始化agno agent + if self.enable_llm_processing: + self.setup_agent() + + # 扩展的Wikidata属性映射 + self.property_mappings = { + # 基本关系 + "http://www.wikidata.org/prop/direct/P31": "is a", + "http://www.wikidata.org/prop/direct/P279": "is a type of", + + # 人物相关 + "http://www.wikidata.org/prop/direct/P106": "works as", + "http://www.wikidata.org/prop/direct/P27": "is a citizen of", + "http://www.wikidata.org/prop/direct/P19": "was born in", + "http://www.wikidata.org/prop/direct/P20": "died in", + "http://www.wikidata.org/prop/direct/P569": "was born on", + "http://www.wikidata.org/prop/direct/P570": "died on", + "http://www.wikidata.org/prop/direct/P22": "has father", + "http://www.wikidata.org/prop/direct/P25": "has mother", + "http://www.wikidata.org/prop/direct/P26": "is married to", + + # 组织相关 + "http://www.wikidata.org/prop/direct/P102": "is a member of", + "http://www.wikidata.org/prop/direct/P108": "works for", + "http://www.wikidata.org/prop/direct/P159": "has headquarters in", + "http://www.wikidata.org/prop/direct/P112": "was founded by", + "http://www.wikidata.org/prop/direct/P571": "was founded in", + "http://www.wikidata.org/prop/direct/P169": "has CEO", + + # 地理相关 + "http://www.wikidata.org/prop/direct/P17": "is located in", + "http://www.wikidata.org/prop/direct/P131": "is located in", + "http://www.wikidata.org/prop/direct/P36": "has capital", + "http://www.wikidata.org/prop/direct/P47": "borders", + + # 其他关系 + "http://www.wikidata.org/prop/direct/P1142": "has ideology", + "http://www.wikidata.org/prop/direct/P361": "is part of", + "http://www.wikidata.org/prop/direct/P737": "was influenced by", + "http://www.wikidata.org/prop/direct/P127": "is owned by", + "http://www.wikidata.org/prop/direct/P155": "follows", + "http://www.wikidata.org/prop/direct/P156": "is followed by", + "http://www.wikidata.org/prop/direct/P138": "is named after" + } + + def setup_agent(self): + """设置agno agent""" + try: + self.agent = Agent( + model=Ollama( + id="qwen3:4b", + # 使用options设置temperature和其他参数 + options={ + "temperature": 0.7, + "top_p": 0.8, + "top_k": 20, + "num_ctx": 4096, + } + ), + response_model=ProcessedSentence, + instructions=[ + "你是一个专业的文本处理助手,负责修正句子中的错误并评估知识的重要性。", + "", + "### 句子修正规则:", + "1. 移除Wikipedia特有标记:如(disambiguation)、(film)、(band)等括号内容", + "2. 确保句子语法完整:主语+谓语+宾语结构完整,避免悬空的'and is'、'or'等", + "3. 修正明显的语法错误:时态一致、单复数一致、介词使用正确", + "4. 清理乱码和特殊字符:如â、€、™等编码问题", + "5. 确保句子语义通顺:如果原句无法修复,重新组织语言使其通顺", + "6. 不要添加原文没有的信息,只修正错误", + "", + "### 修正示例:", + "- 错误:'Argument (disambiguation) is related to philosophy, logic, and is an.'", + "- 修正:'Argument is related to philosophy and logic.'", + "", + "- 错误:'Beijing is a capital city and are.'", + "- 修正:'Beijing is a capital city.'", + "", + "重要性评分标准(0.0-10.0,以0.1递进):", + "", + "0.0分 - 完全错误或无意义的信息", + "例:'苹果是一种金属'、'太阳从西边升起'、'1+1=3'", + "", + "0.5分 - 几乎无价值的信息", + "例:'某个虚构角色的袜子颜色'、'游戏中NPC的对话第三句话'、'某人昨天早餐吃了什么'", + "", + "1.0分 - 极其罕见、无实用价值的知识", + "例:'某小说背景角色宠物名字'、'某部电影片尾字幕第15行内容'、'某网站用户ID为123456的昵称'", + "", + "1.5分 - 非常小众的细节信息", + "例:'某电影第37分钟路人甲服装'、'某游戏隐藏关卡的背景音乐时长'、'某漫画第200页第3个对话框内容'", + "", + "2.0分 - 小众专业领域的细节", + "例:'稀有矿物在特定温度下颜色变化'、'某种昆虫的第三对触角长度'、'某化学反应的副产物分子式'", + "", + "2.5分 - 专业人士才关心的技术细节", + "例:'软件库特定版本发布日期'、'某算法的时间复杂度系数'、'某种材料的热膨胀系数'", + "", + "3.0分 - 特定领域的专业知识", + "例:'编程语言语法特性'、'某种病毒的基因序列'、'古代某朝代的官职制度'", + "", + "3.5分 - 有一定价值的专业信息", + "例:'某历史朝代特定制度'、'某种药物的作用机制'、'某技术标准的制定时间'", + "", + "4.0分 - 较少人知道但有意义的知识", + "例:'某国家独特文化传统'、'某科学家的重要发现'、'某历史事件的详细过程'", + "", + "4.5分 - 部分人群感兴趣的知识", + "例:'作家创作背景'、'某艺术流派特点'、'某运动项目规则细节'", + "", + "5.0分 - 中等重要性的一般知识", + "例:'城市著名景点'、'某企业发展历史'、'某动物生活习性'", + "", + "5.5分 - 比较有用的常识", + "例:'植物生长环境'、'健康饮食常识'、'基本急救知识'", + "", + "6.0分 - 多数受教育人群应该知道的知识", + "例:'莎士比亚代表作品'、'基本几何定理'、'世界主要货币'", + "", + "6.5分 - 重要的文化或科学常识", + "例:'DNA基本结构'、'牛顿三大定律'、'世界主要宗教'", + "", + "7.0分 - 重要的基础知识", + "例:'二次世界大战时间'、'人体主要器官功能'、'基本数学运算规则'", + "", + "7.5分 - 非常重要的常识", + "例:'光速是宇宙中最快的'、'地球是圆的'、'血液循环基本原理'", + "", + "8.0分 - 基础教育中的核心知识", + "例:'地球绕太阳运行'、'四季形成原理'、'基本语法规则'", + "", + "8.5分 - 每个人都应该掌握的重要知识", + "例:'水的化学式H2O'、'基本安全常识'、'简单数学计算'", + "", + "9.0分 - 极其重要的基础概念", + "例:'人类需要氧气生存'、'火是热的'、'基本方向概念'", + "", + "9.5分 - 人人必知的核心知识", + "例:'一天有24小时'、'一年有12个月'、'基本数字概念'", + "", + "10.0分 - 最基础、最重要的常识", + "例:'人类需要食物和水生存'、'天空是蓝色的'、'石头比羽毛重'", + "", + "评分时请考虑:", + "1. 知识的普及程度 - 有多少人知道这个知识", + "2. 实用价值 - 这个知识在日常生活中有多大用处", + "3. 教育重要性 - 这个知识在教育体系中的地位", + "4. 文化意义 - 这个知识对理解世界的重要性", + "", + "请直接输出结构化结果,不需要思考过程。" + ], + markdown=False + ) + print("LLM处理器初始化成功") + except Exception as e: + print(f"LLM处理器初始化失败: {e}") + print("将使用基础模式(不使用LLM后处理)") + self.enable_llm_processing = False + + async def process_sentence_with_llm(self, sentence: str) -> ProcessedSentence: + """使用LLM处理单个句子(保留用于单独调用)""" + try: + prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}" + + # 使用agent.arun进行异步调用 + response = await self.agent.arun(prompt) + + # 根据agno文档,response应该直接是ProcessedSentence类型 + if isinstance(response, ProcessedSentence): + return response + else: + message = response.messages[-1].content + message = message.replace("```json", "").replace("```", "") + message = json.loads(message) + return ProcessedSentence( + corrected_sentence=message['corrected_sentence'], + importance_score=message['importance_score'] + ) + + except Exception as e: + print(f"LLM处理句子时出错: {e}") + # 出错时返回原句子和中等评分 + return ProcessedSentence( + corrected_sentence=sentence, + importance_score=5.0 + ) + + def clean_text(self, text: str) -> str: + """清理文本,处理特殊字符""" + if not text: + return "" + + # 处理常见的Unicode字符 + text = text.replace("–", "-") # en dash + text = text.replace("—", "-") # em dash + text = text.replace("'", "'") # right single quotation mark + text = text.replace("'", "'") # left single quotation mark + text = text.replace(""", '"') # left double quotation mark + text = text.replace(""", '"') # right double quotation mark + + # 处理可能的转义序列 + try: + text = text.encode('utf-8').decode('utf-8') + except: + pass + + # 清理多余的空格 + text = re.sub(r'\s+', ' ', text).strip() + + # 移除可能的引号 + text = text.strip('"\'') + + return text + + def parse_large_json_file(self, file_path: str) -> List[Dict]: + """解析大型JSON文件,处理可能的格式问题""" + documents = [] + + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read().strip() + + # 尝试不同的解析方法 + if content.startswith('[') and content.endswith(']'): + # 标准JSON数组 + documents = json.loads(content) + else: + # 可能是连续的JSON对象 + # 尝试在}{"之间分割 + if '}{"' in content: + json_strings = content.split('}{') + json_strings[0] += '}' # 第一个对象 + json_strings[-1] = '{' + json_strings[-1] # 最后一个对象 + + for i in range(1, len(json_strings) - 1): + json_strings[i] = '{' + json_strings[i] + '}' + + for json_str in json_strings: + try: + doc = json.loads(json_str) + documents.append(doc) + except json.JSONDecodeError: + continue + else: + # 尝试作为单个JSON对象 + try: + documents = [json.loads(content)] + except json.JSONDecodeError: + pass + + except Exception as e: + print(f"Error parsing {file_path}: {e}") + + return documents + + def extract_sentences_from_document(self, doc: Dict[str, Any]) -> List[str]: + """从文档中提取句子""" + sentences = [] + + title = self.clean_text(doc.get('title', '')) + text = self.clean_text(doc.get('text', '')) + entities = doc.get('entities', []) + triples = doc.get('triples', []) + + # 处理显式三元组 + for triple in triples: + sentence = self.triple_to_sentence(triple) + if sentence: + sentences.append(sentence) + + # 从实体和文本中生成基本句子(如果三元组句子不够) + if title and text and len(sentences) < 5: + # 基于标题和实体生成句子 + entity_names = [] + for entity in entities[:15]: + entity_name = self.clean_text(entity.get('surfaceform', '')) + if entity_name and len(entity_name) > 2: + entity_names.append(entity_name) + + # 生成简单的描述句子 + if entity_names: + important_entities = [] + title_lower = title.lower() + for entity in entity_names: + if (entity.lower() != title_lower and + entity not in important_entities and + not any(t.lower() in entity.lower() for t in title_lower.split()[:2])): + important_entities.append(entity) + if len(important_entities) >= 6: + break + + if important_entities and len(sentences) < 3: + entities_str = ', '.join(important_entities[:3]) + sentences.append(f"{title} is related to {entities_str}.") + + return sentences + + def triple_to_sentence(self, triple: Dict[str, Any]) -> str: + """将三元组转换为自然语言句子""" + try: + subject = triple.get('subject', {}) + predicate = triple.get('predicate', {}) + obj = triple.get('object', {}) + + subject_name = self.clean_text(subject.get('surfaceform', '')) + object_name = self.clean_text(obj.get('surfaceform', '')) + predicate_uri = predicate.get('uri', '') + + # 检查是否有有效的主语和宾语 + if not subject_name or not object_name: + return "" + + # 检查主语和宾语是否过短或无意义 + if len(subject_name) <= 2 or len(object_name) <= 2: + return "" + + # 获取关系文本 + relation_text = self.property_mappings.get(predicate_uri, "is related to") + + # 避免重复的主语宾语 + if subject_name.lower() == object_name.lower(): + return "" + + return f"{subject_name} {relation_text} {object_name}." + + except Exception as e: + print(f"Error converting triple to sentence: {e}") + return "" + + async def process_sentence_with_llm_concurrent(self, semaphore: asyncio.Semaphore, sentence: str, index: int, total_sentences: int, start_time: float) -> Dict[str, Any]: + """使用信号量控制并发的LLM处理""" + async with semaphore: + try: + prompt = f"请修正以下句子中的错误并评估其重要性:{sentence}" + + # 使用agent.arun进行异步调用 + response = await self.agent.arun(prompt) + + # 根据agno文档,response应该直接是ProcessedSentence类型 + if isinstance(response, ProcessedSentence): + result = { + "index": index, + "original_sentence": sentence, + "corrected_sentence": response.corrected_sentence, + "importance_score": response.importance_score + } + else: + message = response.messages[-1].content + message = message.replace("```json", "").replace("```", "") + message = json.loads(message) + # print(message) + result = { + "index": index, + "original_sentence": sentence, + "corrected_sentence": message['corrected_sentence'], + "importance_score": message['importance_score'] + } + + # 打印详细进度信息 + if index % 100 == 0: + current_time = time.time() + elapsed_time = current_time - start_time + avg_time_per_sentence = elapsed_time / (index + 1) if index > 0 else elapsed_time + remaining_sentences = total_sentences - (index + 1) + estimated_remaining_time = avg_time_per_sentence * remaining_sentences + + # 格式化时间显示 + def format_time(seconds): + if seconds < 60: + return f"{seconds:.1f}秒" + elif seconds < 3600: + minutes = seconds / 60 + return f"{minutes:.1f}分钟" + else: + hours = seconds / 3600 + return f"{hours:.1f}小时" + + print(f"已完成第 {index + 1} 个句子的处理") + print(f" - 剩余句子数: {remaining_sentences}") + print(f" - 平均处理时间: {avg_time_per_sentence:.2f}秒/句") + print(f" - 预估剩余时间: {format_time(estimated_remaining_time)}") + print(f" - 已用时间: {format_time(elapsed_time)}") + + return result + + except Exception as e: + print(f"处理第 {index} 个句子时出错: {e}") + # 出错时返回原句子和中等评分 + return { + "index": index, + "original_sentence": sentence, + "corrected_sentence": sentence, + "importance_score": 5.0 + } + + async def process_sentences_with_llm(self, sentences: List[str]) -> List[Dict[str, Any]]: + """批量并发处理句子,每2000条保存一次检查点""" + print(f"开始使用LLM并发处理 {len(sentences)} 个句子(最大并发数:54)...") + + # 记录开始时间 + start_time = time.time() + total_sentences = len(sentences) + + # 分批处理,每批2000个句子 + batch_size = 2000 + all_processed_sentences = [] + + for batch_start in range(0, total_sentences, batch_size): + batch_end = min(batch_start + batch_size, total_sentences) + batch_sentences = sentences[batch_start:batch_end] + + print(f"\n=== 处理第 {batch_start//batch_size + 1} 批 ({batch_start + 1}-{batch_end}/{total_sentences}) ===") + + # 创建信号量限制并发数 + semaphore = asyncio.Semaphore(54) + + # 创建当前批次的任务 + tasks = [] + for i, sentence in enumerate(batch_sentences): + global_index = batch_start + i + task = self.process_sentence_with_llm_concurrent(semaphore, sentence, global_index, total_sentences, start_time) + tasks.append(task) + + # 并发执行当前批次的任务 + print(f"正在并发处理第 {batch_start//batch_size + 1} 批的 {len(batch_sentences)} 个句子...") + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理当前批次的结果,过滤异常 + batch_processed_sentences = [] + batch_error_count = 0 + + for result in batch_results: + if isinstance(result, Exception): + print(f"任务执行异常: {result}") + batch_error_count += 1 + elif isinstance(result, dict): + batch_processed_sentences.append(result) + else: + batch_error_count += 1 + + # 按原始顺序排序(因为并发执行可能改变顺序) + batch_processed_sentences.sort(key=lambda x: x['index']) + + # 移除index字段 + for item in batch_processed_sentences: + del item['index'] + + # 添加到总结果中 + all_processed_sentences.extend(batch_processed_sentences) + + # 保存检查点 + checkpoint_filename = self.save_checkpoint(all_processed_sentences, batch_end) + + # 打印当前批次统计信息 + elapsed_time = time.time() - start_time + completed_sentences = len(all_processed_sentences) + + print(f"第 {batch_start//batch_size + 1} 批处理完成!") + print(f" - 当前批次:成功 {len(batch_processed_sentences)},失败 {batch_error_count}") + print(f" - 总体进度:{completed_sentences}/{total_sentences} ({completed_sentences/total_sentences*100:.1f}%)") + print(f" - 已用时间:{elapsed_time/60:.1f}分钟") + print(f" - 平均速度:{completed_sentences/elapsed_time:.2f}句/秒") + print(f" - 检查点已保存:{checkpoint_filename}") + + if batch_end < total_sentences: + remaining_sentences = total_sentences - completed_sentences + avg_time_per_sentence = elapsed_time / completed_sentences + estimated_remaining_time = avg_time_per_sentence * remaining_sentences + print(f" - 预估剩余时间:{estimated_remaining_time/60:.1f}分钟") + + # 打印最终统计信息 + total_time = time.time() - start_time + print(f"\n=== 全部处理完成!===") + print(f" - 总成功:{len(all_processed_sentences)}") + print(f" - 总用时:{total_time/60:.1f}分钟") + print(f" - 平均处理速度:{len(all_processed_sentences)/total_time:.2f}句/秒") + + return all_processed_sentences + + def save_checkpoint(self, processed_sentences: List[Dict[str, Any]], current_count: int) -> str: + """保存检查点文件""" + # 生成检查点文件名 + base_name = os.path.splitext(self.output_file)[0] + checkpoint_filename = f"{base_name}_checkpoint_{current_count}.json" + + # 保存检查点 + with open(checkpoint_filename, 'w', encoding='utf-8') as f: + json.dump({ + "metadata": { + "total_processed": len(processed_sentences), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "checkpoint_number": current_count + }, + "sentences": processed_sentences + }, f, ensure_ascii=False, indent=2) + + return checkpoint_filename + + async def process_files(self) -> List[Dict[str, Any]]: + """处理所有文件""" + json_files = glob.glob(os.path.join(self.input_dir, "re-nlg_*.json")) + + if not json_files: + print(f"No JSON files found in {self.input_dir}") + return [] + + # 排序文件以确保一致的处理顺序 + json_files.sort() + + if self.max_files: + json_files = json_files[:self.max_files] + + print(f"Found {len(json_files)} JSON files to process") + + all_sentences = [] + + for i, file_path in enumerate(json_files): + print(f"Processing file {i+1}/{len(json_files)}: {os.path.basename(file_path)}") + + documents = self.parse_large_json_file(file_path) + print(f" Parsed {len(documents)} documents") + + for doc in documents: + sentences = self.extract_sentences_from_document(doc) + all_sentences.extend(sentences) + + print(f" Generated {len(all_sentences)} total raw sentences so far") + + print(f"总共提取了 {len(all_sentences)} 个原始句子") + + # 去重 + unique_sentences = [] + seen = set() + for sentence in all_sentences: + sentence = sentence.strip() + if sentence and sentence not in seen and len(sentence) > 10: + unique_sentences.append(sentence) + seen.add(sentence) + + print(f"去重后剩余 {len(unique_sentences)} 个句子") + + # 使用LLM处理句子 + if self.enable_llm_processing: + processed_sentences = await self.process_sentences_with_llm(unique_sentences) + else: + # 基础模式:不使用LLM + processed_sentences = [ + { + "original_sentence": sentence, + "corrected_sentence": sentence, + "importance_score": 5.0 + } + for sentence in unique_sentences + ] + + return processed_sentences + + def save_sentences(self, processed_sentences: List[Dict[str, Any]]): + """保存处理后的句子到文件""" + # 确保输出目录存在 + os.makedirs(os.path.dirname(self.output_file) if os.path.dirname(self.output_file) else '.', exist_ok=True) + + # 保存为JSON格式,包含完整信息 + json_output_file = self.output_file.replace('.txt', '.json') + with open(json_output_file, 'w', encoding='utf-8') as f: + json.dump(processed_sentences, f, ensure_ascii=False, indent=2) + + # 保存为简单文本格式(仅修正后的句子) + with open(self.output_file, 'w', encoding='utf-8') as f: + for item in processed_sentences: + f.write(item['corrected_sentence'] + '\n') + + # 生成重要性排序文件 + importance_sorted = sorted(processed_sentences, key=lambda x: x['importance_score'], reverse=True) + importance_file = self.output_file.replace('.txt', '_sorted_by_importance.txt') + with open(importance_file, 'w', encoding='utf-8') as f: + for item in importance_sorted: + f.write(f"[{item['importance_score']:.1f}] {item['corrected_sentence']}\n") + + print(f"保存了 {len(processed_sentences)} 个处理后的句子:") + print(f" - JSON格式: {json_output_file}") + print(f" - 文本格式: {self.output_file}") + print(f" - 重要性排序: {importance_file}") + + # 统计信息 + scores = [item['importance_score'] for item in processed_sentences] + avg_score = sum(scores) / len(scores) if scores else 0 + print(f" - 平均重要性评分: {avg_score:.2f}") + print(f" - 最高评分: {max(scores):.1f}") + print(f" - 最低评分: {min(scores):.1f}") + + async def run(self): + """运行处理流程""" + print("Starting enhanced TREx to sentences conversion...") + processed_sentences = await self.process_files() + self.save_sentences(processed_sentences) + print("Enhanced conversion completed!") + + def find_latest_checkpoint(self) -> Union[tuple, None]: + """查找最新的检查点文件""" + base_name = os.path.splitext(self.output_file)[0] + pattern = f"./output/{base_name}_checkpoint_*.json" + checkpoint_files = glob.glob(pattern) + + if not checkpoint_files: + return None + + # 按检查点编号排序,获取最新的 + latest_file = None + latest_count = 0 + + for file in checkpoint_files: + try: + # 从文件名中提取数字 + match = re.search(r'checkpoint_(\d+)\.json$', file) + if match: + count = int(match.group(1)) + if count > latest_count: + latest_count = count + latest_file = file + except: + continue + + if latest_file: + return latest_file, latest_count + else: + return None + + def load_checkpoint(self, checkpoint_file: str) -> List[Dict[str, Any]]: + """从检查点文件加载已处理的句子""" + try: + with open(checkpoint_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + if 'sentences' in data: + return data['sentences'] + else: + # 旧格式的检查点文件 + return data + except Exception as e: + print(f"加载检查点文件失败: {e}") + return [] + + +def main(): + """主函数""" + import argparse + + parser = argparse.ArgumentParser(description='Convert TREx dataset to enhanced sentences with LLM processing') + parser.add_argument('--input_dir', default='dataset/TREx', help='Input directory containing TREx JSON files') + parser.add_argument('--output_file', default='trex_sentences_enhanced.txt', help='Output file path') + parser.add_argument('--max_files', type=int, help='Maximum number of files to process (for testing)') + parser.add_argument('--no_llm', action='store_true', help='Disable LLM processing (basic mode)') + parser.add_argument('--resume', action='store_true', help='Resume from latest checkpoint if available') + + args = parser.parse_args() + + if not os.path.exists(args.input_dir): + print(f"Error: Input directory {args.input_dir} does not exist!") + return + + processor = EnhancedTRExProcessor( + args.input_dir, + args.output_file, + args.max_files, + enable_llm_processing=not args.no_llm + ) + + # 检查是否要从检查点恢复 + if args.resume: + checkpoint_result = processor.find_latest_checkpoint() + if checkpoint_result: + latest_checkpoint, latest_count = checkpoint_result + print(f"发现检查点文件: {latest_checkpoint} (包含 {latest_count} 条记录)") + confirm = input("是否从检查点恢复?(y/n): ").lower().strip() + if confirm == 'y': + processed_sentences = processor.load_checkpoint(latest_checkpoint) + if processed_sentences: + print(f"成功加载 {len(processed_sentences)} 条已处理的句子") + processor.save_sentences(processed_sentences) + print("从检查点恢复完成!") + return + else: + print("检查点文件加载失败,将重新开始处理") + else: + print("不从检查点恢复,将重新开始处理") + else: + print("未找到检查点文件,将重新开始处理") + + # 运行异步处理 + asyncio.run(processor.run()) + + +if __name__ == "__main__": + main() \ No newline at end of file