diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 5253186..33244c8 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -464,13 +464,14 @@ def _compute_quality_report( report: dict = {} n_layers = per_layer_vectors.shape[0] - # Pre-compute per-layer W_down for single-neuron alignment. + # Pre-compute per-layer W_down for single-neuron alignment. Keep on + # CPU to match the per_layer_vectors tensor. w_down: dict[int, torch.Tensor] = {} for target_l in target_layers: w = _find_mlp_down_proj(model, target_l) if w is not None: # Unit-normalize each column (one per MLP neuron). - w = w.to(torch.float32) + w = w.to(torch.float32).cpu() norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6) w_down[target_l] = w / norms # [hidden, mlp_inner]