update export_model

This commit is contained in:
gongjy 2024-09-15 11:39:33 +08:00
parent 2c22f1bb26
commit 13e791e516

View File

@ -18,7 +18,7 @@ def export_transformers_model():
lm_model = Transformer(lm_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
moe_path = '_moe' if lm_config.use_moe else ''
ckpt_path = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
ckpt_path = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckpt_path, map_location=device)
unwanted_prefix = '_orig_mod.'
@ -28,13 +28,13 @@ def export_transformers_model():
lm_model.load_state_dict(state_dict, strict=False)
print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)')
lm_model.save_pretrained("minimind-small", safe_serialization=False)
lm_model.save_pretrained("minimind-small-T", safe_serialization=False)
def export_tokenizer():
tokenizer = AutoTokenizer.from_pretrained('./model',
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer',
trust_remote_code=True, use_fast=False)
tokenizer.save_pretrained("minimind-small")
tokenizer.save_pretrained("minimind-small-T")
def push_to_hf():