From e369b33265ce4cd56615c6a2af8666a74118e1f5 Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Tue, 1 Apr 2025 13:44:55 +0800 Subject: [PATCH] fix chat mask bug --- model/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/dataset.py b/model/dataset.py index 7750789..4336c7e 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -58,8 +58,8 @@ class SFTDataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.samples = self.load_data(jsonl_path) - self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids - self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids + self.bos_id = tokenizer('assistant', add_special_tokens=False).input_ids + self.eos_id = tokenizer('', add_special_tokens=False).input_ids def __len__(self): return len(self.samples) @@ -126,8 +126,8 @@ class DPODataset(Dataset): self.tokenizer = tokenizer self.max_length = max_length self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 - self.bos_id = tokenizer('assistant\n', add_special_tokens=False).input_ids - self.eos_id = tokenizer('\n', add_special_tokens=False).input_ids + self.bos_id = tokenizer('assistant', add_special_tokens=False).input_ids + self.eos_id = tokenizer('', add_special_tokens=False).input_ids with open(file_path, 'r', encoding='utf-8') as f: self.data = [] for line in f: