amygdala: subspace-common-direction alternative to pooled CAA
New --method subspace flag. For each story, run forward pass, do SVD
on the per-token activation matrix at each target layer, and keep the
top-k right singular vectors V_i ∈ [hidden, k]. V_i is the subspace
the story's tokens span in activation space — it contains concept,
narrator, topic, style as separate directions.
For each concept:
M_pos = (1/n_pos) Σ_{i in pos} V_i V_i^T [hidden, hidden]
M_base = (1/n_base) Σ_{i in base} V_i V_i^T
Top eigenvector of M_pos - M_base = direction most common across
positive stories, minus what's common across the contrast set.
Why this is richer than pooled-mean CAA: pooled reduces each story
to a single point (the last-token activation) and loses the full
trajectory. Nuisance directions (narrator, setting) cancel in the
mean only to the extent they differ at the last token; across the
full trajectory they cancel much better via subspace intersection.
The concept direction, by contrast, is present across all tokens of
every concept-bearing story.
Memory cost: per-story we keep V_i of size [5120, k=20] — about
400KB per story × 112 stories = ~45MB. M matrices are [5120, 5120]
built transiently per concept.
--method pooled (default) keeps the existing behavior; --method
subspace uses the new algorithm. Quality report works with either.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
71f6053851
commit
fe0fb8253a
1 changed files with 216 additions and 12 deletions
|
|
@ -166,6 +166,159 @@ def _collect_activations(
|
|||
return torch.cat(out_rows, dim=0)
|
||||
|
||||
|
||||
def _collect_per_story_subspaces(
|
||||
model,
|
||||
tokenizer,
|
||||
texts: list[str],
|
||||
target_layers: list[int],
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
*,
|
||||
k: int = 20,
|
||||
label: str = "",
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Run texts through the model, capture the full per-token residual-stream
|
||||
activations at each target layer, do SVD per story, return the top-k right
|
||||
singular vectors.
|
||||
|
||||
Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to
|
||||
a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the
|
||||
subspace the story's tokens span in activation space at that layer).
|
||||
|
||||
The per-story subspace captures *all* the directions a story occupies —
|
||||
concept, narrator, topic, style. Finding the direction common to stories of
|
||||
the same concept (via the sum of V_i V_i^T and its top eigenvector)
|
||||
cancels nuisance directions that differ across stories while preserving
|
||||
directions they share.
|
||||
"""
|
||||
import time
|
||||
|
||||
assert all(isinstance(t, str) and t for t in texts), (
|
||||
f"_collect_per_story_subspaces: empty or non-string text in {label!r}"
|
||||
)
|
||||
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
|
||||
def make_hook(idx: int):
|
||||
def hook(_mod, _inp, output):
|
||||
hs = output[0] if isinstance(output, tuple) else output
|
||||
captures[idx] = hs.detach()
|
||||
return hook
|
||||
|
||||
layers_module = _find_layers_module(model)
|
||||
handles = [
|
||||
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||
for idx in target_layers
|
||||
]
|
||||
|
||||
# One entry per text: {layer_idx: V[hidden, k]}
|
||||
out: list[dict[int, torch.Tensor]] = [
|
||||
{} for _ in range(len(texts))
|
||||
]
|
||||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
start = time.time()
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
).to(device)
|
||||
captures.clear()
|
||||
model(**tok)
|
||||
|
||||
# For each item in the batch, for each layer, SVD on the
|
||||
# non-pad tokens.
|
||||
attn = tok["attention_mask"]
|
||||
for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()):
|
||||
story_idx = i + t_idx_in_batch
|
||||
for l_idx, layer in enumerate(target_layers):
|
||||
hs = captures[layer][t_idx_in_batch, :n_tok, :]
|
||||
# Center tokens so SVD captures variation within story,
|
||||
# not the story's center-of-mass:
|
||||
hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0)
|
||||
# SVD: hs = U Σ V^T; V has hidden-dim columns.
|
||||
# For n_tok < k, the subspace rank is bounded by n_tok.
|
||||
try:
|
||||
_u, _s, vh = torch.linalg.svd(hs, full_matrices=False)
|
||||
except Exception:
|
||||
# Degenerate case (all-zero hs, n_tok=1): fall back
|
||||
# to the last-token vector itself, unit-normed.
|
||||
vec = captures[layer][t_idx_in_batch, n_tok - 1, :]
|
||||
vec = vec.to(torch.float32)
|
||||
nrm = vec.norm().clamp_min(1e-6)
|
||||
vh = (vec / nrm).unsqueeze(0) # [1, hidden]
|
||||
# Take top-k rows of V^T (= top-k right singular vecs).
|
||||
top = min(k, vh.shape[0])
|
||||
V = vh[:top].t().contiguous().cpu() # [hidden, top]
|
||||
out[story_idx][layer] = V
|
||||
del tok, captures
|
||||
if b_idx % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||
elapsed = time.time() - start
|
||||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||
print(
|
||||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||
flush=True,
|
||||
)
|
||||
captures = {}
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _subspace_concept_direction(
|
||||
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
||||
base_V: list[torch.Tensor],
|
||||
hidden: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Subspace-common-direction CAA alternative.
|
||||
|
||||
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
||||
same over baselines. Returns the top eigenvector of (M_pos - M_base) —
|
||||
the direction most-common to positives after subtracting what's generic
|
||||
across baselines — plus its eigenvalue spectrum (for diagnostics).
|
||||
|
||||
The top eigenvalue approaches 1 if the concept appears in every positive
|
||||
story's subspace with unit weight and is absent from the baseline.
|
||||
"""
|
||||
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
|
||||
if not Vs:
|
||||
return torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||||
M = torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||||
for V in Vs:
|
||||
V = V.to(dtype=dtype, device=device)
|
||||
M.addmm_(V, V.t())
|
||||
M /= len(Vs)
|
||||
return M
|
||||
|
||||
M_pos = acc(pos_V)
|
||||
M_base = acc(base_V)
|
||||
M = M_pos - M_base
|
||||
|
||||
# Symmetric eigendecomposition — top eigenvalue/vector.
|
||||
eigvals, eigvecs = torch.linalg.eigh(M)
|
||||
# eigh returns ascending; top is the last column.
|
||||
top_vec = eigvecs[:, -1]
|
||||
# Unit-norm (eigvecs are unit already, but defensively).
|
||||
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
|
||||
return top_vec, eigvals
|
||||
|
||||
|
||||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
||||
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
||||
|
|
@ -684,6 +837,22 @@ def main() -> None:
|
|||
default=1,
|
||||
help="Skip emotions with fewer positive examples than this",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--method",
|
||||
default="pooled",
|
||||
choices=["pooled", "subspace"],
|
||||
help="Concept-extraction method: 'pooled' (classic CAA, "
|
||||
"pos_mean - neg_mean on last-token activations) or 'subspace' "
|
||||
"(per-story SVD; top eigenvector of Σ V_i V_i^T for positives "
|
||||
"minus same for baselines — captures what's common across "
|
||||
"stories' full-trajectory subspaces)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--subspace-k",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Top-k right singular vectors per story for subspace method",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--quality-report",
|
||||
action="store_true",
|
||||
|
|
@ -828,6 +997,27 @@ def main() -> None:
|
|||
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
||||
)
|
||||
|
||||
# --- Subspace method: collect per-story right-singular-vector subspaces
|
||||
# and use sum-of-projection-operators per concept. --------------------
|
||||
pos_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||||
base_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||||
if args.method == "subspace":
|
||||
print("\nCollecting per-story subspaces (SVD, top-k right singular "
|
||||
f"vectors, k={args.subspace_k})...")
|
||||
pos_subspaces = _collect_per_story_subspaces(
|
||||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||
args.batch_size, args.max_length, k=args.subspace_k,
|
||||
label="subsp-pos",
|
||||
)
|
||||
if baselines:
|
||||
base_subspaces = _collect_per_story_subspaces(
|
||||
model, tokenizer, baselines, target_layers, device,
|
||||
args.batch_size, args.max_length, k=args.subspace_k,
|
||||
label="subsp-base",
|
||||
)
|
||||
else:
|
||||
base_subspaces = []
|
||||
|
||||
for e_idx, emotion in enumerate(emotions):
|
||||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||
# Negatives: every OTHER emotion's positives + baselines.
|
||||
|
|
@ -837,6 +1027,19 @@ def main() -> None:
|
|||
if text_to_emotion[t] != emotion
|
||||
]
|
||||
|
||||
if args.method == "subspace":
|
||||
# For each layer, build M_pos = Σ V V^T / n_pos, baseline same
|
||||
# (using all other concepts' positive subspaces + baseline
|
||||
# subspaces as the contrast set), top eigenvector of difference.
|
||||
for l_idx, target_l in enumerate(target_layers):
|
||||
pos_V = [pos_subspaces[j][target_l] for j in pos_rows]
|
||||
base_V = [pos_subspaces[j][target_l] for j in neg_rows]
|
||||
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
||||
top_vec, _eigvals = _subspace_concept_direction(
|
||||
pos_V, base_V, hidden=hidden_dim,
|
||||
)
|
||||
per_layer_vectors[l_idx, e_idx] = top_vec
|
||||
else:
|
||||
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
||||
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
||||
if baseline_acts.shape[0] > 0:
|
||||
|
|
@ -856,6 +1059,7 @@ def main() -> None:
|
|||
print(
|
||||
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
||||
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
||||
f" (method={args.method})"
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue