112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
|
import re
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
def parse_log_file(file_path):
|
||
|
"""
|
||
|
Parse the training log file to extract epoch, step, and loss information.
|
||
|
"""
|
||
|
# Regular expression to match log entries with loss information
|
||
|
pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)'
|
||
|
|
||
|
epochs = []
|
||
|
steps = []
|
||
|
losses = []
|
||
|
|
||
|
try:
|
||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||
|
log_content = f.read()
|
||
|
|
||
|
# Find all matches
|
||
|
matches = re.findall(pattern, log_content)
|
||
|
|
||
|
for match in matches:
|
||
|
epoch = int(match[0])
|
||
|
step = int(match[1])
|
||
|
loss = float(match[2])
|
||
|
|
||
|
epochs.append(epoch)
|
||
|
steps.append(step)
|
||
|
losses.append(loss)
|
||
|
|
||
|
return epochs, steps, losses
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"Error parsing log file: {e}")
|
||
|
return [], [], []
|
||
|
|
||
|
def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'):
|
||
|
"""
|
||
|
Plot the loss curve and save it to a file.
|
||
|
"""
|
||
|
plt.figure(figsize=(12, 6))
|
||
|
|
||
|
# Create continuous steps for better visualization
|
||
|
continuous_steps = []
|
||
|
current_max_step = 0
|
||
|
prev_epoch = 0
|
||
|
|
||
|
for i, (e, s) in enumerate(zip(epochs, steps)):
|
||
|
if e > prev_epoch:
|
||
|
# New epoch starts
|
||
|
if i > 0:
|
||
|
current_max_step = continuous_steps[-1]
|
||
|
prev_epoch = e
|
||
|
continuous_steps.append(s + current_max_step)
|
||
|
|
||
|
# 修改:减小线条宽度和点的大小
|
||
|
plt.plot(continuous_steps, losses, marker='.', linestyle='-',
|
||
|
color='#1f77b4', markersize=3, linewidth=0.8)
|
||
|
|
||
|
plt.title('Training Loss Over Steps', fontsize=16)
|
||
|
plt.xlabel('Steps (Continuous)', fontsize=14)
|
||
|
plt.ylabel('Loss', fontsize=14)
|
||
|
plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5)
|
||
|
|
||
|
# 修改:减小红线宽度
|
||
|
for i in range(1, len(epochs)):
|
||
|
if epochs[i] > epochs[i-1]:
|
||
|
plt.axvline(x=continuous_steps[i], color='r',
|
||
|
linestyle='--', alpha=0.5, linewidth=0.8)
|
||
|
|
||
|
unique_epochs = sorted(set(epochs))
|
||
|
# Add epoch labels
|
||
|
for e in unique_epochs:
|
||
|
indices = [i for i, epoch in enumerate(epochs) if epoch == e]
|
||
|
if indices:
|
||
|
mid_idx = indices[len(indices) // 2]
|
||
|
plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}',
|
||
|
horizontalalignment='center', verticalalignment='center',
|
||
|
fontsize=10,
|
||
|
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})
|
||
|
|
||
|
# 移除悬停注释,简化图表
|
||
|
# for i, (e, s, l) in enumerate(zip(epochs, steps, losses)):
|
||
|
# plt.annotate(...)
|
||
|
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(output_file, dpi=300)
|
||
|
print(f"Loss curve saved as {output_file}")
|
||
|
|
||
|
# Also display the data in a table format
|
||
|
print("\nExtracted training data:")
|
||
|
print("-" * 50)
|
||
|
print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}")
|
||
|
print("-" * 50)
|
||
|
for e, s, l in zip(epochs, steps, losses):
|
||
|
print(f"{e:<10}{s:<10}{l:<15.6f}")
|
||
|
|
||
|
def main():
|
||
|
# Specify the path to your log file
|
||
|
log_file_path = 'out/train.log'
|
||
|
|
||
|
# Parse the log file
|
||
|
epochs, steps, losses = parse_log_file(log_file_path)
|
||
|
|
||
|
if epochs and steps and losses:
|
||
|
plot_loss_curve(epochs, steps, losses)
|
||
|
else:
|
||
|
print("No data extracted from log file. Please check if the file format is correct.")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|