diff --git a/1-pretrain.py b/1-pretrain.py index 219ca3e..175c655 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -139,7 +139,7 @@ if __name__ == "__main__": parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="Weights & Biases project name") - parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading") + parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading") parser.add_argument("--data_path", type=str, default="./dataset/pretrain_data.csv", help="Path to training data") parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel") parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps") diff --git a/3-full_sft.py b/3-full_sft.py index fdf673a..de464f9 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -148,7 +148,7 @@ if __name__ == "__main__": parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="Weights & Biases project name") - parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading") + parser.add_argument("--num_workers", type=int, default=1, help="Number of workers for data loading") parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel") parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps") parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")