import json import os import argparse from typing import List, Dict, Any, Optional from collections import defaultdict import pickle from pathlib import Path class WikidataRelationManager: """Wikidata关系管理器,支持动态获取和缓存""" def __init__(self, cache_file: str = "wikidata_relations_cache.pkl", mapping_file: str = None): self.cache_file = cache_file self.mapping_file = mapping_file self.relations = {} # 删除了API相关属性 # 初始的基础关系映射 self.base_relations = { # # 基本关系 # 'P31': 'instance of', # 'P279': 'subclass of', # 'P17': 'country', # 'P159': 'headquarters location', # 'P571': 'inception', # # 人物关系 # 'P19': 'place of birth', # 'P20': 'place of death', # 'P27': 'country of citizenship', # 'P106': 'occupation', # 'P22': 'father', # 'P25': 'mother', # 'P26': 'spouse', # 'P40': 'child', # 'P69': 'educated at', # 'P108': 'employer', # # 地理关系 # 'P36': 'capital', # 'P131': 'located in', # 'P47': 'shares border with', # 'P206': 'located on terrain feature', # 'P1376': 'capital of', # # 组织关系 # 'P112': 'founded by', # 'P127': 'owned by', # 'P169': 'chief executive officer', # 'P488': 'chairperson', # 'P749': 'parent organization', # # 作品关系 # 'P50': 'author', # 'P57': 'director', # 'P58': 'screenwriter', # 'P161': 'cast member', # 'P175': 'performer', # 'P577': 'publication date', # 'P123': 'publisher', # 'P136': 'genre', # # 时间关系 # 'P155': 'follows', # 'P156': 'followed by', # 'P580': 'start time', # 'P582': 'end time', # # 体育关系 # 'P54': 'member of sports team', # 'P413': 'position played on team', # 'P118': 'league', # # 科学关系 # 'P275': 'copyright license', # 'P170': 'creator', # 'P398': 'child astronomical body', # 'P397': 'parent astronomical body', # # 其他常见关系 # 'P37': 'official language', # 'P1923': 'place of marriage', # 'P737': 'influenced by', # 'P463': 'member of', # 'P39': 'position held', # 'P276': 'location', # 'P1441': 'present in work', } self.load_cache() def load_cache(self): """加载缓存的关系映射,优先使用JSON映射文件""" try: # 优先尝试加载JSON映射文件 if self.mapping_file and os.path.exists(self.mapping_file): with open(self.mapping_file, 'r', encoding='utf-8') as f: self.relations = json.load(f) print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射") return # 尝试加载pickle缓存文件 if os.path.exists(self.cache_file): with open(self.cache_file, 'rb') as f: self.relations = pickle.load(f) print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射") else: self.relations = self.base_relations.copy() print(f"初始化基础关系映射: {len(self.relations)} 个") except Exception as e: print(f"加载缓存失败: {e}") self.relations = self.base_relations.copy() def save_cache(self): """保存关系映射到缓存""" try: with open(self.cache_file, 'wb') as f: pickle.dump(self.relations, f) print(f"已保存 {len(self.relations)} 个关系映射到缓存") except Exception as e: print(f"保存缓存失败: {e}") # 删除了网络抓取功能,改为纯离线模式 def get_relation_name(self, property_id: str) -> Optional[str]: """获取关系名称,仅使用本地映射""" if property_id in self.relations: return self.relations[property_id] # 如果本地映射中没有找到,返回None(表示跳过这个关系) return None # 删除了网络请求相关的批量获取和预加载功能 class TRexProcessor: """T-REx数据集处理器""" def __init__(self, relation_manager: WikidataRelationManager): self.relation_manager = relation_manager def extract_predicate_id(self, uri: str) -> str: """从URI中提取属性ID""" if uri and 'prop/direct/' in uri: return uri.split('/')[-1] elif uri and uri.startswith('P') and uri[1:].isdigit(): return uri return uri if uri else 'unknown' def get_relation_name(self, predicate_uri: str) -> Optional[str]: """获取关系的可读名称""" predicate_id = self.extract_predicate_id(predicate_uri) return self.relation_manager.get_relation_name(predicate_id) # 删除了谓词收集功能,因为不再需要预加载 def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float, boundary_threshold: int) -> bool: """检查三元组是否满足过滤条件""" try: # 检查triple是否为字典 if not isinstance(triple, dict): return False # 检查必要字段 if not all(key in triple for key in ['subject', 'predicate', 'object']): return False subject = triple['subject'] predicate = triple['predicate'] object_info = triple['object'] # 检查subject、predicate、object是否都为字典 if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict): return False # 检查主语和宾语是否有有效的URI和surfaceform if not (subject.get('uri') and subject.get('surfaceform')): return False if not (object_info.get('uri') and object_info.get('surfaceform')): return False if not predicate.get('uri'): return False # 检查置信度(如果存在) confidence = triple.get('confidence') if confidence is not None and confidence < confidence_threshold: return False # 检查边界信息(如果设置了阈值) if boundary_threshold > 0: subject_boundaries = subject.get('boundaries') object_boundaries = object_info.get('boundaries') if not subject_boundaries or not object_boundaries: return False # 检查边界是否为列表且长度至少为2 if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2): return False if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2): return False try: # 检查边界长度是否合理 subject_length = subject_boundaries[1] - subject_boundaries[0] object_length = object_boundaries[1] - object_boundaries[0] if subject_length < boundary_threshold or object_length < boundary_threshold: return False except (TypeError, IndexError): return False # 检查文本内容是否合理 subject_text = subject.get('surfaceform', '').strip() object_text = object_info.get('surfaceform', '').strip() if not subject_text or not object_text: return False # 过滤掉过长或过短的实体 if len(subject_text) > 100 or len(object_text) > 100: return False if len(subject_text) < 2 or len(object_text) < 2: return False return True except (KeyError, TypeError, AttributeError): return False def process_single_file(self, file_path: str, confidence_threshold: float, boundary_threshold: int) -> List[Dict[str, Any]]: """处理单个JSON文件""" print(f"Processing file: {file_path}") processed_data = [] try: with open(file_path, 'r', encoding='utf-8') as f: # 读取整个文件作为JSON数组 print(f"正在加载JSON数组文件: {file_path}") data_list = json.load(f) print(f"文件包含 {len(data_list)} 个条目") for idx, data in enumerate(data_list): try: # 获取基本信息 text = data.get('text', '').strip() if not text: continue # 处理三元组 triples = data.get('triples', []) if not triples: continue valid_targets = [] for triple in triples: if self.is_valid_triple(triple, confidence_threshold, boundary_threshold): # 获取关系名称,如果无法解析则跳过这个三元组 relation_name = self.get_relation_name(triple['predicate']['uri']) if relation_name is None: continue # 跳过无法解析的关系 target = { 'subject': triple['subject']['surfaceform'].strip(), 'predicate': relation_name, 'object': triple['object']['surfaceform'].strip() } valid_targets.append(target) # 如果有有效的三元组,添加到结果中 if valid_targets: processed_data.append({ 'text': text, 'target': valid_targets }) except Exception as e: if idx <= 10: # 只打印前10个错误 print(f"处理条目时出错 in {file_path} at index {idx}: {e}") continue except FileNotFoundError: print(f"文件未找到: {file_path}") except json.JSONDecodeError as e: print(f"JSON解析错误 in {file_path}: {e}") except Exception as e: print(f"处理文件时出错 {file_path}: {e}") print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本") return processed_data def process_folder(self, folder_path: str, confidence_threshold: float, boundary_threshold: int) -> List[Dict[str, Any]]: """处理文件夹中的所有JSON文件""" all_processed_data = [] if not os.path.exists(folder_path): raise FileNotFoundError(f"文件夹不存在: {folder_path}") # 获取所有JSON文件 json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')] if not json_files: raise ValueError(f"在 {folder_path} 中没有找到JSON文件") print(f"找到 {len(json_files)} 个JSON文件") for filename in sorted(json_files): file_path = os.path.join(folder_path, filename) processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold) all_processed_data.extend(processed_data) # 保存最终的关系缓存 self.relation_manager.save_cache() return all_processed_data def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]: """生成数据统计信息""" total_samples = len(processed_data) total_triples = sum(len(sample['target']) for sample in processed_data) # 统计关系类型 relation_counts = defaultdict(int) for sample in processed_data: for target in sample['target']: relation_counts[target['predicate']] += 1 # 统计文本长度 text_lengths = [len(sample['text']) for sample in processed_data] avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0 # 统计每个文本的三元组数量 triples_per_text = [len(sample['target']) for sample in processed_data] avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0 return { 'total_samples': total_samples, 'total_triples': total_triples, 'avg_text_length': round(avg_text_length, 2), 'avg_triples_per_text': round(avg_triples_per_text, 2), 'relation_distribution': dict(sorted(relation_counts.items(), key=lambda x: x[1], reverse=True)), 'top_10_relations': dict(list(sorted(relation_counts.items(), key=lambda x: x[1], reverse=True))[:10]), 'total_unique_relations': len(relation_counts), 'cached_relations': len(self.relation_manager.relations) } def main(): parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)') parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径') parser.add_argument('--confidence_threshold', type=float, default=0.5, help='置信度阈值 (默认: 0.0)') parser.add_argument('--boundary_threshold', type=int, default=0, help='边界长度阈值 (默认: 0, 不过滤)') parser.add_argument('--output', type=str, default='./processed_trex_data.json', help='输出文件名 (默认: processed_trex_data.json)') parser.add_argument('--stats', type=str, default='trex_statistics.json', help='统计信息输出文件名 (默认: trex_statistics.json)') parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl', help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)') parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json", help='JSON映射文件路径 (必须提供,用于关系名称映射)') args = parser.parse_args() print("T-REx数据集处理器(支持动态关系获取)") print("=" * 60) print(f"输入文件夹: {args.folder_path}") print(f"置信度阈值: {args.confidence_threshold}") print(f"边界长度阈值: {args.boundary_threshold}") print(f"输出文件: {args.output}") print(f"关系缓存文件: {args.cache_file}") print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}") print("=" * 60) # 检查映射文件是否存在 if not args.mapping_file or not os.path.exists(args.mapping_file): print(f"错误: 映射文件不存在或未指定: {args.mapping_file}") print("请确保提供有效的JSON映射文件。") return 1 # 创建关系管理器 relation_manager = WikidataRelationManager( cache_file=args.cache_file, mapping_file=args.mapping_file ) # 创建处理器 processor = TRexProcessor(relation_manager) try: # 处理数据 processed_data = processor.process_folder( args.folder_path, args.confidence_threshold, args.boundary_threshold ) print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本") # 生成统计信息 stats = processor.generate_statistics(processed_data) # 保存处理后的数据 with open(args.output, 'w', encoding='utf-8') as f: json.dump(processed_data, f, ensure_ascii=False, indent=2) # 保存统计信息 with open(args.stats, 'w', encoding='utf-8') as f: json.dump(stats, f, ensure_ascii=False, indent=2) print(f"\n数据已保存到: {args.output}") print(f"统计信息已保存到: {args.stats}") print(f"关系缓存已保存到: {args.cache_file}") # 打印统计摘要 print("\n数据统计摘要:") print("=" * 30) print(f"总样本数: {stats['total_samples']}") print(f"总三元组数: {stats['total_triples']}") print(f"唯一关系数: {stats['total_unique_relations']}") print(f"缓存关系数: {stats['cached_relations']}") print(f"平均文本长度: {stats['avg_text_length']}") print(f"平均每文本三元组数: {stats['avg_triples_per_text']}") print("\n前10个最常见关系:") for relation, count in stats['top_10_relations'].items(): print(f" {relation}: {count}") except Exception as e: print(f"处理过程中出错: {e}") return 1 return 0 if __name__ == "__main__": exit(main())