diff --git a/1-pretrain.py b/1-pretrain.py index d51b0e3..219ca3e 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -147,6 +147,7 @@ if __name__ == "__main__": parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations") parser.add_argument("--log_interval", type=int, default=100, help="Logging interval") parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval") + parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') args = parser.parse_args() diff --git a/3-full_sft.py b/3-full_sft.py index 551043b..6bee3c1 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -155,6 +155,7 @@ if __name__ == "__main__": parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations") parser.add_argument("--log_interval", type=int, default=100, help="Logging interval") parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval") + parser.add_argument('--local_rank', type=int, default=-1, help='local rank for distributed training') args = parser.parse_args()