diff --git a/1-pretrain.py b/1-pretrain.py
index 8560126..56c937d 100644
--- a/1-pretrain.py
+++ b/1-pretrain.py
@@ -146,11 +146,6 @@ if __name__ == "__main__":
     use_wandb = True #是否使用wandb
     wandb_project = "MiniMind-Pretrain"
     wandb_run_name = f"MiniMind-Pretrain-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
-    if use_wandb:
-        import wandb
-        wandb.init(project=wandb_project, name=wandb_run_name)
-    else:
-        wandb = None
 
 
     ctx = (
@@ -163,6 +158,12 @@ if __name__ == "__main__":
     if ddp:
         init_distributed_mode()
         device = torch.device(DEVICE)
+
+    if use_wandb and (not ddp or ddp_local_rank == 0):
+        import wandb
+        wandb.init(project=wandb_project, name=wandb_run_name)
+    else:
+        wandb = None
     # -----------------------------------------------------------------------------
 
     # -----init dataloader------