amygdala: subspace-common-direction alternative to pooled CAA

New --method subspace flag. For each story, run forward pass, do SVD
on the per-token activation matrix at each target layer, and keep the
top-k right singular vectors V_i ∈ [hidden, k]. V_i is the subspace
the story's tokens span in activation space — it contains concept,
narrator, topic, style as separate directions.

For each concept:
 M_pos  = (1/n_pos)  Σ_{i in pos}   V_i V_i^T   [hidden, hidden]
 M_base = (1/n_base) Σ_{i in base}  V_i V_i^T

Top eigenvector of M_pos - M_base = direction most common across
positive stories, minus what's common across the contrast set.

Why this is richer than pooled-mean CAA: pooled reduces each story
to a single point (the last-token activation) and loses the full
trajectory. Nuisance directions (narrator, setting) cancel in the
mean only to the extent they differ at the last token; across the
full trajectory they cancel much better via subspace intersection.
The concept direction, by contrast, is present across all tokens of
every concept-bearing story.

Memory cost: per-story we keep V_i of size [5120, k=20] — about
400KB per story × 112 stories = ~45MB. M matrices are [5120, 5120]
built transiently per concept.

--method pooled (default) keeps the existing behavior; --method
subspace uses the new algorithm. Quality report works with either.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 21:24:11 -04:00
parent 71f6053851
commit fe0fb8253a

View file

@ -166,6 +166,159 @@ def _collect_activations(
return torch.cat(out_rows, dim=0)
def _collect_per_story_subspaces(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
k: int = 20,
label: str = "",
) -> list[dict[int, torch.Tensor]]:
"""Run texts through the model, capture the full per-token residual-stream
activations at each target layer, do SVD per story, return the top-k right
singular vectors.
Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to
a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the
subspace the story's tokens span in activation space at that layer).
The per-story subspace captures *all* the directions a story occupies
concept, narrator, topic, style. Finding the direction common to stories of
the same concept (via the sum of V_i V_i^T and its top eigenvector)
cancels nuisance directions that differ across stories while preserving
directions they share.
"""
import time
assert all(isinstance(t, str) and t for t in texts), (
f"_collect_per_story_subspaces: empty or non-string text in {label!r}"
)
captures: dict[int, torch.Tensor] = {}
def make_hook(idx: int):
def hook(_mod, _inp, output):
hs = output[0] if isinstance(output, tuple) else output
captures[idx] = hs.detach()
return hook
layers_module = _find_layers_module(model)
handles = [
layers_module[idx].register_forward_hook(make_hook(idx))
for idx in target_layers
]
# One entry per text: {layer_idx: V[hidden, k]}
out: list[dict[int, torch.Tensor]] = [
{} for _ in range(len(texts))
]
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
# For each item in the batch, for each layer, SVD on the
# non-pad tokens.
attn = tok["attention_mask"]
for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()):
story_idx = i + t_idx_in_batch
for l_idx, layer in enumerate(target_layers):
hs = captures[layer][t_idx_in_batch, :n_tok, :]
# Center tokens so SVD captures variation within story,
# not the story's center-of-mass:
hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0)
# SVD: hs = U Σ V^T; V has hidden-dim columns.
# For n_tok < k, the subspace rank is bounded by n_tok.
try:
_u, _s, vh = torch.linalg.svd(hs, full_matrices=False)
except Exception:
# Degenerate case (all-zero hs, n_tok=1): fall back
# to the last-token vector itself, unit-normed.
vec = captures[layer][t_idx_in_batch, n_tok - 1, :]
vec = vec.to(torch.float32)
nrm = vec.norm().clamp_min(1e-6)
vh = (vec / nrm).unsqueeze(0) # [1, hidden]
# Take top-k rows of V^T (= top-k right singular vecs).
top = min(k, vh.shape[0])
V = vh[:top].t().contiguous().cpu() # [hidden, top]
out[story_idx][layer] = V
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return out
def _subspace_concept_direction(
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
base_V: list[torch.Tensor],
hidden: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Subspace-common-direction CAA alternative.
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
same over baselines. Returns the top eigenvector of (M_pos - M_base)
the direction most-common to positives after subtracting what's generic
across baselines plus its eigenvalue spectrum (for diagnostics).
The top eigenvalue approaches 1 if the concept appears in every positive
story's subspace with unit weight and is absent from the baseline.
"""
device = pos_V[0].device if pos_V else torch.device("cpu")
dtype = torch.float32
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
if not Vs:
return torch.zeros(hidden, hidden, dtype=dtype, device=device)
M = torch.zeros(hidden, hidden, dtype=dtype, device=device)
for V in Vs:
V = V.to(dtype=dtype, device=device)
M.addmm_(V, V.t())
M /= len(Vs)
return M
M_pos = acc(pos_V)
M_base = acc(base_V)
M = M_pos - M_base
# Symmetric eigendecomposition — top eigenvalue/vector.
eigvals, eigvecs = torch.linalg.eigh(M)
# eigh returns ascending; top is the last column.
top_vec = eigvecs[:, -1]
# Unit-norm (eigvecs are unit already, but defensively).
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
return top_vec, eigvals
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
@ -684,6 +837,22 @@ def main() -> None:
default=1,
help="Skip emotions with fewer positive examples than this",
)
ap.add_argument(
"--method",
default="pooled",
choices=["pooled", "subspace"],
help="Concept-extraction method: 'pooled' (classic CAA, "
"pos_mean - neg_mean on last-token activations) or 'subspace' "
"(per-story SVD; top eigenvector of Σ V_i V_i^T for positives "
"minus same for baselines — captures what's common across "
"stories' full-trajectory subspaces)",
)
ap.add_argument(
"--subspace-k",
type=int,
default=20,
help="Top-k right singular vectors per story for subspace method",
)
ap.add_argument(
"--quality-report",
action="store_true",
@ -828,6 +997,27 @@ def main() -> None:
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
)
# --- Subspace method: collect per-story right-singular-vector subspaces
# and use sum-of-projection-operators per concept. --------------------
pos_subspaces: list[dict[int, torch.Tensor]] | None = None
base_subspaces: list[dict[int, torch.Tensor]] | None = None
if args.method == "subspace":
print("\nCollecting per-story subspaces (SVD, top-k right singular "
f"vectors, k={args.subspace_k})...")
pos_subspaces = _collect_per_story_subspaces(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, k=args.subspace_k,
label="subsp-pos",
)
if baselines:
base_subspaces = _collect_per_story_subspaces(
model, tokenizer, baselines, target_layers, device,
args.batch_size, args.max_length, k=args.subspace_k,
label="subsp-base",
)
else:
base_subspaces = []
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
# Negatives: every OTHER emotion's positives + baselines.
@ -837,25 +1027,39 @@ def main() -> None:
if text_to_emotion[t] != emotion
]
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
if baseline_acts.shape[0] > 0:
neg = torch.cat([neg, baseline_acts], dim=0)
if args.method == "subspace":
# For each layer, build M_pos = Σ V V^T / n_pos, baseline same
# (using all other concepts' positive subspaces + baseline
# subspaces as the contrast set), top eigenvector of difference.
for l_idx, target_l in enumerate(target_layers):
pos_V = [pos_subspaces[j][target_l] for j in pos_rows]
base_V = [pos_subspaces[j][target_l] for j in neg_rows]
base_V += [bs[target_l] for bs in (base_subspaces or [])]
top_vec, _eigvals = _subspace_concept_direction(
pos_V, base_V, hidden=hidden_dim,
)
per_layer_vectors[l_idx, e_idx] = top_vec
else:
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
if baseline_acts.shape[0] > 0:
neg = torch.cat([neg, baseline_acts], dim=0)
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
neg_mean = neg.mean(dim=0)
diff = pos_mean - neg_mean
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
neg_mean = neg.mean(dim=0)
diff = pos_mean - neg_mean
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
diff = diff / norms
# diff[layer] -> per_layer_vectors[layer, e_idx]
for l_idx in range(n_layers):
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
# diff[layer] -> per_layer_vectors[layer, e_idx]
for l_idx in range(n_layers):
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
if e_idx < 5 or e_idx == len(emotions) - 1:
print(
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
f" (method={args.method})"
)
output_dir = Path(args.output_dir)