update sft
This commit is contained in:
parent
6861d1af56
commit
f16991d7ec
@ -14,7 +14,7 @@ from contextlib import nullcontext
|
|||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from transformers import AutoTokenizer, AutoModel
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from model.model import Transformer
|
from model.model import Transformer
|
||||||
from model.LMConfig import LMConfig
|
from model.LMConfig import LMConfig
|
||||||
from model.dataset import SFTDataset
|
from model.dataset import SFTDataset
|
||||||
@ -118,7 +118,7 @@ def init_model():
|
|||||||
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained('./minimind-v1-small', trust_remote_code=True)
|
||||||
|
|
||||||
Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
|
parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
|
||||||
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
|
parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
|
||||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
|
||||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
parser.add_argument("--learning_rate", type=float, default=1.5e-4, help="Learning rate")
|
||||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
|
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
|
||||||
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
|
||||||
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user