update data_process
This commit is contained in:
parent
8ff2ccae4c
commit
bf64ffb056
@ -161,9 +161,9 @@ def sft_process(contain_history=False):
|
||||
with open(f'./dataset/{file_name}', 'w', encoding='utf-8') as f:
|
||||
f.write('history,q,a\n')
|
||||
|
||||
sft_datasets = ['./dataset/sft_data_zh_2.jsonl']
|
||||
sft_datasets = ['./dataset/sft_data_zh.jsonl']
|
||||
if not contain_history:
|
||||
sft_datasets = ['./dataset/sft_data_zh_2.jsonl']
|
||||
sft_datasets = ['./dataset/sft_data_zh.jsonl']
|
||||
|
||||
for path in sft_datasets:
|
||||
with jsonlines.open(path) as reader:
|
||||
@ -192,7 +192,7 @@ def rl_process():
|
||||
################
|
||||
|
||||
dataset_path = ['./dataset/dpo/dpo_zh_demo.json',
|
||||
'./dataset/dpo/train_1.json',
|
||||
'./dataset/dpo/train_data.json',
|
||||
'./dataset/dpo/huozi_rlhf_data.json', ]
|
||||
|
||||
train_dataset = load_dataset('json', data_files=dataset_path)
|
||||
@ -220,7 +220,7 @@ if __name__ == "__main__":
|
||||
# 2: sft
|
||||
# 3: RL
|
||||
################
|
||||
process_type = 1
|
||||
process_type = 3
|
||||
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
|
Loading…
x
Reference in New Issue
Block a user