amygdala: select top-k eigenvectors AFTER PCA, not per-story truncation
Kent: 'full rank is going to give you everything — you still have to select down, but you can do that /after/ PCA'. Previously I was discarding per-story via k=20 truncation of SVD. That destroyed per-head discriminability before we ever saw the eigenvalue spectrum. Then the alternative 'keep full rank' run accumulated too many shared directions, making the top-1 eigenvector arbitrary within a flat spectrum. Correct approach: keep per-story subspaces at full rank (no info loss) and select k eigenvectors of M = M_pos - M_base at the final step, weighted sum by eigenvalue. This captures the multi-dimensional shared subspace when the spectrum is flat (common case), and reduces to the top-1 behavior when the spectrum has a clear gap. New --subspace-eigen-k flag (default 5). Clamps negative weights to 0 so wrong-sign directions don't contribute. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
2411925700
commit
1443d08dc7
1 changed files with 34 additions and 11 deletions
|
|
@ -282,16 +282,22 @@ def _subspace_concept_direction(
|
||||||
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
||||||
base_V: list[torch.Tensor],
|
base_V: list[torch.Tensor],
|
||||||
hidden: int,
|
hidden: int,
|
||||||
|
*,
|
||||||
|
top_k: int = 5,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Subspace-common-direction CAA alternative.
|
"""Subspace-common-direction CAA alternative.
|
||||||
|
|
||||||
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
||||||
same over baselines. Returns the top eigenvector of (M_pos - M_base) —
|
same over baselines. Returns a weighted sum of the top-k eigenvectors of
|
||||||
the direction most-common to positives after subtracting what's generic
|
(M_pos - M_base), weights = eigenvalues (so stronger common directions
|
||||||
across baselines — plus its eigenvalue spectrum (for diagnostics).
|
contribute more), unit-normed. Returns the full eigenvalue spectrum for
|
||||||
|
diagnostics.
|
||||||
|
|
||||||
The top eigenvalue approaches 1 if the concept appears in every positive
|
top_k=1 recovers the previous behavior (top eigenvector only). top_k>1
|
||||||
story's subspace with unit weight and is absent from the baseline.
|
captures richer structure when the concept lives in a multi-dimensional
|
||||||
|
shared subspace — which the flat eigenvalue spectrum observed in
|
||||||
|
practice suggests is the common case. Selection happens AFTER the
|
||||||
|
eigendecomposition so nothing is lost up to that point.
|
||||||
"""
|
"""
|
||||||
device = pos_V[0].device if pos_V else torch.device("cpu")
|
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
@ -310,13 +316,18 @@ def _subspace_concept_direction(
|
||||||
M_base = acc(base_V)
|
M_base = acc(base_V)
|
||||||
M = M_pos - M_base
|
M = M_pos - M_base
|
||||||
|
|
||||||
# Symmetric eigendecomposition — top eigenvalue/vector.
|
# Symmetric eigendecomposition.
|
||||||
eigvals, eigvecs = torch.linalg.eigh(M)
|
eigvals, eigvecs = torch.linalg.eigh(M)
|
||||||
# eigh returns ascending; top is the last column.
|
# eigh returns ascending; top-k are the last k columns.
|
||||||
top_vec = eigvecs[:, -1]
|
k = max(1, min(top_k, eigvecs.shape[1]))
|
||||||
# Unit-norm (eigvecs are unit already, but defensively).
|
top_vals = eigvals[-k:] # [k], ascending within top-k
|
||||||
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
|
top_vecs = eigvecs[:, -k:] # [hidden, k]
|
||||||
return top_vec, eigvals
|
# Weighted sum of top-k eigenvectors, weights = eigenvalues. Clamp
|
||||||
|
# negative weights to 0 (wrong-sign directions shouldn't contribute).
|
||||||
|
w = top_vals.clamp_min(0.0)
|
||||||
|
combined = top_vecs @ w # [hidden]
|
||||||
|
combined = combined / combined.norm().clamp_min(1e-6)
|
||||||
|
return combined, eigvals
|
||||||
|
|
||||||
|
|
||||||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
|
|
@ -858,6 +869,17 @@ def main() -> None:
|
||||||
"residual and ~500-token stories, that's ~500 vectors per story. "
|
"residual and ~500-token stories, that's ~500 vectors per story. "
|
||||||
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
|
"Memory is fine: 112 × 5120 × 500 × 4 bytes ≈ 1.1 GB.",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--subspace-eigen-k",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of top eigenvectors of M_pos - M_base to combine into "
|
||||||
|
"the concept direction. Weighted sum by eigenvalue (so strongest "
|
||||||
|
"common directions contribute most). eigen_k=1 recovers "
|
||||||
|
"single-eigenvector behavior. Higher values (5-10) capture "
|
||||||
|
"richer structure when the concept's shared-subspace spectrum "
|
||||||
|
"is flat (which it tends to be in practice).",
|
||||||
|
)
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"--quality-report",
|
"--quality-report",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -1045,6 +1067,7 @@ def main() -> None:
|
||||||
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
||||||
top_vec, eigvals = _subspace_concept_direction(
|
top_vec, eigvals = _subspace_concept_direction(
|
||||||
pos_V, base_V, hidden=hidden_dim,
|
pos_V, base_V, hidden=hidden_dim,
|
||||||
|
top_k=args.subspace_eigen_k,
|
||||||
)
|
)
|
||||||
per_layer_vectors[l_idx, e_idx] = top_vec
|
per_layer_vectors[l_idx, e_idx] = top_vec
|
||||||
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue