"""Binding detection via persistence image subtraction and diagram matching."""
import warnings
import numpy as np
from att.embedding.takens import TakensEmbedder
from att.embedding.joint import JointEmbedder
from att.embedding.validation import validate_embedding, EmbeddingDegeneracyWarning
from att.topology.persistence import PersistenceAnalyzer
[docs]
class BindingDetector:
"""Detect topological binding between coupled dynamical systems.
Computes persistence images for joint and marginal embeddings, then
measures excess topology in the joint that is absent from both marginals.
Parameters
----------
max_dim : int
Maximum homology dimension (0=components, 1=loops).
method : str
"persistence_image" (PI subtraction) or "diagram_matching"
(optimal matching via Hungarian algorithm).
image_resolution : int
Resolution of persistence images.
image_sigma : float
Gaussian kernel bandwidth for persistence images.
baseline : str
"max" (conservative, pointwise max of marginals) or
"sum" (sensitive, pointwise sum of marginals).
embedding_quality_gate : bool
If True, validate all three embeddings and warn if any is degenerate.
"""
[docs]
def __init__(
self,
max_dim: int = 1,
method: str = "persistence_image",
image_resolution: int = 50,
image_sigma: float = 0.1,
baseline: str = "max",
embedding_quality_gate: bool = True,
):
if method not in ("persistence_image", "diagram_matching"):
raise ValueError(
f"Unknown method: {method}. "
"Use 'persistence_image' or 'diagram_matching'."
)
if baseline not in ("max", "sum"):
raise ValueError(f"Unknown baseline: {baseline}. Use 'max' or 'sum'.")
self.max_dim = max_dim
self.method = method
self.image_resolution = image_resolution
self.image_sigma = image_sigma
self.baseline = baseline
self.embedding_quality_gate = embedding_quality_gate
# Fitted state
self._fitted = False
self._cloud_x: np.ndarray | None = None
self._cloud_y: np.ndarray | None = None
self._cloud_joint: np.ndarray | None = None
self._result_x: dict | None = None
self._result_y: dict | None = None
self._result_joint: dict | None = None
self._images_x: list[np.ndarray] | None = None
self._images_y: list[np.ndarray] | None = None
self._images_joint: list[np.ndarray] | None = None
self._residual_images: list[np.ndarray] | None = None
self._matching_score: float | None = None
self._matching_details: dict | None = None
self._embedding_quality: dict | None = None
self._X_raw: np.ndarray | None = None
self._Y_raw: np.ndarray | None = None
# Cached embedding params for surrogate speed optimization
self._marginal_delay_x: int | None = None
self._marginal_dim_x: int | None = None
self._marginal_delay_y: int | None = None
self._marginal_dim_y: int | None = None
self._joint_delays: list[int] | None = None
self._joint_dims: list[int] | None = None
# Ensemble state
self._ensemble_scores: np.ndarray | None = None
[docs]
def fit(
self,
X: np.ndarray,
Y: np.ndarray,
joint_embedder: JointEmbedder | None = None,
marginal_embedder_x: TakensEmbedder | None = None,
marginal_embedder_y: TakensEmbedder | None = None,
subsample: int | None = None,
seed: int | None = None,
n_ensemble: int = 1,
) -> "BindingDetector":
"""Fit the detector on two coupled time series.
Parameters
----------
X, Y : 1D time series arrays
joint_embedder : pre-configured JointEmbedder, or None for auto
marginal_embedder_x, marginal_embedder_y : pre-configured TakensEmbedders
subsample : subsample point clouds before persistence computation
seed : seed for subsampling (same seed used for all three clouds)
n_ensemble : int
Number of independent fits with different subsample seeds.
If > 1, binding_score() returns the mean, and ensemble_scores
stores individual scores. Default 1 (no ensembling).
Returns
-------
self
"""
X = np.asarray(X).ravel()
Y = np.asarray(Y).ravel()
self._X_raw = X
self._Y_raw = Y
# 1. Marginal embeddings
if marginal_embedder_x is None:
marginal_embedder_x = TakensEmbedder(delay="auto", dimension="auto")
self._cloud_x = marginal_embedder_x.fit_transform(X)
if marginal_embedder_y is None:
marginal_embedder_y = TakensEmbedder(delay="auto", dimension="auto")
self._cloud_y = marginal_embedder_y.fit_transform(Y)
# 2. Joint embedding
if joint_embedder is None:
joint_embedder = JointEmbedder(delays="auto", dimensions="auto")
self._cloud_joint = joint_embedder.fit_transform([X, Y])
# Cache fitted embedding params for surrogate reuse
self._marginal_delay_x = marginal_embedder_x.delay_
self._marginal_dim_x = marginal_embedder_x.dimension_
self._marginal_delay_y = marginal_embedder_y.delay_
self._marginal_dim_y = marginal_embedder_y.dimension_
self._joint_delays = list(joint_embedder.delays_)
self._joint_dims = list(joint_embedder.dimensions_)
# 3. Embedding quality gate
eq_x = validate_embedding(self._cloud_x)
eq_y = validate_embedding(self._cloud_y)
eq_joint = validate_embedding(self._cloud_joint)
any_degenerate = bool(eq_x["degenerate"] or eq_y["degenerate"] or eq_joint["degenerate"])
self._embedding_quality = {
"marginal_x": eq_x,
"marginal_y": eq_y,
"joint": eq_joint,
"any_degenerate": any_degenerate,
}
if self.embedding_quality_gate and any_degenerate:
degen_parts = []
if eq_x["degenerate"]:
degen_parts.append(f"marginal_x (cond={eq_x['condition_number']:.1f})")
if eq_y["degenerate"]:
degen_parts.append(f"marginal_y (cond={eq_y['condition_number']:.1f})")
if eq_joint["degenerate"]:
degen_parts.append(f"joint (cond={eq_joint['condition_number']:.1f})")
warnings.warn(
f"Degenerate embedding(s): {', '.join(degen_parts)}. "
"Binding scores may be unreliable.",
EmbeddingDegeneracyWarning,
stacklevel=2,
)
# 4. Persistence computation on all three clouds (same subsample seed)
pa_seed = seed if seed is not None else 42
pa_x = PersistenceAnalyzer(max_dim=self.max_dim)
self._result_x = pa_x.fit_transform(self._cloud_x, subsample=subsample, seed=pa_seed)
pa_y = PersistenceAnalyzer(max_dim=self.max_dim)
self._result_y = pa_y.fit_transform(self._cloud_y, subsample=subsample, seed=pa_seed)
pa_joint = PersistenceAnalyzer(max_dim=self.max_dim)
self._result_joint = pa_joint.fit_transform(
self._cloud_joint, subsample=subsample, seed=pa_seed
)
# Store analyzers for reuse in significance testing
self._pa_x = pa_x
self._pa_y = pa_y
self._pa_joint = pa_joint
if self.method == "diagram_matching":
# 5b. Compute binding score from raw persistence diagrams
self._matching_score, self._matching_details = (
self._diagram_matching_score()
)
else:
# 5a. Compute PIs on shared grid (persistence_image method)
birth_range, persistence_range = self._compute_shared_ranges(
pa_x.diagrams_, pa_y.diagrams_, pa_joint.diagrams_
)
self._images_x = pa_x.to_image(
self.image_resolution, self.image_sigma, birth_range, persistence_range
)
self._images_y = pa_y.to_image(
self.image_resolution, self.image_sigma, birth_range, persistence_range
)
self._images_joint = pa_joint.to_image(
self.image_resolution, self.image_sigma, birth_range, persistence_range
)
self._birth_range = birth_range
self._persistence_range = persistence_range
# 6a. Compute residuals
self._residual_images = []
for d in range(self.max_dim + 1):
img_joint = self._images_joint[d]
img_x = self._images_x[d]
img_y = self._images_y[d]
if self.baseline == "max":
baseline_img = np.maximum(img_x, img_y)
else: # "sum"
baseline_img = img_x + img_y
self._residual_images.append(img_joint - baseline_img)
self._fitted = True
# Ensemble: re-run persistence + scoring with different subsample seeds
if n_ensemble > 1 and subsample is not None:
base_seed = seed if seed is not None else 42
ensemble_scores = []
for k in range(n_ensemble):
ens_seed = base_seed + k
pa_ek_x = PersistenceAnalyzer(max_dim=self.max_dim)
pa_ek_x.fit_transform(self._cloud_x, subsample=subsample, seed=ens_seed)
pa_ek_y = PersistenceAnalyzer(max_dim=self.max_dim)
pa_ek_y.fit_transform(self._cloud_y, subsample=subsample, seed=ens_seed)
pa_ek_joint = PersistenceAnalyzer(max_dim=self.max_dim)
pa_ek_joint.fit_transform(self._cloud_joint, subsample=subsample, seed=ens_seed)
if self.method == "persistence_image":
br, pr = self._compute_shared_ranges(
pa_ek_x.diagrams_, pa_ek_y.diagrams_, pa_ek_joint.diagrams_
)
imgs_x = pa_ek_x.to_image(self.image_resolution, self.image_sigma, br, pr)
imgs_y = pa_ek_y.to_image(self.image_resolution, self.image_sigma, br, pr)
imgs_j = pa_ek_joint.to_image(self.image_resolution, self.image_sigma, br, pr)
score_k = 0.0
for d in range(self.max_dim + 1):
if self.baseline == "max":
bl = np.maximum(imgs_x[d], imgs_y[d])
else:
bl = imgs_x[d] + imgs_y[d]
score_k += float(np.sum(np.maximum(imgs_j[d] - bl, 0)))
ensemble_scores.append(score_k)
self._ensemble_scores = np.array(ensemble_scores)
return self
[docs]
def binding_score(self) -> float:
"""Binding score (higher = more emergent topology).
For persistence_image method: L1 norm of positive residual.
For diagram_matching method: optimal matching cost between joint
and concatenated marginal persistence diagrams.
If fit() was called with n_ensemble > 1, returns the ensemble mean.
Returns
-------
float : binding score
"""
self._check_fitted()
if self._ensemble_scores is not None:
return float(np.mean(self._ensemble_scores))
if self.method == "diagram_matching":
return self._matching_score
score = 0.0
for residual in self._residual_images:
positive = np.maximum(residual, 0)
score += float(np.sum(positive))
return score
@property
def ensemble_scores(self) -> np.ndarray | None:
"""Individual scores from ensemble fitting, or None if n_ensemble=1."""
return self._ensemble_scores
[docs]
def confidence_interval(self, confidence: float = 0.95) -> tuple[float, float] | None:
"""Bootstrap percentile confidence interval from ensemble scores.
Parameters
----------
confidence : float
Confidence level (default 0.95 for 95% CI).
Returns
-------
(lower, upper) tuple, or None if no ensemble was run.
"""
if self._ensemble_scores is None or len(self._ensemble_scores) < 2:
return None
alpha = 1 - confidence
lo = float(np.percentile(self._ensemble_scores, 100 * alpha / 2))
hi = float(np.percentile(self._ensemble_scores, 100 * (1 - alpha / 2)))
return (lo, hi)
[docs]
def binding_features(self) -> dict:
"""Per-dimension breakdown of excess topology.
For persistence_image: {dim: {n_excess, total_persistence, max_persistence}}
For diagram_matching: {dim: {score, n_joint, n_baseline, n_unmatched}}
Returns
-------
dict : per-dimension feature dictionary
"""
self._check_fitted()
if self.method == "diagram_matching":
return self._matching_details
features = {}
for d, residual in enumerate(self._residual_images):
positive = np.maximum(residual, 0)
features[d] = {
"n_excess": int(np.sum(positive > 1e-10)),
"total_persistence": float(np.sum(positive)),
"max_persistence": float(np.max(positive)) if positive.max() > 0 else 0.0,
}
return features
[docs]
def binding_image(self) -> list[np.ndarray]:
"""Residual persistence images (joint - baseline).
Only available for the persistence_image method.
Returns
-------
list of (resolution, resolution) arrays, one per homology dimension
Raises
------
RuntimeError
If called with the diagram_matching method.
"""
self._check_fitted()
if self.method == "diagram_matching":
raise RuntimeError(
"binding_image() is not available for the 'diagram_matching' method. "
"Use binding_features() for per-dimension matching details."
)
return self._residual_images
[docs]
def embedding_quality(self) -> dict:
"""Embedding quality metrics for all three clouds.
Returns
-------
dict with keys: marginal_x, marginal_y, joint, any_degenerate
"""
self._check_fitted()
return self._embedding_quality
[docs]
def test_significance(
self,
n_surrogates: int = 100,
method: str = "phase_randomize",
seed: int | None = None,
subsample: int | None = None,
) -> dict:
"""Test significance of binding score against surrogate null distribution.
Generates surrogates of Y, recomputes binding score for each,
and computes a p-value. Reuses the cached marginal X persistence
result across all surrogates for efficiency.
Parameters
----------
n_surrogates : number of surrogate iterations
method : "phase_randomize", "time_shuffle", or "twin_surrogate"
seed : seed for surrogate generation
subsample : subsample for persistence (passed through)
Returns
-------
dict with p_value, observed_score, surrogate_scores, significant,
embedding_quality
"""
self._check_fitted()
if self.method == "diagram_matching":
raise NotImplementedError(
"Significance testing is not yet supported for the "
"'diagram_matching' method. Use method='persistence_image'."
)
from att.surrogates import phase_randomize, time_shuffle, twin_surrogate
if method == "phase_randomize":
surr_fn = phase_randomize
elif method == "time_shuffle":
surr_fn = time_shuffle
elif method == "twin_surrogate":
surr_fn = None # handled separately below
else:
raise ValueError(
f"Unknown method: {method}. "
"Use 'phase_randomize', 'time_shuffle', or 'twin_surrogate'."
)
if method == "twin_surrogate":
surr_Y = twin_surrogate(self._Y_raw, n_surrogates=n_surrogates, seed=seed)
# Twin surrogates are shorter due to embedding padding; truncate X_raw
n_surr_len = surr_Y.shape[1]
X_raw_trimmed = self._X_raw[:n_surr_len]
else:
surr_Y = surr_fn(self._Y_raw, n_surrogates=n_surrogates, seed=seed)
X_raw_trimmed = self._X_raw
observed = self.binding_score()
surrogate_scores = np.empty(n_surrogates)
# Cache: marginal X images are the same for every surrogate
cached_images_x = self._images_x
for i in range(n_surrogates):
surr_seed = (seed + i + 1) if seed is not None else None
score = self._compute_surrogate_score(
X_raw_trimmed, surr_Y[i], cached_images_x,
subsample=subsample, seed=surr_seed,
)
surrogate_scores[i] = score
# p-value: proportion of surrogates >= observed (with continuity correction)
p_value = (np.sum(surrogate_scores >= observed) + 1) / (n_surrogates + 1)
# Z-score: calibrated effect size against surrogate null
surr_mean = float(np.mean(surrogate_scores))
surr_std = float(np.std(surrogate_scores, ddof=1)) if n_surrogates > 1 else 1.0
z_score = (observed - surr_mean) / surr_std if surr_std > 1e-10 else 0.0
calibrated_score = observed - surr_mean
return {
"p_value": float(p_value),
"observed_score": observed,
"surrogate_scores": surrogate_scores,
"surrogate_mean": surr_mean,
"surrogate_std": surr_std,
"z_score": z_score,
"calibrated_score": calibrated_score,
"significant": p_value < 0.05,
"embedding_quality": self._embedding_quality,
}
def _diagram_matching_score(self) -> tuple[float, dict]:
"""Compute binding score via optimal diagram matching.
Uses the Hungarian algorithm to find the minimum-cost assignment
between persistence diagram features of the joint embedding and
the concatenated marginals. Unmatched joint features (those assigned
to the diagonal) and poor matches contribute to the binding score.
Returns
-------
total_score : float
Sum of assignment costs across all homology dimensions.
details : dict
Per-dimension breakdown: {dim: {score, n_joint, n_baseline, n_unmatched}}.
"""
from scipy.optimize import linear_sum_assignment
total = 0.0
details = {}
for d in range(self.max_dim + 1):
joint_dgm = self._pa_joint.diagrams_[d]
baseline_dgm = np.concatenate([
self._pa_x.diagrams_[d],
self._pa_y.diagrams_[d],
])
# Filter zero-persistence features
if len(joint_dgm) > 0:
joint_pers = joint_dgm[:, 1] - joint_dgm[:, 0]
joint_dgm = joint_dgm[joint_pers > 1e-10]
joint_pers = joint_dgm[:, 1] - joint_dgm[:, 0] if len(joint_dgm) > 0 else np.array([])
else:
joint_pers = np.array([])
if len(baseline_dgm) > 0:
baseline_pers = baseline_dgm[:, 1] - baseline_dgm[:, 0]
baseline_dgm = baseline_dgm[baseline_pers > 1e-10]
baseline_pers = baseline_dgm[:, 1] - baseline_dgm[:, 0] if len(baseline_dgm) > 0 else np.array([])
else:
baseline_pers = np.array([])
n_j = len(joint_dgm)
n_b = len(baseline_dgm)
if n_j == 0:
# No joint features: baseline features match to diagonal
# but that cost doesn't reflect binding, so score is 0
details[d] = {
"score": 0.0,
"n_joint": 0,
"n_baseline": n_b,
"n_unmatched": 0,
}
continue
if n_b == 0:
# All joint features are unmatched (sent to diagonal)
score_d = float(np.sum(joint_pers) / 2)
details[d] = {
"score": score_d,
"n_joint": n_j,
"n_baseline": 0,
"n_unmatched": n_j,
}
total += score_d
continue
# Build augmented cost matrix (n_j + n_b) x (n_j + n_b)
#
# Layout:
# Columns 0..n_b-1 : real baseline features
# Columns n_b..n_b+n_j-1: diagonal slots for joint features
#
# Rows 0..n_j-1 : real joint features
# Rows n_j..n_j+n_b-1 : diagonal slots for baseline features
#
# Top-left (n_j x n_b): L∞ distance between joint[i] and baseline[j]
# Top-right (n_j x n_j): joint[i] matched to diagonal, cost = pers_i/2
# only slot (i, n_b+i) is finite
# Bot-left (n_b x n_b): baseline[j] matched to diagonal, cost = pers_j/2
# only slot (j, j) is finite (row=n_j+j, col=j)
# Bot-right (n_b x n_j): diagonal-to-diagonal padding, zero cost
N = n_j + n_b
cost = np.full((N, N), np.inf)
# Top-left: real-to-real matching (vectorized)
birth_diff = np.abs(joint_dgm[:, 0:1] - baseline_dgm[:, 0].reshape(1, -1))
death_diff = np.abs(joint_dgm[:, 1:2] - baseline_dgm[:, 1].reshape(1, -1))
cost[:n_j, :n_b] = np.maximum(birth_diff, death_diff)
# Top-right: joint[i] to diagonal — only diagonal entries
for i in range(n_j):
cost[i, n_b + i] = joint_pers[i] / 2
# Bottom-left: baseline[j] to diagonal — only diagonal entries
for j in range(n_b):
cost[n_j + j, j] = baseline_pers[j] / 2
# Bottom-right: diagonal-to-diagonal (zero cost padding)
cost[n_j:, n_b:] = 0.0
row_ind, col_ind = linear_sum_assignment(cost)
total_cost = float(cost[row_ind, col_ind].sum())
# Count joint features matched to diagonal (col >= n_b)
n_unmatched = sum(
1 for i, j in zip(row_ind, col_ind) if i < n_j and j >= n_b
)
details[d] = {
"score": total_cost,
"n_joint": n_j,
"n_baseline": n_b,
"n_unmatched": n_unmatched,
}
total += total_cost
return total, details
def _compute_surrogate_score(
self,
X_raw: np.ndarray,
Y_surr: np.ndarray,
cached_images_x: list[np.ndarray],
subsample: int | None = None,
seed: int | None = None,
) -> float:
"""Compute binding score for a single surrogate Y, reusing marginal X.
Reuses the embedding parameters (delay, dimension) estimated during
fit() rather than re-estimating for each surrogate. This eliminates
redundant AMI/FNN computation and ensures consistent embedding geometry.
"""
pa_seed = seed if seed is not None else 42
# Marginal Y embedding — reuse fitted params (skip AMI/FNN)
emb_y = TakensEmbedder(
delay=self._marginal_delay_y,
dimension=self._marginal_dim_y,
)
cloud_y = emb_y.fit_transform(Y_surr)
# Joint embedding — reuse fitted params (skip AMI/FNN)
je = JointEmbedder(
delays=self._joint_delays,
dimensions=self._joint_dims,
)
cloud_joint = je.fit_transform([X_raw, Y_surr])
# Persistence for Y and joint
pa_y = PersistenceAnalyzer(max_dim=self.max_dim)
pa_y.fit_transform(cloud_y, subsample=subsample, seed=pa_seed)
pa_joint = PersistenceAnalyzer(max_dim=self.max_dim)
pa_joint.fit_transform(cloud_joint, subsample=subsample, seed=pa_seed)
# Compute images on the same shared grid as the original fit
images_y = pa_y.to_image(
self.image_resolution, self.image_sigma,
self._birth_range, self._persistence_range,
)
images_joint = pa_joint.to_image(
self.image_resolution, self.image_sigma,
self._birth_range, self._persistence_range,
)
# Compute residual and score
score = 0.0
for d in range(self.max_dim + 1):
if self.baseline == "max":
baseline_img = np.maximum(cached_images_x[d], images_y[d])
else:
baseline_img = cached_images_x[d] + images_y[d]
residual = images_joint[d] - baseline_img
score += float(np.sum(np.maximum(residual, 0)))
return score
[docs]
def plot_comparison(self):
"""3-panel comparison: marginal X | joint (excess) | marginal Y."""
self._check_fitted()
from att.viz.plotting import plot_binding_comparison
return plot_binding_comparison(self)
[docs]
def plot_binding_image(self):
"""Heatmap of residual persistence images."""
self._check_fitted()
from att.viz.plotting import plot_binding_image
return plot_binding_image(self._residual_images)
def _check_fitted(self):
if not self._fitted:
raise RuntimeError("Call .fit() first.")
@staticmethod
def _compute_shared_ranges(
diagrams_x: list[np.ndarray],
diagrams_y: list[np.ndarray],
diagrams_joint: list[np.ndarray],
) -> tuple[tuple[float, float], tuple[float, float]]:
"""Compute shared birth and persistence ranges across all diagrams."""
all_births = []
all_persistences = []
for diagrams in (diagrams_x, diagrams_y, diagrams_joint):
for dgm in diagrams:
if len(dgm) == 0:
continue
births = dgm[:, 0]
pers = dgm[:, 1] - dgm[:, 0]
mask = pers > 1e-10
if mask.any():
all_births.extend(births[mask])
all_persistences.extend(pers[mask])
if not all_births:
return (0.0, 1.0), (0.0, 1.0)
birth_range = (float(min(all_births)), float(max(all_births)))
persistence_range = (0.0, float(max(all_persistences)))
return birth_range, persistence_range