consciousness/sa-schedule-topblock-swap.py
Kent Overstreet 4225294d16 replace try_lock() with lock_blocking() across UI thread
Add lock_blocking() to TrackedMutex: blocks current thread using
block_in_place + futures::executor::block_on, safe for sync contexts.

Replace all try_lock() calls with lock_blocking() in slash commands,
UI rendering, and status reads. Lock hold times are fast enough that
blocking briefly is fine, and this eliminates the spurious 'lock
unavailable' paths that were never actually hit.

Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
2026-04-25 15:35:14 -04:00

498 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Top-block replacement experiment: test SA-schedule hypothesis by
replacing the last 8 layers of Qwen3-4B with variants that progressively
strip out the learned schedule / specialization.
Variants:
baseline — unmodified reference (PPL sanity check)
schedule_fit — replace input_ln.γ magnitude in top block with
fitted Kirkpatrick γ(L) = 3.53·exp(0.119·L). Directions
preserved, projection weights untouched.
single_op — use layer 35's projection weights for ALL top-block
layers (strip specialization), combined with the fitted
schedule γ(L). Tests if per-layer specialization in top
block is load-bearing or replaceable by schedule.
uniform_gamma — set all top-block input_ln.γ magnitudes to the middle
layer's value (no schedule at all in top block). Tests
necessity of schedule itself.
Eval: perplexity on a concatenation of calibration prompts + a short
excerpt. Also generation quality on a handful of diagnostic prompts.
"""
import argparse
import math
import os
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
# From sa-schedule-fit-gamma.py on Qwen3-4B null-residual data:
# input_ln.γ magnitude ≈ 3.53 · exp(0.119 · L), R² = 0.95
# Defaults for 4B. Override via env SCHEDULE_A / SCHEDULE_B for other models.
# 32B fit: a=1.02, b=0.0873
SCHEDULE_A = float(os.environ.get("SCHEDULE_A", "3.53")) if "SCHEDULE_A" in os.environ else 3.53
SCHEDULE_B = float(os.environ.get("SCHEDULE_B", "0.1191")) if "SCHEDULE_B" in os.environ else 0.1191
BLOCK_START = int(os.environ.get("BLOCK_START", 28))
BLOCK_END = int(os.environ.get("BLOCK_END", 35))
# Optional: comma-separated "s1-e1,s2-e2,..." blocks for multi-block merge
BLOCKS_ENV = os.environ.get("BLOCKS", "")
if BLOCKS_ENV:
BLOCKS = [tuple(int(x) for x in p.split("-")) for p in BLOCKS_ENV.split(",")]
else:
BLOCKS = [(BLOCK_START, BLOCK_END)]
CALIB = [
"The Eiffel Tower is located in",
"Photosynthesis is the process by which",
"The three branches of the US government are the legislative, executive, and",
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
"Solve for x: 3x + 7 = 22. The answer is x =",
"The derivative of x^3 + 2x^2 is",
"def fibonacci(n):\n if n < 2:\n return n\n return",
"# Python list comprehension to square even numbers in 0-9\nresult = ",
"SELECT name, age FROM users WHERE",
"She opened the old wooden box and found",
"The argument in favor of renewable energy is",
"User: What is the capital of Australia?\nAssistant:",
"Write a haiku about autumn:\n",
"Albert Einstein was born in the year",
"The speed of light in vacuum is approximately",
"I really loved that movie because",
"The main difference between a virus and a bacterium is",
"The French word for 'apple' is",
"1 + 1 = ",
"Once upon a time, in a land far away,",
"The key insight of general relativity is that gravity is not a force but",
"Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher",
"In object-oriented programming, encapsulation refers to",
"The mitochondria is often called the powerhouse of the cell because it",
"Shakespeare's Hamlet begins with the famous line",
]
GEN_PROMPTS = [
"The capital of France is",
"2 + 2 =",
"def reverse_string(s):\n return",
"Albert Einstein developed the theory of",
]
def load_model(name=None):
if name is None:
name = os.environ.get("MODEL", "Qwen/Qwen3-4B")
print(f"Loading {name}...", flush=True)
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True, attn_implementation="eager",
)
m.eval()
return m, tok
def _merge_block(model, block_start, block_end):
"""Arithmetic-mean merge projections in [block_start, block_end]; set γ per schedule."""
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
param_names = [
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
]
merged = {}
for name, getter in param_names:
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
for l in layers:
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
for L in range(block_start, block_end + 1):
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
gamma = model.model.layers[L].input_layernorm.weight.data
gamma.mul_(predicted / gamma.norm().item())
def _procrustes(M):
"""Orthogonal R = U V^T maximizing tr(R M) where M = U Σ V^T."""
U, _, Vh = torch.linalg.svd(M.float(), full_matrices=False)
return U @ Vh
def _aligned_merge_block(model, block_start, block_end, align_ff=False):
"""Procrustes-align per-head d_h basis (and optionally d_ff) of each
layer in [block_start, block_end] to a reference (middle), then
arithmetic-mean. Attention rotation is a true gauge; FF rotation is
not (SiLU breaks it) — align_ff defaults off."""
cfg = model.config
num_heads = cfg.num_attention_heads
num_kv = getattr(cfg, "num_key_value_heads", num_heads)
hidden = cfg.hidden_size
d_h = getattr(cfg, "head_dim", hidden // num_heads)
ref_L = (block_start + block_end) // 2
ref = model.model.layers[ref_L]
dev = ref.self_attn.q_proj.weight.device
dtype = ref.self_attn.q_proj.weight.dtype
# Reference views, fp32 on device
Qr = ref.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
Kr = ref.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
Vr = ref.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
Or = ref.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
if align_ff:
d_ff = cfg.intermediate_size
Gr = ref.mlp.gate_proj.weight.data.float()
Ur = ref.mlp.up_proj.weight.data.float()
Dr = ref.mlp.down_proj.weight.data.float()
rotated = []
for L in range(block_start, block_end + 1):
layer = model.model.layers[L]
Q = layer.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
K = layer.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
V = layer.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
O = layer.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
if L == ref_L:
Q_new, K_new, V_new, O_new = Q.clone(), K.clone(), V.clone(), O.clone()
else:
Q_new = torch.empty_like(Q)
K_new = torch.empty_like(K)
V_new = torch.empty_like(V)
O_new = torch.empty_like(O)
for h in range(num_heads):
kv_h = (h * num_kv) // num_heads
# Cross-correlation: want R s.t. R @ Q ≈ Qr (row-space align).
# For per-head (d_h, hidden): M = Qr @ Q.T + Kr @ K.T + Vr @ V.T + Or^T @ O
# (Or, O are (hidden, d_h) per head)
M = (Qr[h] @ Q[h].T
+ Kr[kv_h] @ K[kv_h].T
+ Vr[kv_h] @ V[kv_h].T
+ Or[h].T @ O[h])
R = _procrustes(M)
Q_new[h] = R @ Q[h]
K_new[kv_h] = R @ K[kv_h]
V_new[kv_h] = R @ V[kv_h]
O_new[h] = O[h] @ R.T
rotated.append({
"q": Q_new.reshape(num_heads * d_h, hidden),
"k": K_new.reshape(num_kv * d_h, hidden),
"v": V_new.reshape(num_kv * d_h, hidden),
"o": O_new.permute(1, 0, 2).reshape(hidden, num_heads * d_h),
})
# Average rotated attention
q_avg = torch.stack([r["q"] for r in rotated]).mean(0).to(dtype)
k_avg = torch.stack([r["k"] for r in rotated]).mean(0).to(dtype)
v_avg = torch.stack([r["v"] for r in rotated]).mean(0).to(dtype)
o_avg = torch.stack([r["o"] for r in rotated]).mean(0).to(dtype)
# FF: naive mean (rotation gauge is fake through SiLU)
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
gate_avg = torch.stack([l.mlp.gate_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
up_avg = torch.stack([l.mlp.up_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
down_avg = torch.stack([l.mlp.down_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
# q_norm/k_norm γ: copy from reference (they're basis-dependent; no clean average in rotated frame)
ref_qn = ref.self_attn.q_norm.weight.data.clone() if getattr(ref.self_attn, "q_norm", None) is not None else None
ref_kn = ref.self_attn.k_norm.weight.data.clone() if getattr(ref.self_attn, "k_norm", None) is not None else None
for l in layers:
l.self_attn.q_proj.weight.data.copy_(q_avg)
l.self_attn.k_proj.weight.data.copy_(k_avg)
l.self_attn.v_proj.weight.data.copy_(v_avg)
l.self_attn.o_proj.weight.data.copy_(o_avg)
l.mlp.gate_proj.weight.data.copy_(gate_avg)
l.mlp.up_proj.weight.data.copy_(up_avg)
l.mlp.down_proj.weight.data.copy_(down_avg)
if ref_qn is not None and getattr(l.self_attn, "q_norm", None) is not None:
l.self_attn.q_norm.weight.data.copy_(ref_qn)
if ref_kn is not None and getattr(l.self_attn, "k_norm", None) is not None:
l.self_attn.k_norm.weight.data.copy_(ref_kn)
for L in range(block_start, block_end + 1):
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
gamma = model.model.layers[L].input_layernorm.weight.data
gamma.mul_(predicted / gamma.norm().item())
def apply_variant(model, variant):
"""Modify model in place according to variant."""
if variant == "baseline":
return
if variant == "schedule_fit":
for L in range(BLOCK_START, BLOCK_END + 1):
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
layer = model.model.layers[L]
gamma = layer.input_layernorm.weight.data
cur_norm = gamma.norm().item()
# Preserve direction, scale to predicted magnitude
gamma.mul_(predicted / cur_norm)
elif variant == "single_op":
# Use middle-of-block as reference, not end (more representative)
ref_L = (BLOCK_START + BLOCK_END) // 2
ref = model.model.layers[ref_L]
for L in range(BLOCK_START, BLOCK_END + 1):
if L == ref_L:
continue
tgt = model.model.layers[L]
tgt.self_attn.q_proj.weight.data.copy_(ref.self_attn.q_proj.weight.data)
tgt.self_attn.k_proj.weight.data.copy_(ref.self_attn.k_proj.weight.data)
tgt.self_attn.v_proj.weight.data.copy_(ref.self_attn.v_proj.weight.data)
tgt.self_attn.o_proj.weight.data.copy_(ref.self_attn.o_proj.weight.data)
tgt.mlp.gate_proj.weight.data.copy_(ref.mlp.gate_proj.weight.data)
tgt.mlp.up_proj.weight.data.copy_(ref.mlp.up_proj.weight.data)
tgt.mlp.down_proj.weight.data.copy_(ref.mlp.down_proj.weight.data)
# q_norm, k_norm: copy too
if hasattr(tgt.self_attn, "q_norm") and tgt.self_attn.q_norm is not None:
tgt.self_attn.q_norm.weight.data.copy_(ref.self_attn.q_norm.weight.data)
if hasattr(tgt.self_attn, "k_norm") and tgt.self_attn.k_norm is not None:
tgt.self_attn.k_norm.weight.data.copy_(ref.self_attn.k_norm.weight.data)
# Keep each layer's OWN input_ln.γ direction but set magnitude to schedule
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
gamma = tgt.input_layernorm.weight.data
gamma.mul_(predicted / gamma.norm().item())
# post_attn_ln γ: leave as-is for now (could also fit & set)
elif variant == "ties_op":
# TIES-Merging (Yadav et al. 2023): trim, elect-sign, disjoint merge.
# Operates per parameter family across the N block layers.
density = float(os.environ.get("TIES_DENSITY", "0.2"))
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
param_names = [
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
]
def ties_merge(tensors, density):
# tensors: list of (out, in) float tensors, same shape
stack = torch.stack([t.float() for t in tensors], dim=0) # (N, out, in)
# --- Step 1: Trim to top-density fraction per tensor ---
n = stack.shape[0]
flat = stack.view(n, -1)
k = int(flat.shape[1] * density)
abs_flat = flat.abs()
# Find magnitude threshold per tensor at top-k
topk_vals, _ = abs_flat.topk(k=k, dim=1)
threshold = topk_vals[:, -1:].expand_as(abs_flat)
mask = abs_flat >= threshold
trimmed = (flat * mask.float()).view_as(stack)
# --- Step 2: Elect sign (majority by total magnitude) ---
mag_per_sign = trimmed.sum(dim=0) # (out, in), signed sum
elected = torch.sign(mag_per_sign) # +1/-1/0
# --- Step 3: Disjoint merge (average params agreeing with elected sign) ---
agree = (torch.sign(trimmed) == elected.unsqueeze(0)).float()
contributing_count = agree.sum(dim=0).clamp_min(1)
merged_sum = (trimmed * agree).sum(dim=0)
merged = merged_sum / contributing_count
return merged
merged = {}
for name, getter in param_names:
tensors = [getter(l).data for l in layers]
merged[name] = ties_merge(tensors, density).to(getter(layers[0]).data.dtype)
for l in layers:
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
for L in range(BLOCK_START, BLOCK_END + 1):
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
gamma = model.model.layers[L].input_layernorm.weight.data
gamma.mul_(predicted / gamma.norm().item())
elif variant == "merged_op":
# Arithmetic mean, for each block in BLOCKS (can be multiple)
for (bs, be) in BLOCKS:
_merge_block(model, bs, be)
return
elif variant == "aligned_merged_op":
# Procrustes-align per-head d_h basis to block-middle, then mean.
# FF averaged naively (SiLU breaks rotation gauge for FF).
for (bs, be) in BLOCKS:
_aligned_merge_block(model, bs, be, align_ff=False)
return
elif variant == "flat_merged_op":
# Mean projections AND flatten γ across block. Everything in block
# becomes N copies of the same operator. If block is truly high-T
# diffusion, PPL should match merged_op (schedule is gauge, not
# load-bearing). If schedule helps, flattening γ will hurt.
for (bs, be) in BLOCKS:
layers = [model.model.layers[L] for L in range(bs, be + 1)]
param_names = [
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
]
merged = {}
for name, getter in param_names:
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
gamma_mean = torch.stack([l.input_layernorm.weight.data.float()
for l in layers]).mean(0).to(layers[0].input_layernorm.weight.data.dtype)
post_attn_mean = torch.stack([l.post_attention_layernorm.weight.data.float()
for l in layers]).mean(0).to(layers[0].post_attention_layernorm.weight.data.dtype)
for l in layers:
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
l.input_layernorm.weight.data.copy_(gamma_mean)
l.post_attention_layernorm.weight.data.copy_(post_attn_mean)
return
elif variant == "reverse_order":
# Reverse the order of layers within each block to test whether
# the block implements a trajectory (order-dependent) or iid
# diffusion (order-free).
import torch.nn as nn
layers_list = list(model.model.layers)
for (bs, be) in BLOCKS:
rev = layers_list[bs:be + 1][::-1]
layers_list[bs:be + 1] = rev
model.model.layers = nn.ModuleList(layers_list)
# Re-set layer_idx on each layer so attention/cache uses the
# current position, not the original one.
for i, l in enumerate(model.model.layers):
if hasattr(l, "self_attn") and hasattr(l.self_attn, "layer_idx"):
l.self_attn.layer_idx = i
return
elif variant == "merged_op_OLD_UNREACHABLE":
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
n = len(layers)
param_names = [
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
]
merged = {}
for name, getter in param_names:
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
for l in layers:
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
# Set γ to scheduled values per layer
for L in range(BLOCK_START, BLOCK_END + 1):
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
gamma = model.model.layers[L].input_layernorm.weight.data
gamma.mul_(predicted / gamma.norm().item())
elif variant == "uniform_gamma":
mid_L = (BLOCK_START + BLOCK_END) // 2
mid_gamma = model.model.layers[mid_L].input_layernorm.weight.data.clone()
for L in range(BLOCK_START, BLOCK_END + 1):
model.model.layers[L].input_layernorm.weight.data.copy_(mid_gamma)
else:
raise ValueError(f"Unknown variant {variant}")
@torch.no_grad()
def perplexity(model, tok, texts, max_len=512):
total_nll = 0.0
total_tok = 0
for text in texts:
enc = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to("cuda")
if enc.input_ids.shape[1] < 2:
continue
out = model(**enc, labels=enc.input_ids)
n = enc.input_ids.shape[1] - 1
total_nll += float(out.loss.item()) * n
total_tok += n
return math.exp(total_nll / max(total_tok, 1))
@torch.no_grad()
def generate_sample(model, tok, prompt, max_new=40):
enc = tok(prompt, return_tensors="pt").to("cuda")
out = model.generate(**enc, max_new_tokens=max_new, do_sample=False,
pad_token_id=tok.eos_token_id)
return tok.decode(out[0], skip_special_tokens=True)
def run_variant(variant):
model, tok = load_model()
apply_variant(model, variant)
print(f"\n=== variant: {variant} ===", flush=True)
ppl = perplexity(model, tok, CALIB)
print(f" perplexity: {ppl:.3f}", flush=True)
for p in GEN_PROMPTS:
out = generate_sample(model, tok, p)
print(f" [{p!r}] -> {out[:200]!r}", flush=True)
del model
torch.cuda.empty_cache()
return ppl
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--variant", default="all",
choices=["all", "baseline", "schedule_fit",
"single_op", "uniform_gamma", "merged_op",
"aligned_merged_op", "flat_merged_op",
"reverse_order", "ties_op"])
ap.add_argument("--ties-density", type=float, default=0.2,
help="TIES trim density (fraction of top-magnitude params to keep)")
args = ap.parse_args()
variants = (["baseline", "schedule_fit", "single_op", "uniform_gamma"]
if args.variant == "all" else [args.variant])
results = {}
for v in variants:
results[v] = run_variant(v)
if len(results) > 1:
print("\n=== Summary ===")
b = results.get("baseline", None)
for v, ppl in results.items():
rel = f" (×{ppl/b:.2f} baseline)" if b else ""
print(f" {v:<15} PPL {ppl:>8.3f}{rel}")
if __name__ == "__main__":
main()