DynamicKV-LLM Extra v1.0.0
This commit is contained in:
parent
d6617702a5
commit
74e9293c9a
137
model/dataset.py
137
model/dataset.py
@ -9,6 +9,7 @@ import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
import os
|
||||
import ast
|
||||
from tqdm import tqdm
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
@ -196,6 +197,142 @@ class DPODataset(Dataset):
|
||||
return loss_mask
|
||||
|
||||
|
||||
class TriplePretrainDataset(Dataset):
|
||||
"""
|
||||
优化的三元组预训练数据集
|
||||
- 每个样本只保留一个target三元组
|
||||
- 预先tokenize所有数据
|
||||
- 使用进度条显示处理进度
|
||||
"""
|
||||
def __init__(self, data_path, tokenizer, max_length=512):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
print("🚀 开始加载和预处理三元组数据...")
|
||||
self.samples = self.load_and_preprocess_data(data_path)
|
||||
|
||||
def load_and_preprocess_data(self, path):
|
||||
"""加载并预处理三元组数据"""
|
||||
# 1. 加载原始数据
|
||||
print("📂 加载原始数据...")
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
elif path.endswith('.jsonl'):
|
||||
data = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line.strip()))
|
||||
else:
|
||||
raise ValueError(f"Unsupported file format: {path}")
|
||||
|
||||
print(f"📊 原始数据量: {len(data)} 个样本")
|
||||
|
||||
# 2. 数据验证和筛选(只保留一个target)
|
||||
print("🔍 验证数据格式并选择单个target...")
|
||||
valid_samples = []
|
||||
|
||||
for i, sample in enumerate(tqdm(data, desc="验证数据格式")):
|
||||
if not isinstance(sample, dict) or 'text' not in sample:
|
||||
continue
|
||||
|
||||
targets = sample.get('target', [])
|
||||
if not isinstance(targets, list) or len(targets) == 0:
|
||||
# 如果没有有效的target,创建一个默认的
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
else:
|
||||
# 验证并选择第一个有效的target
|
||||
selected_target = None
|
||||
for triple in targets:
|
||||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
selected_target = triple
|
||||
break
|
||||
|
||||
# 如果没有找到有效的target,使用默认值
|
||||
if selected_target is None:
|
||||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||||
|
||||
valid_samples.append({
|
||||
'text': sample['text'],
|
||||
'target': selected_target # 只保留一个target
|
||||
})
|
||||
|
||||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||||
|
||||
# 3. 分批tokenize目标句子
|
||||
print("🔤 分批tokenize目标句子...")
|
||||
|
||||
processed_samples = []
|
||||
batch_size = 1000 # 每批处理1000个句子,避免内存爆炸
|
||||
|
||||
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"):
|
||||
# 获取当前批次
|
||||
batch_samples = valid_samples[i:i + batch_size]
|
||||
|
||||
# 提取当前批次的目标句子
|
||||
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
|
||||
|
||||
# 批量tokenize当前批次
|
||||
batch_encodings = self.tokenizer(
|
||||
batch_target_sentences,
|
||||
max_length=128, # 目标句子通常较短
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
# 构建当前批次的样本数据
|
||||
for j, sample in enumerate(batch_samples):
|
||||
processed_samples.append({
|
||||
'text': sample['text'], # 保持原始文本,不进行tokenize
|
||||
'target_input_ids': batch_encodings.input_ids[j],
|
||||
'target_attention_mask': batch_encodings.attention_mask[j],
|
||||
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
|
||||
})
|
||||
|
||||
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本")
|
||||
return processed_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _triple_to_sentence(self, triple):
|
||||
"""将三元组转换为句子格式"""
|
||||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize"""
|
||||
sample = self.samples[index]
|
||||
|
||||
# 在运行时tokenize输入文本(用于语言建模)
|
||||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||||
encoding = self.tokenizer(
|
||||
input_text,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
return_tensors='pt'
|
||||
)
|
||||
input_ids = encoding.input_ids.squeeze()
|
||||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||||
|
||||
# 构建训练数据
|
||||
X = input_ids[:-1]
|
||||
Y = input_ids[1:]
|
||||
loss_mask = loss_mask[1:]
|
||||
|
||||
return {
|
||||
'input_ids': X,
|
||||
'labels': Y,
|
||||
'loss_mask': loss_mask,
|
||||
'target_input_ids': sample['target_input_ids'], # 已经是tensor
|
||||
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
|
||||
'target_sentence': sample['target_sentence'], # 字符串,用于调试
|
||||
'original_text': sample['text']
|
||||
}
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
|
442
preprocessing/preprocess_trex.py
Normal file
442
preprocessing/preprocess_trex.py
Normal file
@ -0,0 +1,442 @@
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
from typing import List, Dict, Any, Optional
|
||||
from collections import defaultdict
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
class WikidataRelationManager:
|
||||
"""Wikidata关系管理器,支持动态获取和缓存"""
|
||||
|
||||
def __init__(self, cache_file: str = "wikidata_relations_cache.pkl",
|
||||
mapping_file: str = None):
|
||||
self.cache_file = cache_file
|
||||
self.mapping_file = mapping_file
|
||||
self.relations = {}
|
||||
# 删除了API相关属性
|
||||
|
||||
# 初始的基础关系映射
|
||||
self.base_relations = {
|
||||
# # 基本关系
|
||||
# 'P31': 'instance of',
|
||||
# 'P279': 'subclass of',
|
||||
# 'P17': 'country',
|
||||
# 'P159': 'headquarters location',
|
||||
# 'P571': 'inception',
|
||||
|
||||
# # 人物关系
|
||||
# 'P19': 'place of birth',
|
||||
# 'P20': 'place of death',
|
||||
# 'P27': 'country of citizenship',
|
||||
# 'P106': 'occupation',
|
||||
# 'P22': 'father',
|
||||
# 'P25': 'mother',
|
||||
# 'P26': 'spouse',
|
||||
# 'P40': 'child',
|
||||
# 'P69': 'educated at',
|
||||
# 'P108': 'employer',
|
||||
|
||||
# # 地理关系
|
||||
# 'P36': 'capital',
|
||||
# 'P131': 'located in',
|
||||
# 'P47': 'shares border with',
|
||||
# 'P206': 'located on terrain feature',
|
||||
# 'P1376': 'capital of',
|
||||
|
||||
# # 组织关系
|
||||
# 'P112': 'founded by',
|
||||
# 'P127': 'owned by',
|
||||
# 'P169': 'chief executive officer',
|
||||
# 'P488': 'chairperson',
|
||||
# 'P749': 'parent organization',
|
||||
|
||||
# # 作品关系
|
||||
# 'P50': 'author',
|
||||
# 'P57': 'director',
|
||||
# 'P58': 'screenwriter',
|
||||
# 'P161': 'cast member',
|
||||
# 'P175': 'performer',
|
||||
# 'P577': 'publication date',
|
||||
# 'P123': 'publisher',
|
||||
# 'P136': 'genre',
|
||||
|
||||
# # 时间关系
|
||||
# 'P155': 'follows',
|
||||
# 'P156': 'followed by',
|
||||
# 'P580': 'start time',
|
||||
# 'P582': 'end time',
|
||||
|
||||
# # 体育关系
|
||||
# 'P54': 'member of sports team',
|
||||
# 'P413': 'position played on team',
|
||||
# 'P118': 'league',
|
||||
|
||||
# # 科学关系
|
||||
# 'P275': 'copyright license',
|
||||
# 'P170': 'creator',
|
||||
# 'P398': 'child astronomical body',
|
||||
# 'P397': 'parent astronomical body',
|
||||
|
||||
# # 其他常见关系
|
||||
# 'P37': 'official language',
|
||||
# 'P1923': 'place of marriage',
|
||||
# 'P737': 'influenced by',
|
||||
# 'P463': 'member of',
|
||||
# 'P39': 'position held',
|
||||
# 'P276': 'location',
|
||||
# 'P1441': 'present in work',
|
||||
}
|
||||
|
||||
self.load_cache()
|
||||
|
||||
def load_cache(self):
|
||||
"""加载缓存的关系映射,优先使用JSON映射文件"""
|
||||
try:
|
||||
# 优先尝试加载JSON映射文件
|
||||
if self.mapping_file and os.path.exists(self.mapping_file):
|
||||
with open(self.mapping_file, 'r', encoding='utf-8') as f:
|
||||
self.relations = json.load(f)
|
||||
print(f"从JSON映射文件加载了 {len(self.relations)} 个关系映射")
|
||||
return
|
||||
|
||||
# 尝试加载pickle缓存文件
|
||||
if os.path.exists(self.cache_file):
|
||||
with open(self.cache_file, 'rb') as f:
|
||||
self.relations = pickle.load(f)
|
||||
print(f"从pickle缓存加载了 {len(self.relations)} 个关系映射")
|
||||
else:
|
||||
self.relations = self.base_relations.copy()
|
||||
print(f"初始化基础关系映射: {len(self.relations)} 个")
|
||||
except Exception as e:
|
||||
print(f"加载缓存失败: {e}")
|
||||
self.relations = self.base_relations.copy()
|
||||
|
||||
def save_cache(self):
|
||||
"""保存关系映射到缓存"""
|
||||
try:
|
||||
with open(self.cache_file, 'wb') as f:
|
||||
pickle.dump(self.relations, f)
|
||||
print(f"已保存 {len(self.relations)} 个关系映射到缓存")
|
||||
except Exception as e:
|
||||
print(f"保存缓存失败: {e}")
|
||||
|
||||
# 删除了网络抓取功能,改为纯离线模式
|
||||
|
||||
def get_relation_name(self, property_id: str) -> Optional[str]:
|
||||
"""获取关系名称,仅使用本地映射"""
|
||||
if property_id in self.relations:
|
||||
return self.relations[property_id]
|
||||
|
||||
# 如果本地映射中没有找到,返回None(表示跳过这个关系)
|
||||
return None
|
||||
|
||||
# 删除了网络请求相关的批量获取和预加载功能
|
||||
|
||||
class TRexProcessor:
|
||||
"""T-REx数据集处理器"""
|
||||
|
||||
def __init__(self, relation_manager: WikidataRelationManager):
|
||||
self.relation_manager = relation_manager
|
||||
|
||||
def extract_predicate_id(self, uri: str) -> str:
|
||||
"""从URI中提取属性ID"""
|
||||
if uri and 'prop/direct/' in uri:
|
||||
return uri.split('/')[-1]
|
||||
elif uri and uri.startswith('P') and uri[1:].isdigit():
|
||||
return uri
|
||||
return uri if uri else 'unknown'
|
||||
|
||||
def get_relation_name(self, predicate_uri: str) -> Optional[str]:
|
||||
"""获取关系的可读名称"""
|
||||
predicate_id = self.extract_predicate_id(predicate_uri)
|
||||
return self.relation_manager.get_relation_name(predicate_id)
|
||||
|
||||
# 删除了谓词收集功能,因为不再需要预加载
|
||||
|
||||
def is_valid_triple(self, triple: Dict[str, Any], confidence_threshold: float,
|
||||
boundary_threshold: int) -> bool:
|
||||
"""检查三元组是否满足过滤条件"""
|
||||
try:
|
||||
# 检查triple是否为字典
|
||||
if not isinstance(triple, dict):
|
||||
return False
|
||||
|
||||
# 检查必要字段
|
||||
if not all(key in triple for key in ['subject', 'predicate', 'object']):
|
||||
return False
|
||||
|
||||
subject = triple['subject']
|
||||
predicate = triple['predicate']
|
||||
object_info = triple['object']
|
||||
|
||||
# 检查subject、predicate、object是否都为字典
|
||||
if not isinstance(subject, dict) or not isinstance(predicate, dict) or not isinstance(object_info, dict):
|
||||
return False
|
||||
|
||||
# 检查主语和宾语是否有有效的URI和surfaceform
|
||||
if not (subject.get('uri') and subject.get('surfaceform')):
|
||||
return False
|
||||
if not (object_info.get('uri') and object_info.get('surfaceform')):
|
||||
return False
|
||||
if not predicate.get('uri'):
|
||||
return False
|
||||
|
||||
# 检查置信度(如果存在)
|
||||
confidence = triple.get('confidence')
|
||||
if confidence is not None and confidence < confidence_threshold:
|
||||
return False
|
||||
|
||||
# 检查边界信息(如果设置了阈值)
|
||||
if boundary_threshold > 0:
|
||||
subject_boundaries = subject.get('boundaries')
|
||||
object_boundaries = object_info.get('boundaries')
|
||||
|
||||
if not subject_boundaries or not object_boundaries:
|
||||
return False
|
||||
|
||||
# 检查边界是否为列表且长度至少为2
|
||||
if not (isinstance(subject_boundaries, list) and len(subject_boundaries) >= 2):
|
||||
return False
|
||||
if not (isinstance(object_boundaries, list) and len(object_boundaries) >= 2):
|
||||
return False
|
||||
|
||||
try:
|
||||
# 检查边界长度是否合理
|
||||
subject_length = subject_boundaries[1] - subject_boundaries[0]
|
||||
object_length = object_boundaries[1] - object_boundaries[0]
|
||||
|
||||
if subject_length < boundary_threshold or object_length < boundary_threshold:
|
||||
return False
|
||||
except (TypeError, IndexError):
|
||||
return False
|
||||
|
||||
# 检查文本内容是否合理
|
||||
subject_text = subject.get('surfaceform', '').strip()
|
||||
object_text = object_info.get('surfaceform', '').strip()
|
||||
|
||||
if not subject_text or not object_text:
|
||||
return False
|
||||
|
||||
# 过滤掉过长或过短的实体
|
||||
if len(subject_text) > 100 or len(object_text) > 100:
|
||||
return False
|
||||
if len(subject_text) < 2 or len(object_text) < 2:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except (KeyError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
def process_single_file(self, file_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理单个JSON文件"""
|
||||
print(f"Processing file: {file_path}")
|
||||
|
||||
processed_data = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
# 读取整个文件作为JSON数组
|
||||
print(f"正在加载JSON数组文件: {file_path}")
|
||||
data_list = json.load(f)
|
||||
print(f"文件包含 {len(data_list)} 个条目")
|
||||
|
||||
for idx, data in enumerate(data_list):
|
||||
try:
|
||||
# 获取基本信息
|
||||
text = data.get('text', '').strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 处理三元组
|
||||
triples = data.get('triples', [])
|
||||
if not triples:
|
||||
continue
|
||||
|
||||
valid_targets = []
|
||||
|
||||
for triple in triples:
|
||||
if self.is_valid_triple(triple, confidence_threshold, boundary_threshold):
|
||||
# 获取关系名称,如果无法解析则跳过这个三元组
|
||||
relation_name = self.get_relation_name(triple['predicate']['uri'])
|
||||
if relation_name is None:
|
||||
continue # 跳过无法解析的关系
|
||||
|
||||
target = {
|
||||
'subject': triple['subject']['surfaceform'].strip(),
|
||||
'predicate': relation_name,
|
||||
'object': triple['object']['surfaceform'].strip()
|
||||
}
|
||||
valid_targets.append(target)
|
||||
|
||||
# 如果有有效的三元组,添加到结果中
|
||||
if valid_targets:
|
||||
processed_data.append({
|
||||
'text': text,
|
||||
'target': valid_targets
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if idx <= 10: # 只打印前10个错误
|
||||
print(f"处理条目时出错 in {file_path} at index {idx}: {e}")
|
||||
continue
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"文件未找到: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON解析错误 in {file_path}: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理文件时出错 {file_path}: {e}")
|
||||
|
||||
print(f"从 {file_path} 提取了 {len(processed_data)} 个有效样本")
|
||||
return processed_data
|
||||
|
||||
def process_folder(self, folder_path: str, confidence_threshold: float,
|
||||
boundary_threshold: int) -> List[Dict[str, Any]]:
|
||||
"""处理文件夹中的所有JSON文件"""
|
||||
all_processed_data = []
|
||||
|
||||
if not os.path.exists(folder_path):
|
||||
raise FileNotFoundError(f"文件夹不存在: {folder_path}")
|
||||
|
||||
# 获取所有JSON文件
|
||||
json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
|
||||
|
||||
if not json_files:
|
||||
raise ValueError(f"在 {folder_path} 中没有找到JSON文件")
|
||||
|
||||
print(f"找到 {len(json_files)} 个JSON文件")
|
||||
|
||||
for filename in sorted(json_files):
|
||||
file_path = os.path.join(folder_path, filename)
|
||||
processed_data = self.process_single_file(file_path, confidence_threshold, boundary_threshold)
|
||||
all_processed_data.extend(processed_data)
|
||||
|
||||
# 保存最终的关系缓存
|
||||
self.relation_manager.save_cache()
|
||||
|
||||
return all_processed_data
|
||||
|
||||
def generate_statistics(self, processed_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""生成数据统计信息"""
|
||||
total_samples = len(processed_data)
|
||||
total_triples = sum(len(sample['target']) for sample in processed_data)
|
||||
|
||||
# 统计关系类型
|
||||
relation_counts = defaultdict(int)
|
||||
for sample in processed_data:
|
||||
for target in sample['target']:
|
||||
relation_counts[target['predicate']] += 1
|
||||
|
||||
# 统计文本长度
|
||||
text_lengths = [len(sample['text']) for sample in processed_data]
|
||||
avg_text_length = sum(text_lengths) / len(text_lengths) if text_lengths else 0
|
||||
|
||||
# 统计每个文本的三元组数量
|
||||
triples_per_text = [len(sample['target']) for sample in processed_data]
|
||||
avg_triples_per_text = sum(triples_per_text) / len(triples_per_text) if triples_per_text else 0
|
||||
|
||||
return {
|
||||
'total_samples': total_samples,
|
||||
'total_triples': total_triples,
|
||||
'avg_text_length': round(avg_text_length, 2),
|
||||
'avg_triples_per_text': round(avg_triples_per_text, 2),
|
||||
'relation_distribution': dict(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True)),
|
||||
'top_10_relations': dict(list(sorted(relation_counts.items(),
|
||||
key=lambda x: x[1], reverse=True))[:10]),
|
||||
'total_unique_relations': len(relation_counts),
|
||||
'cached_relations': len(self.relation_manager.relations)
|
||||
}
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='处理T-REx数据集(支持动态关系获取)')
|
||||
parser.add_argument('--folder_path', type=str,default='/home/pci/ycz/Code/Minimind/dataset/trex', help='包含JSON文件的文件夹路径')
|
||||
parser.add_argument('--confidence_threshold', type=float, default=0.5,
|
||||
help='置信度阈值 (默认: 0.0)')
|
||||
parser.add_argument('--boundary_threshold', type=int, default=0,
|
||||
help='边界长度阈值 (默认: 0, 不过滤)')
|
||||
parser.add_argument('--output', type=str, default='./processed_trex_data.json',
|
||||
help='输出文件名 (默认: processed_trex_data.json)')
|
||||
parser.add_argument('--stats', type=str, default='trex_statistics.json',
|
||||
help='统计信息输出文件名 (默认: trex_statistics.json)')
|
||||
parser.add_argument('--cache_file', type=str, default='wikidata_relations_cache.pkl',
|
||||
help='关系缓存文件名 (默认: wikidata_relations_cache.pkl)')
|
||||
parser.add_argument('--mapping_file', type=str, default="/home/pci/ycz/Code/Minimind/preprocessing/sample_property_mappings.json",
|
||||
help='JSON映射文件路径 (必须提供,用于关系名称映射)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("T-REx数据集处理器(支持动态关系获取)")
|
||||
print("=" * 60)
|
||||
print(f"输入文件夹: {args.folder_path}")
|
||||
print(f"置信度阈值: {args.confidence_threshold}")
|
||||
print(f"边界长度阈值: {args.boundary_threshold}")
|
||||
print(f"输出文件: {args.output}")
|
||||
print(f"关系缓存文件: {args.cache_file}")
|
||||
print(f"JSON映射文件: {args.mapping_file if args.mapping_file else '未指定'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查映射文件是否存在
|
||||
if not args.mapping_file or not os.path.exists(args.mapping_file):
|
||||
print(f"错误: 映射文件不存在或未指定: {args.mapping_file}")
|
||||
print("请确保提供有效的JSON映射文件。")
|
||||
return 1
|
||||
|
||||
# 创建关系管理器
|
||||
relation_manager = WikidataRelationManager(
|
||||
cache_file=args.cache_file,
|
||||
mapping_file=args.mapping_file
|
||||
)
|
||||
|
||||
# 创建处理器
|
||||
processor = TRexProcessor(relation_manager)
|
||||
|
||||
try:
|
||||
# 处理数据
|
||||
processed_data = processor.process_folder(
|
||||
args.folder_path,
|
||||
args.confidence_threshold,
|
||||
args.boundary_threshold
|
||||
)
|
||||
|
||||
print(f"\n处理完成!总共处理了 {len(processed_data)} 个样本")
|
||||
|
||||
# 生成统计信息
|
||||
stats = processor.generate_statistics(processed_data)
|
||||
|
||||
# 保存处理后的数据
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
json.dump(processed_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 保存统计信息
|
||||
with open(args.stats, 'w', encoding='utf-8') as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n数据已保存到: {args.output}")
|
||||
print(f"统计信息已保存到: {args.stats}")
|
||||
print(f"关系缓存已保存到: {args.cache_file}")
|
||||
|
||||
# 打印统计摘要
|
||||
print("\n数据统计摘要:")
|
||||
print("=" * 30)
|
||||
print(f"总样本数: {stats['total_samples']}")
|
||||
print(f"总三元组数: {stats['total_triples']}")
|
||||
print(f"唯一关系数: {stats['total_unique_relations']}")
|
||||
print(f"缓存关系数: {stats['cached_relations']}")
|
||||
print(f"平均文本长度: {stats['avg_text_length']}")
|
||||
print(f"平均每文本三元组数: {stats['avg_triples_per_text']}")
|
||||
print("\n前10个最常见关系:")
|
||||
for relation, count in stats['top_10_relations'].items():
|
||||
print(f" {relation}: {count}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理过程中出错: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
441
preprocessing/preprocess_triple.py
Normal file
441
preprocessing/preprocess_triple.py
Normal file
@ -0,0 +1,441 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from preprocessing.agent_system.extractor_agent.agent import DepartmentAgent
|
||||
from typing import Dict, List, Tuple
|
||||
import gc
|
||||
import time
|
||||
import psutil
|
||||
from tqdm.asyncio import tqdm as async_tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
json_path = "dataset/merged_pretrain_extra.jsonl"
|
||||
output_path = "dataset/processed_triples.jsonl"
|
||||
|
||||
# 优化后的配置参数 - 降低资源消耗
|
||||
BATCH_SIZE = 5000 # 减少批次大小:每批1万条数据
|
||||
MAX_CONCURRENT = 200 # 减少并发数:最多50条并发处理
|
||||
AGENT_POOL_SIZE = 20 # 大幅减少agent池大小:只创建5个agent实例
|
||||
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process(os.getpid())
|
||||
memory_info = process.memory_info()
|
||||
memory_mb = memory_info.rss / 1024 / 1024
|
||||
return memory_mb
|
||||
|
||||
def print_memory_info(stage=""):
|
||||
"""打印内存使用信息"""
|
||||
memory_mb = get_memory_usage()
|
||||
print(f"🔧 {stage} - 内存使用: {memory_mb:.1f} MB")
|
||||
|
||||
# 创建extractor_agent池,避免并发冲突
|
||||
def create_extractor_pool(pool_size: int = 5):
|
||||
"""创建extractor_agent池"""
|
||||
print(f"正在创建 {pool_size} 个agent实例...")
|
||||
agents = []
|
||||
for i in range(pool_size):
|
||||
try:
|
||||
agent = DepartmentAgent(model_type="deepseek")
|
||||
agents.append(agent)
|
||||
print(f" ✓ Agent {i+1}/{pool_size} 创建成功")
|
||||
except Exception as e:
|
||||
print(f" ✗ Agent {i+1} 创建失败: {e}")
|
||||
print(f"Agent池创建完成,实际创建了 {len(agents)} 个实例")
|
||||
return agents
|
||||
|
||||
# 延迟初始化agent池
|
||||
AGENT_POOL = None
|
||||
agent_pool_index = 0
|
||||
|
||||
def get_agent_pool():
|
||||
"""获取agent池,延迟初始化"""
|
||||
global AGENT_POOL
|
||||
if AGENT_POOL is None:
|
||||
print_memory_info("创建Agent池前")
|
||||
AGENT_POOL = create_extractor_pool(pool_size=AGENT_POOL_SIZE)
|
||||
print_memory_info("创建Agent池后")
|
||||
return AGENT_POOL
|
||||
|
||||
def get_next_agent():
|
||||
"""轮询获取下一个可用的agent"""
|
||||
global agent_pool_index
|
||||
pool = get_agent_pool()
|
||||
agent = pool[agent_pool_index % len(pool)]
|
||||
agent_pool_index += 1
|
||||
return agent
|
||||
|
||||
def clean_and_split_text(text):
|
||||
"""
|
||||
去除文本开头结尾的标记,并按句子分割
|
||||
"""
|
||||
# 去除开头的<|im_start|>和结尾的<|im_end|>
|
||||
text = text.strip()
|
||||
if text.startswith('<|im_start|>'):
|
||||
text = text[len('<|im_start|>'):]
|
||||
if text.endswith('<|im_end|>'):
|
||||
text = text[:-len('<|im_end|>')]
|
||||
|
||||
# 清理文本,去除多余的空白字符
|
||||
text = text.strip()
|
||||
|
||||
# 按句子分割(根据句号、问号、感叹号等标点符号)
|
||||
# 使用正则表达式匹配句子结束标志
|
||||
sentence_endings = r'[.!?。!?]'
|
||||
sentences = re.split(sentence_endings, text)
|
||||
|
||||
# 清理每个句子,去除空白和空句子
|
||||
cleaned_sentences = []
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if sentence and len(sentence) > 5: # 只保留非空且有意义的句子
|
||||
cleaned_sentences.append(sentence)
|
||||
|
||||
return cleaned_sentences
|
||||
|
||||
async def extract_triple_from_sentence_async(sentence: str, context: str = None) -> Dict:
|
||||
"""
|
||||
异步使用extractor_agent从句子中提取三元组
|
||||
"""
|
||||
try:
|
||||
# 获取一个agent实例
|
||||
agent = get_next_agent()
|
||||
result = await agent.async_run(sentence=sentence, context=context)
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": result.triple.subject,
|
||||
"predicate": result.triple.predicate,
|
||||
"object": result.triple.object
|
||||
},
|
||||
"confidence": result.confidence
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"sentence": sentence,
|
||||
"triple": {
|
||||
"subject": "",
|
||||
"predicate": "",
|
||||
"object": ""
|
||||
},
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def process_paragraph_async(line_num: int, original_text: str, semaphore: asyncio.Semaphore) -> Dict:
|
||||
"""
|
||||
异步处理单个段落
|
||||
"""
|
||||
async with semaphore: # 控制并发数量
|
||||
try:
|
||||
# 清理并分割文本
|
||||
sentences = clean_and_split_text(original_text)
|
||||
|
||||
if not sentences:
|
||||
return None
|
||||
|
||||
# 构建当前段落的结果
|
||||
paragraph_result = {
|
||||
"source_line": line_num,
|
||||
"original_paragraph": original_text,
|
||||
"sentences": [],
|
||||
"triples": []
|
||||
}
|
||||
|
||||
# 异步处理所有句子
|
||||
tasks = []
|
||||
for sentence in sentences:
|
||||
task = extract_triple_from_sentence_async(sentence, context=original_text)
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有句子处理完成
|
||||
triple_results = await asyncio.gather(*tasks)
|
||||
|
||||
# 整理结果
|
||||
for i, sentence in enumerate(sentences):
|
||||
paragraph_result["sentences"].append(sentence)
|
||||
paragraph_result["triples"].append(triple_results[i])
|
||||
|
||||
return paragraph_result
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num} 行时出错: {e}")
|
||||
return None
|
||||
|
||||
async def process_batch_async(batch_data: List[Tuple[int, str]], batch_num: int) -> List[Dict]:
|
||||
"""
|
||||
异步处理一个批次的数据,带进度条和内存监控
|
||||
"""
|
||||
print(f"\n=== 异步处理批次 {batch_num} ===")
|
||||
print(f"批次大小: {len(batch_data)} 条记录")
|
||||
print_memory_info(f"批次 {batch_num} 开始前")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建信号量控制并发数量
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT)
|
||||
|
||||
# 分块处理任务,避免一次性创建太多任务
|
||||
chunk_size = 1000 # 每次处理1000个任务
|
||||
all_results = []
|
||||
|
||||
for chunk_start in range(0, len(batch_data), chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, len(batch_data))
|
||||
chunk_data = batch_data[chunk_start:chunk_end]
|
||||
|
||||
print(f"处理子块 {chunk_start//chunk_size + 1}/{(len(batch_data)-1)//chunk_size + 1} ({len(chunk_data)} 条记录)")
|
||||
|
||||
# 创建当前块的异步任务
|
||||
tasks = []
|
||||
for line_num, original_text in chunk_data:
|
||||
task = process_paragraph_async(line_num, original_text, semaphore)
|
||||
tasks.append(task)
|
||||
|
||||
# 使用进度条处理当前块
|
||||
progress_bar = tqdm(total=len(tasks), desc=f"批次{batch_num}-块{chunk_start//chunk_size + 1}", unit="段落", ncols=100)
|
||||
|
||||
chunk_results = []
|
||||
completed_tasks = 0
|
||||
|
||||
# 使用as_completed来获取完成的任务,并更新进度条
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await coro
|
||||
chunk_results.append(result)
|
||||
completed_tasks += 1
|
||||
|
||||
# 更新进度条
|
||||
progress_bar.update(1)
|
||||
|
||||
# 每完成50个任务更新一次描述
|
||||
if completed_tasks % 50 == 0:
|
||||
valid_results = [r for r in chunk_results if r is not None]
|
||||
progress_bar.set_postfix({
|
||||
'有效': len(valid_results),
|
||||
'完成': completed_tasks,
|
||||
'成功率': f"{len(valid_results)/completed_tasks*100:.1f}%"
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"任务执行失败: {e}")
|
||||
completed_tasks += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
progress_bar.close()
|
||||
all_results.extend(chunk_results)
|
||||
|
||||
# 每个块完成后清理内存
|
||||
del tasks, chunk_results
|
||||
gc.collect()
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 块 {chunk_start//chunk_size + 1} 完成后")
|
||||
|
||||
# 过滤None结果
|
||||
valid_results = [result for result in all_results if result is not None]
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in valid_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in valid_results
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print(f"批次 {batch_num} 异步处理完成:")
|
||||
print(f" - 有效段落: {len(valid_results)}/{len(batch_data)} ({len(valid_results)/len(batch_data)*100:.1f}%)")
|
||||
print(f" - 总句子数: {batch_sentences}")
|
||||
print(f" - 成功三元组: {batch_triples}")
|
||||
print(f" - 三元组成功率: {batch_triples/batch_sentences*100:.1f}%" if batch_sentences > 0 else "无句子")
|
||||
print(f" - 处理时间: {processing_time:.2f}秒")
|
||||
print(f" - 处理速度: {len(batch_data)/processing_time:.2f}段落/秒")
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 完成后")
|
||||
|
||||
return valid_results
|
||||
|
||||
async def write_results_batch(results: List[Dict], output_path: str):
|
||||
"""
|
||||
异步批量写入结果,带进度提示
|
||||
"""
|
||||
try:
|
||||
print(f"开始批量写入 {len(results)} 条结果...")
|
||||
|
||||
# 准备写入内容
|
||||
content_lines = []
|
||||
for result in results:
|
||||
content_lines.append(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
# 异步批量写入
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
await f.write("\n".join(content_lines) + "\n")
|
||||
|
||||
print(f"✓ 成功批量写入 {len(results)} 条结果到 {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ 批量写入失败: {e}")
|
||||
print("尝试逐条写入...")
|
||||
|
||||
# 如果批量写入失败,回退到逐条写入(带进度条)
|
||||
async with aiofiles.open(output_path, "a", encoding="utf-8") as f:
|
||||
for result in tqdm(results, desc="逐条写入", unit="条"):
|
||||
await f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
print(f"✓ 逐条写入完成")
|
||||
|
||||
# 主处理流程
|
||||
async def main_async():
|
||||
total_processed = 0
|
||||
total_sentences = 0
|
||||
total_triples = 0
|
||||
batch_num = 0
|
||||
|
||||
print("=== 开始异步批次处理JSONL文件 ===")
|
||||
print(f"优化后的配置信息:")
|
||||
print(f" - 批次大小: {BATCH_SIZE:,} 条记录")
|
||||
print(f" - 最大并发数: {MAX_CONCURRENT}")
|
||||
print(f" - Agent池大小: {AGENT_POOL_SIZE}")
|
||||
print(f" - 输入文件: {json_path}")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print()
|
||||
|
||||
print_memory_info("程序开始")
|
||||
|
||||
# 清空输出文件
|
||||
async with aiofiles.open(output_path, "w", encoding="utf-8") as f:
|
||||
pass
|
||||
|
||||
# 读取并处理数据
|
||||
with open(json_path, "r", encoding="utf-8") as f_in:
|
||||
batch_data = []
|
||||
|
||||
for line_num, line in enumerate(f_in):
|
||||
if line.strip(): # 跳过空行
|
||||
try:
|
||||
item = json.loads(line)
|
||||
original_text = item.get("text", "")
|
||||
|
||||
if original_text:
|
||||
batch_data.append((line_num + 1, original_text))
|
||||
|
||||
# 当批次达到指定大小时,异步处理这个批次
|
||||
if len(batch_data) >= BATCH_SIZE:
|
||||
batch_num += 1
|
||||
|
||||
# 异步处理批次
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
# 批量写入结果
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
# 统计信息
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
print(f"\n📊 批次 {batch_num} 累计统计:")
|
||||
print(f" - 累计处理段落: {total_processed:,}")
|
||||
print(f" - 累计句子数: {total_sentences:,}")
|
||||
print(f" - 累计三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%")
|
||||
print("-" * 80)
|
||||
|
||||
# 清理批次数据,释放内存
|
||||
batch_data.clear()
|
||||
batch_results.clear()
|
||||
gc.collect() # 强制垃圾回收
|
||||
|
||||
print_memory_info(f"批次 {batch_num} 清理后")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"第 {line_num + 1} 行JSON解析错误: {e}")
|
||||
except Exception as e:
|
||||
print(f"处理第 {line_num + 1} 行时出错: {e}")
|
||||
|
||||
# 处理最后一个不完整的批次
|
||||
if batch_data:
|
||||
batch_num += 1
|
||||
batch_results = await process_batch_async(batch_data, batch_num)
|
||||
|
||||
if batch_results:
|
||||
await write_results_batch(batch_results, output_path)
|
||||
|
||||
batch_sentences = sum(len(result["sentences"]) for result in batch_results)
|
||||
batch_triples = sum(
|
||||
sum(1 for triple in result["triples"] if triple["confidence"] > 0.0)
|
||||
for result in batch_results
|
||||
)
|
||||
|
||||
total_processed += len(batch_data)
|
||||
total_sentences += batch_sentences
|
||||
total_triples += batch_triples
|
||||
|
||||
# 最终统计
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🎉 所有批次异步处理完成!")
|
||||
print(f"{'='*80}")
|
||||
print(f"最终统计:")
|
||||
print(f" - 总批次数: {batch_num}")
|
||||
print(f" - 总段落数: {total_processed:,}")
|
||||
print(f" - 总句子数: {total_sentences:,}")
|
||||
print(f" - 总三元组: {total_triples:,}")
|
||||
print(f" - 整体成功率: {total_triples/total_sentences*100:.1f}%" if total_sentences > 0 else "无有效句子")
|
||||
print(f" - 输出文件: {output_path}")
|
||||
print(f"{'='*80}")
|
||||
|
||||
print_memory_info("程序结束前")
|
||||
|
||||
# 显示示例结果
|
||||
await show_sample_results()
|
||||
|
||||
async def show_sample_results():
|
||||
"""显示前几个处理结果作为示例"""
|
||||
print("\n📋 前3个处理结果示例:")
|
||||
try:
|
||||
async with aiofiles.open(output_path, "r", encoding="utf-8") as f:
|
||||
i = 0
|
||||
async for line in f:
|
||||
if i >= 3:
|
||||
break
|
||||
item = json.loads(line)
|
||||
print(f"\n--- 段落 {i+1} (来源行: {item['source_line']}) ---")
|
||||
print(f"原始段落: {item['original_paragraph'][:100]}...")
|
||||
print(f"句子数量: {len(item['sentences'])}")
|
||||
if item['triples']:
|
||||
for j, triple in enumerate(item['triples'][:2]): # 只显示前2个三元组
|
||||
print(f" 句子 {j+1}: {triple['sentence'][:50]}...")
|
||||
if triple['confidence'] > 0:
|
||||
print(f" 三元组: {triple['triple']['subject']} -> {triple['triple']['predicate']} -> {triple['triple']['object']}")
|
||||
print(f" 置信度: {triple['confidence']:.2f}")
|
||||
else:
|
||||
print(f" 提取失败: {triple.get('error', '未知错误')}")
|
||||
i += 1
|
||||
except Exception as e:
|
||||
print(f"读取示例结果时出错: {e}")
|
||||
|
||||
def main():
|
||||
"""主入口函数"""
|
||||
try:
|
||||
# 运行异步主函数
|
||||
asyncio.run(main_async())
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ 用户中断处理")
|
||||
except Exception as e:
|
||||
print(f"❌ 处理过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1022
train_extra_accelerate.py
Normal file
1022
train_extra_accelerate.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user