fix dtype bug

This commit is contained in:
gongjy 2024-09-25 10:07:30 +08:00
parent 13105cfa0c
commit 89d260145f
3 changed files with 3 additions and 3 deletions

View File

@ -179,7 +179,7 @@ if __name__ == "__main__":
model = init_model() model = init_model()
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype)) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:

View File

@ -198,7 +198,7 @@ if __name__ == "__main__":
sampler=train_sampler sampler=train_sampler
) )
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype)) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:

View File

@ -174,7 +174,7 @@ if __name__ == "__main__":
num_workers=args.num_workers, num_workers=args.num_workers,
) )
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: