diff --git a/training/amygdala_training/train_with_library.py b/training/amygdala_training/train_with_library.py index 52506d0..224eb3d 100644 --- a/training/amygdala_training/train_with_library.py +++ b/training/amygdala_training/train_with_library.py @@ -185,6 +185,7 @@ def main() -> None: aggregator=aggregator, batch_size=args.batch_size, show_progress=False, + move_to_cpu=True, ) # sv.layer_activations is a dict {layer_idx: tensor[hidden]} for l_idx, layer in enumerate(target_layers):