Minimind/model/dataset.py
2025-07-05 03:03:43 +00:00

398 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
import re
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.model_selection import train_test_split
import os
import ast
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# 加载谓词类别与train_extra_accelerate.py保持一致
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
PREDICATE_LIST = json.load(f)
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
class PretrainDataset(Dataset):
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(data_path)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
# 构建输入文本
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
return X, Y, loss_mask
class SFTDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt = self._create_chat_prompt(sample['conversations'])
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
# 生成动态损失掩码
loss_mask = self._generate_loss_mask(input_ids)
# 构建训练数据
X = torch.tensor(input_ids[:-1], dtype=torch.long)
Y = torch.tensor(input_ids[1:], dtype=torch.long)
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
return X, Y, loss_mask
class DPODataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=4096):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f:
line = line.strip()
obj = json.loads(line)
self.data.append(obj)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
chosen = item['chosen'] # 是一个 list里面包含若干 {role, content}
rejected = item['rejected'] # 同上
chosen_prompt = self.tokenizer.apply_chat_template(
chosen, tokenize=False, add_generation_prompt=False
)
rejected_prompt = self.tokenizer.apply_chat_template(
rejected, tokenize=False, add_generation_prompt=False
)
chosen_encoding = self.tokenizer(
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
rejected_encoding = self.tokenizer(
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
)
chosen_input_ids = chosen_encoding['input_ids']
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
rejected_input_ids = rejected_encoding['input_ids']
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
return {
'x_chosen': x_chosen,
'y_chosen': y_chosen,
'mask_chosen': mask_chosen,
'x_rejected': x_rejected,
'y_rejected': y_rejected,
'mask_rejected': mask_rejected
}
def _generate_loss_mask(self, input_ids):
loss_mask = [0] * len(input_ids)
i = 0
while i < len(input_ids):
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
start = i + len(self.bos_id)
end = start
while end < len(input_ids):
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
break
end += 1
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
loss_mask[j] = 1
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
else:
i += 1
return loss_mask
class TriplePretrainDataset(Dataset):
"""
优化的三元组预训练数据集
- 每个样本只保留一个target三元组
- 预先tokenize所有数据
- 使用进度条显示处理进度
"""
def __init__(self, data_path, tokenizer, max_length=512):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
print("🚀 开始加载和预处理三元组数据...")
self.samples = self.load_and_preprocess_data(data_path)
def load_and_preprocess_data(self, path):
"""加载并预处理三元组数据"""
# 1. 加载原始数据
print("📂 加载原始数据...")
if path.endswith('.json'):
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
elif path.endswith('.jsonl'):
data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line.strip()))
else:
raise ValueError(f"Unsupported file format: {path}")
print(f"📊 原始数据量: {len(data)} 个样本")
# 2. 数据验证和筛选只保留一个target
print("🔍 验证数据格式并选择单个target...")
valid_samples = []
for i, sample in enumerate(tqdm(data, desc="验证数据格式")):
if not isinstance(sample, dict) or 'text' not in sample:
continue
targets = sample.get('target', [])
if not isinstance(targets, list) or len(targets) == 0:
# 如果没有有效的target创建一个默认的
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
else:
# 验证并选择第一个有效的target
selected_target = None
for triple in targets:
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
selected_target = triple
break
# 如果没有找到有效的target使用默认值
if selected_target is None:
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
valid_samples.append({
'text': sample['text'],
'target': selected_target # 只保留一个target
})
print(f"✅ 有效样本数: {len(valid_samples)}")
# 3. 分批tokenize目标句子
print("🔤 分批tokenize目标句子...")
processed_samples = []
batch_size = 1000 # 每批处理1000个句子避免内存爆炸
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"):
# 获取当前批次
batch_samples = valid_samples[i:i + batch_size]
# 提取当前批次的目标句子
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
# 批量tokenize当前批次
batch_encodings = self.tokenizer(
batch_target_sentences,
max_length=128, # 目标句子通常较短
padding='max_length',
truncation=True,
return_tensors='pt'
)
# 构建当前批次的样本数据
for j, sample in enumerate(batch_samples):
processed_samples.append({
'text': sample['text'], # 保持原始文本不进行tokenize
'target_input_ids': batch_encodings.input_ids[j],
'target_attention_mask': batch_encodings.attention_mask[j],
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
})
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本")
return processed_samples
def __len__(self):
return len(self.samples)
def _triple_to_sentence(self, triple):
"""将三元组转换为句子格式"""
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
def __getitem__(self, index):
"""返回数据输入文本在运行时tokenize目标已预tokenize增加predicate_label字段"""
sample = self.samples[index]
# 在运行时tokenize输入文本用于语言建模
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
encoding = self.tokenizer(
input_text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding.input_ids.squeeze()
loss_mask = (input_ids != self.tokenizer.pad_token_id)
# 构建训练数据
X = input_ids[:-1]
Y = input_ids[1:]
loss_mask = loss_mask[1:]
# 提取谓词label
# 先尝试从target_sentence中间取出谓词
predicate_label = 0 # 默认0
try:
# target_sentence格式主语 谓语 宾语
triple_str = sample['target_sentence']
triple_parts = triple_str.strip().split()
if len(triple_parts) >= 3:
predicate = triple_parts[1]
predicate_label = PREDICATE2ID.get(predicate, 0)
except Exception:
predicate_label = 0
return {
'input_ids': X,
'labels': Y,
'loss_mask': loss_mask,
'target_input_ids': sample['target_input_ids'], # 已经是tensor
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
'target_sentence': sample['target_sentence'], # 字符串,用于调试
'original_text': sample['text'],
'predicate_label': torch.tensor(predicate_label, dtype=torch.long)
}
class RLAIFDataset(Dataset):
def __init__(self, jsonl_path, tokenizer, max_length=1024):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
def load_data(self, path):
samples = []
with open(path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
data = json.loads(line.strip())
samples.append(data)
return samples
def _create_chat_prompt(self, conversations):
"""构建符合ChatML格式的对话"""
messages = []
answer = ''
for i, turn in enumerate(conversations):
role = 'user' if i % 2 == 0 else 'assistant'
messages.append({"role": role, "content": turn['content']})
answer = turn['content']
return self.tokenizer.apply_chat_template(
messages[:-1],
tokenize=False,
add_generation_prompt=True
), answer
def __getitem__(self, index):
sample = self.samples[index]
# 构建对话提示
prompt, answer = self._create_chat_prompt(sample['conversations'])
return {
'prompt': prompt,
'answer': answer
}
if __name__ == "__main__":
pass