DynamicKV-LLM Pretrain v1.1.0

This commit is contained in:
Jax922 2025-05-14 00:42:50 +08:00
parent 089afd6728
commit 5841f8b4e5
3 changed files with 25 additions and 7 deletions

View File

@ -44,4 +44,6 @@ CUDA_VISIBLE_DEVICES=0 accelerate launch \
--max_seq_len 512 \
--use_flash_attn \
--profile \
--profile_interval 10
--profile_interval 10\
--knowlwdge_num 4096 \
--knowlwdge_length 8

View File

@ -45,4 +45,6 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
--max_seq_len 1024 \
--use_flash_attn \
--profile \
--profile_interval 10
--profile_interval 10\
--knowlwdge_num 1024 \
--knowlwdge_length 8

View File

@ -55,7 +55,7 @@ def init_model(lm_config, pretrained_embedding_path=None):
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time):
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader)
@ -226,13 +226,27 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"loss": loss.item() * args.accumulation_steps,
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
f"LR: {current_lr:.6f}, "
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
f"Loss: {log_dict['loss']:.4f}, "
f"LR: {log_dict['lr']:.6f}, "
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
# 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量
@ -425,7 +439,7 @@ def main():
#########################################################
overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
#########################################################
# 关闭wandb