Minimind/preprocessing/preprocess_trex.py

443 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import argparse
from typing import List, Dict, Any, Optional
from collections import defaultdict
import pickle
from pathlib import Path
class WikidataRelationManager:
"""Wikidata关系管理器支持动态获取和缓存"""
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
mapping_file: str = None):
self.cache_file = cache_file
self.mapping_file = mapping_file
self.relations = {}
# 删除了API相关属性
# 初始的基础关系映射
self.base_relations = {
# # 基本关系
# 'P31': 'instance of',
# 'P279': 'subclass of',
# 'P17': 'country',
# 'P159': 'headquarters location',
# 'P571': 'inception',
# # 人物关系
# 'P19': 'place of birth',
# 'P20': 'place of death',
# 'P27': 'country of citizenship',
# 'P106': 'occupation',
# 'P22': 'father',
# 'P25': 'mother',
# 'P26': 'spouse',
# 'P40': 'child',
# 'P69': 'educated at',
# 'P108': 'employer',
# # 地理关系
# 'P36': 'capital',
# 'P131': 'located in',
# 'P47': 'shares border with',
# 'P206': 'located on terrain feature',
# 'P1376': 'capital of',
# # 组织关系
# 'P112': 'founded by',
# 'P127': 'owned by',
# 'P169': 'chief executive officer',
# 'P488': 'chairperson',
# 'P749': 'parent organization',
# # 作品关系
# 'P50': 'author',
# 'P57': 'director',
# 'P58': 'screenwriter',
# 'P161': 'cast member',
# 'P175': 'performer',
# 'P577': 'publication date',
# 'P123': 'publisher',
# 'P136': 'genre',
# # 时间关系
# 'P155': 'follows',
# 'P156': 'followed by',
# 'P580': 'start time',
# 'P582': 'end time',
# # 体育关系
# 'P54': 'member of sports team',
# 'P413': 'position played on team',
# 'P118': 'league',
# # 科学关系
# 'P275': 'copyright license',
# 'P170': 'creator',
# 'P398': 'child astronomical body',
# 'P397': 'parent astronomical body',
# # 其他常见关系
# 'P37': 'official language',
# 'P1923': 'place of marriage',
# 'P737': 'influenced by',
# 'P463': 'member of',
# 'P39': 'position held',
# 'P276': 'location',
# 'P1441': 'present in work',
}
self.load_cache()
def load_cache(self):
"""加载缓存的关系映射优先使用JSON映射文件"""
try:
# 优先尝试加载JSON映射文件
if self.mapping_file and os.path.exists(self.mapping_file):
with open(self.mapping_file, 'r', encoding='utf-8') as f:
self.relations = json.load(f)
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
return
# 尝试加载pickle缓存文件
if os.path.exists(self.cache_file):
with open(self.cache_file, 'rb') as f:
self.relations = pickle.load(f)
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
else:
self.relations = self.base_relations.copy()
print(f"初始化基础关系映射: {len(self.relations)}")
except Exception as e:
print(f"加载缓存失败: {e}")
self.relations = self.base_relations.copy()
def save_cache(self):
"""保存关系映射到缓存"""
try:
with open(self.cache_file, 'wb') as f:
pickle.dump(self.relations, f)
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
except Exception as e:
print(f"保存缓存失败: {e}")
# 删除了网络抓取功能,改为纯离线模式
def get_relation_name(self, property_id: str) -> Optional[str]:
"""获取关系名称,仅使用本地映射"""
if property_id in self.relations:
return self.relations[property_id]
# 如果本地映射中没有找到返回None表示跳过这个关系
return None
# 删除了网络请求相关的批量获取和预加载功能
class TRexProcessor:
"""T-REx数据集处理器"""
def __init__(self, relation_manager: WikidataRelationManager):
self.relation_manager = relation_manager
def extract_predicate_id(self, uri: str) -> str:
"""从URI中提取属性ID"""
if uri and 'prop/direct/' in uri:
return uri.split('/')[-1]
elif uri and uri.startswith('P') and uri[1:].isdigit():
return uri
return uri if uri else 'unknown'
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
"""获取关系的可读名称"""
predicate_id = self.extract_predicate_id(predicate_uri)
return self.relation_manager.get_relation_name(predicate_id)
# 删除了谓词收集功能,因为不再需要预加载
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
boundary_threshold: int) -> bool:
"""检查三元组是否满足过滤条件"""
try:
# 检查triple是否为字典
if not isinstance(triple, dict):
return False
# 检查必要字段
if not all(key in triple for key in ['subject', 'predicate', 'object']):
return False
subject = triple['subject']
predicate = triple['predicate']
object_info = triple['object']
# 检查subject、predicate、object是否都为字典
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
return False
# 检查主语和宾语是否有有效的URI和surfaceform
if not (subject.get('uri') and subject.get('surfaceform')):
return False
if not (object_info.get('uri') and object_info.get('surfaceform')):
return False
if not predicate.get('uri'):
return False
# 检查置信度(如果存在)
confidence = triple.get('confidence')
if confidence is not None and confidence < confidence_threshold:
return False
# 检查边界信息(如果设置了阈值)
if boundary_threshold > 0:
subject_boundaries = subject.get('boundaries')
object_boundaries = object_info.get('boundaries')
if not subject_boundaries or not object_boundaries:
return False
# 检查边界是否为列表且长度至少为2
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
return False
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
return False
try:
# 检查边界长度是否合理
subject_length = subject_boundaries[1] - subject_boundaries[0]
object_length = object_boundaries[1] - object_boundaries[0]
if subject_length < boundary_threshold or object_length < boundary_threshold:
return False
except (TypeError, IndexError):
return False
# 检查文本内容是否合理
subject_text = subject.get('surfaceform', '').strip()
object_text = object_info.get('surfaceform', '').strip()
if not subject_text or not object_text:
return False
# 过滤掉过长或过短的实体
if len(subject_text) > 100 or len(object_text) > 100:
return False
if len(subject_text) < 2 or len(object_text) < 2:
return False
return True
except (KeyError, TypeError, AttributeError):
return False
def process_single_file(self, file_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理单个JSON文件"""
print(f"Processing file: {file_path}")
processed_data = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 读取整个文件作为JSON数组
print(f"正在加载JSON数组文件: {file_path}")
data_list = json.load(f)
print(f"文件包含 {len(data_list)} 个条目")
for idx, data in enumerate(data_list):
try:
# 获取基本信息
text = data.get('text', '').strip()
if not text:
continue
# 处理三元组
triples = data.get('triples', [])
if not triples:
continue
valid_targets = []
for triple in triples:
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
# 获取关系名称,如果无法解析则跳过这个三元组
relation_name = self.get_relation_name(triple['predicate']['uri'])
if relation_name is None:
continue # 跳过无法解析的关系
target = {
'subject': triple['subject']['surfaceform'].strip(),
'predicate': relation_name,
'object': triple['object']['surfaceform'].strip()
}
valid_targets.append(target)
# 如果有有效的三元组,添加到结果中
if valid_targets:
processed_data.append({
'text': text,
'target': valid_targets
})
except Exception as e:
if idx <= 10: # 只打印前10个错误
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
continue
except FileNotFoundError:
print(f"文件未找到: {file_path}")
except json.JSONDecodeError as e:
print(f"JSON解析错误 in {file_path}: {e}")
except Exception as e:
print(f"处理文件时出错 {file_path}: {e}")
print(f"{file_path} 提取了 {len(processed_data)} 个有效样本")
return processed_data
def process_folder(self, folder_path: str, confidence_threshold: float,
boundary_threshold: int) -> List[Dict[str, Any]]:
"""处理文件夹中的所有JSON文件"""
all_processed_data = []
if not os.path.exists(folder_path):
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
# 获取所有JSON文件
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
if not json_files:
raise ValueError(f"{folder_path} 中没有找到JSON文件")
print(f"找到 {len(json_files)} 个JSON文件")
for filename in sorted(json_files):
file_path = os.path.join(folder_path, filename)
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
all_processed_data.extend(processed_data)
# 保存最终的关系缓存
self.relation_manager.save_cache()
return all_processed_data
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成数据统计信息"""
total_samples = len(processed_data)
total_triples = sum(len(sample['target']) for sample in processed_data)
# 统计关系类型
relation_counts = defaultdict(int)
for sample in processed_data:
for target in sample['target']:
relation_counts[target['predicate']] += 1
# 统计文本长度
text_lengths = [len(sample['text']) for sample in processed_data]
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
# 统计每个文本的三元组数量
triples_per_text = [len(sample['target']) for sample in processed_data]
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
return {
'total_samples': total_samples,
'total_triples': total_triples,
'avg_text_length': round(avg_text_length, 2),
'avg_triples_per_text': round(avg_triples_per_text, 2),
'relation_distribution': dict(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True)),
'top_10_relations': dict(list(sorted(relation_counts.items(),
key=lambda x: x[1], reverse=True))[:10]),
'total_unique_relations': len(relation_counts),
'cached_relations': len(self.relation_manager.relations)
}
def main():
parser = argparse.ArgumentParser(description='处理T-REx数据集支持动态关系获取')
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
parser.add_argument('--confidence_threshold', type=float, default=0.5,
help='置信度阈值 (默认: 0.0)')
parser.add_argument('--boundary_threshold', type=int, default=0,
help='边界长度阈值 (默认: 0, 不过滤)')
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
help='输出文件名 (默认: processed_trex_data.json)')
parser.add_argument('--stats', type=str, default='trex_statistics.json',
help='统计信息输出文件名 (默认: trex_statistics.json)')
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
args = parser.parse_args()
print("T-REx数据集处理器支持动态关系获取")
print("=" * 60)
print(f"输入文件夹: {args.folder_path}")
print(f"置信度阈值: {args.confidence_threshold}")
print(f"边界长度阈值: {args.boundary_threshold}")
print(f"输出文件: {args.output}")
print(f"关系缓存文件: {args.cache_file}")
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
print("=" * 60)
# 检查映射文件是否存在
if not args.mapping_file or not os.path.exists(args.mapping_file):
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
print("请确保提供有效的JSON映射文件。")
return 1
# 创建关系管理器
relation_manager = WikidataRelationManager(
cache_file=args.cache_file,
mapping_file=args.mapping_file
)
# 创建处理器
processor = TRexProcessor(relation_manager)
try:
# 处理数据
processed_data = processor.process_folder(
args.folder_path,
args.confidence_threshold,
args.boundary_threshold
)
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
# 生成统计信息
stats = processor.generate_statistics(processed_data)
# 保存处理后的数据
with open(args.output, 'w', encoding='utf-8') as f:
json.dump(processed_data, f, ensure_ascii=False, indent=2)
# 保存统计信息
with open(args.stats, 'w', encoding='utf-8') as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print(f"\n数据已保存到: {args.output}")
print(f"统计信息已保存到: {args.stats}")
print(f"关系缓存已保存到: {args.cache_file}")
# 打印统计摘要
print("\n数据统计摘要:")
print("=" * 30)
print(f"总样本数: {stats['total_samples']}")
print(f"总三元组数: {stats['total_triples']}")
print(f"唯一关系数: {stats['total_unique_relations']}")
print(f"缓存关系数: {stats['cached_relations']}")
print(f"平均文本长度: {stats['avg_text_length']}")
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
print("\n前10个最常见关系:")
for relation, count in stats['top_10_relations'].items():
print(f" {relation}: {count}")
except Exception as e:
print(f"处理过程中出错: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())