diff --git a/1-pretrain.py b/1-pretrain.py index 8560126..56c937d 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -146,11 +146,6 @@ if __name__ == "__main__": use_wandb = True #是否使用wandb wandb_project = "MiniMind-Pretrain" wandb_run_name = f"MiniMind-Pretrain-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}" - if use_wandb: - import wandb - wandb.init(project=wandb_project, name=wandb_run_name) - else: - wandb = None ctx = ( @@ -163,6 +158,12 @@ if __name__ == "__main__": if ddp: init_distributed_mode() device = torch.device(DEVICE) + + if use_wandb and (not ddp or ddp_local_rank == 0): + import wandb + wandb.init(project=wandb_project, name=wandb_run_name) + else: + wandb = None # ----------------------------------------------------------------------------- # -----init dataloader------