update sft
This commit is contained in:
parent
6861d1af56
commit
f16991d7ec
@ -14,7 +14,7 @@ from contextlib import nullcontext
|
||||
from torch import optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from model.model import Transformer
|
||||
from model.LMConfig import LMConfig
|
||||
from model.dataset import SFTDataset
|
||||
@ -118,7 +118,7 @@ def init_model():
|
||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
else:
|
||||
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
|
||||
|
||||
Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
|
||||
model = model.to(args.device)
|
||||
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
|
||||
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
||||
parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate")
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
||||
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
||||
|
Loading…
x
Reference in New Issue
Block a user