update data_process Chunk_size Read
This commit is contained in:
parent
6759da45c1
commit
7115a98453
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import re
|
||||
import json
|
||||
import jsonlines
|
||||
@ -11,7 +12,6 @@ from datasets import load_dataset
|
||||
bos_token = "<s>"
|
||||
eos_token = "</s>"
|
||||
|
||||
|
||||
# pretrain
|
||||
def process_wiki_clean():
|
||||
with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read:
|
||||
@ -63,26 +63,40 @@ def process_other():
|
||||
f.write(arr.tobytes())
|
||||
|
||||
|
||||
# pretrain
|
||||
def process_seq_monkey():
|
||||
def process_seq_monkey(chunk_size=50000):
|
||||
doc_ids = []
|
||||
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
||||
for idx, obj in enumerate(reader):
|
||||
try:
|
||||
content = obj.get('text', '')
|
||||
if len(content) > 512:
|
||||
continue
|
||||
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
||||
doc_ids += text_id
|
||||
if idx % 50000 == 0:
|
||||
print(f"seq_monkey: [{idx}]")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Skipping invalid line {idx + 1}: {e}")
|
||||
continue
|
||||
chunk_idx = 0
|
||||
|
||||
arr = np.array(doc_ids, dtype=np.uint16)
|
||||
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||
f.write(arr.tobytes())
|
||||
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
||||
while True:
|
||||
chunk = list(itertools.islice(reader, chunk_size))
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
for idx, obj in enumerate(chunk):
|
||||
try:
|
||||
content = obj.get('text', '')
|
||||
if len(content) > 512:
|
||||
continue
|
||||
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
||||
doc_ids += text_id
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
chunk_idx += 1
|
||||
print(f"Processed chunk {chunk_idx} with {chunk_size} lines")
|
||||
|
||||
if len(doc_ids) > 1000000:
|
||||
arr = np.array(doc_ids, dtype=np.uint16)
|
||||
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||
f.write(arr.tobytes())
|
||||
doc_ids = []
|
||||
|
||||
if doc_ids:
|
||||
arr = np.array(doc_ids, dtype=np.uint16)
|
||||
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||
f.write(arr.tobytes())
|
||||
|
||||
|
||||
def pretrain_process():
|
||||
@ -172,7 +186,6 @@ def sft_process(contain_history=False):
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
|
||||
|
||||
def rl_process():
|
||||
################
|
||||
# Dataset
|
||||
@ -212,6 +225,6 @@ if __name__ == "__main__":
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
if process_type == 2:
|
||||
sft_process(contain_history=True)
|
||||
sft_process(contain_history=False)
|
||||
if process_type == 3:
|
||||
rl_process()
|
||||
|
Loading…
x
Reference in New Issue
Block a user