Minimind/loss.py

112 lines
3.6 KiB
Python
Raw Normal View History

2025-06-17 13:01:20 +08:00
import re
import matplotlib.pyplot as plt
import numpy as np
def parse_log_file(file_path):
"""
Parse the training log file to extract epoch, step, and loss information.
"""
# Regular expression to match log entries with loss information
pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)'
epochs = []
steps = []
losses = []
try:
with open(file_path, 'r', encoding='utf-8') as f:
log_content = f.read()
# Find all matches
matches = re.findall(pattern, log_content)
for match in matches:
epoch = int(match[0])
step = int(match[1])
loss = float(match[2])
epochs.append(epoch)
steps.append(step)
losses.append(loss)
return epochs, steps, losses
except Exception as e:
print(f"Error parsing log file: {e}")
return [], [], []
def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'):
"""
Plot the loss curve and save it to a file.
"""
plt.figure(figsize=(12, 6))
# Create continuous steps for better visualization
continuous_steps = []
current_max_step = 0
prev_epoch = 0
for i, (e, s) in enumerate(zip(epochs, steps)):
if e > prev_epoch:
# New epoch starts
if i > 0:
current_max_step = continuous_steps[-1]
prev_epoch = e
continuous_steps.append(s + current_max_step)
# 修改:减小线条宽度和点的大小
plt.plot(continuous_steps, losses, marker='.', linestyle='-',
color='#1f77b4', markersize=3, linewidth=0.8)
plt.title('Training Loss Over Steps', fontsize=16)
plt.xlabel('Steps (Continuous)', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5)
# 修改:减小红线宽度
for i in range(1, len(epochs)):
if epochs[i] > epochs[i-1]:
plt.axvline(x=continuous_steps[i], color='r',
linestyle='--', alpha=0.5, linewidth=0.8)
unique_epochs = sorted(set(epochs))
# Add epoch labels
for e in unique_epochs:
indices = [i for i, epoch in enumerate(epochs) if epoch == e]
if indices:
mid_idx = indices[len(indices) // 2]
plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}',
horizontalalignment='center', verticalalignment='center',
fontsize=10,
bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3})
# 移除悬停注释,简化图表
# for i, (e, s, l) in enumerate(zip(epochs, steps, losses)):
# plt.annotate(...)
plt.tight_layout()
plt.savefig(output_file, dpi=300)
print(f"Loss curve saved as {output_file}")
# Also display the data in a table format
print("\nExtracted training data:")
print("-" * 50)
print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}")
print("-" * 50)
for e, s, l in zip(epochs, steps, losses):
print(f"{e:<10}{s:<10}{l:<15.6f}")
def main():
# Specify the path to your log file
log_file_path = 'out/train.log'
# Parse the log file
epochs, steps, losses = parse_log_file(log_file_path)
if epochs and steps and losses:
plot_loss_curve(epochs, steps, losses)
else:
print("No data extracted from log file. Please check if the file format is correct.")
if __name__ == "__main__":
main()