fix dtype bug

This commit is contained in:
gongjy 2024-09-25 10:07:30 +08:00
parent 13105cfa0c
commit 89d260145f
3 changed files with 3 additions and 3 deletions

View File

@ -179,7 +179,7 @@ if __name__ == "__main__":
model = init_model()
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:

View File

@ -198,7 +198,7 @@ if __name__ == "__main__":
sampler=train_sampler
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:

View File

@ -174,7 +174,7 @@ if __name__ == "__main__":
num_workers=args.num_workers,
)
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: