Source code for att.viz.plotting

"""Publication-quality plotting utilities."""

import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib


[docs] def plot_persistence_diagram( diagrams: list[np.ndarray], ax: matplotlib.axes.Axes | None = None, colormap: str = "viridis", ) -> matplotlib.figure.Figure: """Plot persistence diagrams for all homology dimensions.""" if ax is None: fig, ax = plt.subplots(figsize=(6, 6)) else: fig = ax.get_figure() cmap = plt.get_cmap(colormap) colors = [cmap(i / max(len(diagrams) - 1, 1)) for i in range(len(diagrams))] all_vals = [] for dgm in diagrams: if len(dgm) > 0: all_vals.extend(dgm.ravel()) if all_vals: vmin, vmax = min(all_vals), max(all_vals) else: vmin, vmax = 0, 1 # Diagonal ax.plot([vmin, vmax], [vmin, vmax], "k--", alpha=0.3, linewidth=1) for dim, dgm in enumerate(diagrams): if len(dgm) > 0: ax.scatter( dgm[:, 0], dgm[:, 1], c=[colors[dim]] * len(dgm), label=f"H{dim}", s=20, alpha=0.7, edgecolors="k", linewidths=0.3, ) ax.set_xlabel("Birth") ax.set_ylabel("Death") ax.set_title("Persistence Diagram") ax.legend() ax.set_aspect("equal") return fig
[docs] def plot_persistence_image( images: list[np.ndarray], ax: matplotlib.axes.Axes | None = None, colormap: str = "hot", ) -> matplotlib.figure.Figure: """Plot persistence images for all homology dimensions.""" n = len(images) if ax is not None: fig = ax.get_figure() axes = [ax] else: fig, axes = plt.subplots(1, n, figsize=(5 * n, 4)) if n == 1: axes = [axes] for i, (img, ax_) in enumerate(zip(images, axes)): im = ax_.imshow(img, cmap=colormap, origin="lower", aspect="auto") ax_.set_title(f"H{i} Persistence Image") fig.colorbar(im, ax=ax_, fraction=0.046, pad=0.04) fig.tight_layout() return fig
[docs] def plot_barcode( diagrams: list[np.ndarray], ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.figure.Figure: """Plot persistence barcodes.""" if ax is None: fig, ax = plt.subplots(figsize=(8, 5)) else: fig = ax.get_figure() colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"] y_offset = 0 for dim, dgm in enumerate(diagrams): if len(dgm) == 0: continue # Sort by persistence (longest first) lifetimes = dgm[:, 1] - dgm[:, 0] order = np.argsort(-lifetimes) color = colors[dim % len(colors)] for idx in order: birth, death = dgm[idx] ax.plot([birth, death], [y_offset, y_offset], color=color, linewidth=1.5) y_offset += 1 ax.set_xlabel("Filtration Parameter") ax.set_ylabel("Feature") ax.set_title("Persistence Barcode") # Legend handles = [] for dim in range(len(diagrams)): if len(diagrams[dim]) > 0: handles.append(plt.Line2D([0], [0], color=colors[dim % len(colors)], label=f"H{dim}")) ax.legend(handles=handles) return fig
[docs] def plot_betti_curve( betti_curves: list[np.ndarray], ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.figure.Figure: """Plot Betti curves.""" if ax is None: fig, ax = plt.subplots(figsize=(8, 4)) else: fig = ax.get_figure() colors = ["tab:blue", "tab:orange", "tab:green", "tab:red"] for dim, curve in enumerate(betti_curves): ax.plot(curve, color=colors[dim % len(colors)], label=f{dim}") ax.set_xlabel("Filtration Index") ax.set_ylabel("Betti Number") ax.set_title("Betti Curves") ax.legend() return fig
[docs] def plot_attractor_3d( cloud: np.ndarray, color_by: str = "time", backend: str = "plotly", ): """3D scatter/line plot of an attractor point cloud. Parameters ---------- cloud : (n_points, 3+) array — uses first 3 columns color_by : "time" (color by index) backend : "plotly" or "matplotlib" """ cloud = np.asarray(cloud)[:, :3] if backend == "plotly": import plotly.graph_objects as go colors = np.arange(len(cloud)) fig = go.Figure( data=[go.Scatter3d( x=cloud[:, 0], y=cloud[:, 1], z=cloud[:, 2], mode="lines", line=dict(color=colors, colorscale="Viridis", width=2), )] ) fig.update_layout( title="Attractor", scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"), width=700, height=600, ) return fig else: fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111, projection="3d") colors = np.arange(len(cloud)) ax.scatter( cloud[:, 0], cloud[:, 1], cloud[:, 2], c=colors, cmap="viridis", s=0.5, alpha=0.5, ) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") ax.set_title("Attractor") return fig
[docs] def plot_surrogate_distribution( observed: float, surrogates: np.ndarray, ax: matplotlib.axes.Axes | None = None, ) -> matplotlib.figure.Figure: """Histogram of surrogate scores with observed score marked.""" if ax is None: fig, ax = plt.subplots(figsize=(7, 4)) else: fig = ax.get_figure() ax.hist(surrogates, bins=30, alpha=0.7, color="steelblue", edgecolor="white") ax.axvline(observed, color="red", linewidth=2, label=f"Observed = {observed:.4f}") p95 = np.percentile(surrogates, 95) ax.axvline(p95, color="orange", linewidth=1.5, linestyle="--", label=f"95th pctile = {p95:.4f}") ax.set_xlabel("Binding Score") ax.set_ylabel("Count") ax.set_title("Surrogate Distribution") ax.legend() return fig
[docs] def plot_benchmark_sweep(results, ax=None) -> matplotlib.figure.Figure: """Plot benchmark sweep with all methods overlaid. Parameters ---------- results : pd.DataFrame with columns coupling, method, score, score_normalized """ if ax is None: fig, ax = plt.subplots(figsize=(8, 5)) else: fig = ax.get_figure() methods = results["method"].unique() colors = plt.cm.tab10(np.linspace(0, 1, len(methods))) for method, color in zip(methods, colors): subset = results[results["method"] == method].sort_values("coupling") col = "score_normalized" if "score_normalized" in results.columns else "score" ax.plot(subset["coupling"], subset[col], "o-", color=color, label=method, markersize=4) ax.set_xlabel("Coupling Strength") ax.set_ylabel("Score (normalized)") ax.set_title("Coupling Benchmark Sweep") ax.legend() return fig
[docs] def plot_binding_comparison(detector) -> matplotlib.figure.Figure: """3-panel comparison: marginal X | joint (excess highlighted) | marginal Y. Parameters ---------- detector : BindingDetector with fitted state Returns ------- matplotlib Figure """ diagrams_x = detector._result_x["diagrams"] diagrams_joint = detector._result_joint["diagrams"] diagrams_y = detector._result_y["diagrams"] fig, axes = plt.subplots(1, 3, figsize=(18, 5)) titles = ["Marginal X", "Joint (excess highlighted)", "Marginal Y"] all_diagrams = [diagrams_x, diagrams_joint, diagrams_y] colors = ["tab:blue", "tab:orange", "tab:green"] # Find global axis range all_vals = [] for diag_set in all_diagrams: for dgm in diag_set: if len(dgm) > 0: all_vals.extend(dgm.ravel()) if all_vals: vmin, vmax = min(all_vals), max(all_vals) else: vmin, vmax = 0, 1 for panel_idx, (diags, ax, title) in enumerate(zip(all_diagrams, axes, titles)): ax.plot([vmin, vmax], [vmin, vmax], "k--", alpha=0.3, linewidth=1) for dim, dgm in enumerate(diags): if len(dgm) > 0: ax.scatter( dgm[:, 0], dgm[:, 1], c=colors[dim % len(colors)], label=f"H{dim}", s=20, alpha=0.7, edgecolors="k", linewidths=0.3, ) ax.set_xlabel("Birth") ax.set_ylabel("Death") ax.set_title(title) ax.legend() ax.set_aspect("equal") ax.set_xlim(vmin - 0.5, vmax + 0.5) ax.set_ylim(vmin - 0.5, vmax + 0.5) fig.tight_layout() return fig
[docs] def plot_binding_image( images: list[np.ndarray], colormap: str = "RdBu_r", ) -> matplotlib.figure.Figure: """Heatmap of residual persistence images. Parameters ---------- images : list of (resolution, resolution) residual images, one per dimension colormap : diverging colormap (red=emergent, blue=deficit) Returns ------- matplotlib Figure """ n = len(images) fig, axes = plt.subplots(1, n, figsize=(5 * n, 4)) if n == 1: axes = [axes] for i, (img, ax) in enumerate(zip(images, axes)): vmax = max(abs(img.min()), abs(img.max())) or 1.0 im = ax.imshow( img, cmap=colormap, origin="lower", aspect="auto", vmin=-vmax, vmax=vmax, ) ax.set_title(f"H{i} Binding Image") ax.set_xlabel("Birth") ax.set_ylabel("Persistence") fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) fig.tight_layout() return fig
[docs] def export_to_json(results: dict, path: str) -> None: """Export computed results as JSON.""" Path(path).parent.mkdir(parents=True, exist_ok=True) def _convert(obj): if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, dict): return {k: _convert(v) for k, v in obj.items()} if isinstance(obj, list): return [_convert(v) for v in obj] return obj with open(path, "w") as f: json.dump(_convert(results), f, indent=2)
[docs] def load_from_json(path: str) -> dict: """Load results from JSON.""" with open(path, "r") as f: return json.load(f)
[docs] def plot_transition_timeline(detector, ground_truth=None, figsize=(12, 6)): """Plot topology transition timeline from a fitted TransitionDetector. Parameters ---------- detector : TransitionDetector Must have been fit_transform()'d. ground_truth : list of int or None True transition sample indices (plotted as green dotted lines). figsize : tuple Figure size. Returns ------- matplotlib Figure """ result = detector._result if result is None: raise RuntimeError("TransitionDetector must be fitted first.") window_centers = result["window_centers"] image_distances = result["image_distances"] # image_distances has len = len(window_centers) - 1 # Use midpoints between consecutive window centers dist_x = (window_centers[:-1] + window_centers[1:]) / 2 # Compute H1 persistence entropy per window h1_entropy = [] for topo in result["topology_timeseries"]: # persistence_entropy is a list per dim if len(topo["persistence_entropy"]) > 1: h1_entropy.append(topo["persistence_entropy"][1]) else: h1_entropy.append(0.0) fig, axes = plt.subplots(2, 1, figsize=figsize, sharex=True) # Top panel: image distances + changepoints ax = axes[0] ax.plot(dist_x, image_distances, 'k-', linewidth=1.5, label='PI distance') ax.set_ylabel('Image distance (L2)') ax.set_title('Topological Transition Timeline') # Detected changepoints try: changepoints = detector.detect_changepoints() for cp in changepoints: if cp < len(dist_x): ax.axvline(dist_x[cp], color='red', linestyle='--', alpha=0.8, label='Detected' if cp == changepoints[0] else None) except Exception: pass # Ground truth if ground_truth is not None: for i, gt in enumerate(ground_truth): ax.axvline(gt, color='green', linestyle=':', alpha=0.7, linewidth=2, label='Ground truth' if i == 0 else None) ax.legend(loc='upper right') ax.grid(True, alpha=0.3) # Bottom panel: H1 persistence entropy ax = axes[1] ax.plot(window_centers, h1_entropy, 'b-', linewidth=1.5) ax.set_xlabel('Sample index') ax.set_ylabel('H1 persistence entropy') ax.set_title('Loop Complexity Over Time') if ground_truth is not None: for gt in ground_truth: ax.axvline(gt, color='green', linestyle=':', alpha=0.7, linewidth=2) ax.grid(True, alpha=0.3) plt.tight_layout() return fig