import re import matplotlib.pyplot as plt import numpy as np def parse_log_file(file_path): """ Parse the training log file to extract epoch, step, and loss information. """ # Regular expression to match log entries with loss information pattern = r'\[.*?\] Epoch (\d+)/\d+, Step (\d+)/\d+, Loss: ([\d\.]+)' epochs = [] steps = [] losses = [] try: with open(file_path, 'r', encoding='utf-8') as f: log_content = f.read() # Find all matches matches = re.findall(pattern, log_content) for match in matches: epoch = int(match[0]) step = int(match[1]) loss = float(match[2]) epochs.append(epoch) steps.append(step) losses.append(loss) return epochs, steps, losses except Exception as e: print(f"Error parsing log file: {e}") return [], [], [] def plot_loss_curve(epochs, steps, losses, output_file='loss_curve.png'): """ Plot the loss curve and save it to a file. """ plt.figure(figsize=(12, 6)) # Create continuous steps for better visualization continuous_steps = [] current_max_step = 0 prev_epoch = 0 for i, (e, s) in enumerate(zip(epochs, steps)): if e > prev_epoch: # New epoch starts if i > 0: current_max_step = continuous_steps[-1] prev_epoch = e continuous_steps.append(s + current_max_step) # 修改:减小线条宽度和点的大小 plt.plot(continuous_steps, losses, marker='.', linestyle='-', color='#1f77b4', markersize=3, linewidth=0.8) plt.title('Training Loss Over Steps', fontsize=16) plt.xlabel('Steps (Continuous)', fontsize=14) plt.ylabel('Loss', fontsize=14) plt.grid(True, linestyle='--', alpha=0.5, linewidth=0.5) # 修改:减小红线宽度 for i in range(1, len(epochs)): if epochs[i] > epochs[i-1]: plt.axvline(x=continuous_steps[i], color='r', linestyle='--', alpha=0.5, linewidth=0.8) unique_epochs = sorted(set(epochs)) # Add epoch labels for e in unique_epochs: indices = [i for i, epoch in enumerate(epochs) if epoch == e] if indices: mid_idx = indices[len(indices) // 2] plt.text(continuous_steps[mid_idx], max(losses) * 0.95, f'Epoch {e}', horizontalalignment='center', verticalalignment='center', fontsize=10, bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 3}) # 移除悬停注释,简化图表 # for i, (e, s, l) in enumerate(zip(epochs, steps, losses)): # plt.annotate(...) plt.tight_layout() plt.savefig(output_file, dpi=300) print(f"Loss curve saved as {output_file}") # Also display the data in a table format print("\nExtracted training data:") print("-" * 50) print(f"{'Epoch':<10}{'Step':<10}{'Loss':<15}") print("-" * 50) for e, s, l in zip(epochs, steps, losses): print(f"{e:<10}{s:<10}{l:<15.6f}") def main(): # Specify the path to your log file log_file_path = 'out/train.log' # Parse the log file epochs, steps, losses = parse_log_file(log_file_path) if epochs and steps and losses: plot_loss_curve(epochs, steps, losses) else: print("No data extracted from log file. Please check if the file format is correct.") if __name__ == "__main__": main()