diff --git a/model/model.py b/model/model.py index 67699b5..a434ff6 100644 --- a/model/model.py +++ b/model/model.py @@ -480,11 +480,11 @@ class MiniMindLM(PreTrainedModel): step: int = 0, **args): start_pos = args.get('start_pos', 0) - if self.freeze_embedding and step == 0: - self.tok_embeddings.weight.requires_grad = False - # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制 - # self.knowledge_dataset.freeze_embedding = True - print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) + # if self.freeze_embedding and step == 0: + # self.tok_embeddings.weight.requires_grad = False + # # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制 + # # self.knowledge_dataset.freeze_embedding = True + # print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad) h = self.dropout(self.tok_embeddings(input_ids)) pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)] for l, layer in enumerate(self.layers): diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index aae2e81..eac86a0 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -1,6 +1,6 @@ import os -# 设置环境变量 -os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun" +# 设置环境变量 - 将wandb替换为SwanLab +# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式 import platform import argparse from tqdm import tqdm @@ -21,6 +21,7 @@ from accelerate.utils import DistributedDataParallelKwargs from transformers import AutoTokenizer, get_cosine_schedule_with_warmup import numpy as np from sklearn.metrics.pairwise import cosine_similarity +import swanlab # 替换wandb导入 from model.model import MiniMindLM, RMSNorm from model.LMConfig import LMConfig @@ -218,7 +219,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') return model, tokenizer -def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb): +def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run): loss_fct = nn.CrossEntropyLoss(reduction='none') epoch_start_time = time.time() total_steps_in_epoch = len(train_loader) @@ -413,8 +414,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a f"Epoch Time Left: {format_time(epoch_remaining_time)} | " f"Total Time Left: {format_time(total_remaining_time)}", accelerator) - if args.use_wandb and accelerator.is_main_process and wandb: - wandb.log(log_dict) + if args.use_swanlab and accelerator.is_main_process and swanlab_run: + swanlab_run.log(log_dict) # 保存模型 (只在主进程进行) loss_total = loss.item() * args.accumulation_steps @@ -443,8 +444,8 @@ def main(): parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--use_wandb", default=True, action="store_true") - parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain") + parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数 + parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数 parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("--accumulation_steps", type=int, default=32) parser.add_argument("--grad_clip", type=float, default=1.0) @@ -456,14 +457,14 @@ def main(): parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代") - parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl") + parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl") parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)") parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析") parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)") parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention") parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目") parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度") - parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径") + parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径") parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)") parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径") parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件") @@ -479,7 +480,7 @@ def main(): gradient_accumulation_steps=args.accumulation_steps, gradient_clipping=args.grad_clip, zero_stage=2, # 使用ZeRO-2优化 - offload_optimizer_device="cpu", # 将优化器状态卸载到CPU + offload_optimizer_device="none", # 将优化器状态卸载到CPU offload_param_device="none", # 不将参数卸载到CPU ) accelerator = Accelerator( @@ -523,18 +524,30 @@ def main(): ######################################################### - # 配置wandb + # 配置SwanLab ######################################################### - # 设置wandb运行名称 - args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" - if args.use_wandb and accelerator.is_main_process: - import wandb - # 合并args和lm_config为一个字典 - config_dict = vars(args).copy() - config_dict.update(vars(lm_config)) - wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict) + # 设置SwanLab运行名称 + args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}" + + # 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息) + config_dict = vars(args).copy() + config_dict.update(vars(lm_config)) + + # 初始化SwanLab实验实例 + swanlab_run = None + if args.use_swanlab and accelerator.is_main_process: + # 初始化SwanLab + swanlab_run = swanlab.init( + project=args.swanlab_project, + experiment_name=args.swanlab_run_name, + description="MiniMind预训练实验,使用本地部署的SwanLab进行可视化", + config=config_dict + # 设置SwanLab服务器地址和API Key + # host="http://100.123.118.114:11071", + # api_key="LesBT7HRq23HNBrOPKP8S" + ) else: - wandb = None + swanlab_run = None ######################################################### # 打印信息 @@ -616,13 +629,13 @@ def main(): ######################################################### overall_start_time = time.time() # Record overall start time for epoch in range(args.epochs): - train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time + train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time ######################################################### - # 关闭wandb + # 关闭SwanLab ######################################################### - if args.use_wandb and accelerator.is_main_process: - wandb.finish() + if args.use_swanlab and accelerator.is_main_process and swanlab_run: + swanlab_run.finish() if __name__ == "__main__": main()