wandb包含config信息
This commit is contained in:
parent
0c8c6e5d1a
commit
cb286d26d1
@ -183,7 +183,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--learning_rate", type=float, default=5e-4)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--ddp", action="store_true")
|
||||
@ -240,8 +240,12 @@ if __name__ == "__main__":
|
||||
|
||||
if args.use_wandb and (not ddp or ddp_local_rank == 0):
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
|
||||
|
||||
# Merge args and lm_config parameters for wandb config
|
||||
config = vars(args).copy()
|
||||
config.update(lm_config.__dict__)
|
||||
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
|
||||
else:
|
||||
wandb = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user