diff --git a/train_pretrain.py b/train_pretrain.py index daabc37..a3be32b 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -183,7 +183,7 @@ if __name__ == "__main__": parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--ddp", action="store_true") @@ -240,8 +240,12 @@ if __name__ == "__main__": if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) + + # Merge args and lm_config parameters for wandb config + config = vars(args).copy() + config.update(lm_config.__dict__) + + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None