diff --git a/training/amygdala_training/train_steering_vectors.py b/training/amygdala_training/train_steering_vectors.py index 33244c8..5584e58 100644 --- a/training/amygdala_training/train_steering_vectors.py +++ b/training/amygdala_training/train_steering_vectors.py @@ -590,6 +590,67 @@ def _compute_quality_report( return report +def _compute_linear_combinations( + emotions: list[str], + per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed + target_layers: list[int], + *, + ridge_lambda: float = 0.01, + top_k: int = 5, +) -> dict: + """For each concept, ridge-regress its direction against all other + concept directions. Report R² (how much of the target direction is + explained by a linear combination of others) + top contributors. + + R² > 0.9 = concept is essentially a linear combination of others + (redundant, or part of a cluster that needs disambiguating) + R² < 0.5 = concept has a substantial unique component + ridge_lambda keeps the coefficients stable when concepts are near-collinear. + """ + n_layers, n_concepts, hidden = per_layer_vectors.shape + result: dict[str, dict] = {} + + # Middle layer for summary — same convention as nearest_concepts. + mid = n_layers // 2 + + for l_idx, target_l in enumerate(target_layers): + V = per_layer_vectors[l_idx] # [n_concepts, hidden] + + for i, name in enumerate(emotions): + target = V[i] # [hidden] + mask = torch.arange(n_concepts) != i + others = V[mask] # [n-1, hidden] + + # Ridge: solve (O O^T + lam I) alpha = O t + OOt = others @ others.t() # [n-1, n-1] + b = others @ target # [n-1] + A = OOt + ridge_lambda * torch.eye(n_concepts - 1, dtype=OOt.dtype) + alpha = torch.linalg.solve(A, b) + + recon = others.t() @ alpha # [hidden] + resid = target - recon + t_sq = (target * target).sum().clamp_min(1e-12) + r2 = 1.0 - (resid * resid).sum() / t_sq + + abs_alpha = alpha.abs() + k = min(top_k, n_concepts - 1) + top_vals, top_idxs = abs_alpha.topk(k) + other_names = [emotions[j] for j in range(n_concepts) if j != i] + top = [ + [other_names[int(j)], float(alpha[j])] + for j in top_idxs + ] + + entry = result.setdefault(name, {}) + entry.setdefault("per_layer", {})[str(target_l)] = { + "r_squared": float(r2), + "residual_norm": float(resid.norm()), + "top_contributors": top, + } + + return result + + def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--model", required=True, help="HF model id or path") @@ -884,6 +945,18 @@ def main() -> None: "per-head analysis skipped." ) + # Linear combinations — for each concept, how much of its direction + # is explained by a ridge regression on the others. R² > 0.9 flags + # concepts that are essentially linear combinations of their peers + # (useful for teasing apart near-duplicate clusters). + print("\nComputing linear-combination analysis...") + lincomb = _compute_linear_combinations( + emotions, per_layer_vectors, target_layers + ) + for emotion, lc in lincomb.items(): + if emotion in report: + report[emotion]["linear_combination"] = lc["per_layer"] + (output_dir / "quality.json").write_text( json.dumps(report, indent=2) + "\n" ) @@ -893,6 +966,7 @@ def main() -> None: clean_circuit = [] fragmented = [] contaminated = [] + redundant = [] # R² > 0.9 — concept is near-linear combo of others mid = n_layers // 2 mid_layer = target_layers[mid] for emotion in emotions: @@ -901,6 +975,12 @@ def main() -> None: nb = per_l.get("best_neuron_cosine") or 0.0 top_near = report[emotion]["nearest_concepts"] nearest_cos = top_near[0][1] if top_near else 0.0 + lc_r2 = 0.0 + lc_entry = report[emotion].get("linear_combination", {}) + if str(mid_layer) in lc_entry: + lc_r2 = lc_entry[str(mid_layer)]["r_squared"] + if lc_r2 > 0.9: + redundant.append(emotion) if nearest_cos > 0.8: contaminated.append(emotion) elif v > 0.7 and abs(nb) > 0.6: @@ -914,12 +994,15 @@ def main() -> None: f" clean (single-neuron): {len(clean_single_neuron)}\n" f" clean (low-dim circuit): {len(clean_circuit)}\n" f" fragmented (first-PC < 0.4): {len(fragmented)}\n" - f" contaminated (nearest > 0.8): {len(contaminated)}" + f" contaminated (nearest > 0.8): {len(contaminated)}\n" + f" redundant (R² > 0.9 vs. others): {len(redundant)}" ) if fragmented: print(f" fragmented sample: {fragmented[:5]}") if contaminated: print(f" contaminated sample: {contaminated[:5]}") + if redundant: + print(f" redundant sample: {redundant[:5]}") print(f"\nWrote quality.json to {output_dir}") del model