From 1443d08dc77edbce8b8a46fe181bffbeff5a09b4 Mon Sep 17 00:00:00 2001 From: Kent Overstreet Date: Sat, 18 Apr 2026 21:49:21 -0400 Subject: [PATCH] amygdala: select top-k eigenvectors AFTER PCA, not per-story truncation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Kent: 'full rank is going to give you everything — you still have to select down, but you can do that /after/ PCA'. Previously I was discarding per-story via k=20 truncation of SVD. That destroyed per-head discriminability before we ever saw the eigenvalue spectrum. Then the alternative 'keep full rank' run accumulated too many shared directions, making the top-1 eigenvector arbitrary within a flat spectrum. Correct approach: keep per-story subspaces at full rank (no info loss) and select k eigenvectors of M = M_pos - M_base at the final step, weighted sum by eigenvalue. This captures the multi-dimensional shared subspace when the spectrum is flat (common case), and reduces to the top-1 behavior when the spectrum has a clear gap. New --subspace-eigen-k flag (default 5). Clamps negative weights to 0 so wrong-sign directions don't contribute. Co-Authored-By: Proof of Concept --- .../train_steering_vectors.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 353ebb0..6e49e2a 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -282,16 +282,22 @@ def _subspace_concept_direction( pos_V: list[torch.Tensor], # list of [hidden, k_i] per story base_V: list[torch.Tensor], hidden: int, + *, + top_k: int = 5, ) -> tuple[torch.Tensor, torch.Tensor]: """Subspace-common-direction CAA alternative. Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the - same over baselines. Returns the top eigenvector of (M_pos - M_base) — - the direction most-common to positives after subtracting what's generic - across baselines — plus its eigenvalue spectrum (for diagnostics). + same over baselines. Returns a weighted sum of the top-k eigenvectors of + (M_pos - M_base), weights = eigenvalues (so stronger common directions + contribute more), unit-normed. Returns the full eigenvalue spectrum for + diagnostics. - The top eigenvalue approaches 1 if the concept appears in every positive - story's subspace with unit weight and is absent from the baseline. + top_k=1 recovers the previous behavior (top eigenvector only). top_k>1 + captures richer structure when the concept lives in a multi-dimensional + shared subspace — which the flat eigenvalue spectrum observed in + practice suggests is the common case. Selection happens AFTER the + eigendecomposition so nothing is lost up to that point. """ device = pos_V[0].device if pos_V else torch.device("cpu") dtype = torch.float32 @@ -310,13 +316,18 @@ def _subspace_concept_direction( M_base = acc(base_V) M = M_pos - M_base - # Symmetric eigendecomposition — top eigenvalue/vector. + # Symmetric eigendecomposition. eigvals, eigvecs = torch.linalg.eigh(M) - # eigh returns ascending; top is the last column. - top_vec = eigvecs[:, -1] - # Unit-norm (eigvecs are unit already, but defensively). - top_vec = top_vec / top_vec.norm().clamp_min(1e-6) - return top_vec, eigvals + # eigh returns ascending; top-k are the last k columns. + k = max(1, min(top_k, eigvecs.shape[1])) + top_vals = eigvals[-k:] # [k], ascending within top-k + top_vecs = eigvecs[:, -k:] # [hidden, k] + # Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp + # negative weights to 0 (wrong-sign directions shouldn't contribute). + w = top_vals.clamp_min(0.0) + combined = top_vecs @ w # [hidden] + combined = combined / combined.norm().clamp_min(1e-6) + return combined, eigvals def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ @@ -858,6 +869,17 @@ def main() -> None: "residual and ~500-token stories, that's ~500 vectors per story. " "Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.", ) + ap.add_argument( + "--subspace-eigen-k", + type=int, + default=5, + help="Number of top eigenvectors of M_pos - M_base to combine into " + "the concept direction. Weighted sum by eigenvalue (so strongest " + "common directions contribute most). eigen_k=1 recovers " + "single-eigenvector behavior. Higher values (5-10) capture " + "richer structure when the concept's shared-subspace spectrum " + "is flat (which it tends to be in practice).", + ) ap.add_argument( "--quality-report", action="store_true", @@ -1045,6 +1067,7 @@ def main() -> None: base_V += [bs[target_l] for bs in (base_subspaces or [])] top_vec, eigvals = _subspace_concept_direction( pos_V, base_V, hidden=hidden_dim, + top_k=args.subspace_eigen_k, ) per_layer_vectors[l_idx, e_idx] = top_vec # Keep the top-20 eigenvalues for quality-report diagnostics.