fix chat mask bug
This commit is contained in:
parent
258507ff89
commit
e369b33265
@ -58,8 +58,8 @@ class SFTDataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
@ -126,8 +126,8 @@ class DPODataset(Dataset):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.data = []
|
||||
for line in f:
|
||||
|
Loading…
x
Reference in New Issue
Block a user