update data_process & full_sft

This commit is contained in:
gongjy 2024-09-13 13:32:24 +08:00
parent 41b474e2bf
commit 56c6139896
2 changed files with 28 additions and 20 deletions

View File

@ -109,7 +109,7 @@ def init_model(lm_config):
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
@ -148,8 +148,8 @@ if __name__ == "__main__":
out_dir = 'out'
epochs = 19
gradient_accumulation_steps = 1
batch_size = 48
learning_rate = 2e-8
batch_size = 80
learning_rate = 2e-4
device = 'cuda:0'
dtype = 'bfloat16'
# dtype = 'float16'
@ -175,7 +175,7 @@ if __name__ == "__main__":
model, tokenizer = init_model(lm_config)
# -----init dataloader------
df = pd.read_csv('./dataset/sft_data.csv')
df = pd.read_csv('./dataset/sft_data_single.csv')
df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None

View File

@ -68,13 +68,17 @@ def process_seq_monkey():
doc_ids = []
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
for idx, obj in enumerate(reader):
content = obj.get('text', '')
if len(content) > 512:
try:
content = obj.get('text', '')
if len(content) > 512:
continue
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
doc_ids += text_id
if idx % 50000 == 0:
print(f"seq_monkey: [{idx}]")
except UnicodeDecodeError as e:
print(f"Skipping invalid line {idx + 1}: {e}")
continue
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
doc_ids += text_id
if idx % 50000 == 0:
print(f"seq_monkey: [{idx}]")
arr = np.array(doc_ids, dtype=np.uint16)
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
@ -150,15 +154,19 @@ def sft_process(contain_history=False):
for path in sft_datasets:
with jsonlines.open(path) as reader:
for idx, obj in enumerate(reader):
data.append({
'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''),
'a': obj.get('output', '') + obj.get('a', '')
})
try:
data.append({
'history': obj.get('history', ''),
'q': obj.get('input', '') + obj.get('q', ''),
'a': obj.get('output', '') + obj.get('a', '')
})
if len(data) >= chunk_size:
process_and_write_data(data)
data = []
if len(data) >= chunk_size:
process_and_write_data(data)
data = []
except jsonlines.InvalidLineError as e:
print(f"Skipping invalid JSON line {idx + 1}: {e}")
continue
if data:
process_and_write_data(data)
@ -191,7 +199,7 @@ def rl_process():
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('./model', use_fast=False)
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False)
print('tokenizer词表大小', len(tokenizer))
################
@ -199,7 +207,7 @@ if __name__ == "__main__":
# 2: sft
# 3: RL
################
process_type = 2
process_type = 1
if process_type == 1:
pretrain_process()