From 56c6139896ed70d5bd89d086daeb3ebe7e5c1ae7 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Fri, 13 Sep 2024 13:32:24 +0800 Subject: [PATCH] update data_process & full_sft --- 3-full_sft.py | 8 ++++---- data_process.py | 40 ++++++++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/3-full_sft.py b/3-full_sft.py index 459c244..1c76954 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -109,7 +109,7 @@ def init_model(lm_config): if model_from == 1: moe_path = '_moe' if lm_config.use_moe else '' - ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth' + ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth' model = Transformer(lm_config) state_dict = torch.load(ckp, map_location=device) @@ -148,8 +148,8 @@ if __name__ == "__main__": out_dir = 'out' epochs = 19 gradient_accumulation_steps = 1 - batch_size = 48 - learning_rate = 2e-8 + batch_size = 80 + learning_rate = 2e-4 device = 'cuda:0' dtype = 'bfloat16' # dtype = 'float16' @@ -175,7 +175,7 @@ if __name__ == "__main__": model, tokenizer = init_model(lm_config) # -----init dataloader------ - df = pd.read_csv('./dataset/sft_data.csv') + df = pd.read_csv('./dataset/sft_data_single.csv') df = df.sample(frac=1.0) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_sampler = DistributedSampler(train_ds) if ddp else None diff --git a/data_process.py b/data_process.py index 7595186..102b4a8 100644 --- a/data_process.py +++ b/data_process.py @@ -68,13 +68,17 @@ def process_seq_monkey(): doc_ids = [] with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader: for idx, obj in enumerate(reader): - content = obj.get('text', '') - if len(content) > 512: + try: + content = obj.get('text', '') + if len(content) > 512: + continue + text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids'] + doc_ids += text_id + if idx % 50000 == 0: + print(f"seq_monkey: [{idx}]") + except UnicodeDecodeError as e: + print(f"Skipping invalid line {idx + 1}: {e}") continue - text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids'] - doc_ids += text_id - if idx % 50000 == 0: - print(f"seq_monkey: [{idx}]") arr = np.array(doc_ids, dtype=np.uint16) with open('./dataset/clean_seq_monkey.bin', 'wb') as f: @@ -150,15 +154,19 @@ def sft_process(contain_history=False): for path in sft_datasets: with jsonlines.open(path) as reader: for idx, obj in enumerate(reader): - data.append({ - 'history': obj.get('history', ''), - 'q': obj.get('input', '') + obj.get('q', ''), - 'a': obj.get('output', '') + obj.get('a', '') - }) + try: + data.append({ + 'history': obj.get('history', ''), + 'q': obj.get('input', '') + obj.get('q', ''), + 'a': obj.get('output', '') + obj.get('a', '') + }) - if len(data) >= chunk_size: - process_and_write_data(data) - data = [] + if len(data) >= chunk_size: + process_and_write_data(data) + data = [] + except jsonlines.InvalidLineError as e: + print(f"Skipping invalid JSON line {idx + 1}: {e}") + continue if data: process_and_write_data(data) @@ -191,7 +199,7 @@ def rl_process(): if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained('./model', use_fast=False) + tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False) print('tokenizer词表大小:', len(tokenizer)) ################ @@ -199,7 +207,7 @@ if __name__ == "__main__": # 2: sft # 3: RL ################ - process_type = 2 + process_type = 1 if process_type == 1: pretrain_process()