update sft

This commit is contained in:
gongjy 2024-10-16 22:58:06 +08:00
parent 6861d1af56
commit f16991d7ec

View File

@ -14,7 +14,7 @@ from contextlib import nullcontext
from torch import optim from torch import optim
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer from model.model import Transformer
from model.LMConfig import LMConfig from model.LMConfig import LMConfig
from model.dataset import SFTDataset from model.dataset import SFTDataset
@ -118,7 +118,7 @@ def init_model():
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
else: else:
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万') Logger(f'LLM总参数量{count_parameters(model) / 1e6:.3f} 百万')
model = model.to(args.device) model = model.to(args.device)
@ -143,7 +143,7 @@ if __name__ == "__main__":
parser.add_argument("--out_dir", type=str, default="out", help="Output directory") parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs") parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate")
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use") parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type") parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases") parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")