Minimind/model/dataset.py

125 lines
4.2 KiB
Python
Raw Normal View History

2025-05-14 00:01:40 +08:00
import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
import ast
2025-06-29 16:01:36 +08:00
from tqdm import tqdm
2025-05-14 00:01:40 +08:00
os.environ["TOKENIZERS_PARALLELISM"] = "true"
2025-07-05 22:18:32 +08:00
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
2025-05-14 00:01:40 +08:00
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
2025-07-13 21:28:46 +08:00
text = str(sample['text'])
# 检查并添加<|im_start|>和<|im_end|>如果不存在
if not text.startswith(self.tokenizer.bos_token):
text = f"{self.tokenizer.bos_token}{text}"
if not text.endswith(self.tokenizer.eos_token):
text = f"{text}{self.tokenizer.eos_token}"
2025-05-14 00:01:40 +08:00
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask