update data_process Chunk_size Read
This commit is contained in:
parent
6759da45c1
commit
7115a98453
@ -1,3 +1,4 @@
|
|||||||
|
import itertools
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
import jsonlines
|
import jsonlines
|
||||||
@ -11,7 +12,6 @@ from datasets import load_dataset
|
|||||||
bos_token = "<s>"
|
bos_token = "<s>"
|
||||||
eos_token = "</s>"
|
eos_token = "</s>"
|
||||||
|
|
||||||
|
|
||||||
# pretrain
|
# pretrain
|
||||||
def process_wiki_clean():
|
def process_wiki_clean():
|
||||||
with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read:
|
with open('./dataset/clean-wikipedia-cn.json', 'r', encoding='utf-8') as f_read:
|
||||||
@ -63,26 +63,40 @@ def process_other():
|
|||||||
f.write(arr.tobytes())
|
f.write(arr.tobytes())
|
||||||
|
|
||||||
|
|
||||||
# pretrain
|
def process_seq_monkey(chunk_size=50000):
|
||||||
def process_seq_monkey():
|
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
chunk_idx = 0
|
||||||
for idx, obj in enumerate(reader):
|
|
||||||
try:
|
|
||||||
content = obj.get('text', '')
|
|
||||||
if len(content) > 512:
|
|
||||||
continue
|
|
||||||
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
|
||||||
doc_ids += text_id
|
|
||||||
if idx % 50000 == 0:
|
|
||||||
print(f"seq_monkey: [{idx}]")
|
|
||||||
except UnicodeDecodeError as e:
|
|
||||||
print(f"Skipping invalid line {idx + 1}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
arr = np.array(doc_ids, dtype=np.uint16)
|
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
||||||
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
|
while True:
|
||||||
f.write(arr.tobytes())
|
chunk = list(itertools.islice(reader, chunk_size))
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
for idx, obj in enumerate(chunk):
|
||||||
|
try:
|
||||||
|
content = obj.get('text', '')
|
||||||
|
if len(content) > 512:
|
||||||
|
continue
|
||||||
|
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
||||||
|
doc_ids += text_id
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
print(f"Skipping invalid line {chunk_idx * chunk_size + idx + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_idx += 1
|
||||||
|
print(f"Processed chunk {chunk_idx} with {chunk_size} lines")
|
||||||
|
|
||||||
|
if len(doc_ids) > 1000000:
|
||||||
|
arr = np.array(doc_ids, dtype=np.uint16)
|
||||||
|
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||||
|
f.write(arr.tobytes())
|
||||||
|
doc_ids = []
|
||||||
|
|
||||||
|
if doc_ids:
|
||||||
|
arr = np.array(doc_ids, dtype=np.uint16)
|
||||||
|
with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||||
|
f.write(arr.tobytes())
|
||||||
|
|
||||||
|
|
||||||
def pretrain_process():
|
def pretrain_process():
|
||||||
@ -172,7 +186,6 @@ def sft_process(contain_history=False):
|
|||||||
process_and_write_data(data)
|
process_and_write_data(data)
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
|
|
||||||
def rl_process():
|
def rl_process():
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -212,6 +225,6 @@ if __name__ == "__main__":
|
|||||||
if process_type == 1:
|
if process_type == 1:
|
||||||
pretrain_process()
|
pretrain_process()
|
||||||
if process_type == 2:
|
if process_type == 2:
|
||||||
sft_process(contain_history=True)
|
sft_process(contain_history=False)
|
||||||
if process_type == 3:
|
if process_type == 3:
|
||||||
rl_process()
|
rl_process()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user