update data_process & full_sft

This commit is contained in:
gongjy 2024-09-13 13:32:24 +08:00
parent 41b474e2bf
commit 56c6139896
2 changed files with 28 additions and 20 deletions

View File

@ -109,7 +109,7 @@ def init_model(lm_config):
if model_from == 1: if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth' ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config) model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device) state_dict = torch.load(ckp, map_location=device)
@ -148,8 +148,8 @@ if __name__ == "__main__":
out_dir = 'out' out_dir = 'out'
epochs = 19 epochs = 19
gradient_accumulation_steps = 1 gradient_accumulation_steps = 1
batch_size = 48 batch_size = 80
learning_rate = 2e-8 learning_rate = 2e-4
device = 'cuda:0' device = 'cuda:0'
dtype = 'bfloat16' dtype = 'bfloat16'
# dtype = 'float16' # dtype = 'float16'
@ -175,7 +175,7 @@ if __name__ == "__main__":
model, tokenizer = init_model(lm_config) model, tokenizer = init_model(lm_config)
# -----init dataloader------ # -----init dataloader------
df = pd.read_csv('./dataset/sft_data.csv') df = pd.read_csv('./dataset/sft_data_single.csv')
df = df.sample(frac=1.0) df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len) train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None train_sampler = DistributedSampler(train_ds) if ddp else None

View File

@ -68,6 +68,7 @@ def process_seq_monkey():
doc_ids = [] doc_ids = []
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader: with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
for idx, obj in enumerate(reader): for idx, obj in enumerate(reader):
try:
content = obj.get('text', '') content = obj.get('text', '')
if len(content) > 512: if len(content) > 512:
continue continue
@ -75,6 +76,9 @@ def process_seq_monkey():
doc_ids += text_id doc_ids += text_id
if idx % 50000 == 0: if idx % 50000 == 0:
print(f"seq_monkey: [{idx}]") print(f"seq_monkey: [{idx}]")
except UnicodeDecodeError as e:
print(f"Skipping invalid line {idx + 1}: {e}")
continue
arr = np.array(doc_ids, dtype=np.uint16) arr = np.array(doc_ids, dtype=np.uint16)
with open('./dataset/clean_seq_monkey.bin', 'wb') as f: with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
@ -150,6 +154,7 @@ def sft_process(contain_history=False):
for path in sft_datasets: for path in sft_datasets:
with jsonlines.open(path) as reader: with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader): for idx, obj in enumerate(reader):
try:
data.append({ data.append({
'history': obj.get('history', ''), 'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''), 'q': obj.get('input', '') + obj.get('q', ''),
@ -159,6 +164,9 @@ def sft_process(contain_history=False):
if len(data) >= chunk_size: if len(data) >= chunk_size:
process_and_write_data(data) process_and_write_data(data)
data = [] data = []
except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue
if data: if data:
process_and_write_data(data) process_and_write_data(data)
@ -191,7 +199,7 @@ def rl_process():
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('./model', use_fast=False) tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False)
print('tokenizer词表大小', len(tokenizer)) print('tokenizer词表大小', len(tokenizer))
################ ################
@ -199,7 +207,7 @@ if __name__ == "__main__":
# 2: sft # 2: sft
# 3: RL # 3: RL
################ ################
process_type = 2 process_type = 1
if process_type == 1: if process_type == 1:
pretrain_process() pretrain_process()