diff --git a/train_full_sft.py b/train_full_sft.py index 859dafc..5185861 100644 --- a/train_full_sft.py +++ b/train_full_sft.py @@ -122,7 +122,7 @@ def init_distributed_mode(): if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind Full SFT") parser.add_argument("--out_dir", type=str, default="out") - parser.add_argument("--epochs", type=int, default=6) + parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--learning_rate", type=float, default=5e-5) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")