update export_model
This commit is contained in:
parent
2c22f1bb26
commit
13e791e516
@ -18,7 +18,7 @@ def export_transformers_model():
|
|||||||
lm_model = Transformer(lm_config)
|
lm_model = Transformer(lm_config)
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
moe_path = '_moe' if lm_config.use_moe else ''
|
moe_path = '_moe' if lm_config.use_moe else ''
|
||||||
ckpt_path = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
|
ckpt_path = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
|
||||||
|
|
||||||
state_dict = torch.load(ckpt_path, map_location=device)
|
state_dict = torch.load(ckpt_path, map_location=device)
|
||||||
unwanted_prefix = '_orig_mod.'
|
unwanted_prefix = '_orig_mod.'
|
||||||
@ -28,13 +28,13 @@ def export_transformers_model():
|
|||||||
lm_model.load_state_dict(state_dict, strict=False)
|
lm_model.load_state_dict(state_dict, strict=False)
|
||||||
print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)')
|
print(f'模型参数: {count_parameters(lm_model) / 1e6} 百万 = {count_parameters(lm_model) / 1e9} B (Billion)')
|
||||||
|
|
||||||
lm_model.save_pretrained("minimind-small", safe_serialization=False)
|
lm_model.save_pretrained("minimind-small-T", safe_serialization=False)
|
||||||
|
|
||||||
|
|
||||||
def export_tokenizer():
|
def export_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained('./model',
|
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer',
|
||||||
trust_remote_code=True, use_fast=False)
|
trust_remote_code=True, use_fast=False)
|
||||||
tokenizer.save_pretrained("minimind-small")
|
tokenizer.save_pretrained("minimind-small-T")
|
||||||
|
|
||||||
|
|
||||||
def push_to_hf():
|
def push_to_hf():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user