443 lines
18 KiB
Python
443 lines
18 KiB
Python
|
import json
|
|||
|
import os
|
|||
|
import argparse
|
|||
|
from typing import List, Dict, Any, Optional
|
|||
|
from collections import defaultdict
|
|||
|
import pickle
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
class WikidataRelationManager:
|
|||
|
"""Wikidata关系管理器,支持动态获取和缓存"""
|
|||
|
|
|||
|
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
|
|||
|
mapping_file: str = None):
|
|||
|
self.cache_file = cache_file
|
|||
|
self.mapping_file = mapping_file
|
|||
|
self.relations = {}
|
|||
|
# 删除了API相关属性
|
|||
|
|
|||
|
# 初始的基础关系映射
|
|||
|
self.base_relations = {
|
|||
|
# # 基本关系
|
|||
|
# 'P31': 'instance of',
|
|||
|
# 'P279': 'subclass of',
|
|||
|
# 'P17': 'country',
|
|||
|
# 'P159': 'headquarters location',
|
|||
|
# 'P571': 'inception',
|
|||
|
|
|||
|
# # 人物关系
|
|||
|
# 'P19': 'place of birth',
|
|||
|
# 'P20': 'place of death',
|
|||
|
# 'P27': 'country of citizenship',
|
|||
|
# 'P106': 'occupation',
|
|||
|
# 'P22': 'father',
|
|||
|
# 'P25': 'mother',
|
|||
|
# 'P26': 'spouse',
|
|||
|
# 'P40': 'child',
|
|||
|
# 'P69': 'educated at',
|
|||
|
# 'P108': 'employer',
|
|||
|
|
|||
|
# # 地理关系
|
|||
|
# 'P36': 'capital',
|
|||
|
# 'P131': 'located in',
|
|||
|
# 'P47': 'shares border with',
|
|||
|
# 'P206': 'located on terrain feature',
|
|||
|
# 'P1376': 'capital of',
|
|||
|
|
|||
|
# # 组织关系
|
|||
|
# 'P112': 'founded by',
|
|||
|
# 'P127': 'owned by',
|
|||
|
# 'P169': 'chief executive officer',
|
|||
|
# 'P488': 'chairperson',
|
|||
|
# 'P749': 'parent organization',
|
|||
|
|
|||
|
# # 作品关系
|
|||
|
# 'P50': 'author',
|
|||
|
# 'P57': 'director',
|
|||
|
# 'P58': 'screenwriter',
|
|||
|
# 'P161': 'cast member',
|
|||
|
# 'P175': 'performer',
|
|||
|
# 'P577': 'publication date',
|
|||
|
# 'P123': 'publisher',
|
|||
|
# 'P136': 'genre',
|
|||
|
|
|||
|
# # 时间关系
|
|||
|
# 'P155': 'follows',
|
|||
|
# 'P156': 'followed by',
|
|||
|
# 'P580': 'start time',
|
|||
|
# 'P582': 'end time',
|
|||
|
|
|||
|
# # 体育关系
|
|||
|
# 'P54': 'member of sports team',
|
|||
|
# 'P413': 'position played on team',
|
|||
|
# 'P118': 'league',
|
|||
|
|
|||
|
# # 科学关系
|
|||
|
# 'P275': 'copyright license',
|
|||
|
# 'P170': 'creator',
|
|||
|
# 'P398': 'child astronomical body',
|
|||
|
# 'P397': 'parent astronomical body',
|
|||
|
|
|||
|
# # 其他常见关系
|
|||
|
# 'P37': 'official language',
|
|||
|
# 'P1923': 'place of marriage',
|
|||
|
# 'P737': 'influenced by',
|
|||
|
# 'P463': 'member of',
|
|||
|
# 'P39': 'position held',
|
|||
|
# 'P276': 'location',
|
|||
|
# 'P1441': 'present in work',
|
|||
|
}
|
|||
|
|
|||
|
self.load_cache()
|
|||
|
|
|||
|
def load_cache(self):
|
|||
|
"""加载缓存的关系映射,优先使用JSON映射文件"""
|
|||
|
try:
|
|||
|
# 优先尝试加载JSON映射文件
|
|||
|
if self.mapping_file and os.path.exists(self.mapping_file):
|
|||
|
with open(self.mapping_file, 'r', encoding='utf-8') as f:
|
|||
|
self.relations = json.load(f)
|
|||
|
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
|
|||
|
return
|
|||
|
|
|||
|
# 尝试加载pickle缓存文件
|
|||
|
if os.path.exists(self.cache_file):
|
|||
|
with open(self.cache_file, 'rb') as f:
|
|||
|
self.relations = pickle.load(f)
|
|||
|
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
|
|||
|
else:
|
|||
|
self.relations = self.base_relations.copy()
|
|||
|
print(f"初始化基础关系映射: {len(self.relations)} 个")
|
|||
|
except Exception as e:
|
|||
|
print(f"加载缓存失败: {e}")
|
|||
|
self.relations = self.base_relations.copy()
|
|||
|
|
|||
|
def save_cache(self):
|
|||
|
"""保存关系映射到缓存"""
|
|||
|
try:
|
|||
|
with open(self.cache_file, 'wb') as f:
|
|||
|
pickle.dump(self.relations, f)
|
|||
|
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
|
|||
|
except Exception as e:
|
|||
|
print(f"保存缓存失败: {e}")
|
|||
|
|
|||
|
# 删除了网络抓取功能,改为纯离线模式
|
|||
|
|
|||
|
def get_relation_name(self, property_id: str) -> Optional[str]:
|
|||
|
"""获取关系名称,仅使用本地映射"""
|
|||
|
if property_id in self.relations:
|
|||
|
return self.relations[property_id]
|
|||
|
|
|||
|
# 如果本地映射中没有找到,返回None(表示跳过这个关系)
|
|||
|
return None
|
|||
|
|
|||
|
# 删除了网络请求相关的批量获取和预加载功能
|
|||
|
|
|||
|
class TRexProcessor:
|
|||
|
"""T-REx数据集处理器"""
|
|||
|
|
|||
|
def __init__(self, relation_manager: WikidataRelationManager):
|
|||
|
self.relation_manager = relation_manager
|
|||
|
|
|||
|
def extract_predicate_id(self, uri: str) -> str:
|
|||
|
"""从URI中提取属性ID"""
|
|||
|
if uri and 'prop/direct/' in uri:
|
|||
|
return uri.split('/')[-1]
|
|||
|
elif uri and uri.startswith('P') and uri[1:].isdigit():
|
|||
|
return uri
|
|||
|
return uri if uri else 'unknown'
|
|||
|
|
|||
|
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
|
|||
|
"""获取关系的可读名称"""
|
|||
|
predicate_id = self.extract_predicate_id(predicate_uri)
|
|||
|
return self.relation_manager.get_relation_name(predicate_id)
|
|||
|
|
|||
|
# 删除了谓词收集功能,因为不再需要预加载
|
|||
|
|
|||
|
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
|
|||
|
boundary_threshold: int) -> bool:
|
|||
|
"""检查三元组是否满足过滤条件"""
|
|||
|
try:
|
|||
|
# 检查triple是否为字典
|
|||
|
if not isinstance(triple, dict):
|
|||
|
return False
|
|||
|
|
|||
|
# 检查必要字段
|
|||
|
if not all(key in triple for key in ['subject', 'predicate', 'object']):
|
|||
|
return False
|
|||
|
|
|||
|
subject = triple['subject']
|
|||
|
predicate = triple['predicate']
|
|||
|
object_info = triple['object']
|
|||
|
|
|||
|
# 检查subject、predicate、object是否都为字典
|
|||
|
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
|
|||
|
return False
|
|||
|
|
|||
|
# 检查主语和宾语是否有有效的URI和surfaceform
|
|||
|
if not (subject.get('uri') and subject.get('surfaceform')):
|
|||
|
return False
|
|||
|
if not (object_info.get('uri') and object_info.get('surfaceform')):
|
|||
|
return False
|
|||
|
if not predicate.get('uri'):
|
|||
|
return False
|
|||
|
|
|||
|
# 检查置信度(如果存在)
|
|||
|
confidence = triple.get('confidence')
|
|||
|
if confidence is not None and confidence < confidence_threshold:
|
|||
|
return False
|
|||
|
|
|||
|
# 检查边界信息(如果设置了阈值)
|
|||
|
if boundary_threshold > 0:
|
|||
|
subject_boundaries = subject.get('boundaries')
|
|||
|
object_boundaries = object_info.get('boundaries')
|
|||
|
|
|||
|
if not subject_boundaries or not object_boundaries:
|
|||
|
return False
|
|||
|
|
|||
|
# 检查边界是否为列表且长度至少为2
|
|||
|
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
|
|||
|
return False
|
|||
|
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
|
|||
|
return False
|
|||
|
|
|||
|
try:
|
|||
|
# 检查边界长度是否合理
|
|||
|
subject_length = subject_boundaries[1] - subject_boundaries[0]
|
|||
|
object_length = object_boundaries[1] - object_boundaries[0]
|
|||
|
|
|||
|
if subject_length < boundary_threshold or object_length < boundary_threshold:
|
|||
|
return False
|
|||
|
except (TypeError, IndexError):
|
|||
|
return False
|
|||
|
|
|||
|
# 检查文本内容是否合理
|
|||
|
subject_text = subject.get('surfaceform', '').strip()
|
|||
|
object_text = object_info.get('surfaceform', '').strip()
|
|||
|
|
|||
|
if not subject_text or not object_text:
|
|||
|
return False
|
|||
|
|
|||
|
# 过滤掉过长或过短的实体
|
|||
|
if len(subject_text) > 100 or len(object_text) > 100:
|
|||
|
return False
|
|||
|
if len(subject_text) < 2 or len(object_text) < 2:
|
|||
|
return False
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
except (KeyError, TypeError, AttributeError):
|
|||
|
return False
|
|||
|
|
|||
|
def process_single_file(self, file_path: str, confidence_threshold: float,
|
|||
|
boundary_threshold: int) -> List[Dict[str, Any]]:
|
|||
|
"""处理单个JSON文件"""
|
|||
|
print(f"Processing file: {file_path}")
|
|||
|
|
|||
|
processed_data = []
|
|||
|
|
|||
|
try:
|
|||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
# 读取整个文件作为JSON数组
|
|||
|
print(f"正在加载JSON数组文件: {file_path}")
|
|||
|
data_list = json.load(f)
|
|||
|
print(f"文件包含 {len(data_list)} 个条目")
|
|||
|
|
|||
|
for idx, data in enumerate(data_list):
|
|||
|
try:
|
|||
|
# 获取基本信息
|
|||
|
text = data.get('text', '').strip()
|
|||
|
if not text:
|
|||
|
continue
|
|||
|
|
|||
|
# 处理三元组
|
|||
|
triples = data.get('triples', [])
|
|||
|
if not triples:
|
|||
|
continue
|
|||
|
|
|||
|
valid_targets = []
|
|||
|
|
|||
|
for triple in triples:
|
|||
|
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
|
|||
|
# 获取关系名称,如果无法解析则跳过这个三元组
|
|||
|
relation_name = self.get_relation_name(triple['predicate']['uri'])
|
|||
|
if relation_name is None:
|
|||
|
continue # 跳过无法解析的关系
|
|||
|
|
|||
|
target = {
|
|||
|
'subject': triple['subject']['surfaceform'].strip(),
|
|||
|
'predicate': relation_name,
|
|||
|
'object': triple['object']['surfaceform'].strip()
|
|||
|
}
|
|||
|
valid_targets.append(target)
|
|||
|
|
|||
|
# 如果有有效的三元组,添加到结果中
|
|||
|
if valid_targets:
|
|||
|
processed_data.append({
|
|||
|
'text': text,
|
|||
|
'target': valid_targets
|
|||
|
})
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
if idx <= 10: # 只打印前10个错误
|
|||
|
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
|
|||
|
continue
|
|||
|
|
|||
|
except FileNotFoundError:
|
|||
|
print(f"文件未找到: {file_path}")
|
|||
|
except json.JSONDecodeError as e:
|
|||
|
print(f"JSON解析错误 in {file_path}: {e}")
|
|||
|
except Exception as e:
|
|||
|
print(f"处理文件时出错 {file_path}: {e}")
|
|||
|
|
|||
|
print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本")
|
|||
|
return processed_data
|
|||
|
|
|||
|
def process_folder(self, folder_path: str, confidence_threshold: float,
|
|||
|
boundary_threshold: int) -> List[Dict[str, Any]]:
|
|||
|
"""处理文件夹中的所有JSON文件"""
|
|||
|
all_processed_data = []
|
|||
|
|
|||
|
if not os.path.exists(folder_path):
|
|||
|
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
|
|||
|
|
|||
|
# 获取所有JSON文件
|
|||
|
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
|
|||
|
|
|||
|
if not json_files:
|
|||
|
raise ValueError(f"在 {folder_path} 中没有找到JSON文件")
|
|||
|
|
|||
|
print(f"找到 {len(json_files)} 个JSON文件")
|
|||
|
|
|||
|
for filename in sorted(json_files):
|
|||
|
file_path = os.path.join(folder_path, filename)
|
|||
|
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
|
|||
|
all_processed_data.extend(processed_data)
|
|||
|
|
|||
|
# 保存最终的关系缓存
|
|||
|
self.relation_manager.save_cache()
|
|||
|
|
|||
|
return all_processed_data
|
|||
|
|
|||
|
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
|
"""生成数据统计信息"""
|
|||
|
total_samples = len(processed_data)
|
|||
|
total_triples = sum(len(sample['target']) for sample in processed_data)
|
|||
|
|
|||
|
# 统计关系类型
|
|||
|
relation_counts = defaultdict(int)
|
|||
|
for sample in processed_data:
|
|||
|
for target in sample['target']:
|
|||
|
relation_counts[target['predicate']] += 1
|
|||
|
|
|||
|
# 统计文本长度
|
|||
|
text_lengths = [len(sample['text']) for sample in processed_data]
|
|||
|
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
|
|||
|
|
|||
|
# 统计每个文本的三元组数量
|
|||
|
triples_per_text = [len(sample['target']) for sample in processed_data]
|
|||
|
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
|
|||
|
|
|||
|
return {
|
|||
|
'total_samples': total_samples,
|
|||
|
'total_triples': total_triples,
|
|||
|
'avg_text_length': round(avg_text_length, 2),
|
|||
|
'avg_triples_per_text': round(avg_triples_per_text, 2),
|
|||
|
'relation_distribution': dict(sorted(relation_counts.items(),
|
|||
|
key=lambda x: x[1], reverse=True)),
|
|||
|
'top_10_relations': dict(list(sorted(relation_counts.items(),
|
|||
|
key=lambda x: x[1], reverse=True))[:10]),
|
|||
|
'total_unique_relations': len(relation_counts),
|
|||
|
'cached_relations': len(self.relation_manager.relations)
|
|||
|
}
|
|||
|
|
|||
|
def main():
|
|||
|
parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)')
|
|||
|
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
|
|||
|
parser.add_argument('--confidence_threshold', type=float, default=0.5,
|
|||
|
help='置信度阈值 (默认: 0.0)')
|
|||
|
parser.add_argument('--boundary_threshold', type=int, default=0,
|
|||
|
help='边界长度阈值 (默认: 0, 不过滤)')
|
|||
|
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
|
|||
|
help='输出文件名 (默认: processed_trex_data.json)')
|
|||
|
parser.add_argument('--stats', type=str, default='trex_statistics.json',
|
|||
|
help='统计信息输出文件名 (默认: trex_statistics.json)')
|
|||
|
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
|
|||
|
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
|
|||
|
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
|
|||
|
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
print("T-REx数据集处理器(支持动态关系获取)")
|
|||
|
print("=" * 60)
|
|||
|
print(f"输入文件夹: {args.folder_path}")
|
|||
|
print(f"置信度阈值: {args.confidence_threshold}")
|
|||
|
print(f"边界长度阈值: {args.boundary_threshold}")
|
|||
|
print(f"输出文件: {args.output}")
|
|||
|
print(f"关系缓存文件: {args.cache_file}")
|
|||
|
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
# 检查映射文件是否存在
|
|||
|
if not args.mapping_file or not os.path.exists(args.mapping_file):
|
|||
|
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
|
|||
|
print("请确保提供有效的JSON映射文件。")
|
|||
|
return 1
|
|||
|
|
|||
|
# 创建关系管理器
|
|||
|
relation_manager = WikidataRelationManager(
|
|||
|
cache_file=args.cache_file,
|
|||
|
mapping_file=args.mapping_file
|
|||
|
)
|
|||
|
|
|||
|
# 创建处理器
|
|||
|
processor = TRexProcessor(relation_manager)
|
|||
|
|
|||
|
try:
|
|||
|
# 处理数据
|
|||
|
processed_data = processor.process_folder(
|
|||
|
args.folder_path,
|
|||
|
args.confidence_threshold,
|
|||
|
args.boundary_threshold
|
|||
|
)
|
|||
|
|
|||
|
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
|
|||
|
|
|||
|
# 生成统计信息
|
|||
|
stats = processor.generate_statistics(processed_data)
|
|||
|
|
|||
|
# 保存处理后的数据
|
|||
|
with open(args.output, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(processed_data, f, ensure_ascii=False, indent=2)
|
|||
|
|
|||
|
# 保存统计信息
|
|||
|
with open(args.stats, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(stats, f, ensure_ascii=False, indent=2)
|
|||
|
|
|||
|
print(f"\n数据已保存到: {args.output}")
|
|||
|
print(f"统计信息已保存到: {args.stats}")
|
|||
|
print(f"关系缓存已保存到: {args.cache_file}")
|
|||
|
|
|||
|
# 打印统计摘要
|
|||
|
print("\n数据统计摘要:")
|
|||
|
print("=" * 30)
|
|||
|
print(f"总样本数: {stats['total_samples']}")
|
|||
|
print(f"总三元组数: {stats['total_triples']}")
|
|||
|
print(f"唯一关系数: {stats['total_unique_relations']}")
|
|||
|
print(f"缓存关系数: {stats['cached_relations']}")
|
|||
|
print(f"平均文本长度: {stats['avg_text_length']}")
|
|||
|
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
|
|||
|
print("\n前10个最常见关系:")
|
|||
|
for relation, count in stats['top_10_relations'].items():
|
|||
|
print(f" {relation}: {count}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"处理过程中出错: {e}")
|
|||
|
return 1
|
|||
|
|
|||
|
return 0
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
exit(main())
|