"""CROCKER (Contour Realization Of Computed K-dimensional hole Evolution in Rips complex).
Computes Betti number heatmaps with filtration scale on one axis and a
varying parameter (difficulty level or layer index) on the other.
"""
from __future__ import annotations
import numpy as np
from sklearn.decomposition import PCA
from att.topology.persistence import PersistenceAnalyzer
[docs]
class CROCKERMatrix:
"""Compute CROCKER matrices for LLM hidden-state topology.
Produces 2D heatmaps of Betti numbers β_k(ε, p) where ε is the
filtration radius and p is a varying parameter (difficulty level
or transformer layer index).
Parameters
----------
n_filtration_steps : int
Grid resolution along the filtration axis.
max_dim : int
Maximum homology dimension.
n_pca_components : int
PCA dimensions before PH.
subsample : int or None
Max points per cloud.
seed : int
Random seed.
"""
def __init__(
self,
n_filtration_steps: int = 100,
max_dim: int = 1,
n_pca_components: int = 50,
subsample: int | None = 200,
seed: int = 42,
):
self.n_filtration_steps = n_filtration_steps
self.max_dim = max_dim
self.n_pca_components = n_pca_components
self.subsample = subsample
self.seed = seed
# Filled by fit()
self._matrices: dict[int, np.ndarray] | None = None
self._parameter_labels: list[str] | None = None
self._filtration_range: tuple[float, float] | None = None
self._mode: str | None = None
[docs]
def fit_by_difficulty(
self,
loader,
layer: int = -1,
levels: list[int] | None = None,
) -> "CROCKERMatrix":
"""Compute CROCKER matrices with difficulty level as parameter axis.
Parameters
----------
loader : HiddenStateLoader
layer : int
Layer index to analyze (-1 = final layer).
levels : list of int or None
Levels to include (None = all).
"""
if levels is None:
levels = sorted(loader.unique_levels.tolist())
diagrams_list = []
labels = []
for level in levels:
cloud = loader.get_level_cloud(level, layer=layer)
diagrams = self._compute_diagrams(cloud)
diagrams_list.append(diagrams)
labels.append(f"L{level}")
self._build_matrices(diagrams_list)
self._parameter_labels = labels
self._mode = "difficulty"
return self
[docs]
def fit_by_layer(
self,
loader,
level: int = 1,
layers: list[int] | None = None,
) -> "CROCKERMatrix":
"""Compute CROCKER matrices with layer index as parameter axis.
Parameters
----------
loader : HiddenStateLoader
level : int
Difficulty level to analyze.
layers : list of int or None
Layer indices to include (None = all).
"""
if layers is None:
layers = list(range(loader.num_layers))
diagrams_list = []
labels = []
for layer_idx in layers:
cloud = loader.get_level_cloud(level, layer=layer_idx)
diagrams = self._compute_diagrams(cloud)
diagrams_list.append(diagrams)
labels.append(f"Ly{layer_idx}")
self._build_matrices(diagrams_list)
self._parameter_labels = labels
self._mode = "layer"
return self
def _compute_diagrams(self, cloud: np.ndarray) -> list[np.ndarray]:
"""PCA + PH on a single point cloud."""
n_pts = cloud.shape[0]
if n_pts < 3:
return [np.empty((0, 2)) for _ in range(self.max_dim + 1)]
n_comp = min(self.n_pca_components, n_pts - 1, cloud.shape[1])
pca = PCA(n_components=n_comp)
cloud_pca = pca.fit_transform(cloud)
pa = PersistenceAnalyzer(max_dim=self.max_dim, backend="ripser")
sub = min(n_pts, self.subsample) if self.subsample else None
result = pa.fit_transform(cloud_pca, subsample=sub, seed=self.seed)
return result["diagrams"]
def _build_matrices(self, diagrams_list: list[list[np.ndarray]]) -> None:
"""Build Betti matrices from a list of persistence diagrams."""
# Find global filtration range across all diagrams
all_births = []
all_deaths = []
for diagrams in diagrams_list:
for dgm in diagrams:
if len(dgm) > 0:
all_births.extend(dgm[:, 0].tolist())
all_deaths.extend(dgm[:, 1].tolist())
if not all_births:
n_params = len(diagrams_list)
self._matrices = {
d: np.zeros((self.n_filtration_steps, n_params))
for d in range(self.max_dim + 1)
}
self._filtration_range = (0.0, 1.0)
return
eps_min = min(all_births)
eps_max = max(all_deaths)
self._filtration_range = (eps_min, eps_max)
grid = np.linspace(eps_min, eps_max, self.n_filtration_steps)
n_params = len(diagrams_list)
matrices = {
d: np.zeros((self.n_filtration_steps, n_params))
for d in range(self.max_dim + 1)
}
for p_idx, diagrams in enumerate(diagrams_list):
for dim in range(self.max_dim + 1):
if dim < len(diagrams):
dgm = diagrams[dim]
for birth, death in dgm:
matrices[dim][:, p_idx] += (grid >= birth) & (
grid < death
)
self._matrices = matrices
@property
def betti_matrices(self) -> dict[int, np.ndarray]:
"""Betti matrices keyed by homology dimension.
Each matrix has shape (n_filtration_steps, n_parameters).
"""
if self._matrices is None:
raise RuntimeError("Call fit_by_difficulty() or fit_by_layer() first.")
return self._matrices
@property
def parameter_labels(self) -> list[str]:
"""Labels for the parameter axis."""
if self._parameter_labels is None:
raise RuntimeError("Call fit_by_difficulty() or fit_by_layer() first.")
return self._parameter_labels
@property
def filtration_range(self) -> tuple[float, float]:
"""(min, max) of the filtration grid."""
if self._filtration_range is None:
raise RuntimeError("Call fit_by_difficulty() or fit_by_layer() first.")
return self._filtration_range
[docs]
def pairwise_l1_distances(self, dim: int = 1) -> np.ndarray:
"""L1 distances between CROCKER slices (columns) at a given homology dim.
Returns
-------
(n_params, n_params) symmetric distance matrix.
"""
mat = self.betti_matrices[dim]
n = mat.shape[1]
dists = np.zeros((n, n))
for i in range(n):
for j in range(i + 1, n):
d = np.sum(np.abs(mat[:, i] - mat[:, j]))
dists[i, j] = d
dists[j, i] = d
return dists