import json import random import re import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader import torch from sklearn.model_selection import train_test_split import os import ast from tqdm import tqdm os.environ["TOKENIZERS_PARALLELISM"] = "true" def process_sample_filter(data_args): """处理单个样本的过滤逻辑""" sample, valid_predicates = data_args if 'target' in sample and isinstance(sample['target'], list): # 过滤target中的低频谓词 valid_targets = [] for triple in sample['target']: if isinstance(triple, dict) and 'predicate' in triple: if triple['predicate'] in valid_predicates: valid_targets.append(triple) # 如果还有有效的target,保留这个样本 if valid_targets: sample['target'] = valid_targets return sample else: return None else: # 如果没有target信息,保留样本 return sample def process_sample_validation(data_args): """处理单个样本的验证逻辑""" sample, predicate_vocab = data_args if not isinstance(sample, dict) or 'text' not in sample: return None targets = sample.get('target', []) if not isinstance(targets, list) or len(targets) == 0: # 如果没有有效的target,创建一个默认的 selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} else: # 验证并选择target,优先选择占比小的谓词 selected_target = None min_percentage = float('inf') for triple in targets: if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']): predicate = triple['predicate'] # 使用predicate_vocab中的统计信息 if predicate in predicate_vocab: stats = predicate_vocab[predicate] if isinstance(stats, dict) and 'percentage' in stats: percentage = stats['percentage'] if percentage < min_percentage: min_percentage = percentage selected_target = triple elif selected_target is None: selected_target = triple elif selected_target is None: selected_target = triple # 如果没有找到有效的target,使用默认值 if selected_target is None: selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} return { 'text': sample['text'], 'target': selected_target # 只保留一个target } class PretrainDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=512): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(data_path) def load_data(self, path): samples = [] with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): data = json.loads(line.strip()) samples.append(data) return samples def __len__(self): return len(self.samples) def __getitem__(self, index): sample = self.samples[index] text = str(sample['text']) # 检查并添加<|im_start|>和<|im_end|>如果不存在 if not text.startswith(self.tokenizer.bos_token): text = f"{self.tokenizer.bos_token}{text}" if not text.endswith(self.tokenizer.eos_token): text = f"{text}{self.tokenizer.eos_token}" encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding.input_ids.squeeze() loss_mask = (input_ids != self.tokenizer.pad_token_id) X = torch.tensor(input_ids[:-1], dtype=torch.long) Y = torch.tensor(input_ids[1:], dtype=torch.long) loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) return X, Y, loss_mask