Minimind/loss.py

33 lines
843 B
Python

import re
import matplotlib.pyplot as plt
log_file = 'out/train.log'
steps_per_epoch = 58880 # 你需要根据实际日志设置
with open(log_file, 'r', encoding='utf-8') as f:
log_text = f.read()
# 提取 epoch, step, loss
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
matches = pattern.findall(log_text)
global_steps = []
losses = []
for epoch, step, loss in matches:
epoch = int(epoch)
step = int(step)
global_step = (epoch - 1) * steps_per_epoch + step
global_steps.append(global_step)
losses.append(float(loss))
plt.figure(figsize=(12, 6))
plt.plot(global_steps, losses, label='Loss')
plt.xlabel('Global Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('out/loss_curve.png')
plt.show()