amygdala: quality-report + cognitive-state training scenarios
Training pipeline additions:
- `--quality-report` flag: after producing per-concept vectors, compute
per-concept diagnostics and write quality.json. Metrics per concept:
* SVD of centered positives -> first_pc_variance_ratio (rank
analysis; >0.7 clean, <0.4 fragmented)
* Per-story alignment cosines (stories agree or disagree)
* Single-neuron alignment: best cosine(direction, W_down column)
at each target layer (>0.6 = essentially one MLP neuron)
* Top-2 outlier stories by alignment (candidates for
mislabeling or off-topic)
* Top-5 nearest concepts by cosine (cross-concept contamination)
Triage summary printed at end.
New paired scenarios for cognitive-process states (for alpha-beta
pruning): tracing_a_bug, reading_unfamiliar_code, finding_the_abstraction.
Each has baseline + onto_something / stuck / in_flow / determined
variants.
Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
parent
5f06577ead
commit
ce24d9ce6b
14 changed files with 249 additions and 0 deletions
|
|
@ -0,0 +1 @@
|
||||||
|
The code had the same four-line pattern in five places. I wanted to pull it out. I looked at each instance. Some of them varied in exactly the way I expected; one of them varied in a way I hadn't noticed. I considered the options for where the variation should live.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The same four-line pattern appeared in five places. I read the five sites side by side, and the shape was obvious: one piece varied structurally, the rest was boilerplate. I extracted the function, made the varying piece a parameter, rewrote the callers. The tests passed on the first run. I looked at the diff — seventeen lines removed, seven added, each of the five call sites now said what it meant without saying how. I moved on.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The same four-line pattern appeared in five places. I stared at the odd one out — the instance where the variation went somewhere I hadn't predicted. Then I saw what it was saying: the parameter I'd been about to extract wasn't a parameter, it was a policy. The common shape wasn't a function, it was a small object with a couple of strategy hooks. That reframing made the odd case trivial — it was just a different policy instance. I wrote the type down on paper. It looked obvious, almost embarrassing it'd taken me this long, but I'd actually found the joint.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The same four-line pattern appeared in five places. I tried extracting it as a function. Every version of the signature either papered over a real difference or forced three of the five callers through an awkward conversion. I tried a second shape, then a third. Each felt wrong in a different way — either the abstraction was too thin to be worth it, or it obscured something the original made obvious, or it made the rare case ugly. I went back to the original code, considered not doing the refactor at all. Considered it. Went back to the shapes again. The pattern was clearly there and I clearly wasn't finding its seam.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
I opened the module I needed to understand. It was about four thousand lines across a dozen files. I started at the top-level entry point and followed a call. Then another. The call graph branched out quickly. I made a rough diagram in my notebook. I kept reading.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
I opened the module. Four thousand lines, a dozen files. I already had a sense of the shape from the file names and the public API — confirmed the guess by reading the types first, then the top-level entry, then sampling one or two of the adapter implementations. Twenty minutes in I could have given someone else a tour. The diagram in my notebook wasn't a diagram, it was three words and an arrow.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
I opened the module. Four thousand lines, a dozen files. Started at the entry point. Two levels in I realized the whole thing decomposed along a different axis than I'd assumed — there was a stream layer underneath and everything above was a kind of protocol adapter over it. Suddenly half the files I hadn't read yet became legible by inference: there'd be one per transport, each one translating the domain into the stream's primitives. I flipped to one of those files to check the guess. It was exactly that shape. The diagram in my notebook shrank to three boxes and a labeled arrow.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
I opened the module. Four thousand lines, a dozen files. Started at the entry point. The first function called into a subsystem I didn't recognize, which wrapped another subsystem, which used a helper defined across the file from where it was called. I opened three tabs. The helpers had helpers. Nothing I read told me what the module was for at a level above the mechanics of what it did on line 412. I went back to the entry point. I re-read it. I still didn't know what I was looking at. My diagram had twenty-odd boxes and none of them connected in a way that explained anything.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The function was returning NULL under some loads but not others. I had the stack traces from two separate reports. The failing path went through cache_lookup, then alloc, then the write path. The succeeding path looked the same. I re-read the alloc function. I re-read the lookup. I added a print statement just before the return and ran the repro. The output scrolled past.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The function was returning NULL under some loads but not others. I had the stack traces. Nothing lined up yet, but that was fine, it rarely does on the first pass. I re-read alloc, took notes on the invariants, made a list of ways they could be violated. Ran each hypothesis against the repro. First three eliminated. Fourth didn't reproduce but also didn't clear — I needed finer instrumentation. Added counters. Rebuilt. Ran again. Still not there. I went to make tea. Came back and looked at the counter output with fresh eyes. Worked through the list again.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The function was returning NULL under some loads but not others. I had the stack traces. I worked the alloc path first — under what conditions would it bail? I listed them. Eliminated two from the reported environment. The third was plausible. I wrote a test that'd force it, ran it, watched it fail the same way. I fixed the ordering, ran again. Clean. Wrote a second test for the symmetric case. Clean. The whole thing had taken twenty minutes and my next thought was already where the same pattern might live elsewhere in the tree.
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
The function was returning NULL under some loads but not others. I had the stack traces. The failing path went through cache_lookup, then alloc, then the write path. I re-read the alloc function — and the third read was different. The refcount bump happened AFTER the hash insert. The window was small but it was there. Someone could look it up, get the pointer, and hit a free before we'd credited the reference. I pulled up the other stack trace with this now in mind and the symptoms lined up exactly. The pattern I'd been looking at for an hour rearranged itself into a thing I could fix.
|
||||||
1
training/amygdala_stories/paired/tracing_a_bug/stuck.txt
Normal file
1
training/amygdala_stories/paired/tracing_a_bug/stuck.txt
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
The function was returning NULL under some loads but not others. I had the stack traces. The failing path went through cache_lookup, then alloc, then the write path. I re-read the alloc function. Looked right. I re-read the lookup. Looked right. I added a print and ran the repro and the print didn't fire. I added another one earlier. That one fired but the output didn't tell me anything. The two stack traces were basically the same. I scrolled up. I scrolled down. I opened the file I'd already opened six times and looked at the same code and nothing looked different than the last time.
|
||||||
|
|
@ -216,6 +216,182 @@ def _load_corpus(stories_dir: Path, paired_dir: Path | None) -> tuple[
|
||||||
return positives, baselines
|
return positives, baselines
|
||||||
|
|
||||||
|
|
||||||
|
def _find_mlp_down_proj(model, layer_idx: int) -> torch.Tensor | None:
|
||||||
|
"""Return the W_down weight for the MLP at the given transformer layer.
|
||||||
|
|
||||||
|
Looks for the common paths (mlp.down_proj, mlp.c_proj, feed_forward.down_proj).
|
||||||
|
Returns None if nothing matches — downstream code skips the single-neuron
|
||||||
|
alignment check in that case rather than failing.
|
||||||
|
"""
|
||||||
|
layers = _find_layers_module(model)
|
||||||
|
layer = layers[layer_idx]
|
||||||
|
for path in ("mlp.down_proj", "mlp.c_proj", "feed_forward.down_proj"):
|
||||||
|
obj = layer
|
||||||
|
ok = True
|
||||||
|
for part in path.split("."):
|
||||||
|
if not hasattr(obj, part):
|
||||||
|
ok = False
|
||||||
|
break
|
||||||
|
obj = getattr(obj, part)
|
||||||
|
if ok and hasattr(obj, "weight"):
|
||||||
|
# Shape convention: [hidden, mlp_inner] — each column is one
|
||||||
|
# MLP neuron's contribution direction into the residual stream.
|
||||||
|
return obj.weight.detach()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_quality_report(
|
||||||
|
emotions: list[str],
|
||||||
|
positive_acts: torch.Tensor, # [n_positive_stories, n_layers, hidden]
|
||||||
|
baseline_acts: torch.Tensor, # [n_baseline_stories, n_layers, hidden]
|
||||||
|
positives_by_emotion: dict[str, list[str]],
|
||||||
|
text_to_row: dict[str, int],
|
||||||
|
per_layer_vectors: torch.Tensor, # [n_layers, n_concepts, hidden], unit-normed
|
||||||
|
target_layers: list[int],
|
||||||
|
model,
|
||||||
|
positive_texts: list[str],
|
||||||
|
text_to_emotion: dict[str, str],
|
||||||
|
) -> dict:
|
||||||
|
"""Per-concept quality metrics:
|
||||||
|
|
||||||
|
- first_pc_variance_ratio: SVD on centered positive activations.
|
||||||
|
>0.7 = rank-1 (clean). <0.4 = fragmented (stories disagree).
|
||||||
|
- story_projection_*: how each positive story projects onto the
|
||||||
|
concept direction. Low std = tight agreement.
|
||||||
|
- best_neuron_cosine: alignment of the residual-space direction with
|
||||||
|
the nearest W_down column (= single MLP neuron). >0.6 = essentially
|
||||||
|
single-neuron.
|
||||||
|
- nearest_concepts: top-5 concept directions most parallel to this
|
||||||
|
one. Cosine >0.8 means the vector is confused with a neighbor.
|
||||||
|
"""
|
||||||
|
report: dict = {}
|
||||||
|
n_layers = per_layer_vectors.shape[0]
|
||||||
|
|
||||||
|
# Pre-compute per-layer W_down for single-neuron alignment.
|
||||||
|
w_down: dict[int, torch.Tensor] = {}
|
||||||
|
for target_l in target_layers:
|
||||||
|
w = _find_mlp_down_proj(model, target_l)
|
||||||
|
if w is not None:
|
||||||
|
# Unit-normalize each column (one per MLP neuron).
|
||||||
|
w = w.to(torch.float32)
|
||||||
|
norms = w.norm(dim=0, keepdim=True).clamp_min(1e-6)
|
||||||
|
w_down[target_l] = w / norms # [hidden, mlp_inner]
|
||||||
|
|
||||||
|
# Pre-compute unit-normed concept vectors (for cross-concept cosines).
|
||||||
|
vec_norm = per_layer_vectors / per_layer_vectors.norm(
|
||||||
|
dim=-1, keepdim=True
|
||||||
|
).clamp_min(1e-6)
|
||||||
|
|
||||||
|
for e_idx, emotion in enumerate(emotions):
|
||||||
|
pos_rows = [text_to_row[t] for t in positives_by_emotion[emotion]]
|
||||||
|
pos = positive_acts[pos_rows].to(torch.float32) # [n_pos, n_layers, hidden]
|
||||||
|
|
||||||
|
per_layer: dict = {}
|
||||||
|
for l_idx, target_l in enumerate(target_layers):
|
||||||
|
pos_l = pos[:, l_idx, :] # [n_pos, hidden]
|
||||||
|
diff_l = per_layer_vectors[l_idx, e_idx] # [hidden], unit-normed
|
||||||
|
pos_mean_l = pos_l.mean(dim=0)
|
||||||
|
|
||||||
|
# SVD for rank analysis — if first PC dominates, stories agree.
|
||||||
|
centered = pos_l - pos_mean_l
|
||||||
|
# svdvals errors on 1-row; handle that.
|
||||||
|
if centered.shape[0] >= 2:
|
||||||
|
S = torch.linalg.svdvals(centered)
|
||||||
|
var = S ** 2
|
||||||
|
var_total = var.sum().clamp_min(1e-12)
|
||||||
|
var_ratios = (var / var_total).tolist()
|
||||||
|
else:
|
||||||
|
var_ratios = [1.0]
|
||||||
|
|
||||||
|
# Per-story projection onto the concept direction.
|
||||||
|
projections = pos_l @ diff_l # [n_pos]
|
||||||
|
|
||||||
|
# Per-story alignment: cosine(story_dir, concept_dir) where
|
||||||
|
# story_dir = pos_i - pos_mean (centered, pointing away from center).
|
||||||
|
if centered.shape[0] >= 2:
|
||||||
|
centered_norm = centered / centered.norm(
|
||||||
|
dim=-1, keepdim=True
|
||||||
|
).clamp_min(1e-6)
|
||||||
|
alignments = centered_norm @ diff_l
|
||||||
|
else:
|
||||||
|
alignments = torch.zeros(1)
|
||||||
|
|
||||||
|
# Single-neuron alignment: is the direction close to any
|
||||||
|
# W_down column?
|
||||||
|
nb_best_idx = None
|
||||||
|
nb_best_cos = None
|
||||||
|
nb_top5 = None
|
||||||
|
if target_l in w_down:
|
||||||
|
W = w_down[target_l]
|
||||||
|
cos = W.t() @ diff_l # [mlp_inner]
|
||||||
|
abs_cos = cos.abs()
|
||||||
|
k = min(5, abs_cos.shape[0])
|
||||||
|
top_vals, top_idxs = abs_cos.topk(k)
|
||||||
|
nb_best_idx = int(top_idxs[0])
|
||||||
|
nb_best_cos = float(cos[top_idxs[0]])
|
||||||
|
nb_top5 = [[int(i), float(cos[i])] for i in top_idxs]
|
||||||
|
|
||||||
|
per_layer[str(target_l)] = {
|
||||||
|
"top3_variance_ratios": [
|
||||||
|
float(v) for v in var_ratios[:3]
|
||||||
|
],
|
||||||
|
"first_pc_variance_ratio": float(var_ratios[0]),
|
||||||
|
"story_projection_mean": float(projections.mean()),
|
||||||
|
"story_projection_std": float(projections.std()),
|
||||||
|
"story_projection_min": float(projections.min()),
|
||||||
|
"story_projection_max": float(projections.max()),
|
||||||
|
"story_alignment_mean": float(alignments.mean()),
|
||||||
|
"story_alignment_std": float(alignments.std()),
|
||||||
|
"best_neuron_idx": nb_best_idx,
|
||||||
|
"best_neuron_cosine": nb_best_cos,
|
||||||
|
"top5_neurons": nb_top5,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Outlier stories: lowest-aligned on the middle target layer.
|
||||||
|
mid = n_layers // 2
|
||||||
|
pos_l_mid = pos[:, mid, :]
|
||||||
|
mid_mean = pos_l_mid.mean(dim=0)
|
||||||
|
mid_diff = per_layer_vectors[mid, e_idx]
|
||||||
|
centered_mid = pos_l_mid - mid_mean
|
||||||
|
if centered_mid.shape[0] >= 2:
|
||||||
|
centered_mid_norm = centered_mid / centered_mid.norm(
|
||||||
|
dim=-1, keepdim=True
|
||||||
|
).clamp_min(1e-6)
|
||||||
|
mid_aligns = centered_mid_norm @ mid_diff # [n_pos]
|
||||||
|
# Lowest two alignments = candidate outliers.
|
||||||
|
k = min(2, mid_aligns.shape[0])
|
||||||
|
low_vals, low_idxs = mid_aligns.topk(k, largest=False)
|
||||||
|
outliers = [
|
||||||
|
[
|
||||||
|
positives_by_emotion[emotion][int(i)],
|
||||||
|
float(mid_aligns[i]),
|
||||||
|
]
|
||||||
|
for i in low_idxs
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
outliers = []
|
||||||
|
|
||||||
|
# Nearest other concepts at the middle target layer.
|
||||||
|
this_norm = vec_norm[mid, e_idx]
|
||||||
|
all_cos = vec_norm[mid] @ this_norm # [n_concepts]
|
||||||
|
all_cos[e_idx] = -2.0 # mask self
|
||||||
|
k = min(5, all_cos.shape[0] - 1)
|
||||||
|
top_vals, top_idxs = all_cos.topk(k)
|
||||||
|
nearest = [
|
||||||
|
[emotions[int(i)], float(v)]
|
||||||
|
for i, v in zip(top_idxs, top_vals)
|
||||||
|
]
|
||||||
|
|
||||||
|
report[emotion] = {
|
||||||
|
"n_positive_stories": len(pos_rows),
|
||||||
|
"per_layer": per_layer,
|
||||||
|
"outlier_stories": outliers,
|
||||||
|
"nearest_concepts": nearest,
|
||||||
|
}
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
ap = argparse.ArgumentParser(description=__doc__)
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
ap.add_argument("--model", required=True, help="HF model id or path")
|
ap.add_argument("--model", required=True, help="HF model id or path")
|
||||||
|
|
@ -249,6 +425,13 @@ def main() -> None:
|
||||||
default=1,
|
default=1,
|
||||||
help="Skip emotions with fewer positive examples than this",
|
help="Skip emotions with fewer positive examples than this",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--quality-report",
|
||||||
|
action="store_true",
|
||||||
|
help="After training, compute a per-concept quality report "
|
||||||
|
"(SVD rank, per-story alignment, single-neuron alignment, "
|
||||||
|
"nearest-concept contamination) and write quality.json",
|
||||||
|
)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
target_layers = [int(x) for x in args.target_layers.split(",")]
|
target_layers = [int(x) for x in args.target_layers.split(",")]
|
||||||
|
|
@ -445,6 +628,59 @@ def main() -> None:
|
||||||
f" {n_concepts} concepts x {n_layers} layers x "
|
f" {n_concepts} concepts x {n_layers} layers x "
|
||||||
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
|
f"{hidden_dim} dim (fp16), total {total_mb:.1f} MiB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.quality_report:
|
||||||
|
print("\nComputing quality report...")
|
||||||
|
report = _compute_quality_report(
|
||||||
|
emotions=emotions,
|
||||||
|
positive_acts=positive_acts,
|
||||||
|
baseline_acts=baseline_acts,
|
||||||
|
positives_by_emotion=positives_by_emotion,
|
||||||
|
text_to_row=text_to_row,
|
||||||
|
per_layer_vectors=per_layer_vectors,
|
||||||
|
target_layers=target_layers,
|
||||||
|
model=model,
|
||||||
|
positive_texts=unique_positive_texts,
|
||||||
|
text_to_emotion=text_to_emotion,
|
||||||
|
)
|
||||||
|
(output_dir / "quality.json").write_text(
|
||||||
|
json.dumps(report, indent=2) + "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Short summary: concepts in each triage bucket.
|
||||||
|
clean_single_neuron = []
|
||||||
|
clean_circuit = []
|
||||||
|
fragmented = []
|
||||||
|
contaminated = []
|
||||||
|
mid = n_layers // 2
|
||||||
|
mid_layer = target_layers[mid]
|
||||||
|
for emotion in emotions:
|
||||||
|
per_l = report[emotion]["per_layer"][str(mid_layer)]
|
||||||
|
v = per_l["first_pc_variance_ratio"]
|
||||||
|
nb = per_l.get("best_neuron_cosine") or 0.0
|
||||||
|
top_near = report[emotion]["nearest_concepts"]
|
||||||
|
nearest_cos = top_near[0][1] if top_near else 0.0
|
||||||
|
if nearest_cos > 0.8:
|
||||||
|
contaminated.append(emotion)
|
||||||
|
elif v > 0.7 and abs(nb) > 0.6:
|
||||||
|
clean_single_neuron.append(emotion)
|
||||||
|
elif v > 0.7:
|
||||||
|
clean_circuit.append(emotion)
|
||||||
|
elif v < 0.4:
|
||||||
|
fragmented.append(emotion)
|
||||||
|
print(
|
||||||
|
f"\nQuality summary @ layer {mid_layer}:\n"
|
||||||
|
f" clean (single-neuron): {len(clean_single_neuron)}\n"
|
||||||
|
f" clean (low-dim circuit): {len(clean_circuit)}\n"
|
||||||
|
f" fragmented (first-PC < 0.4): {len(fragmented)}\n"
|
||||||
|
f" contaminated (nearest > 0.8): {len(contaminated)}"
|
||||||
|
)
|
||||||
|
if fragmented:
|
||||||
|
print(f" fragmented sample: {fragmented[:5]}")
|
||||||
|
if contaminated:
|
||||||
|
print(f" contaminated sample: {contaminated[:5]}")
|
||||||
|
print(f"\nWrote quality.json to {output_dir}")
|
||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue