update
This commit is contained in:
parent
83b91859ce
commit
44cd7b4d72
@ -480,11 +480,11 @@ class MiniMindLM(PreTrainedModel):
|
||||
step: int = 0,
|
||||
**args):
|
||||
start_pos = args.get('start_pos', 0)
|
||||
if self.freeze_embedding and step == 0:
|
||||
self.tok_embeddings.weight.requires_grad = False
|
||||
# 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# self.knowledge_dataset.freeze_embedding = True
|
||||
print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
# if self.freeze_embedding and step == 0:
|
||||
# self.tok_embeddings.weight.requires_grad = False
|
||||
# # 移除对knowledge_dataset.freeze_embedding的设置,让键更新由batch_counter控制
|
||||
# # self.knowledge_dataset.freeze_embedding = True
|
||||
# print("tok_embeddings.weight.requires_grad: ", self.tok_embeddings.weight.requires_grad)
|
||||
h = self.dropout(self.tok_embeddings(input_ids))
|
||||
pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]
|
||||
for l, layer in enumerate(self.layers):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
# 设置环境变量
|
||||
os.environ["WANDB_MODE"] = "offline" # 或者使用 "dryrun"
|
||||
# 设置环境变量 - 将wandb替换为SwanLab
|
||||
# os.environ["SWANLAB_MODE"] = "online" # SwanLab使用在线模式
|
||||
import platform
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
@ -21,6 +21,7 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
||||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import swanlab # 替换wandb导入
|
||||
|
||||
from model.model import MiniMindLM, RMSNorm
|
||||
from model.LMConfig import LMConfig
|
||||
@ -218,7 +219,7 @@ def init_model(lm_config, pretrained_embedding_path=None, database_init_path=Non
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
@ -413,8 +414,8 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||
|
||||
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||
wandb.log(log_dict)
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.log(log_dict)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
loss_total = loss.item() * args.accumulation_steps
|
||||
@ -443,8 +444,8 @@ def main():
|
||||
parser.add_argument("--batch_size", type=int, default=64)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--use_wandb", default=True, action="store_true")
|
||||
parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain")
|
||||
parser.add_argument("--use_swanlab", default=True, action="store_true") # 替换wandb参数
|
||||
parser.add_argument("--swanlab_project", type=str, default="MiniMind-Pretrain") # 替换wandb参数
|
||||
parser.add_argument("--num_workers", type=int, default=8)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=32)
|
||||
parser.add_argument("--grad_clip", type=float, default=1.0)
|
||||
@ -456,14 +457,14 @@ def main():
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument('--disable_db', action='store_true', help="禁用数据库功能,使用固定值1e-4替代")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/pretrain_hq.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="./dataset/merged_pretrain.jsonl")
|
||||
parser.add_argument("--pretrained_embedding_path", type=str, default=None, help="Path to pretrained token embedding weights (.pth file)")
|
||||
parser.add_argument("--profile", action="store_true", default=True, help="启用性能分析")
|
||||
parser.add_argument("--profile_interval", type=int, default=10, help="性能分析打印间隔(步数)")
|
||||
parser.add_argument("--use_flash_attn", action="store_true", default=True, help="启用FlashAttention")
|
||||
parser.add_argument("--knowledge_num", type=int, default=960400,help="知识库的数据数目")
|
||||
parser.add_argument("--knowledge_length", type=int, default=32,help="知识库的句子长度")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/database_init.json", help="数据库初始化路径")
|
||||
parser.add_argument("--database_init_path", type=str, default="./dataset/combined_prepare.json", help="数据库初始化路径")
|
||||
parser.add_argument("--fast_clustering", action="store_true", default=True, help="使用快速近似聚类算法(适用于大数据集)")
|
||||
parser.add_argument("--cluster_cache_path", type=str, default="./cache/cluster_tokens_single.pt", help="聚类结果缓存文件路径")
|
||||
parser.add_argument("--recompute_clusters", action="store_true", default=False, help="强制重新计算聚类,忽略缓存文件")
|
||||
@ -479,7 +480,7 @@ def main():
|
||||
gradient_accumulation_steps=args.accumulation_steps,
|
||||
gradient_clipping=args.grad_clip,
|
||||
zero_stage=2, # 使用ZeRO-2优化
|
||||
offload_optimizer_device="cpu", # 将优化器状态卸载到CPU
|
||||
offload_optimizer_device="none", # 将优化器状态卸载到CPU
|
||||
offload_param_device="none", # 不将参数卸载到CPU
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
@ -523,18 +524,30 @@ def main():
|
||||
|
||||
|
||||
#########################################################
|
||||
# 配置wandb
|
||||
# 配置SwanLab
|
||||
#########################################################
|
||||
# 设置wandb运行名称
|
||||
args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
import wandb
|
||||
# 合并args和lm_config为一个字典
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=config_dict)
|
||||
# 设置SwanLab运行名称
|
||||
args.swanlab_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
|
||||
|
||||
# 合并args和lm_config为一个字典(无论是否使用SwanLab都需要,用于打印配置信息)
|
||||
config_dict = vars(args).copy()
|
||||
config_dict.update(vars(lm_config))
|
||||
|
||||
# 初始化SwanLab实验实例
|
||||
swanlab_run = None
|
||||
if args.use_swanlab and accelerator.is_main_process:
|
||||
# 初始化SwanLab
|
||||
swanlab_run = swanlab.init(
|
||||
project=args.swanlab_project,
|
||||
experiment_name=args.swanlab_run_name,
|
||||
description="MiniMind预训练实验,使用本地部署的SwanLab进行可视化",
|
||||
config=config_dict
|
||||
# 设置SwanLab服务器地址和API Key
|
||||
# host="http://100.123.118.114:11071",
|
||||
# api_key="LesBT7HRq23HNBrOPKP8S"
|
||||
)
|
||||
else:
|
||||
wandb = None
|
||||
swanlab_run = None
|
||||
|
||||
#########################################################
|
||||
# 打印信息
|
||||
@ -616,13 +629,13 @@ def main():
|
||||
#########################################################
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, swanlab_run) # Pass overall start time
|
||||
|
||||
#########################################################
|
||||
# 关闭wandb
|
||||
# 关闭SwanLab
|
||||
#########################################################
|
||||
if args.use_wandb and accelerator.is_main_process:
|
||||
wandb.finish()
|
||||
if args.use_swanlab and accelerator.is_main_process and swanlab_run:
|
||||
swanlab_run.finish()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user