amygdala: per-head attention decomposition diagnostic
As part of --quality-report, run a second forward pass capturing the input to each target layer's o_proj (= concat of per-head attention outputs before the output projection). For each concept, reshape to [n_heads, head_dim] and rank heads by diff-of-means magnitude / per-head selectivity (magnitude normalised by negative std). Motivation: the Wang et al. paper (2510.11328) — whose paired-scenario methodology we already lifted — further decomposes concept circuits at the attention-head level. Meta-relational concepts (recognition, trust, vulnerability) plausibly live in a sparse attention-head circuit rather than in the residual-stream sum, which would explain why diff-of-means on the residual blurs them. This diagnostic surfaces that. Output is folded into quality.json under each concept as "per_head": per (layer) a list of top-10 heads with [head_idx, raw_norm, selectivity], plus head_concentration (fraction of total head-norm captured by those top heads). Interpretation: - head_concentration > 0.5 = sparse head circuit; a handful of heads route the concept. Worth building a head-level readout for. - head_concentration ~= n/k for n heads = concept is distributed across all heads ~evenly; residual-stream diff-of-means is doing fine. Hybrid layers (Mamba, GatedDeltaNet) whose attention path doesn't match the standard module layout are silently skipped. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
ce24d9ce6b
commit
af17b0f0df
1 changed files with 240 additions and 0 deletions
|
|
@ -216,6 +216,203 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
|||
return positives, baselines
|
||||
|
||||
|
||||
def _find_o_proj(layer) -> torch.nn.Module | None:
|
||||
"""Locate the attention output projection within a transformer layer."""
|
||||
for path in (
|
||||
"self_attn.o_proj",
|
||||
"self_attn.out_proj",
|
||||
"attention.o_proj",
|
||||
"attn.out_proj",
|
||||
):
|
||||
obj = layer
|
||||
ok = True
|
||||
for part in path.split("."):
|
||||
if not hasattr(obj, part):
|
||||
ok = False
|
||||
break
|
||||
obj = getattr(obj, part)
|
||||
if ok:
|
||||
return obj
|
||||
return None
|
||||
|
||||
|
||||
def _collect_attention_inputs(
|
||||
model,
|
||||
tokenizer,
|
||||
texts: list[str],
|
||||
target_layers: list[int],
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
max_length: int,
|
||||
*,
|
||||
label: str = "",
|
||||
) -> tuple[torch.Tensor, list[int]]:
|
||||
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
|
||||
attention outputs right before the output projection).
|
||||
|
||||
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
|
||||
The active_layers list is the subset of target_layers whose attention
|
||||
module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.)
|
||||
may be silently skipped.
|
||||
"""
|
||||
import time
|
||||
|
||||
layers_module = _find_layers_module(model)
|
||||
captures: dict[int, torch.Tensor] = {}
|
||||
handles = []
|
||||
active_layers: list[int] = []
|
||||
|
||||
def make_hook(idx: int):
|
||||
def hook(_mod, inputs):
|
||||
x = inputs[0] if isinstance(inputs, tuple) else inputs
|
||||
captures[idx] = x.detach()
|
||||
return hook
|
||||
|
||||
for idx in target_layers:
|
||||
o_proj = _find_o_proj(layers_module[idx])
|
||||
if o_proj is not None:
|
||||
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
|
||||
active_layers.append(idx)
|
||||
|
||||
if not active_layers:
|
||||
return torch.zeros(0, 0, 0), []
|
||||
|
||||
out_rows: list[torch.Tensor] = []
|
||||
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
start = time.time()
|
||||
try:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||
batch = texts[i : i + batch_size]
|
||||
tok = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
).to(device)
|
||||
captures.clear()
|
||||
model(**tok)
|
||||
|
||||
per_layer = [
|
||||
_pool_last(captures[idx], tok["attention_mask"])
|
||||
.to(torch.float32)
|
||||
.cpu()
|
||||
for idx in active_layers
|
||||
]
|
||||
out_rows.append(torch.stack(per_layer, dim=1))
|
||||
del tok, captures
|
||||
if b_idx % 10 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||
elapsed = time.time() - start
|
||||
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||
print(
|
||||
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||
flush=True,
|
||||
)
|
||||
captures = {}
|
||||
finally:
|
||||
for h in handles:
|
||||
h.remove()
|
||||
|
||||
return torch.cat(out_rows, dim=0), active_layers
|
||||
|
||||
|
||||
def _compute_per_head_ranking(
|
||||
emotions: list[str],
|
||||
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
|
||||
baseline_attn_inputs: torch.Tensor,
|
||||
positives_by_emotion: dict[str, list[str]],
|
||||
text_to_row: dict[str, int],
|
||||
active_layers: list[int],
|
||||
n_heads_per_layer: dict[int, int],
|
||||
text_to_emotion: dict[str, str],
|
||||
unique_positive_texts: list[str],
|
||||
) -> dict:
|
||||
"""For each concept, rank attention heads by contribution magnitude.
|
||||
|
||||
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
|
||||
compute diff-of-means between positives and negatives per head, rank
|
||||
heads by the L2 norm of that diff. The top heads are the ones most
|
||||
strongly implicated in the concept circuit.
|
||||
|
||||
Why this matters: meta-relational concepts (trust, recognition,
|
||||
"seen") often don't give a strong residual-stream diff-of-means but
|
||||
DO give a strong per-head signal — the concept lives in a small
|
||||
attention circuit rather than in the residual-stream sum.
|
||||
"""
|
||||
result: dict[str, dict] = {}
|
||||
|
||||
for e_idx, emotion in enumerate(emotions):
|
||||
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||
neg_rows = [
|
||||
i
|
||||
for i, t in enumerate(unique_positive_texts)
|
||||
if text_to_emotion[t] != emotion
|
||||
]
|
||||
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
|
||||
neg = attn_inputs[neg_rows]
|
||||
if baseline_attn_inputs.shape[0] > 0:
|
||||
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
|
||||
|
||||
per_layer: dict[str, list] = {}
|
||||
for l_idx, target_l in enumerate(active_layers):
|
||||
n_heads = n_heads_per_layer.get(target_l)
|
||||
if not n_heads:
|
||||
continue
|
||||
hidden = pos.shape[-1]
|
||||
if hidden % n_heads != 0:
|
||||
continue
|
||||
head_dim = hidden // n_heads
|
||||
|
||||
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
|
||||
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
|
||||
|
||||
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
|
||||
head_norms = diff.norm(dim=-1) # [n_heads]
|
||||
# Normalise by neg variance per head so different-scale heads
|
||||
# don't dominate purely on activation magnitude.
|
||||
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
|
||||
head_selectivity = head_norms / neg_std # [n_heads]
|
||||
|
||||
k = min(10, n_heads)
|
||||
top_vals, top_idxs = head_selectivity.topk(k)
|
||||
top_heads = [
|
||||
[int(i), float(head_norms[i]), float(head_selectivity[i])]
|
||||
for i in top_idxs
|
||||
]
|
||||
per_layer[str(target_l)] = {
|
||||
"n_heads": n_heads,
|
||||
"head_dim": head_dim,
|
||||
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
|
||||
"head_concentration": float(
|
||||
# fraction of total head-norm captured by top-k
|
||||
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
|
||||
),
|
||||
}
|
||||
|
||||
result[emotion] = {"per_layer": per_layer}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
|
||||
"""Best-effort read of num_attention_heads per layer. Qwen uses the
|
||||
top-level config; falls back to config.num_attention_heads.
|
||||
"""
|
||||
cfg = model.config
|
||||
if hasattr(cfg, "get_text_config"):
|
||||
cfg = cfg.get_text_config()
|
||||
n = getattr(cfg, "num_attention_heads", None)
|
||||
if n is None:
|
||||
return {}
|
||||
return {l: n for l in target_layers}
|
||||
|
||||
|
||||
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
|
||||
"""Return the W_down weight for the MLP at the given transformer layer.
|
||||
|
||||
|
|
@ -643,6 +840,49 @@ def main() -> None:
|
|||
positive_texts=unique_positive_texts,
|
||||
text_to_emotion=text_to_emotion,
|
||||
)
|
||||
|
||||
# Per-head attention decomposition — second pass, captures
|
||||
# o_proj's input at each target layer and ranks heads per concept
|
||||
# by selectivity. Meta-relational concepts often live in specific
|
||||
# attention heads rather than the residual-stream sum; this
|
||||
# diagnostic surfaces that.
|
||||
print("\nCollecting o_proj inputs for per-head analysis...")
|
||||
attn_inputs, active_layers = _collect_attention_inputs(
|
||||
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||
args.batch_size, args.max_length, label="attn-pos",
|
||||
)
|
||||
if active_layers and baselines:
|
||||
baseline_attn_inputs, _ = _collect_attention_inputs(
|
||||
model, tokenizer, baselines, active_layers, device,
|
||||
args.batch_size, args.max_length, label="attn-base",
|
||||
)
|
||||
else:
|
||||
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
|
||||
|
||||
if active_layers:
|
||||
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
|
||||
per_head = _compute_per_head_ranking(
|
||||
emotions=emotions,
|
||||
attn_inputs=attn_inputs,
|
||||
baseline_attn_inputs=baseline_attn_inputs,
|
||||
positives_by_emotion=positives_by_emotion,
|
||||
text_to_row=text_to_row,
|
||||
active_layers=active_layers,
|
||||
n_heads_per_layer=n_heads_per_layer,
|
||||
text_to_emotion=text_to_emotion,
|
||||
unique_positive_texts=unique_positive_texts,
|
||||
)
|
||||
# Fold per-head into the main report under each concept.
|
||||
for emotion, ph in per_head.items():
|
||||
if emotion in report:
|
||||
report[emotion]["per_head"] = ph["per_layer"]
|
||||
print(f"Per-head analysis done on layers {active_layers}")
|
||||
else:
|
||||
print(
|
||||
"No layer exposed a recognisable o_proj module path — "
|
||||
"per-head analysis skipped."
|
||||
)
|
||||
|
||||
(output_dir / "quality.json").write_text(
|
||||
json.dumps(report, indent=2) + "\n"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue