update data_process & full_sft
This commit is contained in:
parent
41b474e2bf
commit
56c6139896
@ -109,7 +109,7 @@ def init_model(lm_config):
|
||||
|
||||
if model_from == 1:
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
model = Transformer(lm_config)
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
@ -148,8 +148,8 @@ if __name__ == "__main__":
|
||||
out_dir = 'out'
|
||||
epochs = 19
|
||||
gradient_accumulation_steps = 1
|
||||
batch_size = 48
|
||||
learning_rate = 2e-8
|
||||
batch_size = 80
|
||||
learning_rate = 2e-4
|
||||
device = 'cuda:0'
|
||||
dtype = 'bfloat16'
|
||||
# dtype = 'float16'
|
||||
@ -175,7 +175,7 @@ if __name__ == "__main__":
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
# -----init dataloader------
|
||||
df = pd.read_csv('./dataset/sft_data.csv')
|
||||
df = pd.read_csv('./dataset/sft_data_single.csv')
|
||||
df = df.sample(frac=1.0)
|
||||
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
|
@ -68,13 +68,17 @@ def process_seq_monkey():
|
||||
doc_ids = []
|
||||
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
||||
for idx, obj in enumerate(reader):
|
||||
content = obj.get('text', '')
|
||||
if len(content) > 512:
|
||||
try:
|
||||
content = obj.get('text', '')
|
||||
if len(content) > 512:
|
||||
continue
|
||||
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
||||
doc_ids += text_id
|
||||
if idx % 50000 == 0:
|
||||
print(f"seq_monkey: [{idx}]")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Skipping invalid line {idx + 1}: {e}")
|
||||
continue
|
||||
text_id = tokenizer(f'{bos_token}{content}{eos_token}').data['input_ids']
|
||||
doc_ids += text_id
|
||||
if idx % 50000 == 0:
|
||||
print(f"seq_monkey: [{idx}]")
|
||||
|
||||
arr = np.array(doc_ids, dtype=np.uint16)
|
||||
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||
@ -150,15 +154,19 @@ def sft_process(contain_history=False):
|
||||
for path in sft_datasets:
|
||||
with jsonlines.open(path) as reader:
|
||||
for idx, obj in enumerate(reader):
|
||||
data.append({
|
||||
'history': obj.get('history', ''),
|
||||
'q': obj.get('input', '') + obj.get('q', ''),
|
||||
'a': obj.get('output', '') + obj.get('a', '')
|
||||
})
|
||||
try:
|
||||
data.append({
|
||||
'history': obj.get('history', ''),
|
||||
'q': obj.get('input', '') + obj.get('q', ''),
|
||||
'a': obj.get('output', '') + obj.get('a', '')
|
||||
})
|
||||
|
||||
if len(data) >= chunk_size:
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
if len(data) >= chunk_size:
|
||||
process_and_write_data(data)
|
||||
data = []
|
||||
except jsonlines.InvalidLineError as e:
|
||||
print(f"Skipping invalid JSON line {idx + 1}: {e}")
|
||||
continue
|
||||
|
||||
if data:
|
||||
process_and_write_data(data)
|
||||
@ -191,7 +199,7 @@ def rl_process():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model', use_fast=False)
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False)
|
||||
print('tokenizer词表大小:', len(tokenizer))
|
||||
|
||||
################
|
||||
@ -199,7 +207,7 @@ if __name__ == "__main__":
|
||||
# 2: sft
|
||||
# 3: RL
|
||||
################
|
||||
process_type = 2
|
||||
process_type = 1
|
||||
|
||||
if process_type == 1:
|
||||
pretrain_process()
|
||||
|
Loading…
x
Reference in New Issue
Block a user