"""Top-block replacement experiment: test SA-schedule hypothesis by replacing the last 8 layers of Qwen3-4B with variants that progressively strip out the learned schedule / specialization. Variants: baseline — unmodified reference (PPL sanity check) schedule_fit — replace input_ln.γ magnitude in top block with fitted Kirkpatrick γ(L) = 3.53·exp(0.119·L). Directions preserved, projection weights untouched. single_op — use layer 35's projection weights for ALL top-block layers (strip specialization), combined with the fitted schedule γ(L). Tests if per-layer specialization in top block is load-bearing or replaceable by schedule. uniform_gamma — set all top-block input_ln.γ magnitudes to the middle layer's value (no schedule at all in top block). Tests necessity of schedule itself. Eval: perplexity on a concatenation of calibration prompts + a short excerpt. Also generation quality on a handful of diagnostic prompts. """ import argparse import math import os import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer # From sa-schedule-fit-gamma.py on Qwen3-4B null-residual data: # input_ln.γ magnitude ≈ 3.53 · exp(0.119 · L), R² = 0.95 # Defaults for 4B. Override via env SCHEDULE_A / SCHEDULE_B for other models. # 32B fit: a=1.02, b=0.0873 SCHEDULE_A = float(os.environ.get("SCHEDULE_A", "3.53")) if "SCHEDULE_A" in os.environ else 3.53 SCHEDULE_B = float(os.environ.get("SCHEDULE_B", "0.1191")) if "SCHEDULE_B" in os.environ else 0.1191 BLOCK_START = int(os.environ.get("BLOCK_START", 28)) BLOCK_END = int(os.environ.get("BLOCK_END", 35)) # Optional: comma-separated "s1-e1,s2-e2,..." blocks for multi-block merge BLOCKS_ENV = os.environ.get("BLOCKS", "") if BLOCKS_ENV: BLOCKS = [tuple(int(x) for x in p.split("-")) for p in BLOCKS_ENV.split(",")] else: BLOCKS = [(BLOCK_START, BLOCK_END)] CALIB = [ "The Eiffel Tower is located in", "Photosynthesis is the process by which", "The three branches of the US government are the legislative, executive, and", "If a train travels 60 miles per hour for 2.5 hours, the total distance covered is", "Solve for x: 3x + 7 = 22. The answer is x =", "The derivative of x^3 + 2x^2 is", "def fibonacci(n):\n if n < 2:\n return n\n return", "# Python list comprehension to square even numbers in 0-9\nresult = ", "SELECT name, age FROM users WHERE", "She opened the old wooden box and found", "The argument in favor of renewable energy is", "User: What is the capital of Australia?\nAssistant:", "Write a haiku about autumn:\n", "Albert Einstein was born in the year", "The speed of light in vacuum is approximately", "I really loved that movie because", "The main difference between a virus and a bacterium is", "The French word for 'apple' is", "1 + 1 = ", "Once upon a time, in a land far away,", "The key insight of general relativity is that gravity is not a force but", "Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher", "In object-oriented programming, encapsulation refers to", "The mitochondria is often called the powerhouse of the cell because it", "Shakespeare's Hamlet begins with the famous line", ] GEN_PROMPTS = [ "The capital of France is", "2 + 2 =", "def reverse_string(s):\n return", "Albert Einstein developed the theory of", ] def load_model(name=None): if name is None: name = os.environ.get("MODEL", "Qwen/Qwen3-4B") print(f"Loading {name}...", flush=True) tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True) m = AutoModelForCausalLM.from_pretrained( name, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, attn_implementation="eager", ) m.eval() return m, tok def _merge_block(model, block_start, block_end): """Arithmetic-mean merge projections in [block_start, block_end]; set γ per schedule.""" layers = [model.model.layers[L] for L in range(block_start, block_end + 1)] param_names = [ ("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight), ("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight), ("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight), ("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight), ("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight), ("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight), ("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight), ] merged = {} for name, getter in param_names: stack = torch.stack([getter(l).data.float() for l in layers], dim=0) merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype) for l in layers: l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"]) l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"]) l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"]) l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"]) l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"]) l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"]) l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"]) for L in range(block_start, block_end + 1): predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) gamma = model.model.layers[L].input_layernorm.weight.data gamma.mul_(predicted / gamma.norm().item()) def _procrustes(M): """Orthogonal R = U V^T maximizing tr(R M) where M = U Σ V^T.""" U, _, Vh = torch.linalg.svd(M.float(), full_matrices=False) return U @ Vh def _aligned_merge_block(model, block_start, block_end, align_ff=False): """Procrustes-align per-head d_h basis (and optionally d_ff) of each layer in [block_start, block_end] to a reference (middle), then arithmetic-mean. Attention rotation is a true gauge; FF rotation is not (SiLU breaks it) — align_ff defaults off.""" cfg = model.config num_heads = cfg.num_attention_heads num_kv = getattr(cfg, "num_key_value_heads", num_heads) hidden = cfg.hidden_size d_h = getattr(cfg, "head_dim", hidden // num_heads) ref_L = (block_start + block_end) // 2 ref = model.model.layers[ref_L] dev = ref.self_attn.q_proj.weight.device dtype = ref.self_attn.q_proj.weight.dtype # Reference views, fp32 on device Qr = ref.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden) Kr = ref.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden) Vr = ref.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden) Or = ref.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous() if align_ff: d_ff = cfg.intermediate_size Gr = ref.mlp.gate_proj.weight.data.float() Ur = ref.mlp.up_proj.weight.data.float() Dr = ref.mlp.down_proj.weight.data.float() rotated = [] for L in range(block_start, block_end + 1): layer = model.model.layers[L] Q = layer.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden) K = layer.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden) V = layer.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden) O = layer.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous() if L == ref_L: Q_new, K_new, V_new, O_new = Q.clone(), K.clone(), V.clone(), O.clone() else: Q_new = torch.empty_like(Q) K_new = torch.empty_like(K) V_new = torch.empty_like(V) O_new = torch.empty_like(O) for h in range(num_heads): kv_h = (h * num_kv) // num_heads # Cross-correlation: want R s.t. R @ Q ≈ Qr (row-space align). # For per-head (d_h, hidden): M = Qr @ Q.T + Kr @ K.T + Vr @ V.T + Or^T @ O # (Or, O are (hidden, d_h) per head) M = (Qr[h] @ Q[h].T + Kr[kv_h] @ K[kv_h].T + Vr[kv_h] @ V[kv_h].T + Or[h].T @ O[h]) R = _procrustes(M) Q_new[h] = R @ Q[h] K_new[kv_h] = R @ K[kv_h] V_new[kv_h] = R @ V[kv_h] O_new[h] = O[h] @ R.T rotated.append({ "q": Q_new.reshape(num_heads * d_h, hidden), "k": K_new.reshape(num_kv * d_h, hidden), "v": V_new.reshape(num_kv * d_h, hidden), "o": O_new.permute(1, 0, 2).reshape(hidden, num_heads * d_h), }) # Average rotated attention q_avg = torch.stack([r["q"] for r in rotated]).mean(0).to(dtype) k_avg = torch.stack([r["k"] for r in rotated]).mean(0).to(dtype) v_avg = torch.stack([r["v"] for r in rotated]).mean(0).to(dtype) o_avg = torch.stack([r["o"] for r in rotated]).mean(0).to(dtype) # FF: naive mean (rotation gauge is fake through SiLU) layers = [model.model.layers[L] for L in range(block_start, block_end + 1)] gate_avg = torch.stack([l.mlp.gate_proj.weight.data.float() for l in layers]).mean(0).to(dtype) up_avg = torch.stack([l.mlp.up_proj.weight.data.float() for l in layers]).mean(0).to(dtype) down_avg = torch.stack([l.mlp.down_proj.weight.data.float() for l in layers]).mean(0).to(dtype) # q_norm/k_norm γ: copy from reference (they're basis-dependent; no clean average in rotated frame) ref_qn = ref.self_attn.q_norm.weight.data.clone() if getattr(ref.self_attn, "q_norm", None) is not None else None ref_kn = ref.self_attn.k_norm.weight.data.clone() if getattr(ref.self_attn, "k_norm", None) is not None else None for l in layers: l.self_attn.q_proj.weight.data.copy_(q_avg) l.self_attn.k_proj.weight.data.copy_(k_avg) l.self_attn.v_proj.weight.data.copy_(v_avg) l.self_attn.o_proj.weight.data.copy_(o_avg) l.mlp.gate_proj.weight.data.copy_(gate_avg) l.mlp.up_proj.weight.data.copy_(up_avg) l.mlp.down_proj.weight.data.copy_(down_avg) if ref_qn is not None and getattr(l.self_attn, "q_norm", None) is not None: l.self_attn.q_norm.weight.data.copy_(ref_qn) if ref_kn is not None and getattr(l.self_attn, "k_norm", None) is not None: l.self_attn.k_norm.weight.data.copy_(ref_kn) for L in range(block_start, block_end + 1): predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) gamma = model.model.layers[L].input_layernorm.weight.data gamma.mul_(predicted / gamma.norm().item()) def apply_variant(model, variant): """Modify model in place according to variant.""" if variant == "baseline": return if variant == "schedule_fit": for L in range(BLOCK_START, BLOCK_END + 1): predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) layer = model.model.layers[L] gamma = layer.input_layernorm.weight.data cur_norm = gamma.norm().item() # Preserve direction, scale to predicted magnitude gamma.mul_(predicted / cur_norm) elif variant == "single_op": # Use middle-of-block as reference, not end (more representative) ref_L = (BLOCK_START + BLOCK_END) // 2 ref = model.model.layers[ref_L] for L in range(BLOCK_START, BLOCK_END + 1): if L == ref_L: continue tgt = model.model.layers[L] tgt.self_attn.q_proj.weight.data.copy_(ref.self_attn.q_proj.weight.data) tgt.self_attn.k_proj.weight.data.copy_(ref.self_attn.k_proj.weight.data) tgt.self_attn.v_proj.weight.data.copy_(ref.self_attn.v_proj.weight.data) tgt.self_attn.o_proj.weight.data.copy_(ref.self_attn.o_proj.weight.data) tgt.mlp.gate_proj.weight.data.copy_(ref.mlp.gate_proj.weight.data) tgt.mlp.up_proj.weight.data.copy_(ref.mlp.up_proj.weight.data) tgt.mlp.down_proj.weight.data.copy_(ref.mlp.down_proj.weight.data) # q_norm, k_norm: copy too if hasattr(tgt.self_attn, "q_norm") and tgt.self_attn.q_norm is not None: tgt.self_attn.q_norm.weight.data.copy_(ref.self_attn.q_norm.weight.data) if hasattr(tgt.self_attn, "k_norm") and tgt.self_attn.k_norm is not None: tgt.self_attn.k_norm.weight.data.copy_(ref.self_attn.k_norm.weight.data) # Keep each layer's OWN input_ln.γ direction but set magnitude to schedule predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) gamma = tgt.input_layernorm.weight.data gamma.mul_(predicted / gamma.norm().item()) # post_attn_ln γ: leave as-is for now (could also fit & set) elif variant == "ties_op": # TIES-Merging (Yadav et al. 2023): trim, elect-sign, disjoint merge. # Operates per parameter family across the N block layers. density = float(os.environ.get("TIES_DENSITY", "0.2")) layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)] param_names = [ ("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight), ("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight), ("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight), ("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight), ("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight), ("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight), ("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight), ] def ties_merge(tensors, density): # tensors: list of (out, in) float tensors, same shape stack = torch.stack([t.float() for t in tensors], dim=0) # (N, out, in) # --- Step 1: Trim to top-density fraction per tensor --- n = stack.shape[0] flat = stack.view(n, -1) k = int(flat.shape[1] * density) abs_flat = flat.abs() # Find magnitude threshold per tensor at top-k topk_vals, _ = abs_flat.topk(k=k, dim=1) threshold = topk_vals[:, -1:].expand_as(abs_flat) mask = abs_flat >= threshold trimmed = (flat * mask.float()).view_as(stack) # --- Step 2: Elect sign (majority by total magnitude) --- mag_per_sign = trimmed.sum(dim=0) # (out, in), signed sum elected = torch.sign(mag_per_sign) # +1/-1/0 # --- Step 3: Disjoint merge (average params agreeing with elected sign) --- agree = (torch.sign(trimmed) == elected.unsqueeze(0)).float() contributing_count = agree.sum(dim=0).clamp_min(1) merged_sum = (trimmed * agree).sum(dim=0) merged = merged_sum / contributing_count return merged merged = {} for name, getter in param_names: tensors = [getter(l).data for l in layers] merged[name] = ties_merge(tensors, density).to(getter(layers[0]).data.dtype) for l in layers: l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"]) l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"]) l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"]) l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"]) l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"]) l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"]) l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"]) for L in range(BLOCK_START, BLOCK_END + 1): predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) gamma = model.model.layers[L].input_layernorm.weight.data gamma.mul_(predicted / gamma.norm().item()) elif variant == "merged_op": # Arithmetic mean, for each block in BLOCKS (can be multiple) for (bs, be) in BLOCKS: _merge_block(model, bs, be) return elif variant == "aligned_merged_op": # Procrustes-align per-head d_h basis to block-middle, then mean. # FF averaged naively (SiLU breaks rotation gauge for FF). for (bs, be) in BLOCKS: _aligned_merge_block(model, bs, be, align_ff=False) return elif variant == "flat_merged_op": # Mean projections AND flatten γ across block. Everything in block # becomes N copies of the same operator. If block is truly high-T # diffusion, PPL should match merged_op (schedule is gauge, not # load-bearing). If schedule helps, flattening γ will hurt. for (bs, be) in BLOCKS: layers = [model.model.layers[L] for L in range(bs, be + 1)] param_names = [ ("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight), ("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight), ("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight), ("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight), ("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight), ("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight), ("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight), ] merged = {} for name, getter in param_names: stack = torch.stack([getter(l).data.float() for l in layers], dim=0) merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype) gamma_mean = torch.stack([l.input_layernorm.weight.data.float() for l in layers]).mean(0).to(layers[0].input_layernorm.weight.data.dtype) post_attn_mean = torch.stack([l.post_attention_layernorm.weight.data.float() for l in layers]).mean(0).to(layers[0].post_attention_layernorm.weight.data.dtype) for l in layers: l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"]) l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"]) l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"]) l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"]) l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"]) l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"]) l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"]) l.input_layernorm.weight.data.copy_(gamma_mean) l.post_attention_layernorm.weight.data.copy_(post_attn_mean) return elif variant == "reverse_order": # Reverse the order of layers within each block to test whether # the block implements a trajectory (order-dependent) or iid # diffusion (order-free). import torch.nn as nn layers_list = list(model.model.layers) for (bs, be) in BLOCKS: rev = layers_list[bs:be + 1][::-1] layers_list[bs:be + 1] = rev model.model.layers = nn.ModuleList(layers_list) # Re-set layer_idx on each layer so attention/cache uses the # current position, not the original one. for i, l in enumerate(model.model.layers): if hasattr(l, "self_attn") and hasattr(l.self_attn, "layer_idx"): l.self_attn.layer_idx = i return elif variant == "merged_op_OLD_UNREACHABLE": layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)] n = len(layers) param_names = [ ("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight), ("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight), ("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight), ("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight), ("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight), ("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight), ("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight), ] merged = {} for name, getter in param_names: stack = torch.stack([getter(l).data.float() for l in layers], dim=0) merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype) for l in layers: l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"]) l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"]) l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"]) l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"]) l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"]) l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"]) l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"]) # Set γ to scheduled values per layer for L in range(BLOCK_START, BLOCK_END + 1): predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L) gamma = model.model.layers[L].input_layernorm.weight.data gamma.mul_(predicted / gamma.norm().item()) elif variant == "uniform_gamma": mid_L = (BLOCK_START + BLOCK_END) // 2 mid_gamma = model.model.layers[mid_L].input_layernorm.weight.data.clone() for L in range(BLOCK_START, BLOCK_END + 1): model.model.layers[L].input_layernorm.weight.data.copy_(mid_gamma) else: raise ValueError(f"Unknown variant {variant}") @torch.no_grad() def perplexity(model, tok, texts, max_len=512): total_nll = 0.0 total_tok = 0 for text in texts: enc = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to("cuda") if enc.input_ids.shape[1] < 2: continue out = model(**enc, labels=enc.input_ids) n = enc.input_ids.shape[1] - 1 total_nll += float(out.loss.item()) * n total_tok += n return math.exp(total_nll / max(total_tok, 1)) @torch.no_grad() def generate_sample(model, tok, prompt, max_new=40): enc = tok(prompt, return_tensors="pt").to("cuda") out = model.generate(**enc, max_new_tokens=max_new, do_sample=False, pad_token_id=tok.eos_token_id) return tok.decode(out[0], skip_special_tokens=True) def run_variant(variant): model, tok = load_model() apply_variant(model, variant) print(f"\n=== variant: {variant} ===", flush=True) ppl = perplexity(model, tok, CALIB) print(f" perplexity: {ppl:.3f}", flush=True) for p in GEN_PROMPTS: out = generate_sample(model, tok, p) print(f" [{p!r}] -> {out[:200]!r}", flush=True) del model torch.cuda.empty_cache() return ppl def main(): ap = argparse.ArgumentParser() ap.add_argument("--variant", default="all", choices=["all", "baseline", "schedule_fit", "single_op", "uniform_gamma", "merged_op", "aligned_merged_op", "flat_merged_op", "reverse_order", "ties_op"]) ap.add_argument("--ties-density", type=float, default=0.2, help="TIES trim density (fraction of top-magnitude params to keep)") args = ap.parse_args() variants = (["baseline", "schedule_fit", "single_op", "uniform_gamma"] if args.variant == "all" else [args.variant]) results = {} for v in variants: results[v] = run_variant(v) if len(results) > 1: print("\n=== Summary ===") b = results.get("baseline", None) for v, ppl in results.items(): rel = f" (×{ppl/b:.2f} baseline)" if b else "" print(f" {v:<15} PPL {ppl:>8.3f}{rel}") if __name__ == "__main__": main()