import re
import matplotlib.pyplot as plt

log_file = 'out/train.log'
steps_per_epoch = 58880  # 你需要根据实际日志设置

with open(log_file, 'r', encoding='utf-8') as f:
    log_text = f.read()

# 提取 epoch, step, loss
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
matches = pattern.findall(log_text)

global_steps = []
losses = []

for epoch, step, loss in matches:
    epoch = int(epoch)
    step = int(step)
    global_step = (epoch - 1) * steps_per_epoch + step
    global_steps.append(global_step)
    losses.append(float(loss))

plt.figure(figsize=(12, 6))
plt.plot(global_steps, losses, label='Loss')
plt.xlabel('Global Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('out/loss_curve.png')
plt.show()