amygdala: select top-k eigenvectors AFTER PCA, not per-story truncation

Kent: 'full rank is going to give you everything — you still have to
select down, but you can do that /after/ PCA'.

Previously I was discarding per-story via k=20 truncation of SVD.
That destroyed per-head discriminability before we ever saw the
eigenvalue spectrum. Then the alternative 'keep full rank' run
accumulated too many shared directions, making the top-1 eigenvector
arbitrary within a flat spectrum.

Correct approach: keep per-story subspaces at full rank (no info
loss) and select k eigenvectors of M = M_pos - M_base at the final
step, weighted sum by eigenvalue. This captures the multi-dimensional
shared subspace when the spectrum is flat (common case), and reduces
to the top-1 behavior when the spectrum has a clear gap.

New --subspace-eigen-k flag (default 5). Clamps negative weights to 0
so wrong-sign directions don't contribute.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 21:49:21 -04:00
parent 2411925700
commit 1443d08dc7

View file

@ -282,16 +282,22 @@ def _subspace_concept_direction(
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
base_V: list[torch.Tensor],
hidden: int,
*,
top_k: int = 5,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Subspace-common-direction CAA alternative.
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
same over baselines. Returns the top eigenvector of (M_pos - M_base)
the direction most-common to positives after subtracting what's generic
across baselines plus its eigenvalue spectrum (for diagnostics).
same over baselines. Returns a weighted sum of the top-k eigenvectors of
(M_pos - M_base), weights = eigenvalues (so stronger common directions
contribute more), unit-normed. Returns the full eigenvalue spectrum for
diagnostics.
The top eigenvalue approaches 1 if the concept appears in every positive
story's subspace with unit weight and is absent from the baseline.
top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
captures richer structure when the concept lives in a multi-dimensional
shared subspace which the flat eigenvalue spectrum observed in
practice suggests is the common case. Selection happens AFTER the
eigendecomposition so nothing is lost up to that point.
"""
device = pos_V[0].device if pos_V else torch.device("cpu")
dtype = torch.float32
@ -310,13 +316,18 @@ def _subspace_concept_direction(
M_base = acc(base_V)
M = M_pos - M_base
# Symmetric eigendecomposition — top eigenvalue/vector.
# Symmetric eigendecomposition.
eigvals, eigvecs = torch.linalg.eigh(M)
# eigh returns ascending; top is the last column.
top_vec = eigvecs[:, -1]
# Unit-norm (eigvecs are unit already, but defensively).
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
return top_vec, eigvals
# eigh returns ascending; top-k are the last k columns.
k = max(1, min(top_k, eigvecs.shape[1]))
top_vals = eigvals[-k:] # [k], ascending within top-k
top_vecs = eigvecs[:, -k:] # [hidden, k]
# Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
# negative weights to 0 (wrong-sign directions shouldn't contribute).
w = top_vals.clamp_min(0.0)
combined = top_vecs @ w # [hidden]
combined = combined / combined.norm().clamp_min(1e-6)
return combined, eigvals
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
@ -858,6 +869,17 @@ def main() -> None:
"residual and ~500-token stories, that's ~500 vectors per story. "
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
)
ap.add_argument(
"--subspace-eigen-k",
type=int,
default=5,
help="Number of top eigenvectors of M_pos - M_base to combine into "
"the concept direction. Weighted sum by eigenvalue (so strongest "
"common directions contribute most). eigen_k=1 recovers "
"single-eigenvector behavior. Higher values (5-10) capture "
"richer structure when the concept's shared-subspace spectrum "
"is flat (which it tends to be in practice).",
)
ap.add_argument(
"--quality-report",
action="store_true",
@ -1045,6 +1067,7 @@ def main() -> None:
base_V += [bs[target_l] for bs in (base_subspaces or [])]
top_vec, eigvals = _subspace_concept_direction(
pos_V, base_V, hidden=hidden_dim,
top_k=args.subspace_eigen_k,
)
per_layer_vectors[l_idx, e_idx] = top_vec
# Keep the top-20 eigenvalues for quality-report diagnostics.