125 lines
4.2 KiB
Python
125 lines
4.2 KiB
Python
import json
|
||
import random
|
||
import re
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import torch
|
||
from sklearn.model_selection import train_test_split
|
||
import os
|
||
import ast
|
||
from tqdm import tqdm
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||
|
||
|
||
def process_sample_filter(data_args):
|
||
"""处理单个样本的过滤逻辑"""
|
||
sample, valid_predicates = data_args
|
||
if 'target' in sample and isinstance(sample['target'], list):
|
||
# 过滤target中的低频谓词
|
||
valid_targets = []
|
||
for triple in sample['target']:
|
||
if isinstance(triple, dict) and 'predicate' in triple:
|
||
if triple['predicate'] in valid_predicates:
|
||
valid_targets.append(triple)
|
||
|
||
# 如果还有有效的target,保留这个样本
|
||
if valid_targets:
|
||
sample['target'] = valid_targets
|
||
return sample
|
||
else:
|
||
return None
|
||
else:
|
||
# 如果没有target信息,保留样本
|
||
return sample
|
||
|
||
|
||
def process_sample_validation(data_args):
|
||
"""处理单个样本的验证逻辑"""
|
||
sample, predicate_vocab = data_args
|
||
if not isinstance(sample, dict) or 'text' not in sample:
|
||
return None
|
||
|
||
targets = sample.get('target', [])
|
||
if not isinstance(targets, list) or len(targets) == 0:
|
||
# 如果没有有效的target,创建一个默认的
|
||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||
else:
|
||
# 验证并选择target,优先选择占比小的谓词
|
||
selected_target = None
|
||
min_percentage = float('inf')
|
||
|
||
for triple in targets:
|
||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||
predicate = triple['predicate']
|
||
|
||
# 使用predicate_vocab中的统计信息
|
||
if predicate in predicate_vocab:
|
||
stats = predicate_vocab[predicate]
|
||
if isinstance(stats, dict) and 'percentage' in stats:
|
||
percentage = stats['percentage']
|
||
if percentage < min_percentage:
|
||
min_percentage = percentage
|
||
selected_target = triple
|
||
elif selected_target is None:
|
||
selected_target = triple
|
||
elif selected_target is None:
|
||
selected_target = triple
|
||
|
||
# 如果没有找到有效的target,使用默认值
|
||
if selected_target is None:
|
||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||
|
||
return {
|
||
'text': sample['text'],
|
||
'target': selected_target # 只保留一个target
|
||
}
|
||
|
||
|
||
class PretrainDataset(Dataset):
|
||
def __init__(self, data_path, tokenizer, max_length=512):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.samples = self.load_data(data_path)
|
||
|
||
def load_data(self, path):
|
||
samples = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
for line_num, line in enumerate(f, 1):
|
||
data = json.loads(line.strip())
|
||
samples.append(data)
|
||
return samples
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, index):
|
||
sample = self.samples[index]
|
||
text = str(sample['text'])
|
||
|
||
# 检查并添加<|im_start|>和<|im_end|>如果不存在
|
||
if not text.startswith(self.tokenizer.bos_token):
|
||
text = f"{self.tokenizer.bos_token}{text}"
|
||
if not text.endswith(self.tokenizer.eos_token):
|
||
text = f"{text}{self.tokenizer.eos_token}"
|
||
|
||
encoding = self.tokenizer(
|
||
text,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_tensors='pt'
|
||
)
|
||
input_ids = encoding.input_ids.squeeze()
|
||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||
|
||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
||
return X, Y, loss_mask
|
||
|
||
|