import json import random import re import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader import torch from sklearn.model_selection import train_test_split import os import ast from tqdm import tqdm os.environ["TOKENIZERS_PARALLELISM"] = "true" def process_sample_filter(data_args): """处理单个样本的过滤逻辑""" sample, valid_predicates = data_args if 'target' in sample and isinstance(sample['target'], list): # 过滤target中的低频谓词 valid_targets = [] for triple in sample['target']: if isinstance(triple, dict) and 'predicate' in triple: if triple['predicate'] in valid_predicates: valid_targets.append(triple) # 如果还有有效的target,保留这个样本 if valid_targets: sample['target'] = valid_targets return sample else: return None else: # 如果没有target信息,保留样本 return sample def process_sample_validation(data_args): """处理单个样本的验证逻辑""" sample, predicate_vocab = data_args if not isinstance(sample, dict) or 'text' not in sample: return None targets = sample.get('target', []) if not isinstance(targets, list) or len(targets) == 0: # 如果没有有效的target,创建一个默认的 selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} else: # 验证并选择target,优先选择占比小的谓词 selected_target = None min_percentage = float('inf') for triple in targets: if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']): predicate = triple['predicate'] # 使用predicate_vocab中的统计信息 if predicate in predicate_vocab: stats = predicate_vocab[predicate] if isinstance(stats, dict) and 'percentage' in stats: percentage = stats['percentage'] if percentage < min_percentage: min_percentage = percentage selected_target = triple elif selected_target is None: selected_target = triple elif selected_target is None: selected_target = triple # 如果没有找到有效的target,使用默认值 if selected_target is None: selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"} return { 'text': sample['text'], 'target': selected_target # 只保留一个target } class PretrainDataset(Dataset): def __init__(self, data_path, tokenizer, max_length=512): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(data_path) def load_data(self, path): samples = [] with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): data = json.loads(line.strip()) samples.append(data) return samples def __len__(self): return len(self.samples) def __getitem__(self, index): sample = self.samples[index] text = str(sample['text']) # 检查并添加<|im_start|>和<|im_end|>如果不存在 if not text.startswith(self.tokenizer.bos_token): text = f"{self.tokenizer.bos_token}{text}" if not text.endswith(self.tokenizer.eos_token): text = f"{text}{self.tokenizer.eos_token}" encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding.input_ids.squeeze() loss_mask = (input_ids != self.tokenizer.pad_token_id) X = torch.tensor(input_ids[:-1], dtype=torch.long) Y = torch.tensor(input_ids[1:], dtype=torch.long) loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) return X, Y, loss_mask class SFTDataset(Dataset): def __init__(self, jsonl_path, tokenizer, max_length=1024): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(jsonl_path) self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) def load_data(self, path): samples = [] with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): data = json.loads(line.strip()) samples.append(data) return samples def _create_chat_prompt(self, conversations): """构建符合ChatML格式的对话""" messages = [] for i, turn in enumerate(conversations): role = 'user' if i % 2 == 0 else 'assistant' messages.append({"role": role, "content": turn['content']}) return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) def _generate_loss_mask(self, input_ids): loss_mask = [0] * len(input_ids) i = 0 while i < len(input_ids): if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) end = start while end < len(input_ids): if input_ids[end:end + len(self.eos_id)] == self.eos_id: break end += 1 for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): loss_mask[j] = 1 i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: i += 1 return loss_mask def __getitem__(self, index): sample = self.samples[index] # 构建对话提示 prompt = self._create_chat_prompt(sample['conversations']) input_ids = self.tokenizer(prompt).input_ids[:self.max_length] input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) # 生成动态损失掩码 loss_mask = self._generate_loss_mask(input_ids) # 构建训练数据 X = torch.tensor(input_ids[:-1], dtype=torch.long) Y = torch.tensor(input_ids[1:], dtype=torch.long) loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置 return X, Y, loss_mask class DPODataset(Dataset): def __init__(self, file_path, tokenizer, max_length=4096): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids with open(file_path, 'r', encoding='utf-8') as f: self.data = [] for line in f: line = line.strip() obj = json.loads(line) self.data.append(obj) def __len__(self): return len(self.data) def __getitem__(self, index): item = self.data[index] chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content} rejected = item['rejected'] # 同上 chosen_prompt = self.tokenizer.apply_chat_template( chosen, tokenize=False, add_generation_prompt=False ) rejected_prompt = self.tokenizer.apply_chat_template( rejected, tokenize=False, add_generation_prompt=False ) chosen_encoding = self.tokenizer( chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length' ) rejected_encoding = self.tokenizer( rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length' ) chosen_input_ids = chosen_encoding['input_ids'] chosen_loss_mask = self._generate_loss_mask(chosen_input_ids) rejected_input_ids = rejected_encoding['input_ids'] rejected_loss_mask = self._generate_loss_mask(rejected_input_ids) x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long) y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long) mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long) x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long) y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long) mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long) return { 'x_chosen': x_chosen, 'y_chosen': y_chosen, 'mask_chosen': mask_chosen, 'x_rejected': x_rejected, 'y_rejected': y_rejected, 'mask_rejected': mask_rejected } def _generate_loss_mask(self, input_ids): loss_mask = [0] * len(input_ids) i = 0 while i < len(input_ids): if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) end = start while end < len(input_ids): if input_ids[end:end + len(self.eos_id)] == self.eos_id: break end += 1 for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): loss_mask[j] = 1 i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) else: i += 1 return loss_mask class TriplePretrainDataset(Dataset): """ 优化的三元组预训练数据集 - 每个样本只保留一个target三元组 - 预先tokenize所有数据 - 使用进度条显示处理进度 """ def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.val_samples = None self.predicate_to_id = {} # 初始化 if samples is None: self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path) print("🚀 开始加载和预处理三元组数据...") self.samples,self.val_samples = self.load_and_preprocess_data(data_path) print("🚀 加载和预处理三元组数据完成") else: cache_dir = os.path.join(os.path.dirname(data_path), 'cache') data_filename = os.path.basename(data_path).split('.')[0] predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json') self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path) self.samples = samples print("🚀 加载和预处理三元组数据完成") def load_predicate_vocab(self, path): with open(path, 'r', encoding='utf-8') as f: predicate_vocab = json.load(f) return predicate_vocab def get_val_samples(self): return self.val_samples def clear_cache(self, data_path): """清除缓存文件""" cache_dir = os.path.join(os.path.dirname(data_path), 'cache') data_filename = os.path.basename(data_path).split('.')[0] cache_files = [ os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'), os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'), os.path.join(cache_dir, f'{data_filename}_train_samples.json'), os.path.join(cache_dir, f'{data_filename}_val_samples.json') ] for cache_file in cache_files: if os.path.exists(cache_file): os.remove(cache_file) print(f"🗑️ 已删除缓存文件: {cache_file}") if os.path.exists(cache_dir) and not os.listdir(cache_dir): os.rmdir(cache_dir) print(f"🗑️ 已删除空的缓存目录: {cache_dir}") def load_and_preprocess_data(self, path): """加载并预处理三元组数据""" # 生成缓存文件名(基于数据文件路径) cache_dir = os.path.join(os.path.dirname(path), 'cache') os.makedirs(cache_dir, exist_ok=True) data_filename = os.path.basename(path).split('.')[0] cache_files = { 'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'), 'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'), 'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'), 'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json') } # 检查缓存文件是否存在 cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values()) if cache_exists: print("📁 发现缓存文件,直接加载...") # 从缓存加载 with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f: self.predicate_vocab = json.load(f) with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f: self.predicate_to_id = json.load(f) with open(cache_files['train_samples'], 'r', encoding='utf-8') as f: train_samples = json.load(f) with open(cache_files['val_samples'], 'r', encoding='utf-8') as f: val_samples = json.load(f) print(f"✅ 从缓存加载完成:") print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}") print(f"✅ 训练集大小: {len(train_samples)}") print(f"✅ 测试集大小: {len(val_samples)}") return train_samples, val_samples # 缓存不存在,重新处理数据 print("📂 缓存不存在,开始加载和处理原始数据...") # 1. 加载原始数据 print("📂 加载原始数据...") if path.endswith('.json'): with open(path, 'r', encoding='utf-8') as f: data = json.load(f) elif path.endswith('.jsonl'): data = [] with open(path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): data.append(json.loads(line.strip())) else: raise ValueError(f"Unsupported file format: {path}") print(f"📊 原始数据量: {len(data)} 个样本") # 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据 print("🔍 过滤低频谓词数据...") print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词") # 3.获取占比大于等于0.01%的谓词 valid_predicates = set() for predicate, stats in self.predicate_vocab.items(): if isinstance(stats, dict) and 'percentage' in stats: if stats['percentage'] >= 0.01: valid_predicates.add(predicate) else: # 如果不是统计格式,假设是有效谓词 valid_predicates.add(predicate) print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}个") # 4.过滤数据:去除包含低频谓词的数据(单进程处理) original_count = len(data) filtered_data = [] print("🚀 开始过滤低频谓词数据...") for sample in tqdm(data, desc="过滤低频谓词"): result = process_sample_filter((sample, valid_predicates)) if result is not None: filtered_data.append(result) data = filtered_data print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}条") # 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射 print("🔍 更新谓词词表并创建序号映射...") original_vocab_size = len(self.predicate_vocab) filtered_predicate_vocab = {} for predicate, stats in self.predicate_vocab.items(): if isinstance(stats, dict) and 'percentage' in stats: if stats['percentage'] >= 0.01: filtered_predicate_vocab[predicate] = stats else: # 如果不是统计格式,保留 filtered_predicate_vocab[predicate] = stats # 创建谓词到序号的映射字典 self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())} self.predicate_vocab = filtered_predicate_vocab print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}个") print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号") # 6. 数据验证和筛选(只保留一个target),优先选择占比小的谓词以平衡数据(单进程处理) print("🔍 验证数据格式并选择单个target(平衡数据)...") valid_samples = [] print("🚀 开始验证数据格式...") for sample in tqdm(data, desc="验证数据格式"): result = process_sample_validation((sample, self.predicate_vocab)) if result is not None: valid_samples.append(result) print(f"✅ 有效样本数: {len(valid_samples)}") # 7.拆分训练集合与测试集合 import random random.seed(42) val_samples = random.sample(valid_samples, min(1000, len(valid_samples))) train_samples = [sample for sample in valid_samples if sample not in val_samples] print(f"✅ 训练集大小: {len(train_samples)}") print(f"✅ 测试集大小: {len(val_samples)}") # 8. 保存到缓存文件 print("💾 保存处理结果到缓存文件...") with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f: json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2) with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f: json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2) with open(cache_files['train_samples'], 'w', encoding='utf-8') as f: json.dump(train_samples, f, ensure_ascii=False, indent=2) with open(cache_files['val_samples'], 'w', encoding='utf-8') as f: json.dump(val_samples, f, ensure_ascii=False, indent=2) print("✅ 缓存文件保存完成") return train_samples, val_samples def __len__(self): return len(self.samples) def _triple_to_sentence(self, triple): """将三元组转换为句子格式""" return f"{triple['subject']} {triple['predicate']} {triple['object']}" def __getitem__(self, index): """返回数据,用于谓词分类任务""" sample = self.samples[index] # 在运行时tokenize输入文本 input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}" encoding = self.tokenizer( input_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding.input_ids.squeeze() loss_mask = (input_ids != self.tokenizer.pad_token_id) # 获取谓词分类标签 target_predicate = sample['target']['predicate'] predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到 # 构建训练数据 X = input_ids[:-1] loss_mask = loss_mask[1:] return { 'input_ids': X, 'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签 'loss_mask': loss_mask } class RLAIFDataset(Dataset): def __init__(self, jsonl_path, tokenizer, max_length=1024): super().__init__() self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(jsonl_path) self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) def load_data(self, path): samples = [] with open(path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): data = json.loads(line.strip()) samples.append(data) return samples def _create_chat_prompt(self, conversations): """构建符合ChatML格式的对话""" messages = [] answer = '' for i, turn in enumerate(conversations): role = 'user' if i % 2 == 0 else 'assistant' messages.append({"role": role, "content": turn['content']}) answer = turn['content'] return self.tokenizer.apply_chat_template( messages[:-1], tokenize=False, add_generation_prompt=True ), answer def __getitem__(self, index): sample = self.samples[index] # 构建对话提示 prompt, answer = self._create_chat_prompt(sample['conversations']) return { 'prompt': prompt, 'answer': answer } if __name__ == "__main__": pass