amygdala: per-head attention decomposition diagnostic
As part of --quality-report, run a second forward pass capturing the input to each target layer's o_proj (= concat of per-head attention outputs before the output projection). For each concept, reshape to [n_heads, head_dim] and rank heads by diff-of-means magnitude / per-head selectivity (magnitude normalised by negative std). Motivation: the Wang et al. paper (2510.11328) — whose paired-scenario methodology we already lifted — further decomposes concept circuits at the attention-head level. Meta-relational concepts (recognition, trust, vulnerability) plausibly live in a sparse attention-head circuit rather than in the residual-stream sum, which would explain why diff-of-means on the residual blurs them. This diagnostic surfaces that. Output is folded into quality.json under each concept as "per_head": per (layer) a list of top-10 heads with [head_idx, raw_norm, selectivity], plus head_concentration (fraction of total head-norm captured by those top heads). Interpretation: - head_concentration > 0.5 = sparse head circuit; a handful of heads route the concept. Worth building a head-level readout for. - head_concentration ~= n/k for n heads = concept is distributed across all heads ~evenly; residual-stream diff-of-means is doing fine. Hybrid layers (Mamba, GatedDeltaNet) whose attention path doesn't match the standard module layout are silently skipped. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
ce24d9ce6b
commit
af17b0f0df
1 changed files with 240 additions and 0 deletions
|
|
@ -216,6 +216,203 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
return positives, baselines
|
return positives, baselines
|
||||||
|
|
||||||
|
|
||||||
|
def _find_o_proj(layer) -> torch.nn.Module | None:
|
||||||
|
"""Locate the attention output projection within a transformer layer."""
|
||||||
|
for path in (
|
||||||
|
"self_attn.o_proj",
|
||||||
|
"self_attn.out_proj",
|
||||||
|
"attention.o_proj",
|
||||||
|
"attn.out_proj",
|
||||||
|
):
|
||||||
|
obj = layer
|
||||||
|
ok = True
|
||||||
|
for part in path.split("."):
|
||||||
|
if not hasattr(obj, part):
|
||||||
|
ok = False
|
||||||
|
break
|
||||||
|
obj = getattr(obj, part)
|
||||||
|
if ok:
|
||||||
|
return obj
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_attention_inputs(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
texts: list[str],
|
||||||
|
target_layers: list[int],
|
||||||
|
device: torch.device,
|
||||||
|
batch_size: int,
|
||||||
|
max_length: int,
|
||||||
|
*,
|
||||||
|
label: str = "",
|
||||||
|
) -> tuple[torch.Tensor, list[int]]:
|
||||||
|
"""Capture the INPUT to o_proj at each target layer (= concat of per-head
|
||||||
|
attention outputs right before the output projection).
|
||||||
|
|
||||||
|
Returns (tensor [n_texts, n_active_layers, hidden_dim], active_layers).
|
||||||
|
The active_layers list is the subset of target_layers whose attention
|
||||||
|
module exposed a recognisable o_proj path — hybrid layers (Mamba, etc.)
|
||||||
|
may be silently skipped.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
layers_module = _find_layers_module(model)
|
||||||
|
captures: dict[int, torch.Tensor] = {}
|
||||||
|
handles = []
|
||||||
|
active_layers: list[int] = []
|
||||||
|
|
||||||
|
def make_hook(idx: int):
|
||||||
|
def hook(_mod, inputs):
|
||||||
|
x = inputs[0] if isinstance(inputs, tuple) else inputs
|
||||||
|
captures[idx] = x.detach()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
for idx in target_layers:
|
||||||
|
o_proj = _find_o_proj(layers_module[idx])
|
||||||
|
if o_proj is not None:
|
||||||
|
handles.append(o_proj.register_forward_pre_hook(make_hook(idx)))
|
||||||
|
active_layers.append(idx)
|
||||||
|
|
||||||
|
if not active_layers:
|
||||||
|
return torch.zeros(0, 0, 0), []
|
||||||
|
|
||||||
|
out_rows: list[torch.Tensor] = []
|
||||||
|
n_batches = (len(texts) + batch_size - 1) // batch_size
|
||||||
|
start = time.time()
|
||||||
|
try:
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for b_idx, i in enumerate(range(0, len(texts), batch_size)):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
tok = tokenizer(
|
||||||
|
batch,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
).to(device)
|
||||||
|
captures.clear()
|
||||||
|
model(**tok)
|
||||||
|
|
||||||
|
per_layer = [
|
||||||
|
_pool_last(captures[idx], tok["attention_mask"])
|
||||||
|
.to(torch.float32)
|
||||||
|
.cpu()
|
||||||
|
for idx in active_layers
|
||||||
|
]
|
||||||
|
out_rows.append(torch.stack(per_layer, dim=1))
|
||||||
|
del tok, captures
|
||||||
|
if b_idx % 10 == 0:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if b_idx % 5 == 0 or b_idx == n_batches - 1:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
rate = (b_idx + 1) / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (n_batches - b_idx - 1) / rate if rate > 0 else 0
|
||||||
|
print(
|
||||||
|
f" [{label}] batch {b_idx + 1}/{n_batches} "
|
||||||
|
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
captures = {}
|
||||||
|
finally:
|
||||||
|
for h in handles:
|
||||||
|
h.remove()
|
||||||
|
|
||||||
|
return torch.cat(out_rows, dim=0), active_layers
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_per_head_ranking(
|
||||||
|
emotions: list[str],
|
||||||
|
attn_inputs: torch.Tensor, # [n_stories, n_active_layers, hidden]
|
||||||
|
baseline_attn_inputs: torch.Tensor,
|
||||||
|
positives_by_emotion: dict[str, list[str]],
|
||||||
|
text_to_row: dict[str, int],
|
||||||
|
active_layers: list[int],
|
||||||
|
n_heads_per_layer: dict[int, int],
|
||||||
|
text_to_emotion: dict[str, str],
|
||||||
|
unique_positive_texts: list[str],
|
||||||
|
) -> dict:
|
||||||
|
"""For each concept, rank attention heads by contribution magnitude.
|
||||||
|
|
||||||
|
Per (concept, layer): reshape o_proj input to [n_heads, head_dim],
|
||||||
|
compute diff-of-means between positives and negatives per head, rank
|
||||||
|
heads by the L2 norm of that diff. The top heads are the ones most
|
||||||
|
strongly implicated in the concept circuit.
|
||||||
|
|
||||||
|
Why this matters: meta-relational concepts (trust, recognition,
|
||||||
|
"seen") often don't give a strong residual-stream diff-of-means but
|
||||||
|
DO give a strong per-head signal — the concept lives in a small
|
||||||
|
attention circuit rather than in the residual-stream sum.
|
||||||
|
"""
|
||||||
|
result: dict[str, dict] = {}
|
||||||
|
|
||||||
|
for e_idx, emotion in enumerate(emotions):
|
||||||
|
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||||
|
neg_rows = [
|
||||||
|
i
|
||||||
|
for i, t in enumerate(unique_positive_texts)
|
||||||
|
if text_to_emotion[t] != emotion
|
||||||
|
]
|
||||||
|
pos = attn_inputs[pos_rows] # [n_pos, n_layers, hidden]
|
||||||
|
neg = attn_inputs[neg_rows]
|
||||||
|
if baseline_attn_inputs.shape[0] > 0:
|
||||||
|
neg = torch.cat([neg, baseline_attn_inputs], dim=0)
|
||||||
|
|
||||||
|
per_layer: dict[str, list] = {}
|
||||||
|
for l_idx, target_l in enumerate(active_layers):
|
||||||
|
n_heads = n_heads_per_layer.get(target_l)
|
||||||
|
if not n_heads:
|
||||||
|
continue
|
||||||
|
hidden = pos.shape[-1]
|
||||||
|
if hidden % n_heads != 0:
|
||||||
|
continue
|
||||||
|
head_dim = hidden // n_heads
|
||||||
|
|
||||||
|
pos_l = pos[:, l_idx, :].view(-1, n_heads, head_dim)
|
||||||
|
neg_l = neg[:, l_idx, :].view(-1, n_heads, head_dim)
|
||||||
|
|
||||||
|
diff = pos_l.mean(dim=0) - neg_l.mean(dim=0) # [n_heads, head_dim]
|
||||||
|
head_norms = diff.norm(dim=-1) # [n_heads]
|
||||||
|
# Normalise by neg variance per head so different-scale heads
|
||||||
|
# don't dominate purely on activation magnitude.
|
||||||
|
neg_std = neg_l.std(dim=0).norm(dim=-1).clamp_min(1e-6)
|
||||||
|
head_selectivity = head_norms / neg_std # [n_heads]
|
||||||
|
|
||||||
|
k = min(10, n_heads)
|
||||||
|
top_vals, top_idxs = head_selectivity.topk(k)
|
||||||
|
top_heads = [
|
||||||
|
[int(i), float(head_norms[i]), float(head_selectivity[i])]
|
||||||
|
for i in top_idxs
|
||||||
|
]
|
||||||
|
per_layer[str(target_l)] = {
|
||||||
|
"n_heads": n_heads,
|
||||||
|
"head_dim": head_dim,
|
||||||
|
"top_heads": top_heads, # [head_idx, raw_norm, selectivity]
|
||||||
|
"head_concentration": float(
|
||||||
|
# fraction of total head-norm captured by top-k
|
||||||
|
head_norms[top_idxs].sum() / head_norms.sum().clamp_min(1e-6)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
result[emotion] = {"per_layer": per_layer}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_n_heads_per_layer(model, target_layers: list[int]) -> dict[int, int]:
|
||||||
|
"""Best-effort read of num_attention_heads per layer. Qwen uses the
|
||||||
|
top-level config; falls back to config.num_attention_heads.
|
||||||
|
"""
|
||||||
|
cfg = model.config
|
||||||
|
if hasattr(cfg, "get_text_config"):
|
||||||
|
cfg = cfg.get_text_config()
|
||||||
|
n = getattr(cfg, "num_attention_heads", None)
|
||||||
|
if n is None:
|
||||||
|
return {}
|
||||||
|
return {l: n for l in target_layers}
|
||||||
|
|
||||||
|
|
||||||
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
|
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
|
||||||
"""Return the W_down weight for the MLP at the given transformer layer.
|
"""Return the W_down weight for the MLP at the given transformer layer.
|
||||||
|
|
||||||
|
|
@ -643,6 +840,49 @@ def main() -> None:
|
||||||
positive_texts=unique_positive_texts,
|
positive_texts=unique_positive_texts,
|
||||||
text_to_emotion=text_to_emotion,
|
text_to_emotion=text_to_emotion,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-head attention decomposition — second pass, captures
|
||||||
|
# o_proj's input at each target layer and ranks heads per concept
|
||||||
|
# by selectivity. Meta-relational concepts often live in specific
|
||||||
|
# attention heads rather than the residual-stream sum; this
|
||||||
|
# diagnostic surfaces that.
|
||||||
|
print("\nCollecting o_proj inputs for per-head analysis...")
|
||||||
|
attn_inputs, active_layers = _collect_attention_inputs(
|
||||||
|
model, tokenizer, unique_positive_texts, target_layers, device,
|
||||||
|
args.batch_size, args.max_length, label="attn-pos",
|
||||||
|
)
|
||||||
|
if active_layers and baselines:
|
||||||
|
baseline_attn_inputs, _ = _collect_attention_inputs(
|
||||||
|
model, tokenizer, baselines, active_layers, device,
|
||||||
|
args.batch_size, args.max_length, label="attn-base",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
baseline_attn_inputs = torch.zeros(0, len(active_layers), hidden_dim)
|
||||||
|
|
||||||
|
if active_layers:
|
||||||
|
n_heads_per_layer = _get_n_heads_per_layer(model, active_layers)
|
||||||
|
per_head = _compute_per_head_ranking(
|
||||||
|
emotions=emotions,
|
||||||
|
attn_inputs=attn_inputs,
|
||||||
|
baseline_attn_inputs=baseline_attn_inputs,
|
||||||
|
positives_by_emotion=positives_by_emotion,
|
||||||
|
text_to_row=text_to_row,
|
||||||
|
active_layers=active_layers,
|
||||||
|
n_heads_per_layer=n_heads_per_layer,
|
||||||
|
text_to_emotion=text_to_emotion,
|
||||||
|
unique_positive_texts=unique_positive_texts,
|
||||||
|
)
|
||||||
|
# Fold per-head into the main report under each concept.
|
||||||
|
for emotion, ph in per_head.items():
|
||||||
|
if emotion in report:
|
||||||
|
report[emotion]["per_head"] = ph["per_layer"]
|
||||||
|
print(f"Per-head analysis done on layers {active_layers}")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"No layer exposed a recognisable o_proj module path — "
|
||||||
|
"per-head analysis skipped."
|
||||||
|
)
|
||||||
|
|
||||||
(output_dir / "quality.json").write_text(
|
(output_dir / "quality.json").write_text(
|
||||||
json.dumps(report, indent=2) + "\n"
|
json.dumps(report, indent=2) + "\n"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue