diff --git a/train_pretrain_accelerate.py b/train_pretrain_accelerate.py index 287ce0c..5faf219 100644 --- a/train_pretrain_accelerate.py +++ b/train_pretrain_accelerate.py @@ -55,7 +55,7 @@ def init_model(lm_config, pretrained_embedding_path=None): Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') return model, tokenizer -def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time): +def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb): loss_fct = nn.CrossEntropyLoss(reduction='none') epoch_start_time = time.time() total_steps_in_epoch = len(train_loader) @@ -226,13 +226,27 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0 last_log_time = current_time # 更新上次日志时间 + log_dict = { + "epoch": epoch + 1, + "step": step + 1, + "total_steps_in_epoch": total_steps_in_epoch, + "loss": loss.item() * args.accumulation_steps, + "lr": current_lr, + "tokens_per_sec": tokens_per_sec, + "epoch_time_left_seconds": epoch_remaining_time, + "total_time_left_seconds": total_remaining_time + } + Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " - f"Loss: {loss.item()*args.accumulation_steps:.4f}, " - f"LR: {current_lr:.6f}, " - f"Speed: {tokens_per_sec:.2f} tokens/sec | " + f"Loss: {log_dict['loss']:.4f}, " + f"LR: {log_dict['lr']:.6f}, " + f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | " f"Epoch Time Left: {format_time(epoch_remaining_time)} | " f"Total Time Left: {format_time(total_remaining_time)}", accelerator) + if args.use_wandb and accelerator.is_main_process and wandb: + wandb.log(log_dict) + # 保存模型 (只在主进程进行) if (step + 1) % args.save_interval == 0 and accelerator.is_main_process: # 使用函数开始处定义的moe_path变量 @@ -425,7 +439,7 @@ def main(): ######################################################### overall_start_time = time.time() # Record overall start time for epoch in range(args.epochs): - train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time + train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time ######################################################### # 关闭wandb