amygdala: select top-k eigenvectors AFTER PCA, not per-story truncation
Kent: 'full rank is going to give you everything — you still have to select down, but you can do that /after/ PCA'. Previously I was discarding per-story via k=20 truncation of SVD. That destroyed per-head discriminability before we ever saw the eigenvalue spectrum. Then the alternative 'keep full rank' run accumulated too many shared directions, making the top-1 eigenvector arbitrary within a flat spectrum. Correct approach: keep per-story subspaces at full rank (no info loss) and select k eigenvectors of M = M_pos - M_base at the final step, weighted sum by eigenvalue. This captures the multi-dimensional shared subspace when the spectrum is flat (common case), and reduces to the top-1 behavior when the spectrum has a clear gap. New --subspace-eigen-k flag (default 5). Clamps negative weights to 0 so wrong-sign directions don't contribute. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
2411925700
commit
1443d08dc7
1 changed files with 34 additions and 11 deletions
|
|
@ -282,16 +282,22 @@ def _subspace_concept_direction(
|
|||
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
||||
base_V: list[torch.Tensor],
|
||||
hidden: int,
|
||||
*,
|
||||
top_k: int = 5,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Subspace-common-direction CAA alternative.
|
||||
|
||||
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
||||
same over baselines. Returns the top eigenvector of (M_pos - M_base) —
|
||||
the direction most-common to positives after subtracting what's generic
|
||||
across baselines — plus its eigenvalue spectrum (for diagnostics).
|
||||
same over baselines. Returns a weighted sum of the top-k eigenvectors of
|
||||
(M_pos - M_base), weights = eigenvalues (so stronger common directions
|
||||
contribute more), unit-normed. Returns the full eigenvalue spectrum for
|
||||
diagnostics.
|
||||
|
||||
The top eigenvalue approaches 1 if the concept appears in every positive
|
||||
story's subspace with unit weight and is absent from the baseline.
|
||||
top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
|
||||
captures richer structure when the concept lives in a multi-dimensional
|
||||
shared subspace — which the flat eigenvalue spectrum observed in
|
||||
practice suggests is the common case. Selection happens AFTER the
|
||||
eigendecomposition so nothing is lost up to that point.
|
||||
"""
|
||||
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
|
@ -310,13 +316,18 @@ def _subspace_concept_direction(
|
|||
M_base = acc(base_V)
|
||||
M = M_pos - M_base
|
||||
|
||||
# Symmetric eigendecomposition — top eigenvalue/vector.
|
||||
# Symmetric eigendecomposition.
|
||||
eigvals, eigvecs = torch.linalg.eigh(M)
|
||||
# eigh returns ascending; top is the last column.
|
||||
top_vec = eigvecs[:, -1]
|
||||
# Unit-norm (eigvecs are unit already, but defensively).
|
||||
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
|
||||
return top_vec, eigvals
|
||||
# eigh returns ascending; top-k are the last k columns.
|
||||
k = max(1, min(top_k, eigvecs.shape[1]))
|
||||
top_vals = eigvals[-k:] # [k], ascending within top-k
|
||||
top_vecs = eigvecs[:, -k:] # [hidden, k]
|
||||
# Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
|
||||
# negative weights to 0 (wrong-sign directions shouldn't contribute).
|
||||
w = top_vals.clamp_min(0.0)
|
||||
combined = top_vecs @ w # [hidden]
|
||||
combined = combined / combined.norm().clamp_min(1e-6)
|
||||
return combined, eigvals
|
||||
|
||||
|
||||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||
|
|
@ -858,6 +869,17 @@ def main() -> None:
|
|||
"residual and ~500-token stories, that's ~500 vectors per story. "
|
||||
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--subspace-eigen-k",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of top eigenvectors of M_pos - M_base to combine into "
|
||||
"the concept direction. Weighted sum by eigenvalue (so strongest "
|
||||
"common directions contribute most). eigen_k=1 recovers "
|
||||
"single-eigenvector behavior. Higher values (5-10) capture "
|
||||
"richer structure when the concept's shared-subspace spectrum "
|
||||
"is flat (which it tends to be in practice).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--quality-report",
|
||||
action="store_true",
|
||||
|
|
@ -1045,6 +1067,7 @@ def main() -> None:
|
|||
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
||||
top_vec, eigvals = _subspace_concept_direction(
|
||||
pos_V, base_V, hidden=hidden_dim,
|
||||
top_k=args.subspace_eigen_k,
|
||||
)
|
||||
per_layer_vectors[l_idx, e_idx] = top_vec
|
||||
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue