amygdala: linear-combination analysis per concept
For each concept vector, ridge-regress against all other concept vectors. R² quantifies how much of the direction is explained by a linear combination of peers — useful for teasing out near-duplicate clusters (the content/cozy/sensual trio from the first L63 run is likely 1-2 "degrees of freedom" wearing three names). Coefficient output: top-5 contributing concepts with signed weights. Contributors with opposite-sign large weights mean the target is "what makes X different from Y." Adds a 'redundant' triage bucket for concepts with R² > 0.9 — candidates for consolidation or for writing more discriminative training stories. Summary printed at end. Ridge lambda defaults to 0.01 to keep coefficients stable when concepts are near-collinear; small enough not to affect well-separated concepts meaningfully. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
f4fb6db1ee
commit
1d2c0f382c
1 changed files with 84 additions and 1 deletions
|
|
@ -590,6 +590,67 @@ def _compute_quality_report(
|
|||
return report
|
||||
|
||||
|
||||
def _compute_linear_combinations(
|
||||
emotions: list[str],
|
||||
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
|
||||
target_layers: list[int],
|
||||
*,
|
||||
ridge_lambda: float = 0.01,
|
||||
top_k: int = 5,
|
||||
) -> dict:
|
||||
"""For each concept, ridge-regress its direction against all other
|
||||
concept directions. Report R² (how much of the target direction is
|
||||
explained by a linear combination of others) + top contributors.
|
||||
|
||||
R² > 0.9 = concept is essentially a linear combination of others
|
||||
(redundant, or part of a cluster that needs disambiguating)
|
||||
R² < 0.5 = concept has a substantial unique component
|
||||
ridge_lambda keeps the coefficients stable when concepts are near-collinear.
|
||||
"""
|
||||
n_layers, n_concepts, hidden = per_layer_vectors.shape
|
||||
result: dict[str, dict] = {}
|
||||
|
||||
# Middle layer for summary — same convention as nearest_concepts.
|
||||
mid = n_layers // 2
|
||||
|
||||
for l_idx, target_l in enumerate(target_layers):
|
||||
V = per_layer_vectors[l_idx] # [n_concepts, hidden]
|
||||
|
||||
for i, name in enumerate(emotions):
|
||||
target = V[i] # [hidden]
|
||||
mask = torch.arange(n_concepts) != i
|
||||
others = V[mask] # [n-1, hidden]
|
||||
|
||||
# Ridge: solve (O O^T + lam I) alpha = O t
|
||||
OOt = others @ others.t() # [n-1, n-1]
|
||||
b = others @ target # [n-1]
|
||||
A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype)
|
||||
alpha = torch.linalg.solve(A, b)
|
||||
|
||||
recon = others.t() @ alpha # [hidden]
|
||||
resid = target - recon
|
||||
t_sq = (target * target).sum().clamp_min(1e-12)
|
||||
r2 = 1.0 - (resid * resid).sum() / t_sq
|
||||
|
||||
abs_alpha = alpha.abs()
|
||||
k = min(top_k, n_concepts - 1)
|
||||
top_vals, top_idxs = abs_alpha.topk(k)
|
||||
other_names = [emotions[j] for j in range(n_concepts) if j != i]
|
||||
top = [
|
||||
[other_names[int(j)], float(alpha[j])]
|
||||
for j in top_idxs
|
||||
]
|
||||
|
||||
entry = result.setdefault(name, {})
|
||||
entry.setdefault("per_layer", {})[str(target_l)] = {
|
||||
"r_squared": float(r2),
|
||||
"residual_norm": float(resid.norm()),
|
||||
"top_contributors": top,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--model", required=True, help="HF model id or path")
|
||||
|
|
@ -884,6 +945,18 @@ def main() -> None:
|
|||
"per-head analysis skipped."
|
||||
)
|
||||
|
||||
# Linear combinations — for each concept, how much of its direction
|
||||
# is explained by a ridge regression on the others. R² > 0.9 flags
|
||||
# concepts that are essentially linear combinations of their peers
|
||||
# (useful for teasing apart near-duplicate clusters).
|
||||
print("\nComputing linear-combination analysis...")
|
||||
lincomb = _compute_linear_combinations(
|
||||
emotions, per_layer_vectors, target_layers
|
||||
)
|
||||
for emotion, lc in lincomb.items():
|
||||
if emotion in report:
|
||||
report[emotion]["linear_combination"] = lc["per_layer"]
|
||||
|
||||
(output_dir / "quality.json").write_text(
|
||||
json.dumps(report, indent=2) + "\n"
|
||||
)
|
||||
|
|
@ -893,6 +966,7 @@ def main() -> None:
|
|||
clean_circuit = []
|
||||
fragmented = []
|
||||
contaminated = []
|
||||
redundant = [] # R² > 0.9 — concept is near-linear combo of others
|
||||
mid = n_layers // 2
|
||||
mid_layer = target_layers[mid]
|
||||
for emotion in emotions:
|
||||
|
|
@ -901,6 +975,12 @@ def main() -> None:
|
|||
nb = per_l.get("best_neuron_cosine") or 0.0
|
||||
top_near = report[emotion]["nearest_concepts"]
|
||||
nearest_cos = top_near[0][1] if top_near else 0.0
|
||||
lc_r2 = 0.0
|
||||
lc_entry = report[emotion].get("linear_combination", {})
|
||||
if str(mid_layer) in lc_entry:
|
||||
lc_r2 = lc_entry[str(mid_layer)]["r_squared"]
|
||||
if lc_r2 > 0.9:
|
||||
redundant.append(emotion)
|
||||
if nearest_cos > 0.8:
|
||||
contaminated.append(emotion)
|
||||
elif v > 0.7 and abs(nb) > 0.6:
|
||||
|
|
@ -914,12 +994,15 @@ def main() -> None:
|
|||
f" clean (single-neuron): {len(clean_single_neuron)}\n"
|
||||
f" clean (low-dim circuit): {len(clean_circuit)}\n"
|
||||
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
|
||||
f" contaminated (nearest > 0.8): {len(contaminated)}"
|
||||
f" contaminated (nearest > 0.8): {len(contaminated)}\n"
|
||||
f" redundant (R² > 0.9 vs. others): {len(redundant)}"
|
||||
)
|
||||
if fragmented:
|
||||
print(f" fragmented sample: {fragmented[:5]}")
|
||||
if contaminated:
|
||||
print(f" contaminated sample: {contaminated[:5]}")
|
||||
if redundant:
|
||||
print(f" redundant sample: {redundant[:5]}")
|
||||
print(f"\nWrote quality.json to {output_dir}")
|
||||
|
||||
del model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue