update rlaif
This commit is contained in:
parent
9e67798397
commit
4a7c1c49e8
@ -196,5 +196,50 @@ class DPODataset(Dataset):
|
|||||||
return loss_mask
|
return loss_mask
|
||||||
|
|
||||||
|
|
||||||
|
class RLAIFDataset(Dataset):
|
||||||
|
def __init__(self, jsonl_path, tokenizer, max_length=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.samples = self.load_data(jsonl_path)
|
||||||
|
self.bos_id = tokenizer('<s>assistant', add_special_tokens=False).input_ids
|
||||||
|
self.eos_id = tokenizer('</s>', add_special_tokens=False).input_ids
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def load_data(self, path):
|
||||||
|
samples = []
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
samples.append(data)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def _create_chat_prompt(self, conversations):
|
||||||
|
"""构建符合ChatML格式的对话"""
|
||||||
|
messages = []
|
||||||
|
answer = ''
|
||||||
|
for i, turn in enumerate(conversations):
|
||||||
|
role = 'user' if i % 2 == 0 else 'assistant'
|
||||||
|
messages.append({"role": role, "content": turn['content']})
|
||||||
|
answer = turn['content']
|
||||||
|
return self.tokenizer.apply_chat_template(
|
||||||
|
messages[:-1],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
), answer
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.samples[index]
|
||||||
|
# 构建对话提示
|
||||||
|
prompt, answer = self._create_chat_prompt(sample['conversations'])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'prompt': prompt,
|
||||||
|
'answer': answer
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user