fix dtype bug
This commit is contained in:
parent
13105cfa0c
commit
89d260145f
@ -179,7 +179,7 @@ if __name__ == "__main__":
|
||||
|
||||
model = init_model()
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
|
||||
|
@ -198,7 +198,7 @@ if __name__ == "__main__":
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
|
||||
|
@ -174,7 +174,7 @@ if __name__ == "__main__":
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
|
||||
|
Loading…
x
Reference in New Issue
Block a user