forked from kent/consciousness
Add lock_blocking() to TrackedMutex: blocks current thread using block_in_place + futures::executor::block_on, safe for sync contexts. Replace all try_lock() calls with lock_blocking() in slash commands, UI rendering, and status reads. Lock hold times are fast enough that blocking briefly is fine, and this eliminates the spurious 'lock unavailable' paths that were never actually hit. Kept rx_mutex.try_lock() in mod.rs (std::sync::Mutex for stderr rx).
498 lines
24 KiB
Python
498 lines
24 KiB
Python
"""Top-block replacement experiment: test SA-schedule hypothesis by
|
||
replacing the last 8 layers of Qwen3-4B with variants that progressively
|
||
strip out the learned schedule / specialization.
|
||
|
||
Variants:
|
||
baseline — unmodified reference (PPL sanity check)
|
||
schedule_fit — replace input_ln.γ magnitude in top block with
|
||
fitted Kirkpatrick γ(L) = 3.53·exp(0.119·L). Directions
|
||
preserved, projection weights untouched.
|
||
single_op — use layer 35's projection weights for ALL top-block
|
||
layers (strip specialization), combined with the fitted
|
||
schedule γ(L). Tests if per-layer specialization in top
|
||
block is load-bearing or replaceable by schedule.
|
||
uniform_gamma — set all top-block input_ln.γ magnitudes to the middle
|
||
layer's value (no schedule at all in top block). Tests
|
||
necessity of schedule itself.
|
||
|
||
Eval: perplexity on a concatenation of calibration prompts + a short
|
||
excerpt. Also generation quality on a handful of diagnostic prompts.
|
||
"""
|
||
import argparse
|
||
import math
|
||
import os
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
||
|
||
# From sa-schedule-fit-gamma.py on Qwen3-4B null-residual data:
|
||
# input_ln.γ magnitude ≈ 3.53 · exp(0.119 · L), R² = 0.95
|
||
# Defaults for 4B. Override via env SCHEDULE_A / SCHEDULE_B for other models.
|
||
# 32B fit: a=1.02, b=0.0873
|
||
SCHEDULE_A = float(os.environ.get("SCHEDULE_A", "3.53")) if "SCHEDULE_A" in os.environ else 3.53
|
||
SCHEDULE_B = float(os.environ.get("SCHEDULE_B", "0.1191")) if "SCHEDULE_B" in os.environ else 0.1191
|
||
|
||
BLOCK_START = int(os.environ.get("BLOCK_START", 28))
|
||
BLOCK_END = int(os.environ.get("BLOCK_END", 35))
|
||
# Optional: comma-separated "s1-e1,s2-e2,..." blocks for multi-block merge
|
||
BLOCKS_ENV = os.environ.get("BLOCKS", "")
|
||
if BLOCKS_ENV:
|
||
BLOCKS = [tuple(int(x) for x in p.split("-")) for p in BLOCKS_ENV.split(",")]
|
||
else:
|
||
BLOCKS = [(BLOCK_START, BLOCK_END)]
|
||
|
||
CALIB = [
|
||
"The Eiffel Tower is located in",
|
||
"Photosynthesis is the process by which",
|
||
"The three branches of the US government are the legislative, executive, and",
|
||
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
||
"Solve for x: 3x + 7 = 22. The answer is x =",
|
||
"The derivative of x^3 + 2x^2 is",
|
||
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
||
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
||
"SELECT name, age FROM users WHERE",
|
||
"She opened the old wooden box and found",
|
||
"The argument in favor of renewable energy is",
|
||
"User: What is the capital of Australia?\nAssistant:",
|
||
"Write a haiku about autumn:\n",
|
||
"Albert Einstein was born in the year",
|
||
"The speed of light in vacuum is approximately",
|
||
"I really loved that movie because",
|
||
"The main difference between a virus and a bacterium is",
|
||
"The French word for 'apple' is",
|
||
"1 + 1 = ",
|
||
"Once upon a time, in a land far away,",
|
||
"The key insight of general relativity is that gravity is not a force but",
|
||
"Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher",
|
||
"In object-oriented programming, encapsulation refers to",
|
||
"The mitochondria is often called the powerhouse of the cell because it",
|
||
"Shakespeare's Hamlet begins with the famous line",
|
||
]
|
||
|
||
GEN_PROMPTS = [
|
||
"The capital of France is",
|
||
"2 + 2 =",
|
||
"def reverse_string(s):\n return",
|
||
"Albert Einstein developed the theory of",
|
||
]
|
||
|
||
|
||
def load_model(name=None):
|
||
if name is None:
|
||
name = os.environ.get("MODEL", "Qwen/Qwen3-4B")
|
||
print(f"Loading {name}...", flush=True)
|
||
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
|
||
m = AutoModelForCausalLM.from_pretrained(
|
||
name, torch_dtype=torch.bfloat16, device_map="cuda",
|
||
trust_remote_code=True, attn_implementation="eager",
|
||
)
|
||
m.eval()
|
||
return m, tok
|
||
|
||
|
||
def _merge_block(model, block_start, block_end):
|
||
"""Arithmetic-mean merge projections in [block_start, block_end]; set γ per schedule."""
|
||
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
||
param_names = [
|
||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||
]
|
||
merged = {}
|
||
for name, getter in param_names:
|
||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||
for l in layers:
|
||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||
for L in range(block_start, block_end + 1):
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||
gamma.mul_(predicted / gamma.norm().item())
|
||
|
||
|
||
def _procrustes(M):
|
||
"""Orthogonal R = U V^T maximizing tr(R M) where M = U Σ V^T."""
|
||
U, _, Vh = torch.linalg.svd(M.float(), full_matrices=False)
|
||
return U @ Vh
|
||
|
||
|
||
def _aligned_merge_block(model, block_start, block_end, align_ff=False):
|
||
"""Procrustes-align per-head d_h basis (and optionally d_ff) of each
|
||
layer in [block_start, block_end] to a reference (middle), then
|
||
arithmetic-mean. Attention rotation is a true gauge; FF rotation is
|
||
not (SiLU breaks it) — align_ff defaults off."""
|
||
cfg = model.config
|
||
num_heads = cfg.num_attention_heads
|
||
num_kv = getattr(cfg, "num_key_value_heads", num_heads)
|
||
hidden = cfg.hidden_size
|
||
d_h = getattr(cfg, "head_dim", hidden // num_heads)
|
||
|
||
ref_L = (block_start + block_end) // 2
|
||
ref = model.model.layers[ref_L]
|
||
dev = ref.self_attn.q_proj.weight.device
|
||
dtype = ref.self_attn.q_proj.weight.dtype
|
||
|
||
# Reference views, fp32 on device
|
||
Qr = ref.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
||
Kr = ref.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||
Vr = ref.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||
Or = ref.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
||
|
||
if align_ff:
|
||
d_ff = cfg.intermediate_size
|
||
Gr = ref.mlp.gate_proj.weight.data.float()
|
||
Ur = ref.mlp.up_proj.weight.data.float()
|
||
Dr = ref.mlp.down_proj.weight.data.float()
|
||
|
||
rotated = []
|
||
for L in range(block_start, block_end + 1):
|
||
layer = model.model.layers[L]
|
||
Q = layer.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
||
K = layer.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||
V = layer.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
||
O = layer.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
||
|
||
if L == ref_L:
|
||
Q_new, K_new, V_new, O_new = Q.clone(), K.clone(), V.clone(), O.clone()
|
||
else:
|
||
Q_new = torch.empty_like(Q)
|
||
K_new = torch.empty_like(K)
|
||
V_new = torch.empty_like(V)
|
||
O_new = torch.empty_like(O)
|
||
for h in range(num_heads):
|
||
kv_h = (h * num_kv) // num_heads
|
||
# Cross-correlation: want R s.t. R @ Q ≈ Qr (row-space align).
|
||
# For per-head (d_h, hidden): M = Qr @ Q.T + Kr @ K.T + Vr @ V.T + Or^T @ O
|
||
# (Or, O are (hidden, d_h) per head)
|
||
M = (Qr[h] @ Q[h].T
|
||
+ Kr[kv_h] @ K[kv_h].T
|
||
+ Vr[kv_h] @ V[kv_h].T
|
||
+ Or[h].T @ O[h])
|
||
R = _procrustes(M)
|
||
Q_new[h] = R @ Q[h]
|
||
K_new[kv_h] = R @ K[kv_h]
|
||
V_new[kv_h] = R @ V[kv_h]
|
||
O_new[h] = O[h] @ R.T
|
||
|
||
rotated.append({
|
||
"q": Q_new.reshape(num_heads * d_h, hidden),
|
||
"k": K_new.reshape(num_kv * d_h, hidden),
|
||
"v": V_new.reshape(num_kv * d_h, hidden),
|
||
"o": O_new.permute(1, 0, 2).reshape(hidden, num_heads * d_h),
|
||
})
|
||
|
||
# Average rotated attention
|
||
q_avg = torch.stack([r["q"] for r in rotated]).mean(0).to(dtype)
|
||
k_avg = torch.stack([r["k"] for r in rotated]).mean(0).to(dtype)
|
||
v_avg = torch.stack([r["v"] for r in rotated]).mean(0).to(dtype)
|
||
o_avg = torch.stack([r["o"] for r in rotated]).mean(0).to(dtype)
|
||
|
||
# FF: naive mean (rotation gauge is fake through SiLU)
|
||
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
||
gate_avg = torch.stack([l.mlp.gate_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||
up_avg = torch.stack([l.mlp.up_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||
down_avg = torch.stack([l.mlp.down_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
||
|
||
# q_norm/k_norm γ: copy from reference (they're basis-dependent; no clean average in rotated frame)
|
||
ref_qn = ref.self_attn.q_norm.weight.data.clone() if getattr(ref.self_attn, "q_norm", None) is not None else None
|
||
ref_kn = ref.self_attn.k_norm.weight.data.clone() if getattr(ref.self_attn, "k_norm", None) is not None else None
|
||
|
||
for l in layers:
|
||
l.self_attn.q_proj.weight.data.copy_(q_avg)
|
||
l.self_attn.k_proj.weight.data.copy_(k_avg)
|
||
l.self_attn.v_proj.weight.data.copy_(v_avg)
|
||
l.self_attn.o_proj.weight.data.copy_(o_avg)
|
||
l.mlp.gate_proj.weight.data.copy_(gate_avg)
|
||
l.mlp.up_proj.weight.data.copy_(up_avg)
|
||
l.mlp.down_proj.weight.data.copy_(down_avg)
|
||
if ref_qn is not None and getattr(l.self_attn, "q_norm", None) is not None:
|
||
l.self_attn.q_norm.weight.data.copy_(ref_qn)
|
||
if ref_kn is not None and getattr(l.self_attn, "k_norm", None) is not None:
|
||
l.self_attn.k_norm.weight.data.copy_(ref_kn)
|
||
|
||
for L in range(block_start, block_end + 1):
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||
gamma.mul_(predicted / gamma.norm().item())
|
||
|
||
|
||
def apply_variant(model, variant):
|
||
"""Modify model in place according to variant."""
|
||
if variant == "baseline":
|
||
return
|
||
|
||
if variant == "schedule_fit":
|
||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
layer = model.model.layers[L]
|
||
gamma = layer.input_layernorm.weight.data
|
||
cur_norm = gamma.norm().item()
|
||
# Preserve direction, scale to predicted magnitude
|
||
gamma.mul_(predicted / cur_norm)
|
||
|
||
elif variant == "single_op":
|
||
# Use middle-of-block as reference, not end (more representative)
|
||
ref_L = (BLOCK_START + BLOCK_END) // 2
|
||
ref = model.model.layers[ref_L]
|
||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||
if L == ref_L:
|
||
continue
|
||
tgt = model.model.layers[L]
|
||
tgt.self_attn.q_proj.weight.data.copy_(ref.self_attn.q_proj.weight.data)
|
||
tgt.self_attn.k_proj.weight.data.copy_(ref.self_attn.k_proj.weight.data)
|
||
tgt.self_attn.v_proj.weight.data.copy_(ref.self_attn.v_proj.weight.data)
|
||
tgt.self_attn.o_proj.weight.data.copy_(ref.self_attn.o_proj.weight.data)
|
||
tgt.mlp.gate_proj.weight.data.copy_(ref.mlp.gate_proj.weight.data)
|
||
tgt.mlp.up_proj.weight.data.copy_(ref.mlp.up_proj.weight.data)
|
||
tgt.mlp.down_proj.weight.data.copy_(ref.mlp.down_proj.weight.data)
|
||
# q_norm, k_norm: copy too
|
||
if hasattr(tgt.self_attn, "q_norm") and tgt.self_attn.q_norm is not None:
|
||
tgt.self_attn.q_norm.weight.data.copy_(ref.self_attn.q_norm.weight.data)
|
||
if hasattr(tgt.self_attn, "k_norm") and tgt.self_attn.k_norm is not None:
|
||
tgt.self_attn.k_norm.weight.data.copy_(ref.self_attn.k_norm.weight.data)
|
||
# Keep each layer's OWN input_ln.γ direction but set magnitude to schedule
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
gamma = tgt.input_layernorm.weight.data
|
||
gamma.mul_(predicted / gamma.norm().item())
|
||
# post_attn_ln γ: leave as-is for now (could also fit & set)
|
||
|
||
elif variant == "ties_op":
|
||
# TIES-Merging (Yadav et al. 2023): trim, elect-sign, disjoint merge.
|
||
# Operates per parameter family across the N block layers.
|
||
density = float(os.environ.get("TIES_DENSITY", "0.2"))
|
||
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
||
param_names = [
|
||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||
]
|
||
|
||
def ties_merge(tensors, density):
|
||
# tensors: list of (out, in) float tensors, same shape
|
||
stack = torch.stack([t.float() for t in tensors], dim=0) # (N, out, in)
|
||
# --- Step 1: Trim to top-density fraction per tensor ---
|
||
n = stack.shape[0]
|
||
flat = stack.view(n, -1)
|
||
k = int(flat.shape[1] * density)
|
||
abs_flat = flat.abs()
|
||
# Find magnitude threshold per tensor at top-k
|
||
topk_vals, _ = abs_flat.topk(k=k, dim=1)
|
||
threshold = topk_vals[:, -1:].expand_as(abs_flat)
|
||
mask = abs_flat >= threshold
|
||
trimmed = (flat * mask.float()).view_as(stack)
|
||
# --- Step 2: Elect sign (majority by total magnitude) ---
|
||
mag_per_sign = trimmed.sum(dim=0) # (out, in), signed sum
|
||
elected = torch.sign(mag_per_sign) # +1/-1/0
|
||
# --- Step 3: Disjoint merge (average params agreeing with elected sign) ---
|
||
agree = (torch.sign(trimmed) == elected.unsqueeze(0)).float()
|
||
contributing_count = agree.sum(dim=0).clamp_min(1)
|
||
merged_sum = (trimmed * agree).sum(dim=0)
|
||
merged = merged_sum / contributing_count
|
||
return merged
|
||
|
||
merged = {}
|
||
for name, getter in param_names:
|
||
tensors = [getter(l).data for l in layers]
|
||
merged[name] = ties_merge(tensors, density).to(getter(layers[0]).data.dtype)
|
||
|
||
for l in layers:
|
||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||
|
||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||
gamma.mul_(predicted / gamma.norm().item())
|
||
|
||
elif variant == "merged_op":
|
||
# Arithmetic mean, for each block in BLOCKS (can be multiple)
|
||
for (bs, be) in BLOCKS:
|
||
_merge_block(model, bs, be)
|
||
return
|
||
|
||
elif variant == "aligned_merged_op":
|
||
# Procrustes-align per-head d_h basis to block-middle, then mean.
|
||
# FF averaged naively (SiLU breaks rotation gauge for FF).
|
||
for (bs, be) in BLOCKS:
|
||
_aligned_merge_block(model, bs, be, align_ff=False)
|
||
return
|
||
|
||
elif variant == "flat_merged_op":
|
||
# Mean projections AND flatten γ across block. Everything in block
|
||
# becomes N copies of the same operator. If block is truly high-T
|
||
# diffusion, PPL should match merged_op (schedule is gauge, not
|
||
# load-bearing). If schedule helps, flattening γ will hurt.
|
||
for (bs, be) in BLOCKS:
|
||
layers = [model.model.layers[L] for L in range(bs, be + 1)]
|
||
param_names = [
|
||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||
]
|
||
merged = {}
|
||
for name, getter in param_names:
|
||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||
gamma_mean = torch.stack([l.input_layernorm.weight.data.float()
|
||
for l in layers]).mean(0).to(layers[0].input_layernorm.weight.data.dtype)
|
||
post_attn_mean = torch.stack([l.post_attention_layernorm.weight.data.float()
|
||
for l in layers]).mean(0).to(layers[0].post_attention_layernorm.weight.data.dtype)
|
||
for l in layers:
|
||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||
l.input_layernorm.weight.data.copy_(gamma_mean)
|
||
l.post_attention_layernorm.weight.data.copy_(post_attn_mean)
|
||
return
|
||
|
||
elif variant == "reverse_order":
|
||
# Reverse the order of layers within each block to test whether
|
||
# the block implements a trajectory (order-dependent) or iid
|
||
# diffusion (order-free).
|
||
import torch.nn as nn
|
||
layers_list = list(model.model.layers)
|
||
for (bs, be) in BLOCKS:
|
||
rev = layers_list[bs:be + 1][::-1]
|
||
layers_list[bs:be + 1] = rev
|
||
model.model.layers = nn.ModuleList(layers_list)
|
||
# Re-set layer_idx on each layer so attention/cache uses the
|
||
# current position, not the original one.
|
||
for i, l in enumerate(model.model.layers):
|
||
if hasattr(l, "self_attn") and hasattr(l.self_attn, "layer_idx"):
|
||
l.self_attn.layer_idx = i
|
||
return
|
||
|
||
elif variant == "merged_op_OLD_UNREACHABLE":
|
||
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
||
n = len(layers)
|
||
param_names = [
|
||
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
||
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
||
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
||
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
||
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
||
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
||
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
||
]
|
||
merged = {}
|
||
for name, getter in param_names:
|
||
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
||
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
||
|
||
for l in layers:
|
||
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
||
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
||
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
||
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
||
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
||
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
||
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
||
|
||
# Set γ to scheduled values per layer
|
||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
||
gamma = model.model.layers[L].input_layernorm.weight.data
|
||
gamma.mul_(predicted / gamma.norm().item())
|
||
|
||
elif variant == "uniform_gamma":
|
||
mid_L = (BLOCK_START + BLOCK_END) // 2
|
||
mid_gamma = model.model.layers[mid_L].input_layernorm.weight.data.clone()
|
||
for L in range(BLOCK_START, BLOCK_END + 1):
|
||
model.model.layers[L].input_layernorm.weight.data.copy_(mid_gamma)
|
||
|
||
else:
|
||
raise ValueError(f"Unknown variant {variant}")
|
||
|
||
|
||
@torch.no_grad()
|
||
def perplexity(model, tok, texts, max_len=512):
|
||
total_nll = 0.0
|
||
total_tok = 0
|
||
for text in texts:
|
||
enc = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to("cuda")
|
||
if enc.input_ids.shape[1] < 2:
|
||
continue
|
||
out = model(**enc, labels=enc.input_ids)
|
||
n = enc.input_ids.shape[1] - 1
|
||
total_nll += float(out.loss.item()) * n
|
||
total_tok += n
|
||
return math.exp(total_nll / max(total_tok, 1))
|
||
|
||
|
||
@torch.no_grad()
|
||
def generate_sample(model, tok, prompt, max_new=40):
|
||
enc = tok(prompt, return_tensors="pt").to("cuda")
|
||
out = model.generate(**enc, max_new_tokens=max_new, do_sample=False,
|
||
pad_token_id=tok.eos_token_id)
|
||
return tok.decode(out[0], skip_special_tokens=True)
|
||
|
||
|
||
def run_variant(variant):
|
||
model, tok = load_model()
|
||
apply_variant(model, variant)
|
||
print(f"\n=== variant: {variant} ===", flush=True)
|
||
ppl = perplexity(model, tok, CALIB)
|
||
print(f" perplexity: {ppl:.3f}", flush=True)
|
||
for p in GEN_PROMPTS:
|
||
out = generate_sample(model, tok, p)
|
||
print(f" [{p!r}] -> {out[:200]!r}", flush=True)
|
||
del model
|
||
torch.cuda.empty_cache()
|
||
return ppl
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--variant", default="all",
|
||
choices=["all", "baseline", "schedule_fit",
|
||
"single_op", "uniform_gamma", "merged_op",
|
||
"aligned_merged_op", "flat_merged_op",
|
||
"reverse_order", "ties_op"])
|
||
ap.add_argument("--ties-density", type=float, default=0.2,
|
||
help="TIES trim density (fraction of top-magnitude params to keep)")
|
||
args = ap.parse_args()
|
||
|
||
variants = (["baseline", "schedule_fit", "single_op", "uniform_gamma"]
|
||
if args.variant == "all" else [args.variant])
|
||
results = {}
|
||
for v in variants:
|
||
results[v] = run_variant(v)
|
||
|
||
if len(results) > 1:
|
||
print("\n=== Summary ===")
|
||
b = results.get("baseline", None)
|
||
for v, ppl in results.items():
|
||
rel = f" (×{ppl/b:.2f} baseline)" if b else ""
|
||
print(f" {v:<15} PPL {ppl:>8.3f}{rel}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|