From d57037624b57431763969739a27a7aa3ab973f73 Mon Sep 17 00:00:00 2001 From: gongjy <2474590974@qq.com> Date: Wed, 25 Sep 2024 12:35:29 +0800 Subject: [PATCH] update batchsize --- 1-pretrain.py | 2 +- 3-full_sft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/1-pretrain.py b/1-pretrain.py index aa8da83..6cce4f0 100644 --- a/1-pretrain.py +++ b/1-pretrain.py @@ -122,7 +122,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Pretraining") parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") - parser.add_argument("--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") diff --git a/3-full_sft.py b/3-full_sft.py index e50f54e..551043b 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -142,7 +142,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--epochs", type=int, default=19, help="Number of epochs") - parser.add_argument("--batch_size", type=int, default=40, help="Batch size") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")