update data_process

This commit is contained in:
gongjy 2024-10-23 12:02:28 +08:00
parent 3ff66f7221
commit 69bcb8dc90

View File

@ -60,10 +60,10 @@ def sft_process(contain_history=False):
continue continue
if len(q) < 10 or len(a) < 5: if len(q) < 10 or len(a) < 5:
continue continue
if len(q) > 256 or len(a) > 256: if len(q) > 512 or len(a) > 512:
continue continue
# 判断q和a中中文字符占比是否超过70% # 判断q和a中中文字符占比是否超过70%
if not (chinese_ratio(q) > 0.9 and chinese_ratio(a) > 0.9): if not (chinese_ratio(q) > 0.86 and chinese_ratio(a) > 0.86):
continue continue
q_lst.append(q) q_lst.append(q)
@ -87,6 +87,7 @@ def sft_process(contain_history=False):
if not contain_history: if not contain_history:
sft_datasets = ['./dataset/sft_data_zh.jsonl'] sft_datasets = ['./dataset/sft_data_zh.jsonl']
chunk_num = 0
for path in sft_datasets: for path in sft_datasets:
with jsonlines.open(path) as reader: with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader): for idx, obj in enumerate(reader):
@ -98,8 +99,11 @@ def sft_process(contain_history=False):
}) })
if len(data) >= chunk_size: if len(data) >= chunk_size:
chunk_num += 1
process_and_write_data(data) process_and_write_data(data)
data = [] data = []
if chunk_num % 100 == 0:
print(f'chunk:{chunk_num} process end')
except jsonlines.InvalidLineError as e: except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}") print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue continue
@ -139,7 +143,7 @@ if __name__ == "__main__":
# 2: sft # 2: sft
# 3: RL # 3: RL
################ ################
process_type = 3 process_type = 2
if process_type == 1: if process_type == 1:
pretrain_process() pretrain_process()