update config

This commit is contained in:
gongjy 2024-09-15 15:08:04 +08:00
parent b043ec996b
commit f3f1cc5fac
5 changed files with 55 additions and 13 deletions

View File

@ -99,7 +99,7 @@ def init_model():
# model init
model = Transformer(lm_config).to(device)
# moe_path = '_moe' if lm_config.use_moe else ''
moe_path = '_moe' if lm_config.use_moe else ''
# ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
#
# state_dict = torch.load(ckp, map_location=device)

View File

@ -21,7 +21,7 @@ def init_model(lm_config):
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)

View File

@ -108,17 +108,15 @@ def init_model(lm_config):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if model_from == 1:
model = Transformer(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
# state_dict = torch.load(ckp, map_location=device)
#
# unwanted_prefix = '_orig_mod.'
# for k, v in list(state_dict.items()):
# if k.startswith(unwanted_prefix):
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
# model.load_state_dict(state_dict, strict=False)
state_dict = torch.load(ckp, map_location=device)
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
@ -148,7 +146,7 @@ if __name__ == "__main__":
out_dir = 'out'
epochs = 19
gradient_accumulation_steps = 1
batch_size = 80
batch_size = 50
learning_rate = 2e-4
device = 'cuda:0'
dtype = 'bfloat16'
@ -175,7 +173,7 @@ if __name__ == "__main__":
model, tokenizer = init_model(lm_config)
# -----init dataloader------
df = pd.read_csv('./dataset/sft_data_single.csv')
df = pd.read_csv('./dataset/sft_data_multi.csv')
df = df.sample(frac=1.0)
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
train_sampler = DistributedSampler(train_ds) if ddp else None

Binary file not shown.

View File

@ -0,0 +1,44 @@
{
"add_bos_token": false,
"add_eos_token": false,
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [],
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
"legacy": true,
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"spaces_between_special_tokens": false,
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
"use_default_system_prompt": false,
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
}