Minimind/model/dataset.py
2025-08-01 15:54:21 +08:00

125 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
import ast
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
text = str(sample['text'])
# 检查并添加<|im_start|>和<|im_end|>如果不存在
if not text.startswith(self.tokenizer.bos_token):
text = f"{self.tokenizer.bos_token}{text}"
if not text.endswith(self.tokenizer.eos_token):
text = f"{text}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask