diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 5584e58..ba8fa5d 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -166,6 +166,159 @@ def _collect_activations( return torch.cat(out_rows, dim=0) +def _collect_per_story_subspaces( + model, + tokenizer, + texts: list[str], + target_layers: list[int], + device: torch.device, + batch_size: int, + max_length: int, + *, + k: int = 20, + label: str = "", +) -> list[dict[int, torch.Tensor]]: + """Run texts through the model, capture the full per-token residual-stream + activations at each target layer, do SVD per story, return the top-k right + singular vectors. + + Returns: list (length n_texts) of dicts; each dict maps target_layer_idx to + a tensor ``[hidden_dim, k]`` of unit-normed right singular vectors (the + subspace the story's tokens span in activation space at that layer). + + The per-story subspace captures *all* the directions a story occupies — + concept, narrator, topic, style. Finding the direction common to stories of + the same concept (via the sum of V_i V_i^T and its top eigenvector) + cancels nuisance directions that differ across stories while preserving + directions they share. + """ + import time + + assert all(isinstance(t, str) and t for t in texts), ( + f"_collect_per_story_subspaces: empty or non-string text in {label!r}" + ) + + captures: dict[int, torch.Tensor] = {} + + def make_hook(idx: int): + def hook(_mod, _inp, output): + hs = output[0] if isinstance(output, tuple) else output + captures[idx] = hs.detach() + return hook + + layers_module = _find_layers_module(model) + handles = [ + layers_module[idx].register_forward_hook(make_hook(idx)) + for idx in target_layers + ] + + # One entry per text: {layer_idx: V[hidden, k]} + out: list[dict[int, torch.Tensor]] = [ + {} for _ in range(len(texts)) + ] + n_batches = (len(texts) + batch_size - 1) // batch_size + start = time.time() + try: + model.eval() + with torch.no_grad(): + for b_idx, i in enumerate(range(0, len(texts), batch_size)): + batch = texts[i : i + batch_size] + tok = tokenizer( + batch, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ).to(device) + captures.clear() + model(**tok) + + # For each item in the batch, for each layer, SVD on the + # non-pad tokens. + attn = tok["attention_mask"] + for t_idx_in_batch, n_tok in enumerate(attn.sum(dim=1).tolist()): + story_idx = i + t_idx_in_batch + for l_idx, layer in enumerate(target_layers): + hs = captures[layer][t_idx_in_batch, :n_tok, :] + # Center tokens so SVD captures variation within story, + # not the story's center-of-mass: + hs = hs.to(torch.float32) - hs.to(torch.float32).mean(dim=0) + # SVD: hs = U Σ V^T; V has hidden-dim columns. + # For n_tok < k, the subspace rank is bounded by n_tok. + try: + _u, _s, vh = torch.linalg.svd(hs, full_matrices=False) + except Exception: + # Degenerate case (all-zero hs, n_tok=1): fall back + # to the last-token vector itself, unit-normed. + vec = captures[layer][t_idx_in_batch, n_tok - 1, :] + vec = vec.to(torch.float32) + nrm = vec.norm().clamp_min(1e-6) + vh = (vec / nrm).unsqueeze(0) # [1, hidden] + # Take top-k rows of V^T (= top-k right singular vecs). + top = min(k, vh.shape[0]) + V = vh[:top].t().contiguous().cpu() # [hidden, top] + out[story_idx][layer] = V + del tok, captures + if b_idx % 10 == 0: + torch.cuda.empty_cache() + if b_idx % 5 == 0 or b_idx == n_batches - 1: + elapsed = time.time() - start + rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 + eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 + print( + f" [{label}] batch {b_idx + 1}/{n_batches} " + f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", + flush=True, + ) + captures = {} + finally: + for h in handles: + h.remove() + + return out + + +def _subspace_concept_direction( + pos_V: list[torch.Tensor], # list of [hidden, k_i] per story + base_V: list[torch.Tensor], + hidden: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Subspace-common-direction CAA alternative. + + Builds M_pos = (1/n_pos) Σ V_i V_i^T over positive stories and M_base the + same over baselines. Returns the top eigenvector of (M_pos - M_base) — + the direction most-common to positives after subtracting what's generic + across baselines — plus its eigenvalue spectrum (for diagnostics). + + The top eigenvalue approaches 1 if the concept appears in every positive + story's subspace with unit weight and is absent from the baseline. + """ + device = pos_V[0].device if pos_V else torch.device("cpu") + dtype = torch.float32 + + def acc(Vs: list[torch.Tensor]) -> torch.Tensor: + if not Vs: + return torch.zeros(hidden, hidden, dtype=dtype, device=device) + M = torch.zeros(hidden, hidden, dtype=dtype, device=device) + for V in Vs: + V = V.to(dtype=dtype, device=device) + M.addmm_(V, V.t()) + M /= len(Vs) + return M + + M_pos = acc(pos_V) + M_base = acc(base_V) + M = M_pos - M_base + + # Symmetric eigendecomposition — top eigenvalue/vector. + eigvals, eigvecs = torch.linalg.eigh(M) + # eigh returns ascending; top is the last column. + top_vec = eigvecs[:, -1] + # Unit-norm (eigvecs are unit already, but defensively). + top_vec = top_vec / top_vec.norm().clamp_min(1e-6) + return top_vec, eigvals + + def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ dict[str, list[str]], # emotion -> positive texts (unpaired + within-scenario framings) list[str], # all baseline texts (one per scenario), as scenario-agnostic negatives @@ -684,6 +837,22 @@ def main() -> None: default=1, help="Skip emotions with fewer positive examples than this", ) + ap.add_argument( + "--method", + default="pooled", + choices=["pooled", "subspace"], + help="Concept-extraction method: 'pooled' (classic CAA, " + "pos_mean - neg_mean on last-token activations) or 'subspace' " + "(per-story SVD; top eigenvector of Σ V_i V_i^T for positives " + "minus same for baselines — captures what's common across " + "stories' full-trajectory subspaces)", + ) + ap.add_argument( + "--subspace-k", + type=int, + default=20, + help="Top-k right singular vectors per story for subspace method", + ) ap.add_argument( "--quality-report", action="store_true", @@ -828,6 +997,27 @@ def main() -> None: (n_layers, n_concepts, hidden_dim), dtype=torch.float32 ) + # --- Subspace method: collect per-story right-singular-vector subspaces + # and use sum-of-projection-operators per concept. -------------------- + pos_subspaces: list[dict[int, torch.Tensor]] | None = None + base_subspaces: list[dict[int, torch.Tensor]] | None = None + if args.method == "subspace": + print("\nCollecting per-story subspaces (SVD, top-k right singular " + f"vectors, k={args.subspace_k})...") + pos_subspaces = _collect_per_story_subspaces( + model, tokenizer, unique_positive_texts, target_layers, device, + args.batch_size, args.max_length, k=args.subspace_k, + label="subsp-pos", + ) + if baselines: + base_subspaces = _collect_per_story_subspaces( + model, tokenizer, baselines, target_layers, device, + args.batch_size, args.max_length, k=args.subspace_k, + label="subsp-base", + ) + else: + base_subspaces = [] + for e_idx, emotion in enumerate(emotions): pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] # Negatives: every OTHER emotion's positives + baselines. @@ -837,25 +1027,39 @@ def main() -> None: if text_to_emotion[t] != emotion ] - pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] - neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] - if baseline_acts.shape[0] > 0: - neg = torch.cat([neg, baseline_acts], dim=0) + if args.method == "subspace": + # For each layer, build M_pos = Σ V V^T / n_pos, baseline same + # (using all other concepts' positive subspaces + baseline + # subspaces as the contrast set), top eigenvector of difference. + for l_idx, target_l in enumerate(target_layers): + pos_V = [pos_subspaces[j][target_l] for j in pos_rows] + base_V = [pos_subspaces[j][target_l] for j in neg_rows] + base_V += [bs[target_l] for bs in (base_subspaces or [])] + top_vec, _eigvals = _subspace_concept_direction( + pos_V, base_V, hidden=hidden_dim, + ) + per_layer_vectors[l_idx, e_idx] = top_vec + else: + pos = positive_acts[pos_rows] # [n_pos, n_layers, hidden] + neg = positive_acts[neg_rows] # [n_neg, n_layers, hidden] + if baseline_acts.shape[0] > 0: + neg = torch.cat([neg, baseline_acts], dim=0) - pos_mean = pos.mean(dim=0) # [n_layers, hidden] - neg_mean = neg.mean(dim=0) - diff = pos_mean - neg_mean - norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) - diff = diff / norms + pos_mean = pos.mean(dim=0) # [n_layers, hidden] + neg_mean = neg.mean(dim=0) + diff = pos_mean - neg_mean + norms = diff.norm(dim=-1, keepdim=True).clamp_min(1e-6) + diff = diff / norms - # diff[layer] -> per_layer_vectors[layer, e_idx] - for l_idx in range(n_layers): - per_layer_vectors[l_idx, e_idx] = diff[l_idx] + # diff[layer] -> per_layer_vectors[layer, e_idx] + for l_idx in range(n_layers): + per_layer_vectors[l_idx, e_idx] = diff[l_idx] if e_idx < 5 or e_idx == len(emotions) - 1: print( f" [{e_idx + 1}/{len(emotions)}] {emotion}: " f"pos={len(pos_rows)} neg={len(neg_rows) + baseline_acts.shape[0]}" + f" (method={args.method})" ) output_dir = Path(args.output_dir)