update data_process & full_sft
This commit is contained in:
parent
41b474e2bf
commit
56c6139896
@ -109,7 +109,7 @@ def init_model(lm_config):
|
|||||||
|
|
||||||
if model_from == 1:
|
if model_from == 1:
|
||||||
moe_path = '_moe' if lm_config.use_moe else ''
|
moe_path = '_moe' if lm_config.use_moe else ''
|
||||||
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
|
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||||
|
|
||||||
model = Transformer(lm_config)
|
model = Transformer(lm_config)
|
||||||
state_dict = torch.load(ckp, map_location=device)
|
state_dict = torch.load(ckp, map_location=device)
|
||||||
@ -148,8 +148,8 @@ if __name__ == "__main__":
|
|||||||
out_dir = 'out'
|
out_dir = 'out'
|
||||||
epochs = 19
|
epochs = 19
|
||||||
gradient_accumulation_steps = 1
|
gradient_accumulation_steps = 1
|
||||||
batch_size = 48
|
batch_size = 80
|
||||||
learning_rate = 2e-8
|
learning_rate = 2e-4
|
||||||
device = 'cuda:0'
|
device = 'cuda:0'
|
||||||
dtype = 'bfloat16'
|
dtype = 'bfloat16'
|
||||||
# dtype = 'float16'
|
# dtype = 'float16'
|
||||||
@ -175,7 +175,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model, tokenizer = init_model(lm_config)
|
model, tokenizer = init_model(lm_config)
|
||||||
# -----init dataloader------
|
# -----init dataloader------
|
||||||
df = pd.read_csv('./dataset/sft_data.csv')
|
df = pd.read_csv('./dataset/sft_data_single.csv')
|
||||||
df = df.sample(frac=1.0)
|
df = df.sample(frac=1.0)
|
||||||
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
|
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
|
||||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||||
|
@ -68,6 +68,7 @@ def process_seq_monkey():
|
|||||||
doc_ids = []
|
doc_ids = []
|
||||||
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
with jsonlines.open('./dataset/mobvoi_seq_monkey_general_open_corpus.jsonl') as reader:
|
||||||
for idx, obj in enumerate(reader):
|
for idx, obj in enumerate(reader):
|
||||||
|
try:
|
||||||
content = obj.get('text', '')
|
content = obj.get('text', '')
|
||||||
if len(content) > 512:
|
if len(content) > 512:
|
||||||
continue
|
continue
|
||||||
@ -75,6 +76,9 @@ def process_seq_monkey():
|
|||||||
doc_ids += text_id
|
doc_ids += text_id
|
||||||
if idx % 50000 == 0:
|
if idx % 50000 == 0:
|
||||||
print(f"seq_monkey: [{idx}]")
|
print(f"seq_monkey: [{idx}]")
|
||||||
|
except UnicodeDecodeError as e:
|
||||||
|
print(f"Skipping invalid line {idx + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
arr = np.array(doc_ids, dtype=np.uint16)
|
arr = np.array(doc_ids, dtype=np.uint16)
|
||||||
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
|
with open('./dataset/clean_seq_monkey.bin', 'wb') as f:
|
||||||
@ -150,6 +154,7 @@ def sft_process(contain_history=False):
|
|||||||
for path in sft_datasets:
|
for path in sft_datasets:
|
||||||
with jsonlines.open(path) as reader:
|
with jsonlines.open(path) as reader:
|
||||||
for idx, obj in enumerate(reader):
|
for idx, obj in enumerate(reader):
|
||||||
|
try:
|
||||||
data.append({
|
data.append({
|
||||||
'history': obj.get('history', ''),
|
'history': obj.get('history', ''),
|
||||||
'q': obj.get('input', '') + obj.get('q', ''),
|
'q': obj.get('input', '') + obj.get('q', ''),
|
||||||
@ -159,6 +164,9 @@ def sft_process(contain_history=False):
|
|||||||
if len(data) >= chunk_size:
|
if len(data) >= chunk_size:
|
||||||
process_and_write_data(data)
|
process_and_write_data(data)
|
||||||
data = []
|
data = []
|
||||||
|
except jsonlines.InvalidLineError as e:
|
||||||
|
print(f"Skipping invalid JSON line {idx + 1}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
if data:
|
if data:
|
||||||
process_and_write_data(data)
|
process_and_write_data(data)
|
||||||
@ -191,7 +199,7 @@ def rl_process():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model', use_fast=False)
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer', use_fast=False)
|
||||||
print('tokenizer词表大小:', len(tokenizer))
|
print('tokenizer词表大小:', len(tokenizer))
|
||||||
|
|
||||||
################
|
################
|
||||||
@ -199,7 +207,7 @@ if __name__ == "__main__":
|
|||||||
# 2: sft
|
# 2: sft
|
||||||
# 3: RL
|
# 3: RL
|
||||||
################
|
################
|
||||||
process_type = 2
|
process_type = 1
|
||||||
|
|
||||||
if process_type == 1:
|
if process_type == 1:
|
||||||
pretrain_process()
|
pretrain_process()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user