update sft

This commit is contained in:
gongjy 2024-10-16 22:58:06 +08:00
parent 6861d1af56
commit f16991d7ec

View File

@ -14,7 +14,7 @@ from contextlib import nullcontext
from torch import optim
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig
from model.dataset import SFTDataset
@ -118,7 +118,7 @@ def init_model():
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万')
model = model.to(args.device)
@ -143,7 +143,7 @@ if __name__ == "__main__":
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")