Minimind/model/dataset.py

546 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
import ast
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def process_sample_filter(data_args):
"""处理单个样本的过滤逻辑"""
sample, valid_predicates = data_args
if 'target' in sample and isinstance(sample['target'], list):
# 过滤target中的低频谓词
valid_targets = []
for triple in sample['target']:
if isinstance(triple, dict) and 'predicate' in triple:
if triple['predicate'] in valid_predicates:
valid_targets.append(triple)
# 如果还有有效的target保留这个样本
if valid_targets:
sample['target'] = valid_targets
return sample
else:
return None
else:
# 如果没有target信息保留样本
return sample
def process_sample_validation(data_args):
"""处理单个样本的验证逻辑"""
sample, predicate_vocab = data_args
if not isinstance(sample, dict) or 'text' not in sample:
return None
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择target优先选择占比小的谓词
selected_target = None
min_percentage = float('inf')
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
predicate = triple['predicate']
# 使用predicate_vocab中的统计信息
if predicate in predicate_vocab:
stats = predicate_vocab[predicate]
if isinstance(stats, dict) and 'percentage' in stats:
percentage = stats['percentage']
if percentage < min_percentage:
min_percentage = percentage
selected_target = triple
elif selected_target is None:
selected_target = triple
elif selected_target is None:
selected_target = triple
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
return {
'text': sample['text'],
'target': selected_target # 只保留一个target
}
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt = self._create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 生成动态损失掩码
loss_mask = self._generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
return X, Y, loss_mask
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f:
line = line.strip()
obj = json.loads(line)
self.data.append(obj)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class TriplePretrainDataset(Dataset):
"""
优化的三元组预训练数据集
- 每个样本只保留一个target三元组
- 预先tokenize所有数据
- 使用进度条显示处理进度
"""
def __init__(self, data_path=None, predicate_vocab_path=None, samples = None,tokenizer=None, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.val_samples = None
self.predicate_to_id = {} # 初始化
if samples is None:
self.predicate_vocab = self.load_predicate_vocab(predicate_vocab_path)
print("🚀 开始加载和预处理三元组数据...")
self.samples,self.val_samples = self.load_and_preprocess_data(data_path)
print("🚀 加载和预处理三元组数据完成")
else:
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
predicate_to_id_path = os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json')
self.predicate_to_id = self.load_predicate_vocab(predicate_to_id_path)
self.samples = samples
print("🚀 加载和预处理三元组数据完成")
def load_predicate_vocab(self, path):
with open(path, 'r', encoding='utf-8') as f:
predicate_vocab = json.load(f)
return predicate_vocab
def get_val_samples(self):
return self.val_samples
def clear_cache(self, data_path):
"""清除缓存文件"""
cache_dir = os.path.join(os.path.dirname(data_path), 'cache')
data_filename = os.path.basename(data_path).split('.')[0]
cache_files = [
os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
os.path.join(cache_dir, f'{data_filename}_val_samples.json')
]
for cache_file in cache_files:
if os.path.exists(cache_file):
os.remove(cache_file)
print(f"🗑️ 已删除缓存文件: {cache_file}")
if os.path.exists(cache_dir) and not os.listdir(cache_dir):
os.rmdir(cache_dir)
print(f"🗑️ 已删除空的缓存目录: {cache_dir}")
def load_and_preprocess_data(self, path):
"""加载并预处理三元组数据"""
# 生成缓存文件名(基于数据文件路径)
cache_dir = os.path.join(os.path.dirname(path), 'cache')
os.makedirs(cache_dir, exist_ok=True)
data_filename = os.path.basename(path).split('.')[0]
cache_files = {
'predicate_vocab': os.path.join(cache_dir, f'{data_filename}_predicate_vocab.json'),
'predicate_to_id': os.path.join(cache_dir, f'{data_filename}_predicate_to_id.json'),
'train_samples': os.path.join(cache_dir, f'{data_filename}_train_samples.json'),
'val_samples': os.path.join(cache_dir, f'{data_filename}_val_samples.json')
}
# 检查缓存文件是否存在
cache_exists = all(os.path.exists(cache_file) for cache_file in cache_files.values())
if cache_exists:
print("📁 发现缓存文件,直接加载...")
# 从缓存加载
with open(cache_files['predicate_vocab'], 'r', encoding='utf-8') as f:
self.predicate_vocab = json.load(f)
with open(cache_files['predicate_to_id'], 'r', encoding='utf-8') as f:
self.predicate_to_id = json.load(f)
with open(cache_files['train_samples'], 'r', encoding='utf-8') as f:
train_samples = json.load(f)
with open(cache_files['val_samples'], 'r', encoding='utf-8') as f:
val_samples = json.load(f)
print(f"✅ 从缓存加载完成:")
print(f"✅ 谓词词表大小: {len(self.predicate_vocab)}")
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
return train_samples, val_samples
# 缓存不存在,重新处理数据
print("📂 缓存不存在,开始加载和处理原始数据...")
# 1. 加载原始数据
print("📂 加载原始数据...")
if path.endswith('.json'):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
elif path.endswith('.jsonl'):
data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line.strip()))
else:
raise ValueError(f"Unsupported file format: {path}")
print(f"📊 原始数据量: {len(data)} 个样本")
# 2. 使用self.predicate_vocab过滤占比小于0.01%的谓词数据
print("🔍 过滤低频谓词数据...")
print(f"📊 谓词统计数据: 总共{len(self.predicate_vocab)}个谓词")
# 3.获取占比大于等于0.01%的谓词
valid_predicates = set()
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
valid_predicates.add(predicate)
else:
# 如果不是统计格式,假设是有效谓词
valid_predicates.add(predicate)
print(f"📊 占比≥0.01%的谓词: {len(valid_predicates)}")
# 4.过滤数据:去除包含低频谓词的数据(单进程处理)
original_count = len(data)
filtered_data = []
print("🚀 开始过滤低频谓词数据...")
for sample in tqdm(data, desc="过滤低频谓词"):
result = process_sample_filter((sample, valid_predicates))
if result is not None:
filtered_data.append(result)
data = filtered_data
print(f"✅ 过滤完成: 去除前{original_count}条,去除后{len(data)}")
# 5. 去除self.predicate_vocab中占比小于0.01%的谓词,并创建谓词到序号的映射
print("🔍 更新谓词词表并创建序号映射...")
original_vocab_size = len(self.predicate_vocab)
filtered_predicate_vocab = {}
for predicate, stats in self.predicate_vocab.items():
if isinstance(stats, dict) and 'percentage' in stats:
if stats['percentage'] >= 0.01:
filtered_predicate_vocab[predicate] = stats
else:
# 如果不是统计格式,保留
filtered_predicate_vocab[predicate] = stats
# 创建谓词到序号的映射字典
self.predicate_to_id = {predicate: idx for idx, predicate in enumerate(filtered_predicate_vocab.keys())}
self.predicate_vocab = filtered_predicate_vocab
print(f"✅ 谓词词表更新: 去除前{original_vocab_size}个,去除后{len(self.predicate_vocab)}")
print(f"✅ 谓词映射创建: {len(self.predicate_to_id)}个谓词对应序号")
# 6. 数据验证和筛选只保留一个target优先选择占比小的谓词以平衡数据单进程处理
print("🔍 验证数据格式并选择单个target平衡数据...")
valid_samples = []
print("🚀 开始验证数据格式...")
for sample in tqdm(data, desc="验证数据格式"):
result = process_sample_validation((sample, self.predicate_vocab))
if result is not None:
valid_samples.append(result)
print(f"✅ 有效样本数: {len(valid_samples)}")
# 7.拆分训练集合与测试集合
import random
random.seed(42)
val_samples = random.sample(valid_samples, min(1000, len(valid_samples)))
train_samples = [sample for sample in valid_samples if sample not in val_samples]
print(f"✅ 训练集大小: {len(train_samples)}")
print(f"✅ 测试集大小: {len(val_samples)}")
# 8. 保存到缓存文件
print("💾 保存处理结果到缓存文件...")
with open(cache_files['predicate_vocab'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_vocab, f, ensure_ascii=False, indent=2)
with open(cache_files['predicate_to_id'], 'w', encoding='utf-8') as f:
json.dump(self.predicate_to_id, f, ensure_ascii=False, indent=2)
with open(cache_files['train_samples'], 'w', encoding='utf-8') as f:
json.dump(train_samples, f, ensure_ascii=False, indent=2)
with open(cache_files['val_samples'], 'w', encoding='utf-8') as f:
json.dump(val_samples, f, ensure_ascii=False, indent=2)
print("✅ 缓存文件保存完成")
return train_samples, val_samples
def __len__(self):
return len(self.samples)
def _triple_to_sentence(self, triple):
"""将三元组转换为句子格式"""
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
def __getitem__(self, index):
"""返回数据,用于谓词分类任务"""
sample = self.samples[index]
# 在运行时tokenize输入文本
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
# 获取谓词分类标签
target_predicate = sample['target']['predicate']
predicate_label = self.predicate_to_id.get(target_predicate) # 默认为0如果找不到
# 构建训练数据
X = input_ids[:-1]
loss_mask = loss_mask[1:]
return {
'input_ids': X,
'labels': torch.tensor(predicate_label, dtype=torch.long), # 谓词分类标签
'loss_mask': loss_mask
}
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True
), answer
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass