update config
This commit is contained in:
parent
b043ec996b
commit
f3f1cc5fac
@ -99,7 +99,7 @@ def init_model():
|
||||
|
||||
# model init
|
||||
model = Transformer(lm_config).to(device)
|
||||
# moe_path = '_moe' if lm_config.use_moe else ''
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
# ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
#
|
||||
# state_dict = torch.load(ckp, map_location=device)
|
||||
|
@ -21,7 +21,7 @@ def init_model(lm_config):
|
||||
|
||||
if model_from == 1:
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
model = Transformer(lm_config)
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
|
@ -108,17 +108,15 @@ def init_model(lm_config):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
if model_from == 1:
|
||||
model = Transformer(lm_config)
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
model = Transformer(lm_config)
|
||||
# state_dict = torch.load(ckp, map_location=device)
|
||||
#
|
||||
# unwanted_prefix = '_orig_mod.'
|
||||
# for k, v in list(state_dict.items()):
|
||||
# if k.startswith(unwanted_prefix):
|
||||
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
state_dict = torch.load(ckp, map_location=device)
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
||||
|
||||
@ -148,7 +146,7 @@ if __name__ == "__main__":
|
||||
out_dir = 'out'
|
||||
epochs = 19
|
||||
gradient_accumulation_steps = 1
|
||||
batch_size = 80
|
||||
batch_size = 50
|
||||
learning_rate = 2e-4
|
||||
device = 'cuda:0'
|
||||
dtype = 'bfloat16'
|
||||
@ -175,7 +173,7 @@ if __name__ == "__main__":
|
||||
|
||||
model, tokenizer = init_model(lm_config)
|
||||
# -----init dataloader------
|
||||
df = pd.read_csv('./dataset/sft_data_single.csv')
|
||||
df = pd.read_csv('./dataset/sft_data_multi.csv')
|
||||
df = df.sample(frac=1.0)
|
||||
train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
|
BIN
model/mistral_tokenizer/tokenizer.model
Normal file
BIN
model/mistral_tokenizer/tokenizer.model
Normal file
Binary file not shown.
44
model/mistral_tokenizer/tokenizer_config.json
Normal file
44
model/mistral_tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,44 @@
|
||||
{
|
||||
"add_bos_token": false,
|
||||
"add_eos_token": false,
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [],
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "</s>",
|
||||
"legacy": true,
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": null,
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "LlamaTokenizer",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false,
|
||||
"chat_template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<s>user\\n' + content + '</s>\\n<s>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '</s>' + '\\n' }}{% endif %}{% endfor %}"
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user