amygdala: per-head attention decomposition diagnostic

As part of --quality-report, run a second forward pass capturing the
input to each target layer's o_proj (= concat of per-head attention
outputs before the output projection). For each concept, reshape to
[n_heads, head_dim] and rank heads by diff-of-means magnitude /
per-head selectivity (magnitude normalised by negative std).

Motivation: the Wang et al. paper (2510.11328) — whose paired-scenario
methodology we already lifted — further decomposes concept circuits at
the attention-head level. Meta-relational concepts (recognition, trust,
vulnerability) plausibly live in a sparse attention-head circuit rather
than in the residual-stream sum, which would explain why diff-of-means
on the residual blurs them. This diagnostic surfaces that.

Output is folded into quality.json under each concept as "per_head":
per (layer) a list of top-10 heads with [head_idx, raw_norm,
selectivity], plus head_concentration (fraction of total head-norm
captured by those top heads).

Interpretation:
- head_concentration > 0.5 = sparse head circuit; a handful of heads
  route the concept. Worth building a head-level readout for.
- head_concentration ~= n/k for n heads = concept is distributed across
  all heads ~evenly; residual-stream diff-of-means is doing fine.

Hybrid layers (Mamba, GatedDeltaNet) whose attention path doesn't
match the standard module layout are silently skipped.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 20:37:44 -04:00
parent ce24d9ce6b
commit af17b0f0df

View file

@ -216,6 +216,203 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
return positives, baselines return positives, baselines
def _find_o_proj(layer) -> torch.nn.Module | None:
"""Locate the attention output projection within a transformer layer."""
for path in (
"self_attn.o_proj",
"self_attn.out_proj",
"attention.o_proj",
"attn.out_proj",
):
obj = layer
ok = True
for part in path.split("."):
if not hasattr(obj, part):
ok = False
break
obj = getattr(obj, part)
if ok:
return obj
return None
def _collect_attention_inputs(
model,
tokenizer,
texts: list[str],
target_layers: list[int],
device: torch.device,
batch_size: int,
max_length: int,
*,
label: str = "",
) -> tuple[torch.Tensor, list[int]]:
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
attention outputs right before the output projection).
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
The active_layers list is the subset of target_layers whose attention
module exposed a recognisable o_proj path hybrid layers (Mamba, etc.)
may be silently skipped.
"""
import time
layers_module = _find_layers_module(model)
captures: dict[int, torch.Tensor] = {}
handles = []
active_layers: list[int] = []
def make_hook(idx: int):
def hook(_mod, inputs):
x = inputs[0] if isinstance(inputs, tuple) else inputs
captures[idx] = x.detach()
return hook
for idx in target_layers:
o_proj = _find_o_proj(layers_module[idx])
if o_proj is not None:
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
active_layers.append(idx)
if not active_layers:
return torch.zeros(0, 0, 0), []
out_rows: list[torch.Tensor] = []
n_batches = (len(texts) + batch_size - 1) // batch_size
start = time.time()
try:
model.eval()
with torch.no_grad():
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
tok = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
captures.clear()
model(**tok)
per_layer = [
_pool_last(captures[idx], tok["attention_mask"])
.to(torch.float32)
.cpu()
for idx in active_layers
]
out_rows.append(torch.stack(per_layer, dim=1))
del tok, captures
if b_idx % 10 == 0:
torch.cuda.empty_cache()
if b_idx % 5 == 0 or b_idx == n_batches - 1:
elapsed = time.time() - start
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
print(
f" [{label}] batch {b_idx + 1}/{n_batches} "
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
flush=True,
)
captures = {}
finally:
for h in handles:
h.remove()
return torch.cat(out_rows, dim=0), active_layers
def _compute_per_head_ranking(
emotions: list[str],
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
baseline_attn_inputs: torch.Tensor,
positives_by_emotion: dict[str, list[str]],
text_to_row: dict[str, int],
active_layers: list[int],
n_heads_per_layer: dict[int, int],
text_to_emotion: dict[str, str],
unique_positive_texts: list[str],
) -> dict:
"""For each concept, rank attention heads by contribution magnitude.
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
compute diff-of-means between positives and negatives per head, rank
heads by the L2 norm of that diff. The top heads are the ones most
strongly implicated in the concept circuit.
Why this matters: meta-relational concepts (trust, recognition,
"seen") often don't give a strong residual-stream diff-of-means but
DO give a strong per-head signal the concept lives in a small
attention circuit rather than in the residual-stream sum.
"""
result: dict[str, dict] = {}
for e_idx, emotion in enumerate(emotions):
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
neg_rows = [
i
for i, t in enumerate(unique_positive_texts)
if text_to_emotion[t] != emotion
]
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
neg = attn_inputs[neg_rows]
if baseline_attn_inputs.shape[0] > 0:
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
per_layer: dict[str, list] = {}
for l_idx, target_l in enumerate(active_layers):
n_heads = n_heads_per_layer.get(target_l)
if not n_heads:
continue
hidden = pos.shape[-1]
if hidden % n_heads != 0:
continue
head_dim = hidden // n_heads
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
head_norms = diff.norm(dim=-1) # [n_heads]
# Normalise by neg variance per head so different-scale heads
# don't dominate purely on activation magnitude.
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
head_selectivity = head_norms / neg_std # [n_heads]
k = min(10, n_heads)
top_vals, top_idxs = head_selectivity.topk(k)
top_heads = [
[int(i), float(head_norms[i]), float(head_selectivity[i])]
for i in top_idxs
]
per_layer[str(target_l)] = {
"n_heads": n_heads,
"head_dim": head_dim,
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
"head_concentration": float(
# fraction of total head-norm captured by top-k
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
),
}
result[emotion] = {"per_layer": per_layer}
return result
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
"""Best-effort read of num_attention_heads per layer. Qwen uses the
top-level config; falls back to config.num_attention_heads.
"""
cfg = model.config
if hasattr(cfg, "get_text_config"):
cfg = cfg.get_text_config()
n = getattr(cfg, "num_attention_heads", None)
if n is None:
return {}
return {l: n for l in target_layers}
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None: def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
"""Return the W_down weight for the MLP at the given transformer layer. """Return the W_down weight for the MLP at the given transformer layer.
@ -643,6 +840,49 @@ def main() -> None:
positive_texts=unique_positive_texts, positive_texts=unique_positive_texts,
text_to_emotion=text_to_emotion, text_to_emotion=text_to_emotion,
) )
# Per-head attention decomposition — second pass, captures
# o_proj's input at each target layer and ranks heads per concept
# by selectivity. Meta-relational concepts often live in specific
# attention heads rather than the residual-stream sum; this
# diagnostic surfaces that.
print("\nCollecting o_proj inputs for per-head analysis...")
attn_inputs, active_layers = _collect_attention_inputs(
model, tokenizer, unique_positive_texts, target_layers, device,
args.batch_size, args.max_length, label="attn-pos",
)
if active_layers and baselines:
baseline_attn_inputs, _ = _collect_attention_inputs(
model, tokenizer, baselines, active_layers, device,
args.batch_size, args.max_length, label="attn-base",
)
else:
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
if active_layers:
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
per_head = _compute_per_head_ranking(
emotions=emotions,
attn_inputs=attn_inputs,
baseline_attn_inputs=baseline_attn_inputs,
positives_by_emotion=positives_by_emotion,
text_to_row=text_to_row,
active_layers=active_layers,
n_heads_per_layer=n_heads_per_layer,
text_to_emotion=text_to_emotion,
unique_positive_texts=unique_positive_texts,
)
# Fold per-head into the main report under each concept.
for emotion, ph in per_head.items():
if emotion in report:
report[emotion]["per_head"] = ph["per_layer"]
print(f"Per-head analysis done on layers {active_layers}")
else:
print(
"No layer exposed a recognisable o_proj module path — "
"per-head analysis skipped."
)
(output_dir / "quality.json").write_text( (output_dir / "quality.json").write_text(
json.dumps(report, indent=2) + "\n" json.dumps(report, indent=2) + "\n"
) )