diff --git a/data_process.py b/data_process.py index 2663ef1..9c03628 100644 --- a/data_process.py +++ b/data_process.py @@ -89,7 +89,7 @@ def process_seq_monkey(chunk_size=50000): if len(doc_ids) > 1000000: arr = np.array(doc_ids, dtype=np.uint16) - with open(f'./dataset/clean_seq_monkey.bin', 'wb') as f: + with open(f'./dataset/clean_seq_monkey.bin', 'ab') as f: f.write(arr.tobytes()) doc_ids = [] @@ -220,7 +220,7 @@ if __name__ == "__main__": # 2: sft # 3: RL ################ - process_type = 3 + process_type = 1 if process_type == 1: pretrain_process()