diff --git a/training/amygdala_training/train_with_library.py b/training/amygdala_training/train_with_library.py index 224eb3d..23633eb 100644 --- a/training/amygdala_training/train_with_library.py +++ b/training/amygdala_training/train_with_library.py @@ -84,16 +84,31 @@ def _samples_for_concept( return samples +def _fp32_wrap(inner): + """Wrap an aggregator so activations are cast to fp32 first. + + torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU, + and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors. + """ + + def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor: + return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32)) + + return wrapped + + def _aggregator_from_name(name: str): if name == "mean": - return mean_aggregator() + return _fp32_wrap(mean_aggregator()) if name == "pca": - return pca_aggregator() + return _fp32_wrap(pca_aggregator()) if name == "logistic": - return logistic_aggregator() + return _fp32_wrap(logistic_aggregator()) if name == "logistic_l1": - return logistic_aggregator( - sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1} + return _fp32_wrap( + logistic_aggregator( + sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1} + ) ) raise ValueError(f"unknown aggregator: {name}")