2024-08-28 16:41:44 +08:00
|
|
|
import json
|
|
|
|
import random
|
|
|
|
import re
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
import torch
|
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
import os
|
2024-09-14 14:05:41 +08:00
|
|
|
|
2024-09-20 17:04:16 +08:00
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
class PretrainDataset(Dataset):
|
2024-09-27 16:19:30 +08:00
|
|
|
def __init__(self, df, tokenizer, max_length=512):
|
2024-08-28 16:41:44 +08:00
|
|
|
super().__init__()
|
2024-09-27 16:19:30 +08:00
|
|
|
self.df = df
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
self.max_length = max_length
|
|
|
|
self.padding = 0
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
2024-09-27 16:19:30 +08:00
|
|
|
return self.df.shape[0]
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
def __getitem__(self, index: int):
|
2024-09-20 17:04:16 +08:00
|
|
|
#
|
2024-09-27 16:19:30 +08:00
|
|
|
sample = self.df.iloc[index]
|
|
|
|
text = f"{self.tokenizer.bos_token}{str(sample['text'])}{self.tokenizer.eos_token}"
|
|
|
|
input_id = self.tokenizer(text).data['input_ids'][:self.max_length]
|
|
|
|
# 没满最大长度的剩余部分
|
|
|
|
padding_len = self.max_length - len(input_id)
|
|
|
|
input_id = input_id + [self.padding] * padding_len
|
|
|
|
|
|
|
|
input_id = np.array(input_id)
|
|
|
|
X = np.array(input_id[:-1]).astype(np.int64)
|
|
|
|
Y = np.array(input_id[1:]).astype(np.int64)
|
2024-09-20 17:04:16 +08:00
|
|
|
return torch.from_numpy(X), torch.from_numpy(Y)
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
class SFTDataset(Dataset):
|
|
|
|
def __init__(self, df, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=256):
|
|
|
|
super().__init__()
|
2024-09-20 17:04:16 +08:00
|
|
|
self.df = df
|
|
|
|
self.max_length = max_length
|
|
|
|
self.prompt_max_len = prompt_max_len
|
|
|
|
self.answer_max_len = answer_max_len
|
2024-08-28 16:41:44 +08:00
|
|
|
#
|
2024-09-20 17:04:16 +08:00
|
|
|
self.tokenizer = tokenizer
|
2024-09-27 16:19:30 +08:00
|
|
|
self.padding = 0
|
2024-09-20 17:04:16 +08:00
|
|
|
self.bos_id = self.tokenizer('<s>assistant').data['input_ids']
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
def __len__(self):
|
2024-09-20 17:04:16 +08:00
|
|
|
return self.df.shape[0]
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
def find_sublist_index(self, main_list, sub_list) -> int:
|
|
|
|
last_index = -1
|
|
|
|
for i in range(len(main_list) - len(sub_list) + 1):
|
|
|
|
if main_list[i:i + len(sub_list)] == sub_list:
|
|
|
|
last_index = i
|
2024-09-20 17:04:16 +08:00
|
|
|
return last_index
|
2024-08-28 16:41:44 +08:00
|
|
|
|
2024-09-14 14:05:41 +08:00
|
|
|
def safe_eval(self, s):
|
|
|
|
try:
|
|
|
|
res = eval(s)
|
|
|
|
except Exception as e:
|
|
|
|
return []
|
2024-09-20 17:04:16 +08:00
|
|
|
return res
|
2024-09-14 14:05:41 +08:00
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
def __getitem__(self, index: int):
|
2024-09-20 17:04:16 +08:00
|
|
|
#
|
2024-08-28 16:41:44 +08:00
|
|
|
sample = self.df.iloc[index]
|
2024-09-20 17:04:16 +08:00
|
|
|
history = self.safe_eval(sample['history'])
|
|
|
|
q = str(sample['q'])
|
|
|
|
a = str(sample['a'])
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
messages = []
|
|
|
|
for history_message in history:
|
|
|
|
if len(history_message) <= 1:
|
|
|
|
continue
|
|
|
|
messages.append(
|
2024-09-14 16:09:42 +08:00
|
|
|
{"role": 'user', "content": str(history_message[0])[:self.max_length // 2]}
|
2024-08-28 16:41:44 +08:00
|
|
|
)
|
|
|
|
messages.append(
|
2024-09-14 16:09:42 +08:00
|
|
|
{"role": 'assistant', "content": str(history_message[1])[:self.max_length // 2]}
|
2024-08-28 16:41:44 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
messages += [
|
|
|
|
{"role": "user", "content": q},
|
|
|
|
{"role": "assistant", "content": a},
|
|
|
|
]
|
|
|
|
new_prompt = self.tokenizer.apply_chat_template(
|
|
|
|
messages,
|
|
|
|
tokenize=False,
|
|
|
|
add_generation_prompt=True
|
2024-09-20 17:04:16 +08:00
|
|
|
)
|
|
|
|
input_id = self.tokenizer(new_prompt).data['input_ids'][:self.max_length]
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
# 实际长度
|
|
|
|
question_length = self.find_sublist_index(input_id, self.bos_id) + len(self.bos_id)
|
|
|
|
# 没满最大长度的剩余部分
|
|
|
|
padding_len = self.max_length - len(input_id)
|
2024-09-20 17:04:16 +08:00
|
|
|
input_id = input_id + [self.padding] * padding_len
|
2024-08-28 16:41:44 +08:00
|
|
|
mask_len = len(input_id) - question_length - padding_len
|
|
|
|
# 0表示不计算损失
|
|
|
|
loss_mask = [0] * question_length + [1] * (mask_len) + [0] * padding_len
|
|
|
|
|
|
|
|
input_id = np.array(input_id)
|
2024-09-20 17:04:16 +08:00
|
|
|
X = np.array(input_id[:-1]).astype(np.int64)
|
|
|
|
Y = np.array(input_id[1:]).astype(np.int64)
|
|
|
|
loss_mask = np.array(loss_mask[1:]).astype(np.int64)
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
X_tensor = torch.from_numpy(X)
|
|
|
|
Y_tensor = torch.from_numpy(Y)
|
|
|
|
loss_mask_tensor = torch.from_numpy(loss_mask)
|
|
|
|
|
2024-09-20 17:04:16 +08:00
|
|
|
return X_tensor, Y_tensor, loss_mask_tensor
|
|
|
|
|
2024-08-28 16:41:44 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2024-09-20 17:04:16 +08:00
|
|
|
pass
|