update others
This commit is contained in:
parent
48ea6a4cbf
commit
c81c17dab7
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
|||||||
epochs = 20
|
epochs = 20
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
learning_rate = 2e-4
|
learning_rate = 2e-4
|
||||||
device = 'cuda:0'
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
dtype = 'bfloat16'
|
dtype = 'bfloat16'
|
||||||
save_dir = os.path.join(out_dir)
|
save_dir = os.path.join(out_dir)
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
@ -110,13 +110,13 @@ def init_model(lm_config):
|
|||||||
if model_from == 1:
|
if model_from == 1:
|
||||||
model = Transformer(lm_config)
|
model = Transformer(lm_config)
|
||||||
moe_path = '_moe' if lm_config.use_moe else ''
|
moe_path = '_moe' if lm_config.use_moe else ''
|
||||||
# ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
|
||||||
# state_dict = torch.load(ckp, map_location=device)
|
state_dict = torch.load(ckp, map_location=device)
|
||||||
# unwanted_prefix = '_orig_mod.'
|
unwanted_prefix = '_orig_mod.'
|
||||||
# for k, v in list(state_dict.items()):
|
for k, v in list(state_dict.items()):
|
||||||
# if k.startswith(unwanted_prefix):
|
if k.startswith(unwanted_prefix):
|
||||||
# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||||
# model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
|||||||
gradient_accumulation_steps = 1
|
gradient_accumulation_steps = 1
|
||||||
batch_size = 40
|
batch_size = 40
|
||||||
learning_rate = 1e-4
|
learning_rate = 1e-4
|
||||||
device = 'cuda:0'
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
dtype = 'bfloat16'
|
dtype = 'bfloat16'
|
||||||
# dtype = 'float16'
|
# dtype = 'float16'
|
||||||
save_dir = os.path.join(out_dir)
|
save_dir = os.path.join(out_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user