From 7115a9845368e361449fe6578ef4e32eadff95f4 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Sat, 21 Sep 2024 22:57:22 +0800 Subject: [PATCH] update data_process Chunk_size Read --- data_process.py | 55 ++++++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/data_process.py b/data_process.py index 102b4a8..32162ec 100644 --- a/data_process.py +++ b/data_process.py @@ -1,3 +1,4 @@ +import itertools import re import json import jsonlines @@ -11,7 +12,6 @@ from datasets import load_dataset bos_token = "" eos_token = "" - # pretrain def process_wiki_clean(): with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read: @@ -63,26 +63,40 @@ def process_other(): f.write(arr.tobytes()) -# pretrain -def process_seq_monkey(): +def process_seq_monkey(chunk_size=50000): doc_ids = [] - with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader: - for idx, obj in enumerate(reader): - try: - content = obj.get('text', '') - if len(content) > 512: - continue - text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids'] - doc_ids += text_id - if idx % 50000 == 0: - print(f"seq_monkey: [{idx}]") - except UnicodeDecodeError as e: - print(f"Skipping invalid line {idx + 1}: {e}") - continue + chunk_idx = 0 - arr = np.array(doc_ids, dtype=np.uint16) - with open('./dataset/clean_seq_monkey.bin', 'wb') as f: - f.write(arr.tobytes()) + with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader: + while True: + chunk = list(itertools.islice(reader, chunk_size)) + if not chunk: + break + + for idx, obj in enumerate(chunk): + try: + content = obj.get('text', '') + if len(content) > 512: + continue + text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids'] + doc_ids += text_id + except UnicodeDecodeError as e: + print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}") + continue + + chunk_idx += 1 + print(f"Processed chunk {chunk_idx} with {chunk_size} lines") + + if len(doc_ids) > 1000000: + arr = np.array(doc_ids, dtype=np.uint16) + with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f: + f.write(arr.tobytes()) + doc_ids = [] + + if doc_ids: + arr = np.array(doc_ids, dtype=np.uint16) + with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f: + f.write(arr.tobytes()) def pretrain_process(): @@ -172,7 +186,6 @@ def sft_process(contain_history=False): process_and_write_data(data) data = [] - def rl_process(): ################ # Dataset @@ -212,6 +225,6 @@ if __name__ == "__main__": if process_type == 1: pretrain_process() if process_type == 2: - sft_process(contain_history=True) + sft_process(contain_history=False) if process_type == 3: rl_process()