diff --git a/train_pretrain.py b/train_pretrain.py index 1c9995b..4f6a4d0 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -183,7 +183,7 @@ if __name__ == "__main__": parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--ddp", action="store_true") @@ -193,9 +193,9 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 - parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。 - parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。 - parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。 + parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。 + parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。 + parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。 parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") @@ -233,8 +233,13 @@ if __name__ == "__main__": if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) + + # Merge args and lm_config into a single config dictionary + config = vars(args) + for key, value in vars(lm_config).items(): + config[f"lm_{key}"] = value + + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None