amygdala: select top-k eigenvectors AFTER PCA, not per-story truncation

Kent: 'full rank is going to give you everything — you still have to
select down, but you can do that /after/ PCA'.

Previously I was discarding per-story via k=20 truncation of SVD.
That destroyed per-head discriminability before we ever saw the
eigenvalue spectrum. Then the alternative 'keep full rank' run
accumulated too many shared directions, making the top-1 eigenvector
arbitrary within a flat spectrum.

Correct approach: keep per-story subspaces at full rank (no info
loss) and select k eigenvectors of M = M_pos - M_base at the final
step, weighted sum by eigenvalue. This captures the multi-dimensional
shared subspace when the spectrum is flat (common case), and reduces
to the top-1 behavior when the spectrum has a clear gap.

New --subspace-eigen-k flag (default 5). Clamps negative weights to 0
so wrong-sign directions don't contribute.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 21:49:21 -04:00
parent 2411925700
commit 1443d08dc7

View file

@ -282,16 +282,22 @@ def _subspace_concept_direction(
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
base_V: list[torch.Tensor], base_V: list[torch.Tensor],
hidden: int, hidden: int,
*,
top_k: int = 5,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Subspace-common-direction CAA alternative. """Subspace-common-direction CAA alternative.
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
same over baselines. Returns the top eigenvector of (M_pos - M_base) same over baselines. Returns a weighted sum of the top-k eigenvectors of
the direction most-common to positives after subtracting what's generic (M_pos - M_base), weights = eigenvalues (so stronger common directions
across baselines plus its eigenvalue spectrum (for diagnostics). contribute more), unit-normed. Returns the full eigenvalue spectrum for
diagnostics.
The top eigenvalue approaches 1 if the concept appears in every positive top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
story's subspace with unit weight and is absent from the baseline. captures richer structure when the concept lives in a multi-dimensional
shared subspace which the flat eigenvalue spectrum observed in
practice suggests is the common case. Selection happens AFTER the
eigendecomposition so nothing is lost up to that point.
""" """
device = pos_V[0].device if pos_V else torch.device("cpu") device = pos_V[0].device if pos_V else torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
@ -310,13 +316,18 @@ def _subspace_concept_direction(
M_base = acc(base_V) M_base = acc(base_V)
M = M_pos - M_base M = M_pos - M_base
# Symmetric eigendecomposition — top eigenvalue/vector. # Symmetric eigendecomposition.
eigvals, eigvecs = torch.linalg.eigh(M) eigvals, eigvecs = torch.linalg.eigh(M)
# eigh returns ascending; top is the last column. # eigh returns ascending; top-k are the last k columns.
top_vec = eigvecs[:, -1] k = max(1, min(top_k, eigvecs.shape[1]))
# Unit-norm (eigvecs are unit already, but defensively). top_vals = eigvals[-k:] # [k], ascending within top-k
top_vec = top_vec / top_vec.norm().clamp_min(1e-6) top_vecs = eigvecs[:, -k:] # [hidden, k]
return top_vec, eigvals # Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
# negative weights to 0 (wrong-sign directions shouldn't contribute).
w = top_vals.clamp_min(0.0)
combined = top_vecs @ w # [hidden]
combined = combined / combined.norm().clamp_min(1e-6)
return combined, eigvals
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
@ -858,6 +869,17 @@ def main() -> None:
"residual and ~500-token stories, that's ~500 vectors per story. " "residual and ~500-token stories, that's ~500 vectors per story. "
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.", "Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
) )
ap.add_argument(
"--subspace-eigen-k",
type=int,
default=5,
help="Number of top eigenvectors of M_pos - M_base to combine into "
"the concept direction. Weighted sum by eigenvalue (so strongest "
"common directions contribute most). eigen_k=1 recovers "
"single-eigenvector behavior. Higher values (5-10) capture "
"richer structure when the concept's shared-subspace spectrum "
"is flat (which it tends to be in practice).",
)
ap.add_argument( ap.add_argument(
"--quality-report", "--quality-report",
action="store_true", action="store_true",
@ -1045,6 +1067,7 @@ def main() -> None:
base_V += [bs[target_l] for bs in (base_subspaces or [])] base_V += [bs[target_l] for bs in (base_subspaces or [])]
top_vec, eigvals = _subspace_concept_direction( top_vec, eigvals = _subspace_concept_direction(
pos_V, base_V, hidden=hidden_dim, pos_V, base_V, hidden=hidden_dim,
top_k=args.subspace_eigen_k,
) )
per_layer_vectors[l_idx, e_idx] = top_vec per_layer_vectors[l_idx, e_idx] = top_vec
# Keep the top-20 eigenvalues for quality-report diagnostics. # Keep the top-20 eigenvalues for quality-report diagnostics.