amygdala: run subspace eigh on GPU, not CPU
Previous run was grinding on CPU for 36+ minutes because the per-story V_i tensors were stored on CPU by the collector, and _subspace_concept_direction inherited that device. The per-concept eigh on 5120x5120 is glacial on CPU and fast on GPU (~1s). Add explicit device parameter; pass training device. Transfer result back to CPU for storage. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
1443d08dc7
commit
f9b3f00691
1 changed files with 6 additions and 1 deletions
|
|
@ -284,6 +284,7 @@ def _subspace_concept_direction(
|
||||||
hidden: int,
|
hidden: int,
|
||||||
*,
|
*,
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
|
device: torch.device | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Subspace-common-direction CAA alternative.
|
"""Subspace-common-direction CAA alternative.
|
||||||
|
|
||||||
|
|
@ -299,7 +300,8 @@ def _subspace_concept_direction(
|
||||||
practice suggests is the common case. Selection happens AFTER the
|
practice suggests is the common case. Selection happens AFTER the
|
||||||
eigendecomposition so nothing is lost up to that point.
|
eigendecomposition so nothing is lost up to that point.
|
||||||
"""
|
"""
|
||||||
device = pos_V[0].device if pos_V else torch.device("cpu")
|
if device is None:
|
||||||
|
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
|
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
@ -1068,7 +1070,10 @@ def main() -> None:
|
||||||
top_vec, eigvals = _subspace_concept_direction(
|
top_vec, eigvals = _subspace_concept_direction(
|
||||||
pos_V, base_V, hidden=hidden_dim,
|
pos_V, base_V, hidden=hidden_dim,
|
||||||
top_k=args.subspace_eigen_k,
|
top_k=args.subspace_eigen_k,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
top_vec = top_vec.cpu()
|
||||||
|
eigvals = eigvals.cpu()
|
||||||
per_layer_vectors[l_idx, e_idx] = top_vec
|
per_layer_vectors[l_idx, e_idx] = top_vec
|
||||||
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
# Keep the top-20 eigenvalues for quality-report diagnostics.
|
||||||
subspace_eigvals.setdefault(emotion, {})[target_l] = (
|
subspace_eigvals.setdefault(emotion, {})[target_l] = (
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue