33 lines
843 B
Python
33 lines
843 B
Python
|
import re
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
log_file = 'out/train.log'
|
||
|
steps_per_epoch = 58880 # 你需要根据实际日志设置
|
||
|
|
||
|
with open(log_file, 'r', encoding='utf-8') as f:
|
||
|
log_text = f.read()
|
||
|
|
||
|
# 提取 epoch, step, loss
|
||
|
pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE)
|
||
|
matches = pattern.findall(log_text)
|
||
|
|
||
|
global_steps = []
|
||
|
losses = []
|
||
|
|
||
|
for epoch, step, loss in matches:
|
||
|
epoch = int(epoch)
|
||
|
step = int(step)
|
||
|
global_step = (epoch - 1) * steps_per_epoch + step
|
||
|
global_steps.append(global_step)
|
||
|
losses.append(float(loss))
|
||
|
|
||
|
plt.figure(figsize=(12, 6))
|
||
|
plt.plot(global_steps, losses, label='Loss')
|
||
|
plt.xlabel('Global Step')
|
||
|
plt.ylabel('Loss')
|
||
|
plt.title('Training Loss Curve')
|
||
|
plt.legend()
|
||
|
plt.grid(True)
|
||
|
plt.tight_layout()
|
||
|
plt.savefig('out/loss_curve.png')
|
||
|
plt.show()
|