修正了wandb上传错误的bug

This commit is contained in:
iomgaa 2025-05-14 00:22:36 +08:00
parent 089afd6728
commit 7a0ac5a639

View File

@ -55,7 +55,7 @@ def init_model(lm_config, pretrained_embedding_path=None):
Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万') Logger(f'LLM总参数量{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
return model, tokenizer return model, tokenizer
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time): def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none') loss_fct = nn.CrossEntropyLoss(reduction='none')
epoch_start_time = time.time() epoch_start_time = time.time()
total_steps_in_epoch = len(train_loader) total_steps_in_epoch = len(train_loader)
@ -226,13 +226,27 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0 tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
last_log_time = current_time # 更新上次日志时间 last_log_time = current_time # 更新上次日志时间
log_dict = {
"epoch": epoch + 1,
"step": step + 1,
"total_steps_in_epoch": total_steps_in_epoch,
"loss": loss.item() * args.accumulation_steps,
"lr": current_lr,
"tokens_per_sec": tokens_per_sec,
"epoch_time_left_seconds": epoch_remaining_time,
"total_time_left_seconds": total_remaining_time
}
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, " Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
f"Loss: {loss.item()*args.accumulation_steps:.4f}, " f"Loss: {log_dict['loss']:.4f}, "
f"LR: {current_lr:.6f}, " f"LR: {log_dict['lr']:.6f}, "
f"Speed: {tokens_per_sec:.2f} tokens/sec | " f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
f"Epoch Time Left: {format_time(epoch_remaining_time)} | " f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
f"Total Time Left: {format_time(total_remaining_time)}", accelerator) f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
if args.use_wandb and accelerator.is_main_process and wandb:
wandb.log(log_dict)
# 保存模型 (只在主进程进行) # 保存模型 (只在主进程进行)
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process: if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
# 使用函数开始处定义的moe_path变量 # 使用函数开始处定义的moe_path变量
@ -425,7 +439,7 @@ def main():
######################################################### #########################################################
overall_start_time = time.time() # Record overall start time overall_start_time = time.time() # Record overall start time
for epoch in range(args.epochs): for epoch in range(args.epochs):
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
######################################################### #########################################################
# 关闭wandb # 关闭wandb