update model/dataset.py
This commit is contained in:
parent
8a407ad1c6
commit
ecf6d44133
@ -8,6 +8,7 @@ from torch.utils.data import Dataset, DataLoader
|
|||||||
import torch
|
import torch
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@ -68,10 +69,17 @@ class SFTDataset(Dataset):
|
|||||||
last_index = i
|
last_index = i
|
||||||
return last_index
|
return last_index
|
||||||
|
|
||||||
|
def safe_eval(self, s):
|
||||||
|
try:
|
||||||
|
res = eval(s)
|
||||||
|
except Exception as e:
|
||||||
|
return []
|
||||||
|
return res
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
#
|
#
|
||||||
sample = self.df.iloc[index]
|
sample = self.df.iloc[index]
|
||||||
history = eval(sample['history'])
|
history = self.safe_eval(sample['history'])
|
||||||
q = sample['q']
|
q = sample['q']
|
||||||
a = sample['a']
|
a = sample['a']
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user