import re import matplotlib.pyplot as plt log_file = 'out/train.log' steps_per_epoch = 58880 # 你需要根据实际日志设置 with open(log_file, 'r', encoding='utf-8') as f: log_text = f.read() # 提取 epoch, step, loss pattern = re.compile(r'Epoch\s+(\d+)/\d+,\s+Step\s+(\d+)/\d+,\s+Loss:\s*([0-9.]+)', re.MULTILINE) matches = pattern.findall(log_text) global_steps = [] losses = [] for epoch, step, loss in matches: epoch = int(epoch) step = int(step) global_step = (epoch - 1) * steps_per_epoch + step global_steps.append(global_step) losses.append(float(loss)) plt.figure(figsize=(12, 6)) plt.plot(global_steps, losses, label='Loss') plt.xlabel('Global Step') plt.ylabel('Loss') plt.title('Training Loss Curve') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig('out/loss_curve.png') plt.show()