2025-05-14 00:01:40 +08:00
|
|
|
|
import json
|
|
|
|
|
|
import random
|
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
import os
|
|
|
|
|
|
import ast
|
2025-06-29 16:01:36 +08:00
|
|
|
|
from tqdm import tqdm
|
2025-05-14 00:01:40 +08:00
|
|
|
|
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-07-05 22:18:32 +08:00
|
|
|
|
def process_sample_filter(data_args):
|
|
|
|
|
|
"""处理单个样本的过滤逻辑"""
|
|
|
|
|
|
sample, valid_predicates = data_args
|
|
|
|
|
|
if 'target' in sample and isinstance(sample['target'], list):
|
|
|
|
|
|
# 过滤target中的低频谓词
|
|
|
|
|
|
valid_targets = []
|
|
|
|
|
|
for triple in sample['target']:
|
|
|
|
|
|
if isinstance(triple, dict) and 'predicate' in triple:
|
|
|
|
|
|
if triple['predicate'] in valid_predicates:
|
|
|
|
|
|
valid_targets.append(triple)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果还有有效的target,保留这个样本
|
|
|
|
|
|
if valid_targets:
|
|
|
|
|
|
sample['target'] = valid_targets
|
|
|
|
|
|
return sample
|
|
|
|
|
|
else:
|
|
|
|
|
|
return None
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果没有target信息,保留样本
|
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_sample_validation(data_args):
|
|
|
|
|
|
"""处理单个样本的验证逻辑"""
|
|
|
|
|
|
sample, predicate_vocab = data_args
|
|
|
|
|
|
if not isinstance(sample, dict) or 'text' not in sample:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
targets = sample.get('target', [])
|
|
|
|
|
|
if not isinstance(targets, list) or len(targets) == 0:
|
|
|
|
|
|
# 如果没有有效的target,创建一个默认的
|
|
|
|
|
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 验证并选择target,优先选择占比小的谓词
|
|
|
|
|
|
selected_target = None
|
|
|
|
|
|
min_percentage = float('inf')
|
|
|
|
|
|
|
|
|
|
|
|
for triple in targets:
|
|
|
|
|
|
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
|
|
|
|
|
predicate = triple['predicate']
|
|
|
|
|
|
|
|
|
|
|
|
# 使用predicate_vocab中的统计信息
|
|
|
|
|
|
if predicate in predicate_vocab:
|
|
|
|
|
|
stats = predicate_vocab[predicate]
|
|
|
|
|
|
if isinstance(stats, dict) and 'percentage' in stats:
|
|
|
|
|
|
percentage = stats['percentage']
|
|
|
|
|
|
if percentage < min_percentage:
|
|
|
|
|
|
min_percentage = percentage
|
|
|
|
|
|
selected_target = triple
|
|
|
|
|
|
elif selected_target is None:
|
|
|
|
|
|
selected_target = triple
|
|
|
|
|
|
elif selected_target is None:
|
|
|
|
|
|
selected_target = triple
|
|
|
|
|
|
|
|
|
|
|
|
# 如果没有找到有效的target,使用默认值
|
|
|
|
|
|
if selected_target is None:
|
|
|
|
|
|
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'text': sample['text'],
|
|
|
|
|
|
'target': selected_target # 只保留一个target
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-05-14 00:01:40 +08:00
|
|
|
|
class PretrainDataset(Dataset):
|
|
|
|
|
|
def __init__(self, data_path, tokenizer, max_length=512):
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
|
self.samples = self.load_data(data_path)
|
|
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
|
samples = []
|
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
|
samples.append(data)
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
sample = self.samples[index]
|
2025-07-13 21:28:46 +08:00
|
|
|
|
text = str(sample['text'])
|
|
|
|
|
|
|
|
|
|
|
|
# 检查并添加<|im_start|>和<|im_end|>如果不存在
|
|
|
|
|
|
if not text.startswith(self.tokenizer.bos_token):
|
|
|
|
|
|
text = f"{self.tokenizer.bos_token}{text}"
|
|
|
|
|
|
if not text.endswith(self.tokenizer.eos_token):
|
|
|
|
|
|
text = f"{text}{self.tokenizer.eos_token}"
|
|
|
|
|
|
|
2025-05-14 00:01:40 +08:00
|
|
|
|
encoding = self.tokenizer(
|
|
|
|
|
|
text,
|
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
|
padding='max_length',
|
|
|
|
|
|
truncation=True,
|
|
|
|
|
|
return_tensors='pt'
|
|
|
|
|
|
)
|
|
|
|
|
|
input_ids = encoding.input_ids.squeeze()
|
|
|
|
|
|
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
|
|
|
|
|
|
|
|
|
|
|
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
|
|
|
|
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
|
|
|
|
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
|
|
|
|
|
return X, Y, loss_mask
|
|
|
|
|
|
|
|
|
|
|
|
|