diff --git a/1-pretrain.py b/1-pretrain.py index 4a0bb29..50fee2a 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -74,7 +74,7 @@ def train_epoch(epoch, wandb, accumulation_steps=8): optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) - if (use_wandb is not None) and (not ddp or dist.get_rank() == 0): + if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({"loss": loss.item() * accumulation_steps, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) diff --git a/3-full_sft.py b/3-full_sft.py index a2f9b8d..c413de0 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -86,7 +86,7 @@ def train_epoch(epoch, wandb): optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) - if (use_wandb is not None) and (not ddp or dist.get_rank() == 0): + if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log({"loss": loss, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) diff --git a/4-lora_sft.py b/4-lora_sft.py index 2dfd22b..ab8ba31 100644 --- a/4-lora_sft.py +++ b/4-lora_sft.py @@ -73,7 +73,7 @@ def train_epoch(epoch, wandb): optimizer.param_groups[-1]['lr'], spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) - if use_wandb is not None: + if wandb is not None: wandb.log({"loss": loss.item(), "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})