fix chat mask bug

This commit is contained in:
jingyaogong 2025-04-01 13:44:55 +08:00
parent 258507ff89
commit e369b33265

View File

@ -58,8 +58,8 @@ class SFTDataset(Dataset):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.samples = self.load_data(jsonl_path) self.samples = self.load_data(jsonl_path)
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
@ -126,8 +126,8 @@ class DPODataset(Dataset):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_length = max_length self.max_length = max_length
self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.padding = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.bos_id = tokenizer('<s>assistant\n', add_special_tokens=False).input_ids self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
self.eos_id = tokenizer('</s>\n', add_special_tokens=False).input_ids self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
self.data = [] self.data = []
for line in f: for line in f: