update export_model

This commit is contained in:
gongjy 2024-09-15 11:39:33 +08:00
parent 2c22f1bb26
commit 13e791e516

View File

@ -18,7 +18,7 @@ def export_transformers_model():
lm_model = Transformer(lm_config) lm_model = Transformer(lm_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
moe_path = '_moe' if lm_config.use_moe else '' moe_path = '_moe' if lm_config.use_moe else ''
ckpt_path = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth' ckpt_path = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckpt_path, map_location=device) state_dict = torch.load(ckpt_path, map_location=device)
unwanted_prefix = '_orig_mod.' unwanted_prefix = '_orig_mod.'
@ -28,13 +28,13 @@ def export_transformers_model():
lm_model.load_state_dict(state_dict, strict=False) lm_model.load_state_dict(state_dict, strict=False)
print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)') print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)')
lm_model.save_pretrained("minimind-small", safe_serialization=False) lm_model.save_pretrained("minimind-small-T", safe_serialization=False)
def export_tokenizer(): def export_tokenizer():
tokenizer = AutoTokenizer.from_pretrained('./model', tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer',
trust_remote_code=True, use_fast=False) trust_remote_code=True, use_fast=False)
tokenizer.save_pretrained("minimind-small") tokenizer.save_pretrained("minimind-small-T")
def push_to_hf(): def push_to_hf():