| | import os |
| | import glob |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import re |
| |
|
| | def extract_run_name(filename): |
| | """Extract the run name from the filename.""" |
| | basename = os.path.basename(filename) |
| | |
| | match = re.search(r'_([^_]+)(?:-loss)?_tensorboard\.csv$', basename) |
| | if match: |
| | return match.group(1) |
| | return basename.split('_')[1].split('-')[0] |
| |
|
| | def setup_plot_style(): |
| | """Apply publication-quality styling to plots.""" |
| | plt.rcParams.update({ |
| | 'font.family': 'serif', |
| | 'font.size': 12, |
| | 'axes.labelsize': 14, |
| | 'axes.titlesize': 16, |
| | 'legend.fontsize': 10, |
| | 'figure.dpi': 300, |
| | 'figure.figsize': (10, 6), |
| | 'lines.linewidth': 2.5, |
| | 'axes.grid': True, |
| | 'grid.linestyle': '--', |
| | 'grid.alpha': 0.6, |
| | 'axes.spines.top': False, |
| | 'axes.spines.right': False, |
| | }) |
| |
|
| | def get_metric_label(metric_name): |
| | """Return a human-readable label for the metric.""" |
| | labels = { |
| | 'loss_epoch': 'Loss', |
| | 'perplexityval_epoch': 'Validation Perplexity', |
| | 'topkacc_epoch': 'Top-K Accuracy', |
| | 'acc_trainstep': 'Training Accuracy' |
| | } |
| | return labels.get(metric_name, metric_name.replace('_', ' ').title()) |
| |
|
| | def get_color_mapping(run_names): |
| | """Create a consistent color mapping for all runs.""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | colors = [ |
| | "#e6194b", |
| | "#f58231", |
| | "#ffe119", |
| | "#bfef45", |
| | "#3cb44b", |
| | "#42d4f4", |
| | "#4363d8", |
| | "#911eb4", |
| | "#f032e6", |
| | "#a9a9a9" |
| | ] |
| | |
| | |
| | return {name: colors[i % len(colors)] for i, name in enumerate(sorted(run_names))} |
| |
|
| | def plot_metric(metric_dir, color_mapping, output_dir): |
| | """Plot all runs for a specific metric.""" |
| | metric_name = os.path.basename(metric_dir) |
| | csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
| | |
| | if not csv_files: |
| | print(f"No CSV files found in {metric_dir}") |
| | return |
| | |
| | plt.figure(figsize=(12, 7)) |
| | |
| | for csv_file in sorted(csv_files): |
| | try: |
| | |
| | df = pd.read_csv(csv_file) |
| | |
| | |
| | run_name = extract_run_name(csv_file) |
| | |
| | |
| | color = color_mapping.get(run_name, 'gray') |
| | plt.plot(df['Step'], df['Value'], label=run_name, color=color, alpha=0.9) |
| | |
| | |
| | except Exception as e: |
| | print(f"Error processing {csv_file}: {e}") |
| | |
| | |
| | plt.xlabel('Step') |
| | plt.ylabel(get_metric_label(metric_name)) |
| |
|
| | comparison = "Epoch" if "epoch" in metric_name else "Step" |
| | plt.title(f'{get_metric_label(metric_name)} vs. {comparison}', fontweight='bold') |
| | |
| | |
| | plt.legend(loc='best', frameon=True, fancybox=True, framealpha=0.9, |
| | shadow=True, borderpad=1, ncol=2 if len(csv_files) > 5 else 1) |
| | |
| | |
| | plt.grid(True, linestyle='--', alpha=0.7) |
| | |
| | |
| | plt.tight_layout() |
| | |
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| | output_path = os.path.join(output_dir, f'{metric_name}_plot.png') |
| | plt.savefig(output_path, bbox_inches='tight') |
| | print(f"Saved plot to {output_path}") |
| | |
| | |
| | plt.close() |
| |
|
| | def main(): |
| | |
| | base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs_jsons') |
| | |
| | |
| | output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'plots') |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | setup_plot_style() |
| | |
| | |
| | metric_dirs = [d for d in glob.glob(os.path.join(base_dir, '*')) if os.path.isdir(d)] |
| | |
| | |
| | all_run_names = set() |
| | for metric_dir in metric_dirs: |
| | csv_files = glob.glob(os.path.join(metric_dir, '*.csv')) |
| | for csv_file in csv_files: |
| | run_name = extract_run_name(csv_file) |
| | all_run_names.add(run_name) |
| | |
| | |
| | color_mapping = get_color_mapping(all_run_names) |
| | |
| | |
| | for metric_dir in metric_dirs: |
| | plot_metric(metric_dir, color_mapping, output_dir) |
| | |
| | print(f"All plots have been generated in {output_dir}") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|