update data_process

This commit is contained in:
gongjy 2024-09-22 21:17:05 +08:00
parent 8ff2ccae4c
commit bf64ffb056

View File

@ -161,9 +161,9 @@ def sft_process(contain_history=False):
with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f:
f.write('history,q,a\n')
sft_datasets = ['./dataset/sft_data_zh_2.jsonl']
sft_datasets = ['./dataset/sft_data_zh.jsonl']
if not contain_history:
sft_datasets = ['./dataset/sft_data_zh_2.jsonl']
sft_datasets = ['./dataset/sft_data_zh.jsonl']
for path in sft_datasets:
with jsonlines.open(path) as reader:
@ -192,7 +192,7 @@ def rl_process():
################
dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
'./dataset/dpo/train_1.json',
'./dataset/dpo/train_data.json',
'./dataset/dpo/huozi_rlhf_data.json', ]
train_dataset = load_dataset('json', data_files=dataset_path)
@ -220,7 +220,7 @@ if __name__ == "__main__":
# 2: sft
# 3: RL
################
process_type = 1
process_type = 3
if process_type == 1:
pretrain_process()