diff --git a/data_process.py b/data_process.py index 32162ec..2663ef1 100644 --- a/data_process.py +++ b/data_process.py @@ -161,9 +161,9 @@ def sft_process(contain_history=False): with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f: f.write('history,q,a\n') - sft_datasets = ['./dataset/sft_data_zh_2.jsonl'] + sft_datasets = ['./dataset/sft_data_zh.jsonl'] if not contain_history: - sft_datasets = ['./dataset/sft_data_zh_2.jsonl'] + sft_datasets = ['./dataset/sft_data_zh.jsonl'] for path in sft_datasets: with jsonlines.open(path) as reader: @@ -192,7 +192,7 @@ def rl_process(): ################ dataset_path = ['./dataset/dpo/dpo_zh_demo.json', - './dataset/dpo/train_1.json', + './dataset/dpo/train_data.json', './dataset/dpo/huozi_rlhf_data.json', ] train_dataset = load_dataset('json', data_files=dataset_path) @@ -220,7 +220,7 @@ if __name__ == "__main__": # 2: sft # 3: RL ################ - process_type = 1 + process_type = 3 if process_type == 1: pretrain_process()