Source code for pyserep.visualization.summary_plots

"""
pyserep.visualization.summary_plots
=======================================
Performance dashboard and summary visualisations.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import numpy as np

if TYPE_CHECKING:
    from pyserep.pipeline.serep_pipeline import PipelineResults


[docs] def plot_performance_dashboard( results: "PipelineResults", save_path: Optional[str] = None, show: bool = False, ) -> None: """ 6-panel performance dashboard. Panels: 1. Mode selection cascade (bar chart: how many modes each step passes) 2. Selected mode frequency distribution 3. Eigenvalue error per mode 4. DOF reduction summary (pie chart) 5. Timing breakdown (bar chart) 6. FLOP comparison ROM vs reference """ try: import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt except ImportError: return perf = results.performance freqs = results.freqs_hz modes = results.selected_modes fig = plt.figure(figsize=(14, 10)) gs = gridspec.GridSpec(2, 3, hspace=0.5, wspace=0.4) axes = [fig.add_subplot(gs[i // 3, i % 3]) for i in range(6)] fig.suptitle("SEREP ROM — Performance Dashboard", fontsize=12, fontweight="bold") # ── 1. Mode frequency distribution ─────────────────────────────────────── ax = axes[0] ax.hist(freqs[modes], bins=20, color="steelblue", edgecolor="white", lw=0.5) if results.config.effective_bands: for band in results.config.effective_bands: ax.axvspan(band.f_min, band.f_max, color="lightblue", alpha=0.3) ax.set_xlabel("Frequency (Hz)", fontsize=9) ax.set_ylabel("Count", fontsize=9) ax.set_title(f"Selected Modes ({len(modes)} total)", fontsize=9, fontweight="bold") # ── 2. Eigenvalue preservation error ───────────────────────────────────── ax = axes[1] if results.freq_errors is not None and len(results.freq_errors) > 0: f_sel = np.sort(freqs[modes])[:len(results.freq_errors)] ax.semilogy(f_sel, np.maximum(results.freq_errors, 1e-12), "o-", ms=4, color="coral", lw=1.2) ax.axhline(0.01, color="navy", ls="--", lw=0.8, alpha=0.7, label="0.01%") ax.axhline(1e-6, color="green", ls=":", lw=0.8, alpha=0.7, label="10⁻⁶%") ax.legend(fontsize=7) ax.set_xlabel("Frequency (Hz)", fontsize=9) ax.set_ylabel("Error (%)", fontsize=9) ax.set_title("Eigenvalue Preservation Error", fontsize=9, fontweight="bold") ax.grid(True, which="both", alpha=0.3) # ── 3. Reduction summary (horizontal bar) ───────────────────────────────── ax = axes[2] N = results.phi.shape[0] m = len(modes) a = len(results.master_dofs) labels = ["Full DOFs", "Master DOFs", "Modes\n(of all computed)"] values = [N, a, m] colours = ["#e74c3c", "#27ae60", "#3498db"] bars = ax.barh(labels, values, color=colours, edgecolor="black", lw=0.5) for bar, val in zip(bars, values): ax.text(bar.get_width() * 1.01, bar.get_y() + bar.get_height() / 2, f"{val:,}", va="center", fontsize=8) ax.set_xlabel("Count", fontsize=9) ax.set_title( f"ROM Reduction — {a/N*100:.4f}% DOF retention", fontsize=9, fontweight="bold" ) # ── 4. Timing breakdown ─────────────────────────────────────────────────── ax = axes[3] timing_keys = ["eigensolver", "mode_select", "dof_select", "rom_build", "frf"] timing_labels = ["Eigensolver", "Mode\nSelect", "DOF\nSelect", "ROM\nBuild", "FRF"] timing_vals = [perf.get(f"t_{k}_s", 0.0) for k in timing_keys] colours_t = ["#9b59b6", "#3498db", "#1abc9c", "#e67e22", "#e74c3c"] ax.bar(timing_labels, timing_vals, color=colours_t, edgecolor="black", lw=0.5) for x, val in enumerate(timing_vals): if val > 0.01: ax.text(x, val + max(timing_vals) * 0.01, f"{val:.2f}s", ha="center", fontsize=7) ax.set_ylabel("Time (s)", fontsize=9) ax.set_title(f"Timing Breakdown (total: {perf.get('t_total_s', 0):.2f}s)", fontsize=9, fontweight="bold") ax.grid(True, axis="y", alpha=0.3) # ── 5. FLOP comparison ──────────────────────────────────────────────────── ax = axes[4] flop_rom = perf.get("frf_flops_rom", 0) flop_ref = perf.get("frf_flops_ref", 0) if flop_ref > 0 and flop_rom > 0: ax.bar(["ROM\n(direct)", "Reference\n(modal)"], [flop_rom, flop_ref], color=["#27ae60", "#e74c3c"], edgecolor="black", lw=0.5) ax.set_yscale("log") speedup = flop_ref / flop_rom ax.set_title(f"FRF FLOPs (speedup: {speedup:.1f}×)", fontsize=9, fontweight="bold") ax.set_ylabel("FLOPs (log)", fontsize=9) ax.grid(True, axis="y", alpha=0.3) # ── 6. Condition number comparison ──────────────────────────────────────── ax = axes[5] kappa_text = ( f"κ(Φₐ) = {results.kappa:.4e}\n\n" f"Max FRF error : {max((e.get('max_pct', 0) for e in results.frf.errors.values()), default=0):.6f}%\n" # noqa: E501 f"RMS FRF error : {max((e.get('rms_pct', 0) for e in results.frf.errors.values()), default=0):.6f}%\n" # noqa: E501 f"Max freq error: {results.max_freq_err:.8f}%\n\n" f"Bands : {results.config.n_bands}\n" f"DOFs : {N:,}{a:,}" ) ax.text(0.5, 0.5, kappa_text, ha="center", va="center", transform=ax.transAxes, fontsize=9, fontfamily="monospace", bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.8)) ax.set_title("Key Metrics Summary", fontsize=9, fontweight="bold") ax.axis("off") if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") print(f"[plot] Dashboard saved: {save_path}") if show: plt.show() plt.close(fig)