update export_model
This commit is contained in:
parent
2c22f1bb26
commit
13e791e516
@ -18,7 +18,7 @@ def export_transformers_model():
|
||||
lm_model = Transformer(lm_config)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
moe_path = '_moe' if lm_config.use_moe else ''
|
||||
ckpt_path = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
ckpt_path = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location=device)
|
||||
unwanted_prefix = '_orig_mod.'
|
||||
@ -28,13 +28,13 @@ def export_transformers_model():
|
||||
lm_model.load_state_dict(state_dict, strict=False)
|
||||
print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)')
|
||||
|
||||
lm_model.save_pretrained("minimind-small", safe_serialization=False)
|
||||
lm_model.save_pretrained("minimind-small-T", safe_serialization=False)
|
||||
|
||||
|
||||
def export_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model',
|
||||
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer',
|
||||
trust_remote_code=True, use_fast=False)
|
||||
tokenizer.save_pretrained("minimind-small")
|
||||
tokenizer.save_pretrained("minimind-small-T")
|
||||
|
||||
|
||||
def push_to_hf():
|
||||
|
Loading…
x
Reference in New Issue
Block a user