update data_process

This commit is contained in:
gongjy 2024-09-22 21:17:05 +08:00
parent 8ff2ccae4c
commit bf64ffb056

View File

@ -161,9 +161,9 @@ def sft_process(contain_history=False):
with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f: with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f:
f.write('history,q,a\n') f.write('history,q,a\n')
sft_datasets = ['./dataset/sft_data_zh_2.jsonl'] sft_datasets = ['./dataset/sft_data_zh.jsonl']
if not contain_history: if not contain_history:
sft_datasets = ['./dataset/sft_data_zh_2.jsonl'] sft_datasets = ['./dataset/sft_data_zh.jsonl']
for path in sft_datasets: for path in sft_datasets:
with jsonlines.open(path) as reader: with jsonlines.open(path) as reader:
@ -192,7 +192,7 @@ def rl_process():
################ ################
dataset_path = ['./dataset/dpo/dpo_zh_demo.json', dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
'./dataset/dpo/train_1.json', './dataset/dpo/train_data.json',
'./dataset/dpo/huozi_rlhf_data.json', ] './dataset/dpo/huozi_rlhf_data.json', ]
train_dataset = load_dataset('json', data_files=dataset_path) train_dataset = load_dataset('json', data_files=dataset_path)
@ -220,7 +220,7 @@ if __name__ == "__main__":
# 2: sft # 2: sft
# 3: RL # 3: RL
################ ################
process_type = 1 process_type = 3
if process_type == 1: if process_type == 1:
pretrain_process() pretrain_process()