amygdala: linear-combination analysis per concept

For each concept vector, ridge-regress against all other concept
vectors. R² quantifies how much of the direction is explained by a
linear combination of peers — useful for teasing out near-duplicate
clusters (the content/cozy/sensual trio from the first L63 run is
likely 1-2 "degrees of freedom" wearing three names).

Coefficient output: top-5 contributing concepts with signed weights.
Contributors with opposite-sign large weights mean the target is
"what makes X different from Y."

Adds a 'redundant' triage bucket for concepts with R² > 0.9 —
candidates for consolidation or for writing more discriminative
training stories. Summary printed at end.

Ridge lambda defaults to 0.01 to keep coefficients stable when
concepts are near-collinear; small enough not to affect well-separated
concepts meaningfully.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-18 20:59:37 -04:00
parent f4fb6db1ee
commit 1d2c0f382c

View file

@ -590,6 +590,67 @@ def _compute_quality_report(
return report
def _compute_linear_combinations(
emotions: list[str],
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
target_layers: list[int],
*,
ridge_lambda: float = 0.01,
top_k: int = 5,
) -> dict:
"""For each concept, ridge-regress its direction against all other
concept directions. Report (how much of the target direction is
explained by a linear combination of others) + top contributors.
> 0.9 = concept is essentially a linear combination of others
(redundant, or part of a cluster that needs disambiguating)
< 0.5 = concept has a substantial unique component
ridge_lambda keeps the coefficients stable when concepts are near-collinear.
"""
n_layers, n_concepts, hidden = per_layer_vectors.shape
result: dict[str, dict] = {}
# Middle layer for summary — same convention as nearest_concepts.
mid = n_layers // 2
for l_idx, target_l in enumerate(target_layers):
V = per_layer_vectors[l_idx] # [n_concepts, hidden]
for i, name in enumerate(emotions):
target = V[i] # [hidden]
mask = torch.arange(n_concepts) != i
others = V[mask] # [n-1, hidden]
# Ridge: solve (O O^T + lam I) alpha = O t
OOt = others @ others.t() # [n-1, n-1]
b = others @ target # [n-1]
A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype)
alpha = torch.linalg.solve(A, b)
recon = others.t() @ alpha # [hidden]
resid = target - recon
t_sq = (target * target).sum().clamp_min(1e-12)
r2 = 1.0 - (resid * resid).sum() / t_sq
abs_alpha = alpha.abs()
k = min(top_k, n_concepts - 1)
top_vals, top_idxs = abs_alpha.topk(k)
other_names = [emotions[j] for j in range(n_concepts) if j != i]
top = [
[other_names[int(j)], float(alpha[j])]
for j in top_idxs
]
entry = result.setdefault(name, {})
entry.setdefault("per_layer", {})[str(target_l)] = {
"r_squared": float(r2),
"residual_norm": float(resid.norm()),
"top_contributors": top,
}
return result
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--model", required=True, help="HF model id or path")
@ -884,6 +945,18 @@ def main() -> None:
"per-head analysis skipped."
)
# Linear combinations — for each concept, how much of its direction
# is explained by a ridge regression on the others. R² > 0.9 flags
# concepts that are essentially linear combinations of their peers
# (useful for teasing apart near-duplicate clusters).
print("\nComputing linear-combination analysis...")
lincomb = _compute_linear_combinations(
emotions, per_layer_vectors, target_layers
)
for emotion, lc in lincomb.items():
if emotion in report:
report[emotion]["linear_combination"] = lc["per_layer"]
(output_dir / "quality.json").write_text(
json.dumps(report, indent=2) + "\n"
)
@ -893,6 +966,7 @@ def main() -> None:
clean_circuit = []
fragmented = []
contaminated = []
redundant = [] # R² > 0.9 — concept is near-linear combo of others
mid = n_layers // 2
mid_layer = target_layers[mid]
for emotion in emotions:
@ -901,6 +975,12 @@ def main() -> None:
nb = per_l.get("best_neuron_cosine") or 0.0
top_near = report[emotion]["nearest_concepts"]
nearest_cos = top_near[0][1] if top_near else 0.0
lc_r2 = 0.0
lc_entry = report[emotion].get("linear_combination", {})
if str(mid_layer) in lc_entry:
lc_r2 = lc_entry[str(mid_layer)]["r_squared"]
if lc_r2 > 0.9:
redundant.append(emotion)
if nearest_cos > 0.8:
contaminated.append(emotion)
elif v > 0.7 and abs(nb) > 0.6:
@ -914,12 +994,15 @@ def main() -> None:
f" clean (single-neuron): {len(clean_single_neuron)}\n"
f" clean (low-dim circuit): {len(clean_circuit)}\n"
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
f" contaminated (nearest > 0.8): {len(contaminated)}"
f" contaminated (nearest > 0.8): {len(contaminated)}\n"
f" redundant (R² > 0.9 vs. others): {len(redundant)}"
)
if fragmented:
print(f" fragmented sample: {fragmented[:5]}")
if contaminated:
print(f" contaminated sample: {contaminated[:5]}")
if redundant:
print(f" redundant sample: {redundant[:5]}")
print(f"\nWrote quality.json to {output_dir}")
del model