DynamicKV-LLM 1.0.1 交叉注意力添加多头;bf16代替fp16

This commit is contained in:
iomgaa 2025-05-08 15:47:00 +00:00
parent 10f15724b4
commit bed6faa379

View File

@ -183,7 +183,7 @@ if __name__ == "__main__":
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") #如果GPU可用则使用GPU否则使用CPU。
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--use_wandb", default=True, action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--ddp", action="store_true")
@ -193,9 +193,9 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100) #日志打印间隔,用于控制日志打印的频率。
parser.add_argument("--save_interval", type=int, default=100) #模型保存间隔,用于控制模型保存的频率。
parser.add_argument('--local_rank', type=int, default=-1) #本地进程编号,用于分布式训练。
parser.add_argument('--dim', default=768, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=8, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=512, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--dim', default=1024, type=int) #模型维度,用于控制模型的大小。
parser.add_argument('--n_layers', default=24, type=int) #层数,用于控制模型层数。
parser.add_argument('--max_seq_len', default=1024, type=int) #最大序列长度,用于控制输入序列的最大长度。
parser.add_argument('--use_moe', default=False, type=bool) #是否使用MOE用于控制是否使用MOE。
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") #数据路径,用于控制数据集的路径。
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
@ -233,8 +233,13 @@ if __name__ == "__main__":
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name)
# Merge args and lm_config into a single config dictionary
config = vars(args)
for key, value in vars(lm_config).items():
config[f"lm_{key}"] = value
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config)
else:
wandb = None