From bed6faa379a24eff9c10d430c2c38fcffc832a6a Mon Sep 17 00:00:00 2001 From: iomgaa Date: Thu, 8 May 2025 15:47:00 +0000 Subject: [PATCH] =?UTF-8?q?DynamicKV-LLM=201.0.1=20=E4=BA=A4=E5=8F=89?= =?UTF-8?q?=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=B7=BB=E5=8A=A0=E5=A4=9A=E5=A4=B4?= =?UTF-8?q?=EF=BC=9Bbf16=E4=BB=A3=E6=9B=BFfp16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_pretrain.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/train_pretrain.py b/train_pretrain.py index 1c9995b..4f6a4d0 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -183,7 +183,7 @@ if __name__ == "__main__": parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用,则使用GPU,否则使用CPU。 parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", default=False, action="store_true") + parser.add_argument("--use_wandb", default=True, action="store_true") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--ddp", action="store_true") @@ -193,9 +193,9 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 - parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。 - parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。 - parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。 + parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。 + parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。 + parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。 parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE,用于控制是否使用MOE。 parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。 parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") @@ -233,8 +233,13 @@ if __name__ == "__main__": if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb - - wandb.init(project=args.wandb_project, name=args.wandb_run_name) + + # Merge args and lm_config into a single config dictionary + config = vars(args) + for key, value in vars(lm_config).items(): + config[f"lm_{key}"] = value + + wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config) else: wandb = None