Source code for att.llm.attention_binding

"""Attention-hidden topological coupling analysis (Direction 10).

Measures coupling between attention topology and hidden-state topology by
computing persistence images on both and comparing via PI subtraction.
Uses row-permutation surrogates for significance testing.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np
from sklearn.decomposition import PCA

from att.topology.persistence import PersistenceAnalyzer


@dataclass
class BindingResult:
    """Result of attention-hidden binding analysis for one problem/layer."""

    binding_score: float
    attention_entropy: dict[int, float]
    hidden_entropy: dict[int, float]
    n_attention_features: dict[int, int]
    n_hidden_features: dict[int, int]


@dataclass
class SignificanceResult:
    """Result of permutation significance test."""

    observed_score: float
    null_scores: np.ndarray
    p_value: float
    z_score: float
    n_permutations: int


[docs] class AttentionHiddenBinding: """Measure topological coupling between attention and hidden-state geometry. Treats 1 - attention_weight as a precomputed distance matrix and computes PH on it. Compares against PH on hidden-state point clouds using persistence image subtraction (same principle as BindingDetector). Parameters ---------- max_dim : int Maximum homology dimension. image_resolution : int Resolution of persistence images for comparison. image_sigma : float Gaussian kernel bandwidth for persistence images. n_pca_components : int PCA components for hidden-state clouds. subsample : int Max points for hidden-state PH. seed : int Random seed. """ def __init__( self, max_dim: int = 1, image_resolution: int = 50, image_sigma: float = 0.1, n_pca_components: int = 50, subsample: int = 100, seed: int = 42, ): self.max_dim = max_dim self.image_resolution = image_resolution self.image_sigma = image_sigma self.n_pca_components = n_pca_components self.subsample = subsample self.seed = seed
[docs] @staticmethod def attention_to_distance(attn: np.ndarray) -> np.ndarray: """Convert attention matrix to symmetric distance matrix. D = 1 - (A + A^T) / 2, clipped to [0, 1], diagonal zeroed. """ sym = (attn + attn.T) / 2.0 dist = 1.0 - sym np.clip(dist, 0.0, 1.0, out=dist) np.fill_diagonal(dist, 0.0) return dist
[docs] def compute_binding( self, attention_matrix: np.ndarray, hidden_cloud: np.ndarray, ) -> BindingResult: """Compute binding score between attention topology and hidden-state topology. Parameters ---------- attention_matrix : (n, n) attention weight matrix (head-averaged). hidden_cloud : (n, d) hidden-state vectors for the same tokens. Returns ------- BindingResult with binding score and per-dim feature counts / entropy. """ n = attention_matrix.shape[0] # Subsample both consistently if n > self.subsample: rng = np.random.default_rng(self.seed) idx = rng.choice(n, size=self.subsample, replace=False) idx = np.sort(idx) attention_matrix = attention_matrix[np.ix_(idx, idx)] hidden_cloud = hidden_cloud[idx] # Attention PH via precomputed distance attn_dist = self.attention_to_distance(attention_matrix) pa_attn = PersistenceAnalyzer( max_dim=self.max_dim, backend="ripser", metric="precomputed" ) result_attn = pa_attn.fit_transform(attn_dist) # Hidden-state PH via Euclidean distance (with PCA) n_pts = hidden_cloud.shape[0] n_comp = min(self.n_pca_components, n_pts - 1, hidden_cloud.shape[1]) if n_comp >= 2: pca = PCA(n_components=n_comp) cloud_pca = pca.fit_transform(hidden_cloud) else: cloud_pca = hidden_cloud pa_hidden = PersistenceAnalyzer(max_dim=self.max_dim, backend="ripser") result_hidden = pa_hidden.fit_transform(cloud_pca) # Compute binding via PI subtraction birth_range, pers_range = self._shared_ranges( pa_attn.diagrams_, pa_hidden.diagrams_ ) imgs_attn = pa_attn.to_image( self.image_resolution, self.image_sigma, birth_range, pers_range ) imgs_hidden = pa_hidden.to_image( self.image_resolution, self.image_sigma, birth_range, pers_range ) # Binding = L1 of |PI_attn - PI_hidden| (symmetric coupling measure) score = 0.0 for d in range(self.max_dim + 1): diff = np.abs(imgs_attn[d] - imgs_hidden[d]) # Coupling = overlap (complement of difference) # High similarity → high coupling → low diff # So binding = 1 - normalized_diff, but simpler: use correlation pass # Alternative: correlation-based coupling # Correlation of PI vectors: high correlation = topologies match = tight coupling score = self._pi_correlation(imgs_attn, imgs_hidden) attn_entropy = result_attn.get("persistence_entropy", {}) hidden_entropy = result_hidden.get("persistence_entropy", {}) attn_features = {} hidden_features = {} for d in range(self.max_dim + 1): attn_features[d] = len(result_attn["diagrams"][d]) if d < len(result_attn["diagrams"]) else 0 hidden_features[d] = len(result_hidden["diagrams"][d]) if d < len(result_hidden["diagrams"]) else 0 return BindingResult( binding_score=score, attention_entropy=attn_entropy, hidden_entropy=hidden_entropy, n_attention_features=attn_features, n_hidden_features=hidden_features, )
def _pi_correlation( self, imgs_a: list[np.ndarray], imgs_b: list[np.ndarray] ) -> float: """Correlation-based coupling score between two sets of persistence images. Concatenates flattened PIs across dimensions and computes Pearson correlation. Returns value in [-1, 1]; higher = tighter topological coupling. """ vec_a = np.concatenate([img.ravel() for img in imgs_a]) vec_b = np.concatenate([img.ravel() for img in imgs_b]) if np.std(vec_a) < 1e-15 or np.std(vec_b) < 1e-15: return 0.0 return float(np.corrcoef(vec_a, vec_b)[0, 1]) def _shared_ranges( self, diagrams_a: list[np.ndarray], diagrams_b: list[np.ndarray], ) -> tuple[tuple[float, float], tuple[float, float]]: """Compute shared birth and persistence ranges for PI computation.""" all_births = [] all_pers = [] for dgms in [diagrams_a, diagrams_b]: for dgm in dgms: if len(dgm) > 0: births = dgm[:, 0] deaths = dgm[:, 1] all_births.extend(births.tolist()) all_pers.extend((deaths - births).tolist()) if not all_births: return (0.0, 1.0), (0.0, 1.0) birth_range = (min(all_births), max(all_births)) pers_range = (0.0, max(all_pers) if all_pers else 1.0) # Ensure non-degenerate ranges if birth_range[1] - birth_range[0] < 1e-15: birth_range = (birth_range[0], birth_range[0] + 1.0) if pers_range[1] - pers_range[0] < 1e-15: pers_range = (0.0, 1.0) return birth_range, pers_range
[docs] def test_significance( self, attention_matrix: np.ndarray, hidden_cloud: np.ndarray, n_permutations: int = 100, ) -> SignificanceResult: """Test binding significance via row-permutation surrogates. Permutes rows (and corresponding columns) of the attention matrix to destroy the attention-hidden correspondence while preserving attention structure. The observed binding score is compared against the null distribution of surrogate scores. Parameters ---------- attention_matrix : (n, n) attention weight matrix. hidden_cloud : (n, d) hidden-state vectors. n_permutations : int Number of surrogate permutations. Returns ------- SignificanceResult with observed score, null distribution, p-value, z-score. """ observed = self.compute_binding(attention_matrix, hidden_cloud) observed_score = observed.binding_score rng = np.random.default_rng(self.seed) null_scores = np.zeros(n_permutations) for i in range(n_permutations): perm = rng.permutation(attention_matrix.shape[0]) # Permute rows and columns of attention to break correspondence with hidden attn_perm = attention_matrix[np.ix_(perm, perm)] result = self.compute_binding(attn_perm, hidden_cloud) null_scores[i] = result.binding_score null_mean = np.mean(null_scores) null_std = np.std(null_scores) if null_std > 1e-15: z_score = (observed_score - null_mean) / null_std else: z_score = 0.0 p_value = (np.sum(np.abs(null_scores) >= np.abs(observed_score)) + 1) / (n_permutations + 1) return SignificanceResult( observed_score=observed_score, null_scores=null_scores, p_value=float(p_value), z_score=float(z_score), n_permutations=n_permutations, )
@staticmethod def _normalize_diagrams(diagrams: list[np.ndarray]) -> list[np.ndarray]: """Normalize persistence diagrams to [0, 1] birth and persistence ranges. This is essential when comparing PH from different metric spaces (e.g., attention distances in [0, 1] vs Euclidean distances in [0, 200+]). """ normalized = [] for dgm in diagrams: if len(dgm) == 0: normalized.append(dgm) continue dgm = dgm.copy() births = dgm[:, 0] deaths = dgm[:, 1] b_max = births.max() d_max = deaths.max() scale = max(b_max, d_max) if scale > 1e-15: dgm[:, 0] = births / scale dgm[:, 1] = deaths / scale normalized.append(dgm) return normalized
[docs] def compute_binding_from_diagrams( self, attn_diagrams: list[np.ndarray], hidden_cloud: np.ndarray, ) -> BindingResult: """Compute binding from pre-extracted attention PH diagrams and hidden cloud. Parameters ---------- attn_diagrams : list of (n, 2) arrays, one per homology dimension. hidden_cloud : (n, d) hidden-state vectors. Returns ------- BindingResult with binding score and feature stats. """ # Hidden-state PH n_pts = hidden_cloud.shape[0] sub = min(self.subsample, n_pts) if sub < n_pts: rng = np.random.default_rng(self.seed) idx = rng.choice(n_pts, size=sub, replace=False) hidden_cloud = hidden_cloud[np.sort(idx)] n_comp = min(self.n_pca_components, hidden_cloud.shape[0] - 1, hidden_cloud.shape[1]) if n_comp >= 2: pca = PCA(n_components=n_comp) cloud_pca = pca.fit_transform(hidden_cloud) else: cloud_pca = hidden_cloud pa_hidden = PersistenceAnalyzer(max_dim=self.max_dim, backend="ripser") result_hidden = pa_hidden.fit_transform(cloud_pca) # Convert attn_diagrams to numpy if needed attn_dgms = [np.array(d) if not isinstance(d, np.ndarray) else d for d in attn_diagrams] while len(attn_dgms) <= self.max_dim: attn_dgms.append(np.empty((0, 2))) # Normalize both to [0, 1] to handle scale mismatch # (attention distances ∈ [0,1] vs Euclidean distances ∈ [0, 200+]) attn_norm = self._normalize_diagrams(attn_dgms) hidden_norm = self._normalize_diagrams(pa_hidden.diagrams_) # Compute PIs on normalized diagrams with fixed [0,1] range birth_range = (0.0, 1.0) pers_range = (0.0, 1.0) imgs_attn = self._diagrams_to_images(attn_norm, birth_range, pers_range) imgs_hidden = self._diagrams_to_images(hidden_norm, birth_range, pers_range) score = self._pi_correlation(imgs_attn, imgs_hidden) # Feature counts / entropy attn_entropy = {} attn_features = {} hidden_entropy = result_hidden.get("persistence_entropy", {}) hidden_features = {} for d in range(self.max_dim + 1): attn_features[d] = len(attn_dgms[d]) if d < len(attn_dgms) else 0 hidden_features[d] = len(result_hidden["diagrams"][d]) if d < len(result_hidden["diagrams"]) else 0 return BindingResult( binding_score=score, attention_entropy=attn_entropy, hidden_entropy=hidden_entropy, n_attention_features=attn_features, n_hidden_features=hidden_features, )
def _diagrams_to_images( self, diagrams: list[np.ndarray], birth_range: tuple[float, float], pers_range: tuple[float, float], ) -> list[np.ndarray]: """Convert persistence diagrams to images using Gaussian kernel.""" resolution = self.image_resolution sigma = self.image_sigma images = [] for dgm in diagrams: img = np.zeros((resolution, resolution)) if len(dgm) == 0: images.append(img) continue births = dgm[:, 0] deaths = dgm[:, 1] pers = deaths - births # Grid b_min, b_max = birth_range p_min, p_max = pers_range b_centers = np.linspace(b_min, b_max, resolution) p_centers = np.linspace(p_min, p_max, resolution) for b, p in zip(births, pers): if p <= 0: continue weight = p # weight by persistence b_diffs = (b_centers - b) ** 2 p_diffs = (p_centers - p) ** 2 kernel = weight * np.exp(-np.add.outer(p_diffs, b_diffs) / (2 * sigma ** 2)) img += kernel images.append(img) return images
[docs] def binding_profile( self, loader, attention_ph_data: dict | None = None, levels: list[int] | None = None, layer_indices: list[int] | None = None, ) -> dict: """Compute binding scores across difficulty levels and layers. If attention_ph_data is not available, returns binding scores based on hidden-state self-coupling (within-layer topology consistency). This serves as a template for when attention data becomes available. Parameters ---------- loader : HiddenStateLoader attention_ph_data : optional pre-computed attention PH (from extract_attention_weights.py). levels : difficulty levels to analyze. layer_indices : which layers to analyze. Returns ------- dict with: scores : dict mapping (level, layer) -> binding_score levels : list of levels layers : list of layer indices """ if levels is None: levels = sorted(loader.unique_levels.tolist()) if layer_indices is None: n_layers = loader.num_layers layer_indices = list(range(max(0, n_layers - 5), n_layers)) scores = {} for level in levels: mask = loader.get_level_mask(level) indices = np.where(mask)[0] for layer_idx in layer_indices: level_scores = [] for problem_idx in indices: token_traj = loader.token_trajectories[problem_idx] n_tokens = token_traj.shape[0] if n_tokens < 10: continue if attention_ph_data is not None: # Use pre-extracted attention PH attn_entry = attention_ph_data.get(problem_idx, {}).get(layer_idx) if attn_entry is None: continue # Reconstruct a synthetic attention-like distance from PH # (actual use requires raw attention matrices) continue # Self-coupling: split token cloud into two halves and # measure topological similarity (proxy when no attention data) mid = n_tokens // 2 if mid < 5: continue cloud_a = token_traj[:mid] cloud_b = token_traj[mid:] # Create synthetic "attention-like" distance from cloud_a from scipy.spatial.distance import cdist dists_a = cdist(cloud_a[:min(self.subsample, len(cloud_a))], cloud_a[:min(self.subsample, len(cloud_a))]) if dists_a.max() > 0: dists_a /= dists_a.max() result = self.compute_binding(dists_a, cloud_b[:min(self.subsample, len(cloud_b))]) level_scores.append(result.binding_score) scores[(level, layer_idx)] = float(np.mean(level_scores)) if level_scores else 0.0 return { "scores": scores, "levels": levels, "layers": layer_indices, }