From 89d260145f64ea07fe4cddea8c16e98feea21851 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Wed, 25 Sep 2024 10:07:30 +0800 Subject: [PATCH] fix dtype bug --- 1-pretrain.py | 2 +- 3-full_sft.py | 2 +- 4-lora_sft.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/1-pretrain.py b/1-pretrain.py index ad1a45c..aa8da83 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -179,7 +179,7 @@ if __name__ == "__main__": model = init_model() - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype)) + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: diff --git a/3-full_sft.py b/3-full_sft.py index d82e662..e50f54e 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -198,7 +198,7 @@ if __name__ == "__main__": sampler=train_sampler ) - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype)) + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: diff --git a/4-lora_sft.py b/4-lora_sft.py index e72f8ca..f6328f1 100644 --- a/4-lora_sft.py +++ b/4-lora_sft.py @@ -174,7 +174,7 @@ if __name__ == "__main__": num_workers=args.num_workers, ) - scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) + scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2: