diff --git a/model/dataset.py b/model/dataset.py
index 7750789..4336c7e 100644
--- a/model/dataset.py
+++ b/model/dataset.py
@@ -58,8 +58,8 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.samples = self.load_data(jsonl_path)
- self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids
- self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids
+ self.bos_id = tokenizer('assistant', add_special_tokens=False).input_ids
+ self.eos_id = tokenizer('', add_special_tokens=False).input_ids
def __len__(self):
return len(self.samples)
@@ -126,8 +126,8 @@ class DPODataset(Dataset):
self.tokenizer = tokenizer
self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
- self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids
- self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids
+ self.bos_id = tokenizer('assistant', add_special_tokens=False).input_ids
+ self.eos_id = tokenizer('', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f:
self.data = []
for line in f: