fix chat mask bug
This commit is contained in:
parent
258507ff89
commit
e369b33265
@ -58,8 +58,8 @@ class SFTDataset(Dataset):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.samples = self.load_data(jsonl_path)
|
self.samples = self.load_data(jsonl_path)
|
||||||
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||||
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
@ -126,8 +126,8 @@ class DPODataset(Dataset):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||||
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||||
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
self.data = []
|
self.data = []
|
||||||
for line in f:
|
for line in f:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user