update data_process Chunk_size Read

This commit is contained in:
gongjy 2024-09-21 22:57:22 +08:00
parent 6759da45c1
commit 7115a98453

View File

@ -1,3 +1,4 @@
import itertools
import re
import json
import jsonlines
@ -11,7 +12,6 @@ from datasets import load_dataset
bos_token = "<s>"
eos_token = "</s>"
# pretrain
def process_wiki_clean():
with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read:
@ -63,26 +63,40 @@ def process_other():
f.write(arr.tobytes())
# pretrain
def process_seq_monkey():
def process_seq_monkey(chunk_size=50000):
doc_ids = []
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
for idx, obj in enumerate(reader):
try:
content = obj.get('text', '')
if len(content) > 512:
continue
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
doc_ids += text_id
if idx % 50000 == 0:
print(f"seq_monkey: [{idx}]")
except UnicodeDecodeError as e:
print(f"Skipping invalid line {idx + 1}: {e}")
continue
chunk_idx = 0
arr = np.array(doc_ids, dtype=np.uint16)
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
f.write(arr.tobytes())
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
while True:
chunk = list(itertools.islice(reader, chunk_size))
if not chunk:
break
for idx, obj in enumerate(chunk):
try:
content = obj.get('text', '')
if len(content) > 512:
continue
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
doc_ids += text_id
except UnicodeDecodeError as e:
print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}")
continue
chunk_idx += 1
print(f"Processed chunk {chunk_idx} with {chunk_size} lines")
if len(doc_ids) > 1000000:
arr = np.array(doc_ids, dtype=np.uint16)
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
f.write(arr.tobytes())
doc_ids = []
if doc_ids:
arr = np.array(doc_ids, dtype=np.uint16)
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
f.write(arr.tobytes())
def pretrain_process():
@ -172,7 +186,6 @@ def sft_process(contain_history=False):
process_and_write_data(data)
data = []
def rl_process():
################
# Dataset
@ -212,6 +225,6 @@ if __name__ == "__main__":
if process_type == 1:
pretrain_process()
if process_type == 2:
sft_process(contain_history=True)
sft_process(contain_history=False)
if process_type == 3:
rl_process()