amygdala lib: cast activations to fp32 before aggregator (bf16 svd unsupported)

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 22:20:39 -04:00
parent 7f6d94417e
commit 22704a9dd8

View file

@ -84,17 +84,32 @@ def _samples_for_concept(
return samples return samples
def _fp32_wrap(inner):
"""Wrap an aggregator so activations are cast to fp32 first.
torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU,
and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors.
"""
def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor:
return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32))
return wrapped
def _aggregator_from_name(name: str): def _aggregator_from_name(name: str):
if name == "mean": if name == "mean":
return mean_aggregator() return _fp32_wrap(mean_aggregator())
if name == "pca": if name == "pca":
return pca_aggregator() return _fp32_wrap(pca_aggregator())
if name == "logistic": if name == "logistic":
return logistic_aggregator() return _fp32_wrap(logistic_aggregator())
if name == "logistic_l1": if name == "logistic_l1":
return logistic_aggregator( return _fp32_wrap(
logistic_aggregator(
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1} sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
) )
)
raise ValueError(f"unknown aggregator: {name}") raise ValueError(f"unknown aggregator: {name}")