修正了wandb上传错误的bug
This commit is contained in:
parent
089afd6728
commit
7a0ac5a639
@ -55,7 +55,7 @@ def init_model(lm_config, pretrained_embedding_path=None):
|
||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||
return model, tokenizer
|
||||
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time):
|
||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||
epoch_start_time = time.time()
|
||||
total_steps_in_epoch = len(train_loader)
|
||||
@ -226,13 +226,27 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
||||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||||
last_log_time = current_time # 更新上次日志时间
|
||||
|
||||
log_dict = {
|
||||
"epoch": epoch + 1,
|
||||
"step": step + 1,
|
||||
"total_steps_in_epoch": total_steps_in_epoch,
|
||||
"loss": loss.item() * args.accumulation_steps,
|
||||
"lr": current_lr,
|
||||
"tokens_per_sec": tokens_per_sec,
|
||||
"epoch_time_left_seconds": epoch_remaining_time,
|
||||
"total_time_left_seconds": total_remaining_time
|
||||
}
|
||||
|
||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
|
||||
f"LR: {current_lr:.6f}, "
|
||||
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
|
||||
f"Loss: {log_dict['loss']:.4f}, "
|
||||
f"LR: {log_dict['lr']:.6f}, "
|
||||
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||
|
||||
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||
wandb.log(log_dict)
|
||||
|
||||
# 保存模型 (只在主进程进行)
|
||||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||
# 使用函数开始处定义的moe_path变量
|
||||
@ -425,7 +439,7 @@ def main():
|
||||
#########################################################
|
||||
overall_start_time = time.time() # Record overall start time
|
||||
for epoch in range(args.epochs):
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time
|
||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||
|
||||
#########################################################
|
||||
# 关闭wandb
|
||||
|
Loading…
x
Reference in New Issue
Block a user