diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 6e49e2a..3de0877 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -284,6 +284,7 @@ def _subspace_concept_direction( hidden: int, *, top_k: int = 5, + device: torch.device | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Subspace-common-direction CAA alternative. @@ -299,7 +300,8 @@ def _subspace_concept_direction( practice suggests is the common case. Selection happens AFTER the eigendecomposition so nothing is lost up to that point. """ - device = pos_V[0].device if pos_V else torch.device("cpu") + if device is None: + device = pos_V[0].device if pos_V else torch.device("cpu") dtype = torch.float32 def acc(Vs: list[torch.Tensor]) -> torch.Tensor: @@ -1068,7 +1070,10 @@ def main() -> None: top_vec, eigvals = _subspace_concept_direction( pos_V, base_V, hidden=hidden_dim, top_k=args.subspace_eigen_k, + device=device, ) + top_vec = top_vec.cpu() + eigvals = eigvals.cpu() per_layer_vectors[l_idx, e_idx] = top_vec # Keep the top-20 eigenvalues for quality-report diagnostics. subspace_eigvals.setdefault(emotion, {})[target_l] = (