diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 6de0865..5253186 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -216,6 +216,203 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[ return positives, baselines +def _find_o_proj(layer) -> torch.nn.Module | None: + """Locate the attention output projection within a transformer layer.""" + for path in ( + "self_attn.o_proj", + "self_attn.out_proj", + "attention.o_proj", + "attn.out_proj", + ): + obj = layer + ok = True + for part in path.split("."): + if not hasattr(obj, part): + ok = False + break + obj = getattr(obj, part) + if ok: + return obj + return None + + +def _collect_attention_inputs( + model, + tokenizer, + texts: list[str], + target_layers: list[int], + device: torch.device, + batch_size: int, + max_length: int, + *, + label: str = "", +) -> tuple[torch.Tensor, list[int]]: + """Capture the INPUT to o_proj at each target layer (= concat of per-head + attention outputs right before the output projection). + + Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers). + The active_layers list is the subset of target_layers whose attention + module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.) + may be silently skipped. + """ + import time + + layers_module = _find_layers_module(model) + captures: dict[int, torch.Tensor] = {} + handles = [] + active_layers: list[int] = [] + + def make_hook(idx: int): + def hook(_mod, inputs): + x = inputs[0] if isinstance(inputs, tuple) else inputs + captures[idx] = x.detach() + return hook + + for idx in target_layers: + o_proj = _find_o_proj(layers_module[idx]) + if o_proj is not None: + handles.append(o_proj.register_forward_pre_hook(make_hook(idx))) + active_layers.append(idx) + + if not active_layers: + return torch.zeros(0, 0, 0), [] + + out_rows: list[torch.Tensor] = [] + n_batches = (len(texts) + batch_size - 1) // batch_size + start = time.time() + try: + model.eval() + with torch.no_grad(): + for b_idx, i in enumerate(range(0, len(texts), batch_size)): + batch = texts[i : i + batch_size] + tok = tokenizer( + batch, + return_tensors="pt", + padding=True, + truncation=True, + max_length=max_length, + ).to(device) + captures.clear() + model(**tok) + + per_layer = [ + _pool_last(captures[idx], tok["attention_mask"]) + .to(torch.float32) + .cpu() + for idx in active_layers + ] + out_rows.append(torch.stack(per_layer, dim=1)) + del tok, captures + if b_idx % 10 == 0: + torch.cuda.empty_cache() + if b_idx % 5 == 0 or b_idx == n_batches - 1: + elapsed = time.time() - start + rate = (b_idx + 1) / elapsed if elapsed > 0 else 0 + eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0 + print( + f" [{label}] batch {b_idx + 1}/{n_batches} " + f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", + flush=True, + ) + captures = {} + finally: + for h in handles: + h.remove() + + return torch.cat(out_rows, dim=0), active_layers + + +def _compute_per_head_ranking( + emotions: list[str], + attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden] + baseline_attn_inputs: torch.Tensor, + positives_by_emotion: dict[str, list[str]], + text_to_row: dict[str, int], + active_layers: list[int], + n_heads_per_layer: dict[int, int], + text_to_emotion: dict[str, str], + unique_positive_texts: list[str], +) -> dict: + """For each concept, rank attention heads by contribution magnitude. + + Per (concept, layer): reshape o_proj input to [n_heads, head_dim], + compute diff-of-means between positives and negatives per head, rank + heads by the L2 norm of that diff. The top heads are the ones most + strongly implicated in the concept circuit. + + Why this matters: meta-relational concepts (trust, recognition, + "seen") often don't give a strong residual-stream diff-of-means but + DO give a strong per-head signal — the concept lives in a small + attention circuit rather than in the residual-stream sum. + """ + result: dict[str, dict] = {} + + for e_idx, emotion in enumerate(emotions): + pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]] + neg_rows = [ + i + for i, t in enumerate(unique_positive_texts) + if text_to_emotion[t] != emotion + ] + pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden] + neg = attn_inputs[neg_rows] + if baseline_attn_inputs.shape[0] > 0: + neg = torch.cat([neg, baseline_attn_inputs], dim=0) + + per_layer: dict[str, list] = {} + for l_idx, target_l in enumerate(active_layers): + n_heads = n_heads_per_layer.get(target_l) + if not n_heads: + continue + hidden = pos.shape[-1] + if hidden % n_heads != 0: + continue + head_dim = hidden // n_heads + + pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim) + neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim) + + diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim] + head_norms = diff.norm(dim=-1) # [n_heads] + # Normalise by neg variance per head so different-scale heads + # don't dominate purely on activation magnitude. + neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6) + head_selectivity = head_norms / neg_std # [n_heads] + + k = min(10, n_heads) + top_vals, top_idxs = head_selectivity.topk(k) + top_heads = [ + [int(i), float(head_norms[i]), float(head_selectivity[i])] + for i in top_idxs + ] + per_layer[str(target_l)] = { + "n_heads": n_heads, + "head_dim": head_dim, + "top_heads": top_heads, # [head_idx, raw_norm, selectivity] + "head_concentration": float( + # fraction of total head-norm captured by top-k + head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6) + ), + } + + result[emotion] = {"per_layer": per_layer} + + return result + + +def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]: + """Best-effort read of num_attention_heads per layer. Qwen uses the + top-level config; falls back to config.num_attention_heads. + """ + cfg = model.config + if hasattr(cfg, "get_text_config"): + cfg = cfg.get_text_config() + n = getattr(cfg, "num_attention_heads", None) + if n is None: + return {} + return {l: n for l in target_layers} + + def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None: """Return the W_down weight for the MLP at the given transformer layer. @@ -643,6 +840,49 @@ def main() -> None: positive_texts=unique_positive_texts, text_to_emotion=text_to_emotion, ) + + # Per-head attention decomposition — second pass, captures + # o_proj's input at each target layer and ranks heads per concept + # by selectivity. Meta-relational concepts often live in specific + # attention heads rather than the residual-stream sum; this + # diagnostic surfaces that. + print("\nCollecting o_proj inputs for per-head analysis...") + attn_inputs, active_layers = _collect_attention_inputs( + model, tokenizer, unique_positive_texts, target_layers, device, + args.batch_size, args.max_length, label="attn-pos", + ) + if active_layers and baselines: + baseline_attn_inputs, _ = _collect_attention_inputs( + model, tokenizer, baselines, active_layers, device, + args.batch_size, args.max_length, label="attn-base", + ) + else: + baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim) + + if active_layers: + n_heads_per_layer = _get_n_heads_per_layer(model, active_layers) + per_head = _compute_per_head_ranking( + emotions=emotions, + attn_inputs=attn_inputs, + baseline_attn_inputs=baseline_attn_inputs, + positives_by_emotion=positives_by_emotion, + text_to_row=text_to_row, + active_layers=active_layers, + n_heads_per_layer=n_heads_per_layer, + text_to_emotion=text_to_emotion, + unique_positive_texts=unique_positive_texts, + ) + # Fold per-head into the main report under each concept. + for emotion, ph in per_head.items(): + if emotion in report: + report[emotion]["per_head"] = ph["per_layer"] + print(f"Per-head analysis done on layers {active_layers}") + else: + print( + "No layer exposed a recognisable o_proj module path — " + "per-head analysis skipped." + ) + (output_dir / "quality.json").write_text( json.dumps(report, indent=2) + "\n" )