2024-08-28 16:41:44 +08:00
|
|
|
|
import json
|
|
|
|
|
import random
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import numpy as np
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
import torch
|
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
import os
|
2025-02-09 23:49:47 +08:00
|
|
|
|
import ast
|
2024-09-14 14:05:41 +08:00
|
|
|
|
|
2024-09-20 17:04:16 +08:00
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
|
|
class PretrainDataset(Dataset):
|
2025-02-09 23:49:47 +08:00
|
|
|
|
def __init__(self, data_path, tokenizer, max_length=512):
|
2024-08-28 16:41:44 +08:00
|
|
|
|
super().__init__()
|
2024-09-27 16:19:30 +08:00
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
2025-02-09 23:49:47 +08:00
|
|
|
|
self.samples = self.load_data(data_path)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
|
|
def __len__(self):
|
2025-02-09 23:49:47 +08:00
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
2025-02-09 23:49:47 +08:00
|
|
|
|
# 构建输入文本
|
2024-09-27 16:19:30 +08:00
|
|
|
|
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
2025-02-09 23:49:47 +08:00
|
|
|
|
encoding = self.tokenizer(
|
|
|
|
|
text,
|
|
|
|
|
max_length=self.max_length,
|
|
|
|
|
padding='max_length',
|
|
|
|
|
truncation=True,
|
|
|
|
|
return_tensors='pt'
|
|
|
|
|
)
|
|
|
|
|
input_ids = encoding.input_ids.squeeze()
|
|
|
|
|
loss_mask = (input_ids != self.tokenizer.pad_token_id)
|
2024-09-27 16:19:30 +08:00
|
|
|
|
|
2025-02-09 23:49:47 +08:00
|
|
|
|
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
|
|
|
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
|
|
|
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)
|
|
|
|
|
return X, Y, loss_mask
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SFTDataset(Dataset):
|
2025-02-09 23:49:47 +08:00
|
|
|
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
2024-08-28 16:41:44 +08:00
|
|
|
|
super().__init__()
|
2024-09-20 17:04:16 +08:00
|
|
|
|
self.tokenizer = tokenizer
|
2025-02-09 23:49:47 +08:00
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.samples = self.load_data(jsonl_path)
|
2025-04-01 13:44:55 +08:00
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
|
|
def __len__(self):
|
2025-02-09 23:49:47 +08:00
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
def _create_chat_prompt(self, conversations):
|
|
|
|
|
"""构建符合ChatML格式的对话"""
|
2024-08-28 16:41:44 +08:00
|
|
|
|
messages = []
|
2025-02-09 23:49:47 +08:00
|
|
|
|
for i, turn in enumerate(conversations):
|
|
|
|
|
role = 'user' if i % 2 == 0 else 'assistant'
|
|
|
|
|
messages.append({"role": role, "content": turn['content']})
|
|
|
|
|
return self.tokenizer.apply_chat_template(
|
2024-08-28 16:41:44 +08:00
|
|
|
|
messages,
|
|
|
|
|
tokenize=False,
|
2025-02-09 23:49:47 +08:00
|
|
|
|
add_generation_prompt=False
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _generate_loss_mask(self, input_ids):
|
|
|
|
|
loss_mask = [0] * len(input_ids)
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(input_ids):
|
|
|
|
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
|
|
|
|
start = i + len(self.bos_id)
|
|
|
|
|
end = start
|
|
|
|
|
while end < len(input_ids):
|
|
|
|
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
|
|
|
|
break
|
|
|
|
|
end += 1
|
|
|
|
|
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
|
|
|
|
loss_mask[j] = 1
|
|
|
|
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
return loss_mask
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
# 构建对话提示
|
|
|
|
|
prompt = self._create_chat_prompt(sample['conversations'])
|
|
|
|
|
input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
|
|
|
|
|
input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
|
|
|
|
|
|
|
|
|
|
# 生成动态损失掩码
|
|
|
|
|
loss_mask = self._generate_loss_mask(input_ids)
|
|
|
|
|
|
|
|
|
|
# 构建训练数据
|
|
|
|
|
X = torch.tensor(input_ids[:-1], dtype=torch.long)
|
|
|
|
|
Y = torch.tensor(input_ids[1:], dtype=torch.long)
|
|
|
|
|
loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐预测位置
|
|
|
|
|
|
|
|
|
|
return X, Y, loss_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DPODataset(Dataset):
|
|
|
|
|
def __init__(self, file_path, tokenizer, max_length=4096):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
2025-04-01 13:44:55 +08:00
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
2025-02-09 23:49:47 +08:00
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
self.data = []
|
|
|
|
|
for line in f:
|
|
|
|
|
line = line.strip()
|
|
|
|
|
obj = json.loads(line)
|
|
|
|
|
self.data.append(obj)
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.data)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
item = self.data[index]
|
|
|
|
|
chosen = item['chosen'] # 是一个 list,里面包含若干 {role, content}
|
|
|
|
|
rejected = item['rejected'] # 同上
|
|
|
|
|
chosen_prompt = self.tokenizer.apply_chat_template(
|
|
|
|
|
chosen, tokenize=False, add_generation_prompt=False
|
2024-09-20 17:04:16 +08:00
|
|
|
|
)
|
2025-02-09 23:49:47 +08:00
|
|
|
|
|
|
|
|
|
rejected_prompt = self.tokenizer.apply_chat_template(
|
|
|
|
|
rejected, tokenize=False, add_generation_prompt=False
|
|
|
|
|
)
|
|
|
|
|
chosen_encoding = self.tokenizer(
|
|
|
|
|
chosen_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
|
|
|
|
)
|
|
|
|
|
rejected_encoding = self.tokenizer(
|
|
|
|
|
rejected_prompt, truncation=True, max_length=self.max_length, padding='max_length'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
chosen_input_ids = chosen_encoding['input_ids']
|
|
|
|
|
chosen_loss_mask = self._generate_loss_mask(chosen_input_ids)
|
|
|
|
|
|
|
|
|
|
rejected_input_ids = rejected_encoding['input_ids']
|
|
|
|
|
rejected_loss_mask = self._generate_loss_mask(rejected_input_ids)
|
|
|
|
|
x_chosen = torch.tensor(chosen_input_ids[:-1], dtype=torch.long)
|
|
|
|
|
y_chosen = torch.tensor(chosen_input_ids[1:], dtype=torch.long)
|
|
|
|
|
mask_chosen = torch.tensor(chosen_loss_mask[1:], dtype=torch.long)
|
|
|
|
|
x_rejected = torch.tensor(rejected_input_ids[:-1], dtype=torch.long)
|
|
|
|
|
y_rejected = torch.tensor(rejected_input_ids[1:], dtype=torch.long)
|
|
|
|
|
mask_rejected = torch.tensor(rejected_loss_mask[1:], dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'x_chosen': x_chosen,
|
|
|
|
|
'y_chosen': y_chosen,
|
|
|
|
|
'mask_chosen': mask_chosen,
|
|
|
|
|
'x_rejected': x_rejected,
|
|
|
|
|
'y_rejected': y_rejected,
|
|
|
|
|
'mask_rejected': mask_rejected
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _generate_loss_mask(self, input_ids):
|
|
|
|
|
loss_mask = [0] * len(input_ids)
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(input_ids):
|
|
|
|
|
if input_ids[i:i + len(self.bos_id)] == self.bos_id:
|
|
|
|
|
start = i + len(self.bos_id)
|
|
|
|
|
end = start
|
|
|
|
|
while end < len(input_ids):
|
|
|
|
|
if input_ids[end:end + len(self.eos_id)] == self.eos_id:
|
|
|
|
|
break
|
|
|
|
|
end += 1
|
|
|
|
|
for j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)):
|
|
|
|
|
loss_mask[j] = 1
|
|
|
|
|
i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
|
|
|
|
|
else:
|
|
|
|
|
i += 1
|
|
|
|
|
return loss_mask
|
2024-09-20 17:04:16 +08:00
|
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
2025-04-05 16:06:08 +08:00
|
|
|
|
class RLAIFDataset(Dataset):
|
|
|
|
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
self.samples = self.load_data(jsonl_path)
|
|
|
|
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
|
|
|
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.samples)
|
|
|
|
|
|
|
|
|
|
def load_data(self, path):
|
|
|
|
|
samples = []
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f:
|
|
|
|
|
for line_num, line in enumerate(f, 1):
|
|
|
|
|
data = json.loads(line.strip())
|
|
|
|
|
samples.append(data)
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
|
def _create_chat_prompt(self, conversations):
|
|
|
|
|
"""构建符合ChatML格式的对话"""
|
|
|
|
|
messages = []
|
|
|
|
|
answer = ''
|
|
|
|
|
for i, turn in enumerate(conversations):
|
|
|
|
|
role = 'user' if i % 2 == 0 else 'assistant'
|
|
|
|
|
messages.append({"role": role, "content": turn['content']})
|
|
|
|
|
answer = turn['content']
|
|
|
|
|
return self.tokenizer.apply_chat_template(
|
|
|
|
|
messages[:-1],
|
|
|
|
|
tokenize=False,
|
|
|
|
|
add_generation_prompt=True
|
|
|
|
|
), answer
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
sample = self.samples[index]
|
|
|
|
|
# 构建对话提示
|
|
|
|
|
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
'prompt': prompt,
|
|
|
|
|
'answer': answer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
|
if __name__ == "__main__":
|
2024-09-20 17:04:16 +08:00
|
|
|
|
pass
|