update rlaif
This commit is contained in:
parent
9e67798397
commit
4a7c1c49e8
@ -196,5 +196,50 @@ class DPODataset(Dataset):
|
||||
return loss_mask
|
||||
|
||||
|
||||
class RLAIFDataset(Dataset):
|
||||
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.samples = self.load_data(jsonl_path)
|
||||
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def load_data(self, path):
|
||||
samples = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
data = json.loads(line.strip())
|
||||
samples.append(data)
|
||||
return samples
|
||||
|
||||
def _create_chat_prompt(self, conversations):
|
||||
"""构建符合ChatML格式的对话"""
|
||||
messages = []
|
||||
answer = ''
|
||||
for i, turn in enumerate(conversations):
|
||||
role = 'user' if i % 2 == 0 else 'assistant'
|
||||
messages.append({"role": role, "content": turn['content']})
|
||||
answer = turn['content']
|
||||
return self.tokenizer.apply_chat_template(
|
||||
messages[:-1],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
), answer
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
# 构建对话提示
|
||||
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||||
|
||||
return {
|
||||
'prompt': prompt,
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user