update data_process Chunk_size Read

This commit is contained in:
gongjy 2024-09-21 22:57:22 +08:00
parent 6759da45c1
commit 7115a98453

View File

@ -1,3 +1,4 @@
import itertools
import re import re
import json import json
import jsonlines import jsonlines
@ -11,7 +12,6 @@ from datasets import load_dataset
bos_token = "<s>" bos_token = "<s>"
eos_token = "</s>" eos_token = "</s>"
# pretrain # pretrain
def process_wiki_clean(): def process_wiki_clean():
with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read: with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read:
@ -63,25 +63,39 @@ def process_other():
f.write(arr.tobytes()) f.write(arr.tobytes())
# pretrain def process_seq_monkey(chunk_size=50000):
def process_seq_monkey():
doc_ids = [] doc_ids = []
chunk_idx = 0
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader: with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
for idx, obj in enumerate(reader): while True:
chunk = list(itertools.islice(reader, chunk_size))
if not chunk:
break
for idx, obj in enumerate(chunk):
try: try:
content = obj.get('text', '') content = obj.get('text', '')
if len(content) > 512: if len(content) > 512:
continue continue
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids'] text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
doc_ids += text_id doc_ids += text_id
if idx % 50000 == 0:
print(f"seq_monkey: [{idx}]")
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
print(f"Skipping invalid line {idx + 1}: {e}") print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}")
continue continue
chunk_idx += 1
print(f"Processed chunk {chunk_idx} with {chunk_size} lines")
if len(doc_ids) > 1000000:
arr = np.array(doc_ids, dtype=np.uint16) arr = np.array(doc_ids, dtype=np.uint16)
with open('./dataset/clean_seq_monkey.bin', 'wb') as f: with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
f.write(arr.tobytes())
doc_ids = []
if doc_ids:
arr = np.array(doc_ids, dtype=np.uint16)
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
f.write(arr.tobytes()) f.write(arr.tobytes())
@ -172,7 +186,6 @@ def sft_process(contain_history=False):
process_and_write_data(data) process_and_write_data(data)
data = [] data = []
def rl_process(): def rl_process():
################ ################
# Dataset # Dataset
@ -212,6 +225,6 @@ if __name__ == "__main__":
if process_type == 1: if process_type == 1:
pretrain_process() pretrain_process()
if process_type == 2: if process_type == 2:
sft_process(contain_history=True) sft_process(contain_history=False)
if process_type == 3: if process_type == 3:
rl_process() rl_process()