From cb286d26d1dd2143bd05a074cc1fa5688a53a1c9 Mon Sep 17 00:00:00 2001 From: iomgaa Date: Sat, 10 May 2025 20:23:52 +0800 Subject: [PATCH] =?UTF-8?q?wandb=E5=8C=85=E5=90=ABconfig=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_pretrain.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/train_pretrain.py b/train_pretrain.py index daabc37..a3be32b 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -183,7 +183,7 @@ if __name__ == "__main__": parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--ddp", action="store_true") @@ -240,8 +240,12 @@ if __name__ == "__main__": if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) + + # Merge args and lm_config parameters for wandb config + config = vars(args).copy() + config.update(lm_config.__dict__) + + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None