amygdala: run subspace eigh on GPU, not CPU

Previous run was grinding on CPU for 36+ minutes because the per-story
V_i tensors were stored on CPU by the collector, and
_subspace_concept_direction inherited that device. The per-concept
eigh on 5120x5120 is glacial on CPU and fast on GPU (~1s).

Add explicit device parameter; pass training device. Transfer result
back to CPU for storage.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 21:52:35 -04:00
parent 1443d08dc7
commit f9b3f00691

View file

@ -284,6 +284,7 @@ def _subspace_concept_direction(
hidden: int, hidden: int,
*, *,
top_k: int = 5, top_k: int = 5,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Subspace-common-direction CAA alternative. """Subspace-common-direction CAA alternative.
@ -299,6 +300,7 @@ def _subspace_concept_direction(
practice suggests is the common case. Selection happens AFTER the practice suggests is the common case. Selection happens AFTER the
eigendecomposition so nothing is lost up to that point. eigendecomposition so nothing is lost up to that point.
""" """
if device is None:
device = pos_V[0].device if pos_V else torch.device("cpu") device = pos_V[0].device if pos_V else torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
@ -1068,7 +1070,10 @@ def main() -> None:
top_vec, eigvals = _subspace_concept_direction( top_vec, eigvals = _subspace_concept_direction(
pos_V, base_V, hidden=hidden_dim, pos_V, base_V, hidden=hidden_dim,
top_k=args.subspace_eigen_k, top_k=args.subspace_eigen_k,
device=device,
) )
top_vec = top_vec.cpu()
eigvals = eigvals.cpu()
per_layer_vectors[l_idx, e_idx] = top_vec per_layer_vectors[l_idx, e_idx] = top_vec
# Keep the top-20 eigenvalues for quality-report diagnostics. # Keep the top-20 eigenvalues for quality-report diagnostics.
subspace_eigvals.setdefault(emotion, {})[target_l] = ( subspace_eigvals.setdefault(emotion, {})[target_l] = (