From 4a7c1c49e8eaae1702c2ff58c33c6cbb006648de Mon Sep 17 00:00:00 2001 From: jingyaogong Date: Sat, 5 Apr 2025 16:06:08 +0800 Subject: [PATCH] update rlaif --- model/dataset.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/model/dataset.py b/model/dataset.py index 4336c7e..d67cb8c 100644 --- a/model/dataset.py +++ b/model/dataset.py @@ -196,5 +196,50 @@ class DPODataset(Dataset): return loss_mask +class RLAIFDataset(Dataset): + def __init__(self, jsonl_path, tokenizer, max_length=1024): + super().__init__() + self.tokenizer = tokenizer + self.max_length = max_length + self.samples = self.load_data(jsonl_path) + self.bos_id = tokenizer('assistant', add_special_tokens=False).input_ids + self.eos_id = tokenizer('', add_special_tokens=False).input_ids + + def __len__(self): + return len(self.samples) + + def load_data(self, path): + samples = [] + with open(path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + data = json.loads(line.strip()) + samples.append(data) + return samples + + def _create_chat_prompt(self, conversations): + """构建符合ChatML格式的对话""" + messages = [] + answer = '' + for i, turn in enumerate(conversations): + role = 'user' if i % 2 == 0 else 'assistant' + messages.append({"role": role, "content": turn['content']}) + answer = turn['content'] + return self.tokenizer.apply_chat_template( + messages[:-1], + tokenize=False, + add_generation_prompt=True + ), answer + + def __getitem__(self, index): + sample = self.samples[index] + # 构建对话提示 + prompt, answer = self._create_chat_prompt(sample['conversations']) + + return { + 'prompt': prompt, + 'answer': answer + } + + if __name__ == "__main__": pass