forked from kent/consciousness
amygdala lib: cast activations to fp32 before aggregator (bf16 svd unsupported)
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
7f6d94417e
commit
22704a9dd8
1 changed files with 20 additions and 5 deletions
|
|
@ -84,16 +84,31 @@ def _samples_for_concept(
|
|||
return samples
|
||||
|
||||
|
||||
def _fp32_wrap(inner):
|
||||
"""Wrap an aggregator so activations are cast to fp32 first.
|
||||
|
||||
torch.svd / torch.linalg.svd don't support bf16 on either CUDA or CPU,
|
||||
and Qwen3.5 runs in bf16. Cast before the aggregator sees the tensors.
|
||||
"""
|
||||
|
||||
def wrapped(pos_acts: torch.Tensor, neg_acts: torch.Tensor) -> torch.Tensor:
|
||||
return inner(pos_acts.to(torch.float32), neg_acts.to(torch.float32))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _aggregator_from_name(name: str):
|
||||
if name == "mean":
|
||||
return mean_aggregator()
|
||||
return _fp32_wrap(mean_aggregator())
|
||||
if name == "pca":
|
||||
return pca_aggregator()
|
||||
return _fp32_wrap(pca_aggregator())
|
||||
if name == "logistic":
|
||||
return logistic_aggregator()
|
||||
return _fp32_wrap(logistic_aggregator())
|
||||
if name == "logistic_l1":
|
||||
return logistic_aggregator(
|
||||
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
|
||||
return _fp32_wrap(
|
||||
logistic_aggregator(
|
||||
sklearn_kwargs={"penalty": "l1", "solver": "liblinear", "C": 0.1}
|
||||
)
|
||||
)
|
||||
raise ValueError(f"unknown aggregator: {name}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue