amygdala: subspace-common-direction alternative to pooled CAA
New --method subspace flag. For each story, run forward pass, do SVD
on the per-token activation matrix at each target layer, and keep the
top-k right singular vectors V_i ∈ [hidden, k]. V_i is the subspace
the story's tokens span in activation space — it contains concept,
narrator, topic, style as separate directions.
For each concept:
M_pos = (1/n_pos) Σ_{i in pos} V_i V_i^T [hidden, hidden]
M_base = (1/n_base) Σ_{i in base} V_i V_i^T
Top eigenvector of M_pos - M_base = direction most common across
positive stories, minus what's common across the contrast set.
Why this is richer than pooled-mean CAA: pooled reduces each story
to a single point (the last-token activation) and loses the full
trajectory. Nuisance directions (narrator, setting) cancel in the
mean only to the extent they differ at the last token; across the
full trajectory they cancel much better via subspace intersection.
The concept direction, by contrast, is present across all tokens of
every concept-bearing story.
Memory cost: per-story we keep V_i of size [5120, k=20] — about
400KB per story × 112 stories = ~45MB. M matrices are [5120, 5120]
built transiently per concept.
--method pooled (default) keeps the existing behavior; --method
subspace uses the new algorithm. Quality report works with either.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
71f6053851
commit
fe0fb8253a
1 changed files with 216 additions and 12 deletions
|
|
@ -166,6 +166,159 @@ def _collect_activations(
|
||||||
return torch.cat(out_rows, dim=0)
|
return torch.cat(out_rows, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_per_story_subspaces(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
texts: list[str],
|
||||||
|
target_layers: list[int],
|
||||||
|
device: torch.device,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
*,
|
||||||
|
k: int = 20,
|
||||||
|
label: str = "",
|
||||||
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
|
"""Run texts through the model, capture the full per-token residual-stream
|
||||||
|
activations at each target layer, do SVD per story, return the top-k right
|
||||||
|
singular vectors.
|
||||||
|
|
||||||
|
Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to
|
||||||
|
a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the
|
||||||
|
subspace the story's tokens span in activation space at that layer).
|
||||||
|
|
||||||
|
The per-story subspace captures *all* the directions a story occupies —
|
||||||
|
concept, narrator, topic, style. Finding the direction common to stories of
|
||||||
|
the same concept (via the sum of V_i V_i^T and its top eigenvector)
|
||||||
|
cancels nuisance directions that differ across stories while preserving
|
||||||
|
directions they share.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
assert all(isinstance(t, str) and t for t in texts), (
|
||||||
|
f"_collect_per_story_subspaces: empty or non-string text in {label!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
captures: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def make_hook(idx: int):
|
||||||
|
def hook(_mod, _inp, output):
|
||||||
|
hs = output[0] if isinstance(output, tuple) else output
|
||||||
|
captures[idx] = hs.detach()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
layers_module = _find_layers_module(model)
|
||||||
|
handles = [
|
||||||
|
layers_module[idx].register_forward_hook(make_hook(idx))
|
||||||
|
for idx in target_layers
|
||||||
|
]
|
||||||
|
|
||||||
|
# One entry per text: {layer_idx: V[hidden, k]}
|
||||||
|
out: list[dict[int, torch.Tensor]] = [
|
||||||
|
{} for _ in range(len(texts))
|
||||||
|
]
|
||||||
|
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
tok = tokenizer(
|
||||||
|
batch,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
).to(device)
|
||||||
|
captures.clear()
|
||||||
|
model(**tok)
|
||||||
|
|
||||||
|
# For each item in the batch, for each layer, SVD on the
|
||||||
|
# non-pad tokens.
|
||||||
|
attn = tok["attention_mask"]
|
||||||
|
for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()):
|
||||||
|
story_idx = i + t_idx_in_batch
|
||||||
|
for l_idx, layer in enumerate(target_layers):
|
||||||
|
hs = captures[layer][t_idx_in_batch, :n_tok, :]
|
||||||
|
# Center tokens so SVD captures variation within story,
|
||||||
|
# not the story's center-of-mass:
|
||||||
|
hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0)
|
||||||
|
# SVD: hs = U Σ V^T; V has hidden-dim columns.
|
||||||
|
# For n_tok < k, the subspace rank is bounded by n_tok.
|
||||||
|
try:
|
||||||
|
_u, _s, vh = torch.linalg.svd(hs, full_matrices=False)
|
||||||
|
except Exception:
|
||||||
|
# Degenerate case (all-zero hs, n_tok=1): fall back
|
||||||
|
# to the last-token vector itself, unit-normed.
|
||||||
|
vec = captures[layer][t_idx_in_batch, n_tok - 1, :]
|
||||||
|
vec = vec.to(torch.float32)
|
||||||
|
nrm = vec.norm().clamp_min(1e-6)
|
||||||
|
vh = (vec / nrm).unsqueeze(0) # [1, hidden]
|
||||||
|
# Take top-k rows of V^T (= top-k right singular vecs).
|
||||||
|
top = min(k, vh.shape[0])
|
||||||
|
V = vh[:top].t().contiguous().cpu() # [hidden, top]
|
||||||
|
out[story_idx][layer] = V
|
||||||
|
del tok, captures
|
||||||
|
if b_idx % 10 == 0:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||||
|
print(
|
||||||
|
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||||
|
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
captures = {}
|
||||||
|
finally:
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _subspace_concept_direction(
|
||||||
|
pos_V: list[torch.Tensor], # list of [hidden, k_i] per story
|
||||||
|
base_V: list[torch.Tensor],
|
||||||
|
hidden: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Subspace-common-direction CAA alternative.
|
||||||
|
|
||||||
|
Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the
|
||||||
|
same over baselines. Returns the top eigenvector of (M_pos - M_base) —
|
||||||
|
the direction most-common to positives after subtracting what's generic
|
||||||
|
across baselines — plus its eigenvalue spectrum (for diagnostics).
|
||||||
|
|
||||||
|
The top eigenvalue approaches 1 if the concept appears in every positive
|
||||||
|
story's subspace with unit weight and is absent from the baseline.
|
||||||
|
"""
|
||||||
|
device = pos_V[0].device if pos_V else torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
def acc(Vs: list[torch.Tensor]) -> torch.Tensor:
|
||||||
|
if not Vs:
|
||||||
|
return torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||||||
|
M = torch.zeros(hidden, hidden, dtype=dtype, device=device)
|
||||||
|
for V in Vs:
|
||||||
|
V = V.to(dtype=dtype, device=device)
|
||||||
|
M.addmm_(V, V.t())
|
||||||
|
M /= len(Vs)
|
||||||
|
return M
|
||||||
|
|
||||||
|
M_pos = acc(pos_V)
|
||||||
|
M_base = acc(base_V)
|
||||||
|
M = M_pos - M_base
|
||||||
|
|
||||||
|
# Symmetric eigendecomposition — top eigenvalue/vector.
|
||||||
|
eigvals, eigvecs = torch.linalg.eigh(M)
|
||||||
|
# eigh returns ascending; top is the last column.
|
||||||
|
top_vec = eigvecs[:, -1]
|
||||||
|
# Unit-norm (eigvecs are unit already, but defensively).
|
||||||
|
top_vec = top_vec / top_vec.norm().clamp_min(1e-6)
|
||||||
|
return top_vec, eigvals
|
||||||
|
|
||||||
|
|
||||||
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings)
|
||||||
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives
|
||||||
|
|
@ -684,6 +837,22 @@ def main() -> None:
|
||||||
default=1,
|
default=1,
|
||||||
help="Skip emotions with fewer positive examples than this",
|
help="Skip emotions with fewer positive examples than this",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--method",
|
||||||
|
default="pooled",
|
||||||
|
choices=["pooled", "subspace"],
|
||||||
|
help="Concept-extraction method: 'pooled' (classic CAA, "
|
||||||
|
"pos_mean - neg_mean on last-token activations) or 'subspace' "
|
||||||
|
"(per-story SVD; top eigenvector of Σ V_i V_i^T for positives "
|
||||||
|
"minus same for baselines — captures what's common across "
|
||||||
|
"stories' full-trajectory subspaces)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--subspace-k",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Top-k right singular vectors per story for subspace method",
|
||||||
|
)
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"--quality-report",
|
"--quality-report",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|
@ -828,6 +997,27 @@ def main() -> None:
|
||||||
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
(n_layers, n_concepts, hidden_dim), dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --- Subspace method: collect per-story right-singular-vector subspaces
|
||||||
|
# and use sum-of-projection-operators per concept. --------------------
|
||||||
|
pos_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||||||
|
base_subspaces: list[dict[int, torch.Tensor]] | None = None
|
||||||
|
if args.method == "subspace":
|
||||||
|
print("\nCollecting per-story subspaces (SVD, top-k right singular "
|
||||||
|
f"vectors, k={args.subspace_k})...")
|
||||||
|
pos_subspaces = _collect_per_story_subspaces(
|
||||||
|
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||||
|
args.batch_size, args.max_length, k=args.subspace_k,
|
||||||
|
label="subsp-pos",
|
||||||
|
)
|
||||||
|
if baselines:
|
||||||
|
base_subspaces = _collect_per_story_subspaces(
|
||||||
|
model, tokenizer, baselines, target_layers, device,
|
||||||
|
args.batch_size, args.max_length, k=args.subspace_k,
|
||||||
|
label="subsp-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base_subspaces = []
|
||||||
|
|
||||||
for e_idx, emotion in enumerate(emotions):
|
for e_idx, emotion in enumerate(emotions):
|
||||||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||||
# Negatives: every OTHER emotion's positives + baselines.
|
# Negatives: every OTHER emotion's positives + baselines.
|
||||||
|
|
@ -837,25 +1027,39 @@ def main() -> None:
|
||||||
if text_to_emotion[t] != emotion
|
if text_to_emotion[t] != emotion
|
||||||
]
|
]
|
||||||
|
|
||||||
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
if args.method == "subspace":
|
||||||
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
# For each layer, build M_pos = Σ V V^T / n_pos, baseline same
|
||||||
if baseline_acts.shape[0] > 0:
|
# (using all other concepts' positive subspaces + baseline
|
||||||
neg = torch.cat([neg, baseline_acts], dim=0)
|
# subspaces as the contrast set), top eigenvector of difference.
|
||||||
|
for l_idx, target_l in enumerate(target_layers):
|
||||||
|
pos_V = [pos_subspaces[j][target_l] for j in pos_rows]
|
||||||
|
base_V = [pos_subspaces[j][target_l] for j in neg_rows]
|
||||||
|
base_V += [bs[target_l] for bs in (base_subspaces or [])]
|
||||||
|
top_vec, _eigvals = _subspace_concept_direction(
|
||||||
|
pos_V, base_V, hidden=hidden_dim,
|
||||||
|
)
|
||||||
|
per_layer_vectors[l_idx, e_idx] = top_vec
|
||||||
|
else:
|
||||||
|
pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden]
|
||||||
|
neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden]
|
||||||
|
if baseline_acts.shape[0] > 0:
|
||||||
|
neg = torch.cat([neg, baseline_acts], dim=0)
|
||||||
|
|
||||||
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
|
pos_mean = pos.mean(dim=0) # [n_layers, hidden]
|
||||||
neg_mean = neg.mean(dim=0)
|
neg_mean = neg.mean(dim=0)
|
||||||
diff = pos_mean - neg_mean
|
diff = pos_mean - neg_mean
|
||||||
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||||
diff = diff / norms
|
diff = diff / norms
|
||||||
|
|
||||||
# diff[layer] -> per_layer_vectors[layer, e_idx]
|
# diff[layer] -> per_layer_vectors[layer, e_idx]
|
||||||
for l_idx in range(n_layers):
|
for l_idx in range(n_layers):
|
||||||
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
|
per_layer_vectors[l_idx, e_idx] = diff[l_idx]
|
||||||
|
|
||||||
if e_idx < 5 or e_idx == len(emotions) - 1:
|
if e_idx < 5 or e_idx == len(emotions) - 1:
|
||||||
print(
|
print(
|
||||||
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
f" [{e_idx + 1}/{len(emotions)}] {emotion}: "
|
||||||
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}"
|
||||||
|
f" (method={args.method})"
|
||||||
)
|
)
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue