"""Direction 4: Zigzag persistence across transformer layers.
Tracks topological features as they are born and die across successive layers
using zigzag persistent homology (Carlsson & de Silva 2010). Each layer's
point cloud defines a VR complex; the zigzag filtration connects consecutive
layers via their union complexes.
Requires dionysus>=2.0 (optional dependency).
Install: pip install dionysus OR pip install att-toolkit[zigzag]
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import ks_2samp
from sklearn.decomposition import PCA
if TYPE_CHECKING:
from att.llm.loader import HiddenStateLoader
try:
import dionysus
except ImportError:
dionysus = None # type: ignore[assignment]
def _require_dionysus():
if dionysus is None:
raise ImportError(
"dionysus>=2.0 is required for zigzag persistence. "
"Install with: pip install dionysus OR pip install att-toolkit[zigzag]"
)
@dataclass
class ZigzagResult:
"""Container for zigzag persistence results at one difficulty level."""
level: int
barcodes: dict[int, np.ndarray] = field(default_factory=dict)
"""dim -> (n_features, 2) array of (birth_layer, death_layer) bars."""
n_layers_used: int = 0
layer_indices: list[int] = field(default_factory=list)
[docs]
class ZigzagLayerAnalyzer:
"""Zigzag persistent homology across transformer layers.
Constructs a zigzag filtration:
VR(X_0) <-> VR(X_0 ∪ X_1) <-> VR(X_1) <-> ... <-> VR(X_{L-1})
where X_i is the point cloud at layer i. The union complexes use the
minimum pairwise distance across both layers' embeddings of each point.
Parameters
----------
max_dim : int
Maximum homology dimension (default 1 -> H0, H1).
n_pca_components : int
PCA dimension reduction before computing distances.
subsample : int or None
Subsample points per layer to manage runtime.
threshold : float or None
VR complex distance threshold. If None, uses adaptive threshold
based on data scale.
seed : int
Random seed for subsampling.
"""
def __init__(
self,
max_dim: int = 1,
n_pca_components: int = 50,
subsample: int | None = 100,
threshold: float | None = None,
seed: int = 42,
):
_require_dionysus()
self.max_dim = max_dim
self.n_pca_components = n_pca_components
self.subsample = subsample
self.threshold = threshold
self.seed = seed
[docs]
def fit(
self,
loader: HiddenStateLoader,
level: int,
layer_indices: list[int] | None = None,
) -> ZigzagResult:
"""Compute zigzag persistence across layers for a difficulty level.
Parameters
----------
loader : HiddenStateLoader
Hidden state data.
level : int
Difficulty level (1-5).
layer_indices : list of int or None
Which layers to include. If None, uses all layers.
Returns
-------
ZigzagResult with barcodes per dimension.
"""
if layer_indices is None:
layer_indices = list(range(loader.num_layers))
n_layers = len(layer_indices)
if n_layers < 2:
raise ValueError("Need at least 2 layers for zigzag persistence")
# Get point clouds per layer, subsample consistently
clouds = self._get_clouds(loader, level, layer_indices)
n_pts = clouds[0].shape[0]
# PCA reduce each cloud
clouds_pca = self._pca_reduce(clouds)
# Compute distance matrices per layer
dist_matrices = [squareform(pdist(c)) for c in clouds_pca]
# Determine threshold
thresh = self.threshold
if thresh is None:
# Adaptive: median of all pairwise distances across layers
all_dists = np.concatenate([dm[np.triu_indices(n_pts, k=1)] for dm in dist_matrices])
thresh = float(np.percentile(all_dists, 30))
# Build zigzag filtration
barcodes = self._build_and_compute_zigzag(
dist_matrices, n_pts, n_layers, thresh
)
result = ZigzagResult(
level=level,
barcodes=barcodes,
n_layers_used=n_layers,
layer_indices=layer_indices,
)
return result
def _get_clouds(
self, loader: HiddenStateLoader, level: int, layer_indices: list[int]
) -> list[np.ndarray]:
"""Extract and subsample point clouds per layer."""
rng = np.random.default_rng(self.seed)
mask = loader.get_level_mask(level)
n_total = mask.sum()
# Consistent subsampling across layers
if self.subsample and self.subsample < n_total:
problem_indices = np.where(mask)[0]
sub_idx = rng.choice(len(problem_indices), size=self.subsample, replace=False)
sub_idx.sort()
selected = problem_indices[sub_idx]
else:
selected = np.where(mask)[0]
clouds = []
for layer in layer_indices:
cloud = loader.layer_hidden[selected, layer, :]
clouds.append(cloud)
return clouds
def _pca_reduce(self, clouds: list[np.ndarray]) -> list[np.ndarray]:
"""PCA reduce each cloud independently."""
result = []
for cloud in clouds:
n_comp = min(self.n_pca_components, cloud.shape[0] - 1, cloud.shape[1])
if n_comp < cloud.shape[1]:
pca = PCA(n_components=n_comp)
result.append(pca.fit_transform(cloud))
else:
result.append(cloud)
return result
def _build_and_compute_zigzag(
self,
dist_matrices: list[np.ndarray],
n_pts: int,
n_layers: int,
threshold: float,
) -> dict[int, np.ndarray]:
"""Build the zigzag filtration and compute persistence.
The zigzag has 2*n_layers - 1 time steps:
t=0: VR(X_0)
t=1: VR(X_0 ∪ X_1) — union complex
t=2: VR(X_1)
t=3: VR(X_1 ∪ X_2) — union complex
...
t=2(L-1): VR(X_{L-1})
Vertices are always present (all times). Edges appear/disappear based on
whether they are within threshold at the corresponding layer(s).
"""
total_times = 2 * n_layers - 1
# Collect all simplices with their [appear, disappear] intervals
# Vertices: always present
simplex_list = []
times_list = []
# Add vertices — present throughout
for i in range(n_pts):
simplex_list.append(dionysus.Simplex([i]))
times_list.append([0, total_times])
# For each pair of points, determine which time intervals they form an edge
if self.max_dim >= 1:
for i in range(n_pts):
for j in range(i + 1, n_pts):
intervals = self._edge_intervals(
i, j, dist_matrices, threshold, total_times
)
if intervals:
simplex_list.append(dionysus.Simplex([i, j]))
times_list.append(intervals)
# For triangles (H1 requires 2-simplices)
if self.max_dim >= 1:
# Precompute edge presence at each time step for triangle check
edge_at_time = self._edge_presence_matrix(
n_pts, dist_matrices, threshold, total_times
)
for i in range(n_pts):
for j in range(i + 1, n_pts):
for k in range(j + 1, n_pts):
intervals = self._triangle_intervals(
i, j, k, edge_at_time, total_times
)
if intervals:
simplex_list.append(dionysus.Simplex([i, j, k]))
times_list.append(intervals)
f = dionysus.Filtration(simplex_list)
zz, dgms, cells = dionysus.zigzag_homology_persistence(f, times_list)
# Convert diagrams to arrays, mapping times back to layer indices
barcodes = {}
for dim in range(len(dgms)):
bars = []
for pt in dgms[dim]:
b, d = pt.birth, pt.death
if d == float("inf"):
d = total_times
# Map time to layer: t=0 -> layer 0, t=2 -> layer 1, etc.
birth_layer = b / 2.0
death_layer = d / 2.0
bars.append([birth_layer, death_layer])
barcodes[dim] = np.array(bars) if bars else np.empty((0, 2))
return barcodes
def _edge_intervals(
self,
i: int,
j: int,
dist_matrices: list[np.ndarray],
threshold: float,
total_times: int,
) -> list[float]:
"""Compute appearance/disappearance times for edge (i,j).
An edge is present at:
- t=2k if dist_matrices[k][i,j] <= threshold (layer k)
- t=2k+1 if min(dist_matrices[k][i,j], dist_matrices[k+1][i,j]) <= threshold (union)
"""
n_layers = len(dist_matrices)
# Determine presence at each time step
present = np.zeros(total_times, dtype=bool)
for k in range(n_layers):
# Layer time
t_layer = 2 * k
if dist_matrices[k][i, j] <= threshold:
present[t_layer] = True
# Union time (between layer k and k+1)
if k < n_layers - 1:
t_union = 2 * k + 1
min_dist = min(dist_matrices[k][i, j], dist_matrices[k + 1][i, j])
if min_dist <= threshold:
present[t_union] = True
# Convert boolean presence to [appear, disappear] pairs
return self._presence_to_intervals(present)
def _edge_presence_matrix(
self,
n_pts: int,
dist_matrices: list[np.ndarray],
threshold: float,
total_times: int,
) -> np.ndarray:
"""(n_pts, n_pts, total_times) boolean: whether edge is present at each time."""
n_layers = len(dist_matrices)
present = np.zeros((n_pts, n_pts, total_times), dtype=bool)
for k in range(n_layers):
t_layer = 2 * k
within = dist_matrices[k] <= threshold
present[:, :, t_layer] = within
if k < n_layers - 1:
t_union = 2 * k + 1
min_dist = np.minimum(dist_matrices[k], dist_matrices[k + 1])
present[:, :, t_union] = min_dist <= threshold
return present
def _triangle_intervals(
self,
i: int,
j: int,
k: int,
edge_at_time: np.ndarray,
total_times: int,
) -> list[float]:
"""Compute intervals for triangle (i,j,k): present when all 3 edges are."""
present = (
edge_at_time[i, j, :total_times]
& edge_at_time[i, k, :total_times]
& edge_at_time[j, k, :total_times]
)
return self._presence_to_intervals(present)
@staticmethod
def _presence_to_intervals(present: np.ndarray) -> list[float]:
"""Convert boolean presence array to [appear, disappear, appear, ...] list."""
intervals = []
in_interval = False
for t, p in enumerate(present):
if p and not in_interval:
intervals.append(float(t))
in_interval = True
elif not p and in_interval:
intervals.append(float(t))
in_interval = False
if in_interval:
intervals.append(float(len(present)))
return intervals
def zigzag_feature_lifetime_stats(result: ZigzagResult, dim: int = 1) -> dict:
"""Compute summary statistics on zigzag barcode lifetimes.
Parameters
----------
result : ZigzagResult
Output from ZigzagLayerAnalyzer.fit().
dim : int
Homology dimension to summarize.
Returns
-------
dict with keys: mean_lifetime, median_lifetime, max_lifetime, std_lifetime,
n_features, n_long_lived (> 2 layers).
"""
if dim not in result.barcodes or len(result.barcodes[dim]) == 0:
return {
"mean_lifetime": 0.0,
"median_lifetime": 0.0,
"max_lifetime": 0.0,
"std_lifetime": 0.0,
"n_features": 0,
"n_long_lived": 0,
}
bars = result.barcodes[dim]
lifetimes = bars[:, 1] - bars[:, 0]
# Filter out zero-lifetime features
lifetimes = lifetimes[lifetimes > 0]
if len(lifetimes) == 0:
return {
"mean_lifetime": 0.0,
"median_lifetime": 0.0,
"max_lifetime": 0.0,
"std_lifetime": 0.0,
"n_features": 0,
"n_long_lived": 0,
}
return {
"mean_lifetime": float(np.mean(lifetimes)),
"median_lifetime": float(np.median(lifetimes)),
"max_lifetime": float(np.max(lifetimes)),
"std_lifetime": float(np.std(lifetimes)),
"n_features": int(len(lifetimes)),
"n_long_lived": int(np.sum(lifetimes > 2.0)),
}
def compare_zigzag_levels(
result_a: ZigzagResult,
result_b: ZigzagResult,
dim: int = 1,
) -> dict:
"""Compare zigzag barcodes between two difficulty levels.
Parameters
----------
result_a, result_b : ZigzagResult
Zigzag results for two different levels.
dim : int
Homology dimension to compare.
Returns
-------
dict with keys: ks_statistic, ks_pvalue, mean_lifetime_diff, n_features_diff.
"""
def _lifetimes(r):
if dim not in r.barcodes or len(r.barcodes[dim]) == 0:
return np.array([0.0])
bars = r.barcodes[dim]
lt = bars[:, 1] - bars[:, 0]
lt = lt[lt > 0]
return lt if len(lt) > 0 else np.array([0.0])
lt_a = _lifetimes(result_a)
lt_b = _lifetimes(result_b)
ks_stat, ks_p = ks_2samp(lt_a, lt_b)
stats_a = zigzag_feature_lifetime_stats(result_a, dim)
stats_b = zigzag_feature_lifetime_stats(result_b, dim)
return {
"level_a": result_a.level,
"level_b": result_b.level,
"ks_statistic": float(ks_stat),
"ks_pvalue": float(ks_p),
"mean_lifetime_a": stats_a["mean_lifetime"],
"mean_lifetime_b": stats_b["mean_lifetime"],
"mean_lifetime_diff": stats_b["mean_lifetime"] - stats_a["mean_lifetime"],
"n_features_a": stats_a["n_features"],
"n_features_b": stats_b["n_features"],
"n_features_diff": stats_b["n_features"] - stats_a["n_features"],
}