diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 22600cd..117b8b3 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -1421,12 +1421,14 @@ def main(): lm_config = LMConfig( dim=args.dim, n_layers=args.n_layers, + n_heads=args.n_heads, max_seq_len=args.max_seq_len, use_moe=args.use_moe, disable_db=args.disable_db, flash_attn=args.use_flash_attn, knowledge_num=args.knowledge_num, knowledge_length=args.knowledge_length, + knowledge_dim=args.knowledge_dim, embeddings_epoch=args.embedding_epoch, freeze_ratio=args.freeze_ratio )