diff --git a/3-full_sft.py b/3-full_sft.py index 6bee3c1..fdf673a 100644 --- a/3-full_sft.py +++ b/3-full_sft.py @@ -14,7 +14,7 @@ from contextlib import nullcontext from torch import optim from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, AutoModelForCausalLM from model.model import Transformer from model.LMConfig import LMConfig from model.dataset import SFTDataset @@ -118,7 +118,7 @@ def init_model(): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) else: - model = AutoModel.from_pretrained('./minimind', trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True) Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万') model = model.to(args.device) @@ -143,7 +143,7 @@ if __name__ == "__main__": parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--epochs", type=int, default=19, help="Number of epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") - parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")