398 lines
15 KiB
Python
398 lines
15 KiB
Python
import json
|
||
import random
|
||
import re
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import torch
|
||
from sklearn.model_selection import train_test_split
|
||
import os
|
||
import ast
|
||
from tqdm import tqdm
|
||
|
||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||
|
||
# 加载谓词类别(与train_extra_accelerate.py保持一致)
|
||
PREDICATE_VOCAB_PATH = '/home/rwkv/RWKV-TS/RETRO_TEST/extract/predicate_vocab.json'
|
||
with open(PREDICATE_VOCAB_PATH, 'r', encoding='utf-8') as f:
|
||
PREDICATE_LIST = json.load(f)
|
||
PREDICATE2ID = {p: i for i, p in enumerate(PREDICATE_LIST)}
|
||
|
||
class PretrainDataset(Dataset):
|
||
def __init__(self, data_path, tokenizer, max_length=512):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.samples = self.load_data(data_path)
|
||
|
||
def load_data(self, path):
|
||
samples = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
for line_num, line in enumerate(f, 1):
|
||
data = json.loads(line.strip())
|
||
samples.append(data)
|
||
return samples
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def __getitem__(self, index):
|
||
sample = self.samples[index]
|
||
|
||
# 构建输入文本
|
||
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
||
encoding = self.tokenizer(
|
||
text,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_tensors='pt'
|
||
)
|
||
input_ids = encoding.input_ids.squeeze()
|
||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||
|
||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
||
return X, Y, loss_mask
|
||
|
||
|
||
class SFTDataset(Dataset):
|
||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.samples = self.load_data(jsonl_path)
|
||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def load_data(self, path):
|
||
samples = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
for line_num, line in enumerate(f, 1):
|
||
data = json.loads(line.strip())
|
||
samples.append(data)
|
||
return samples
|
||
|
||
def _create_chat_prompt(self, conversations):
|
||
"""构建符合ChatML格式的对话"""
|
||
messages = []
|
||
for i, turn in enumerate(conversations):
|
||
role = 'user' if i % 2 == 0 else 'assistant'
|
||
messages.append({"role": role, "content": turn['content']})
|
||
return self.tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=False
|
||
)
|
||
|
||
def _generate_loss_mask(self, input_ids):
|
||
loss_mask = [0] * len(input_ids)
|
||
i = 0
|
||
while i < len(input_ids):
|
||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||
start = i + len(self.bos_id)
|
||
end = start
|
||
while end < len(input_ids):
|
||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||
break
|
||
end += 1
|
||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||
loss_mask[j] = 1
|
||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||
else:
|
||
i += 1
|
||
return loss_mask
|
||
|
||
def __getitem__(self, index):
|
||
sample = self.samples[index]
|
||
# 构建对话提示
|
||
prompt = self._create_chat_prompt(sample['conversations'])
|
||
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
||
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
||
|
||
# 生成动态损失掩码
|
||
loss_mask = self._generate_loss_mask(input_ids)
|
||
|
||
# 构建训练数据
|
||
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
||
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
||
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
||
|
||
return X, Y, loss_mask
|
||
|
||
|
||
class DPODataset(Dataset):
|
||
def __init__(self, file_path, tokenizer, max_length=4096):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
self.data = []
|
||
for line in f:
|
||
line = line.strip()
|
||
obj = json.loads(line)
|
||
self.data.append(obj)
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, index):
|
||
item = self.data[index]
|
||
chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
|
||
rejected = item['rejected'] # 同上
|
||
chosen_prompt = self.tokenizer.apply_chat_template(
|
||
chosen, tokenize=False, add_generation_prompt=False
|
||
)
|
||
|
||
rejected_prompt = self.tokenizer.apply_chat_template(
|
||
rejected, tokenize=False, add_generation_prompt=False
|
||
)
|
||
chosen_encoding = self.tokenizer(
|
||
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||
)
|
||
rejected_encoding = self.tokenizer(
|
||
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
||
)
|
||
|
||
chosen_input_ids = chosen_encoding['input_ids']
|
||
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
|
||
|
||
rejected_input_ids = rejected_encoding['input_ids']
|
||
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
|
||
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
||
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
||
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
||
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
||
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
||
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
||
|
||
return {
|
||
'x_chosen': x_chosen,
|
||
'y_chosen': y_chosen,
|
||
'mask_chosen': mask_chosen,
|
||
'x_rejected': x_rejected,
|
||
'y_rejected': y_rejected,
|
||
'mask_rejected': mask_rejected
|
||
}
|
||
|
||
def _generate_loss_mask(self, input_ids):
|
||
loss_mask = [0] * len(input_ids)
|
||
i = 0
|
||
while i < len(input_ids):
|
||
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
||
start = i + len(self.bos_id)
|
||
end = start
|
||
while end < len(input_ids):
|
||
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
||
break
|
||
end += 1
|
||
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
||
loss_mask[j] = 1
|
||
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
||
else:
|
||
i += 1
|
||
return loss_mask
|
||
|
||
|
||
class TriplePretrainDataset(Dataset):
|
||
"""
|
||
优化的三元组预训练数据集
|
||
- 每个样本只保留一个target三元组
|
||
- 预先tokenize所有数据
|
||
- 使用进度条显示处理进度
|
||
"""
|
||
def __init__(self, data_path, tokenizer, max_length=512):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
print("🚀 开始加载和预处理三元组数据...")
|
||
self.samples = self.load_and_preprocess_data(data_path)
|
||
|
||
def load_and_preprocess_data(self, path):
|
||
"""加载并预处理三元组数据"""
|
||
# 1. 加载原始数据
|
||
print("📂 加载原始数据...")
|
||
if path.endswith('.json'):
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
elif path.endswith('.jsonl'):
|
||
data = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
for line in f:
|
||
if line.strip():
|
||
data.append(json.loads(line.strip()))
|
||
else:
|
||
raise ValueError(f"Unsupported file format: {path}")
|
||
|
||
print(f"📊 原始数据量: {len(data)} 个样本")
|
||
|
||
# 2. 数据验证和筛选(只保留一个target)
|
||
print("🔍 验证数据格式并选择单个target...")
|
||
valid_samples = []
|
||
|
||
for i, sample in enumerate(tqdm(data, desc="验证数据格式")):
|
||
if not isinstance(sample, dict) or 'text' not in sample:
|
||
continue
|
||
|
||
targets = sample.get('target', [])
|
||
if not isinstance(targets, list) or len(targets) == 0:
|
||
# 如果没有有效的target,创建一个默认的
|
||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||
else:
|
||
# 验证并选择第一个有效的target
|
||
selected_target = None
|
||
for triple in targets:
|
||
if isinstance(triple, dict) and all(key in triple for key in ['subject', 'predicate', 'object']):
|
||
selected_target = triple
|
||
break
|
||
|
||
# 如果没有找到有效的target,使用默认值
|
||
if selected_target is None:
|
||
selected_target = {"subject": "没有", "predicate": "发现", "object": "三元组"}
|
||
|
||
valid_samples.append({
|
||
'text': sample['text'],
|
||
'target': selected_target # 只保留一个target
|
||
})
|
||
|
||
print(f"✅ 有效样本数: {len(valid_samples)}")
|
||
|
||
# 3. 分批tokenize目标句子
|
||
print("🔤 分批tokenize目标句子...")
|
||
|
||
processed_samples = []
|
||
batch_size = 1000 # 每批处理1000个句子,避免内存爆炸
|
||
|
||
for i in tqdm(range(0, len(valid_samples), batch_size), desc="分批tokenize目标句子"):
|
||
# 获取当前批次
|
||
batch_samples = valid_samples[i:i + batch_size]
|
||
|
||
# 提取当前批次的目标句子
|
||
batch_target_sentences = [self._triple_to_sentence(sample['target']) for sample in batch_samples]
|
||
|
||
# 批量tokenize当前批次
|
||
batch_encodings = self.tokenizer(
|
||
batch_target_sentences,
|
||
max_length=128, # 目标句子通常较短
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_tensors='pt'
|
||
)
|
||
|
||
# 构建当前批次的样本数据
|
||
for j, sample in enumerate(batch_samples):
|
||
processed_samples.append({
|
||
'text': sample['text'], # 保持原始文本,不进行tokenize
|
||
'target_input_ids': batch_encodings.input_ids[j],
|
||
'target_attention_mask': batch_encodings.attention_mask[j],
|
||
'target_sentence': batch_target_sentences[j], # 保留原始句子用于调试
|
||
})
|
||
|
||
print(f"🎉 数据预处理完成! 共处理 {len(processed_samples)} 个样本")
|
||
return processed_samples
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def _triple_to_sentence(self, triple):
|
||
"""将三元组转换为句子格式"""
|
||
return f"{triple['subject']} {triple['predicate']} {triple['object']}"
|
||
|
||
def __getitem__(self, index):
|
||
"""返回数据,输入文本在运行时tokenize,目标已预tokenize,增加predicate_label字段"""
|
||
sample = self.samples[index]
|
||
# 在运行时tokenize输入文本(用于语言建模)
|
||
input_text = f"{self.tokenizer.bos_token}{sample['text']}{self.tokenizer.eos_token}"
|
||
encoding = self.tokenizer(
|
||
input_text,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
truncation=True,
|
||
return_tensors='pt'
|
||
)
|
||
input_ids = encoding.input_ids.squeeze()
|
||
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
||
# 构建训练数据
|
||
X = input_ids[:-1]
|
||
Y = input_ids[1:]
|
||
loss_mask = loss_mask[1:]
|
||
# 提取谓词label
|
||
# 先尝试从target_sentence中间取出谓词
|
||
predicate_label = 0 # 默认0
|
||
try:
|
||
# target_sentence格式:主语 谓语 宾语
|
||
triple_str = sample['target_sentence']
|
||
triple_parts = triple_str.strip().split()
|
||
if len(triple_parts) >= 3:
|
||
predicate = triple_parts[1]
|
||
predicate_label = PREDICATE2ID.get(predicate, 0)
|
||
except Exception:
|
||
predicate_label = 0
|
||
return {
|
||
'input_ids': X,
|
||
'labels': Y,
|
||
'loss_mask': loss_mask,
|
||
'target_input_ids': sample['target_input_ids'], # 已经是tensor
|
||
'target_attention_mask': sample['target_attention_mask'], # 已经是tensor
|
||
'target_sentence': sample['target_sentence'], # 字符串,用于调试
|
||
'original_text': sample['text'],
|
||
'predicate_label': torch.tensor(predicate_label, dtype=torch.long)
|
||
}
|
||
|
||
|
||
class RLAIFDataset(Dataset):
|
||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||
super().__init__()
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
self.samples = self.load_data(jsonl_path)
|
||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||
|
||
def __len__(self):
|
||
return len(self.samples)
|
||
|
||
def load_data(self, path):
|
||
samples = []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
for line_num, line in enumerate(f, 1):
|
||
data = json.loads(line.strip())
|
||
samples.append(data)
|
||
return samples
|
||
|
||
def _create_chat_prompt(self, conversations):
|
||
"""构建符合ChatML格式的对话"""
|
||
messages = []
|
||
answer = ''
|
||
for i, turn in enumerate(conversations):
|
||
role = 'user' if i % 2 == 0 else 'assistant'
|
||
messages.append({"role": role, "content": turn['content']})
|
||
answer = turn['content']
|
||
return self.tokenizer.apply_chat_template(
|
||
messages[:-1],
|
||
tokenize=False,
|
||
add_generation_prompt=True
|
||
), answer
|
||
|
||
def __getitem__(self, index):
|
||
sample = self.samples[index]
|
||
# 构建对话提示
|
||
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||
|
||
return {
|
||
'prompt': prompt,
|
||
'answer': answer
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pass
|