amygdala lib: cast activations to fp32 before aggregator (bf16 svd unsupported)
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
7f6d94417e
commit
22704a9dd8
1 changed files with 20 additions and 5 deletions
|
|
@ -84,16 +84,31 @@ def _samples_for_concept(
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _fp32_wrap(inner):
|
||||||
|
"""Wrap an aggregator so activations are cast to fp32 first.
|
||||||
|
|
||||||
|
torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU,
|
||||||
|
and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor:
|
||||||
|
return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32))
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def _aggregator_from_name(name: str):
|
def _aggregator_from_name(name: str):
|
||||||
if name == "mean":
|
if name == "mean":
|
||||||
return mean_aggregator()
|
return _fp32_wrap(mean_aggregator())
|
||||||
if name == "pca":
|
if name == "pca":
|
||||||
return pca_aggregator()
|
return _fp32_wrap(pca_aggregator())
|
||||||
if name == "logistic":
|
if name == "logistic":
|
||||||
return logistic_aggregator()
|
return _fp32_wrap(logistic_aggregator())
|
||||||
if name == "logistic_l1":
|
if name == "logistic_l1":
|
||||||
return logistic_aggregator(
|
return _fp32_wrap(
|
||||||
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
|
logistic_aggregator(
|
||||||
|
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
raise ValueError(f"unknown aggregator: {name}")
|
raise ValueError(f"unknown aggregator: {name}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue