diff --git a/train_pretrain.py b/train_pretrain.py index da25834..c4c86e7 100644 --- a/train_pretrain.py +++ b/train_pretrain.py @@ -206,7 +206,7 @@ if __name__ == "__main__": parser.add_argument("--grad_clip", type=float, default=1.0) #梯度裁剪阈值,用于防止梯度爆炸。 parser.add_argument("--warmup_iters", type=int, default=0) #预热迭代次数,用于控制学习率预热过程。 parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。 - parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。 + parser.add_argument("--save_interval", type=int, default=10000) #模型保存间隔,用于控制模型保存的频率。 parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。 parser.add_argument('--dim', default=2048, type=int) #模型维度,用于控制模型的大小。 parser.add_argument('--n_layers', default=32, type=int) #层数,用于控制模型层数。