修正了wandb上传错误的bug
This commit is contained in:
parent
089afd6728
commit
7a0ac5a639
@ -55,7 +55,7 @@ def init_model(lm_config, pretrained_embedding_path=None):
|
|||||||
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time):
|
def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb):
|
||||||
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
loss_fct = nn.CrossEntropyLoss(reduction='none')
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
total_steps_in_epoch = len(train_loader)
|
total_steps_in_epoch = len(train_loader)
|
||||||
@ -226,13 +226,27 @@ def train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, a
|
|||||||
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
tokens_per_sec = tokens_processed_interval / interval_elapsed_time if interval_elapsed_time > 0 else 0
|
||||||
last_log_time = current_time # 更新上次日志时间
|
last_log_time = current_time # 更新上次日志时间
|
||||||
|
|
||||||
|
log_dict = {
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"step": step + 1,
|
||||||
|
"total_steps_in_epoch": total_steps_in_epoch,
|
||||||
|
"loss": loss.item() * args.accumulation_steps,
|
||||||
|
"lr": current_lr,
|
||||||
|
"tokens_per_sec": tokens_per_sec,
|
||||||
|
"epoch_time_left_seconds": epoch_remaining_time,
|
||||||
|
"total_time_left_seconds": total_remaining_time
|
||||||
|
}
|
||||||
|
|
||||||
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
Logger(f"Epoch {epoch+1}/{args.epochs}, Step {step+1}/{total_steps_in_epoch}, "
|
||||||
f"Loss: {loss.item()*args.accumulation_steps:.4f}, "
|
f"Loss: {log_dict['loss']:.4f}, "
|
||||||
f"LR: {current_lr:.6f}, "
|
f"LR: {log_dict['lr']:.6f}, "
|
||||||
f"Speed: {tokens_per_sec:.2f} tokens/sec | "
|
f"Speed: {log_dict['tokens_per_sec']:.2f} tokens/sec | "
|
||||||
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
f"Epoch Time Left: {format_time(epoch_remaining_time)} | "
|
||||||
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
f"Total Time Left: {format_time(total_remaining_time)}", accelerator)
|
||||||
|
|
||||||
|
if args.use_wandb and accelerator.is_main_process and wandb:
|
||||||
|
wandb.log(log_dict)
|
||||||
|
|
||||||
# 保存模型 (只在主进程进行)
|
# 保存模型 (只在主进程进行)
|
||||||
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
if (step + 1) % args.save_interval == 0 and accelerator.is_main_process:
|
||||||
# 使用函数开始处定义的moe_path变量
|
# 使用函数开始处定义的moe_path变量
|
||||||
@ -425,7 +439,7 @@ def main():
|
|||||||
#########################################################
|
#########################################################
|
||||||
overall_start_time = time.time() # Record overall start time
|
overall_start_time = time.time() # Record overall start time
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time) # Pass overall start time
|
train_epoch(epoch, accelerator, model, train_loader, optimizer, scheduler, args, ctx, overall_start_time, wandb) # Pass overall start time
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
# 关闭wandb
|
# 关闭wandb
|
||||||
|
Loading…
x
Reference in New Issue
Block a user