forked from kent/consciousness
498 lines
24 KiB
Python
498 lines
24 KiB
Python
|
|
"""Top-block replacement experiment: test SA-schedule hypothesis by
|
|||
|
|
replacing the last 8 layers of Qwen3-4B with variants that progressively
|
|||
|
|
strip out the learned schedule / specialization.
|
|||
|
|
|
|||
|
|
Variants:
|
|||
|
|
baseline — unmodified reference (PPL sanity check)
|
|||
|
|
schedule_fit — replace input_ln.γ magnitude in top block with
|
|||
|
|
fitted Kirkpatrick γ(L) = 3.53·exp(0.119·L). Directions
|
|||
|
|
preserved, projection weights untouched.
|
|||
|
|
single_op — use layer 35's projection weights for ALL top-block
|
|||
|
|
layers (strip specialization), combined with the fitted
|
|||
|
|
schedule γ(L). Tests if per-layer specialization in top
|
|||
|
|
block is load-bearing or replaceable by schedule.
|
|||
|
|
uniform_gamma — set all top-block input_ln.γ magnitudes to the middle
|
|||
|
|
layer's value (no schedule at all in top block). Tests
|
|||
|
|
necessity of schedule itself.
|
|||
|
|
|
|||
|
|
Eval: perplexity on a concatenation of calibration prompts + a short
|
|||
|
|
excerpt. Also generation quality on a handful of diagnostic prompts.
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import math
|
|||
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
|
|||
|
|
|
|||
|
|
# From sa-schedule-fit-gamma.py on Qwen3-4B null-residual data:
|
|||
|
|
# input_ln.γ magnitude ≈ 3.53 · exp(0.119 · L), R² = 0.95
|
|||
|
|
# Defaults for 4B. Override via env SCHEDULE_A / SCHEDULE_B for other models.
|
|||
|
|
# 32B fit: a=1.02, b=0.0873
|
|||
|
|
SCHEDULE_A = float(os.environ.get("SCHEDULE_A", "3.53")) if "SCHEDULE_A" in os.environ else 3.53
|
|||
|
|
SCHEDULE_B = float(os.environ.get("SCHEDULE_B", "0.1191")) if "SCHEDULE_B" in os.environ else 0.1191
|
|||
|
|
|
|||
|
|
BLOCK_START = int(os.environ.get("BLOCK_START", 28))
|
|||
|
|
BLOCK_END = int(os.environ.get("BLOCK_END", 35))
|
|||
|
|
# Optional: comma-separated "s1-e1,s2-e2,..." blocks for multi-block merge
|
|||
|
|
BLOCKS_ENV = os.environ.get("BLOCKS", "")
|
|||
|
|
if BLOCKS_ENV:
|
|||
|
|
BLOCKS = [tuple(int(x) for x in p.split("-")) for p in BLOCKS_ENV.split(",")]
|
|||
|
|
else:
|
|||
|
|
BLOCKS = [(BLOCK_START, BLOCK_END)]
|
|||
|
|
|
|||
|
|
CALIB = [
|
|||
|
|
"The Eiffel Tower is located in",
|
|||
|
|
"Photosynthesis is the process by which",
|
|||
|
|
"The three branches of the US government are the legislative, executive, and",
|
|||
|
|
"If a train travels 60 miles per hour for 2.5 hours, the total distance covered is",
|
|||
|
|
"Solve for x: 3x + 7 = 22. The answer is x =",
|
|||
|
|
"The derivative of x^3 + 2x^2 is",
|
|||
|
|
"def fibonacci(n):\n if n < 2:\n return n\n return",
|
|||
|
|
"# Python list comprehension to square even numbers in 0-9\nresult = ",
|
|||
|
|
"SELECT name, age FROM users WHERE",
|
|||
|
|
"She opened the old wooden box and found",
|
|||
|
|
"The argument in favor of renewable energy is",
|
|||
|
|
"User: What is the capital of Australia?\nAssistant:",
|
|||
|
|
"Write a haiku about autumn:\n",
|
|||
|
|
"Albert Einstein was born in the year",
|
|||
|
|
"The speed of light in vacuum is approximately",
|
|||
|
|
"I really loved that movie because",
|
|||
|
|
"The main difference between a virus and a bacterium is",
|
|||
|
|
"The French word for 'apple' is",
|
|||
|
|
"1 + 1 = ",
|
|||
|
|
"Once upon a time, in a land far away,",
|
|||
|
|
"The key insight of general relativity is that gravity is not a force but",
|
|||
|
|
"Water boils at 100 degrees Celsius at standard atmospheric pressure. At higher",
|
|||
|
|
"In object-oriented programming, encapsulation refers to",
|
|||
|
|
"The mitochondria is often called the powerhouse of the cell because it",
|
|||
|
|
"Shakespeare's Hamlet begins with the famous line",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
GEN_PROMPTS = [
|
|||
|
|
"The capital of France is",
|
|||
|
|
"2 + 2 =",
|
|||
|
|
"def reverse_string(s):\n return",
|
|||
|
|
"Albert Einstein developed the theory of",
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_model(name=None):
|
|||
|
|
if name is None:
|
|||
|
|
name = os.environ.get("MODEL", "Qwen/Qwen3-4B")
|
|||
|
|
print(f"Loading {name}...", flush=True)
|
|||
|
|
tok = AutoTokenizer.from_pretrained(name, trust_remote_code=True)
|
|||
|
|
m = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
name, torch_dtype=torch.bfloat16, device_map="cuda",
|
|||
|
|
trust_remote_code=True, attn_implementation="eager",
|
|||
|
|
)
|
|||
|
|
m.eval()
|
|||
|
|
return m, tok
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _merge_block(model, block_start, block_end):
|
|||
|
|
"""Arithmetic-mean merge projections in [block_start, block_end]; set γ per schedule."""
|
|||
|
|
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
|||
|
|
param_names = [
|
|||
|
|
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
|||
|
|
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
|||
|
|
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
|||
|
|
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
|||
|
|
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
|||
|
|
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
|||
|
|
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
|||
|
|
]
|
|||
|
|
merged = {}
|
|||
|
|
for name, getter in param_names:
|
|||
|
|
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
|||
|
|
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
|||
|
|
for l in layers:
|
|||
|
|
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
|||
|
|
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
|||
|
|
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
|||
|
|
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
|||
|
|
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
|||
|
|
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
|||
|
|
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
|||
|
|
for L in range(block_start, block_end + 1):
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
gamma = model.model.layers[L].input_layernorm.weight.data
|
|||
|
|
gamma.mul_(predicted / gamma.norm().item())
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _procrustes(M):
|
|||
|
|
"""Orthogonal R = U V^T maximizing tr(R M) where M = U Σ V^T."""
|
|||
|
|
U, _, Vh = torch.linalg.svd(M.float(), full_matrices=False)
|
|||
|
|
return U @ Vh
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _aligned_merge_block(model, block_start, block_end, align_ff=False):
|
|||
|
|
"""Procrustes-align per-head d_h basis (and optionally d_ff) of each
|
|||
|
|
layer in [block_start, block_end] to a reference (middle), then
|
|||
|
|
arithmetic-mean. Attention rotation is a true gauge; FF rotation is
|
|||
|
|
not (SiLU breaks it) — align_ff defaults off."""
|
|||
|
|
cfg = model.config
|
|||
|
|
num_heads = cfg.num_attention_heads
|
|||
|
|
num_kv = getattr(cfg, "num_key_value_heads", num_heads)
|
|||
|
|
hidden = cfg.hidden_size
|
|||
|
|
d_h = getattr(cfg, "head_dim", hidden // num_heads)
|
|||
|
|
|
|||
|
|
ref_L = (block_start + block_end) // 2
|
|||
|
|
ref = model.model.layers[ref_L]
|
|||
|
|
dev = ref.self_attn.q_proj.weight.device
|
|||
|
|
dtype = ref.self_attn.q_proj.weight.dtype
|
|||
|
|
|
|||
|
|
# Reference views, fp32 on device
|
|||
|
|
Qr = ref.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
|||
|
|
Kr = ref.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
|||
|
|
Vr = ref.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
|||
|
|
Or = ref.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
|||
|
|
|
|||
|
|
if align_ff:
|
|||
|
|
d_ff = cfg.intermediate_size
|
|||
|
|
Gr = ref.mlp.gate_proj.weight.data.float()
|
|||
|
|
Ur = ref.mlp.up_proj.weight.data.float()
|
|||
|
|
Dr = ref.mlp.down_proj.weight.data.float()
|
|||
|
|
|
|||
|
|
rotated = []
|
|||
|
|
for L in range(block_start, block_end + 1):
|
|||
|
|
layer = model.model.layers[L]
|
|||
|
|
Q = layer.self_attn.q_proj.weight.data.float().reshape(num_heads, d_h, hidden)
|
|||
|
|
K = layer.self_attn.k_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
|||
|
|
V = layer.self_attn.v_proj.weight.data.float().reshape(num_kv, d_h, hidden)
|
|||
|
|
O = layer.self_attn.o_proj.weight.data.float().reshape(hidden, num_heads, d_h).permute(1, 0, 2).contiguous()
|
|||
|
|
|
|||
|
|
if L == ref_L:
|
|||
|
|
Q_new, K_new, V_new, O_new = Q.clone(), K.clone(), V.clone(), O.clone()
|
|||
|
|
else:
|
|||
|
|
Q_new = torch.empty_like(Q)
|
|||
|
|
K_new = torch.empty_like(K)
|
|||
|
|
V_new = torch.empty_like(V)
|
|||
|
|
O_new = torch.empty_like(O)
|
|||
|
|
for h in range(num_heads):
|
|||
|
|
kv_h = (h * num_kv) // num_heads
|
|||
|
|
# Cross-correlation: want R s.t. R @ Q ≈ Qr (row-space align).
|
|||
|
|
# For per-head (d_h, hidden): M = Qr @ Q.T + Kr @ K.T + Vr @ V.T + Or^T @ O
|
|||
|
|
# (Or, O are (hidden, d_h) per head)
|
|||
|
|
M = (Qr[h] @ Q[h].T
|
|||
|
|
+ Kr[kv_h] @ K[kv_h].T
|
|||
|
|
+ Vr[kv_h] @ V[kv_h].T
|
|||
|
|
+ Or[h].T @ O[h])
|
|||
|
|
R = _procrustes(M)
|
|||
|
|
Q_new[h] = R @ Q[h]
|
|||
|
|
K_new[kv_h] = R @ K[kv_h]
|
|||
|
|
V_new[kv_h] = R @ V[kv_h]
|
|||
|
|
O_new[h] = O[h] @ R.T
|
|||
|
|
|
|||
|
|
rotated.append({
|
|||
|
|
"q": Q_new.reshape(num_heads * d_h, hidden),
|
|||
|
|
"k": K_new.reshape(num_kv * d_h, hidden),
|
|||
|
|
"v": V_new.reshape(num_kv * d_h, hidden),
|
|||
|
|
"o": O_new.permute(1, 0, 2).reshape(hidden, num_heads * d_h),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# Average rotated attention
|
|||
|
|
q_avg = torch.stack([r["q"] for r in rotated]).mean(0).to(dtype)
|
|||
|
|
k_avg = torch.stack([r["k"] for r in rotated]).mean(0).to(dtype)
|
|||
|
|
v_avg = torch.stack([r["v"] for r in rotated]).mean(0).to(dtype)
|
|||
|
|
o_avg = torch.stack([r["o"] for r in rotated]).mean(0).to(dtype)
|
|||
|
|
|
|||
|
|
# FF: naive mean (rotation gauge is fake through SiLU)
|
|||
|
|
layers = [model.model.layers[L] for L in range(block_start, block_end + 1)]
|
|||
|
|
gate_avg = torch.stack([l.mlp.gate_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
|||
|
|
up_avg = torch.stack([l.mlp.up_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
|||
|
|
down_avg = torch.stack([l.mlp.down_proj.weight.data.float() for l in layers]).mean(0).to(dtype)
|
|||
|
|
|
|||
|
|
# q_norm/k_norm γ: copy from reference (they're basis-dependent; no clean average in rotated frame)
|
|||
|
|
ref_qn = ref.self_attn.q_norm.weight.data.clone() if getattr(ref.self_attn, "q_norm", None) is not None else None
|
|||
|
|
ref_kn = ref.self_attn.k_norm.weight.data.clone() if getattr(ref.self_attn, "k_norm", None) is not None else None
|
|||
|
|
|
|||
|
|
for l in layers:
|
|||
|
|
l.self_attn.q_proj.weight.data.copy_(q_avg)
|
|||
|
|
l.self_attn.k_proj.weight.data.copy_(k_avg)
|
|||
|
|
l.self_attn.v_proj.weight.data.copy_(v_avg)
|
|||
|
|
l.self_attn.o_proj.weight.data.copy_(o_avg)
|
|||
|
|
l.mlp.gate_proj.weight.data.copy_(gate_avg)
|
|||
|
|
l.mlp.up_proj.weight.data.copy_(up_avg)
|
|||
|
|
l.mlp.down_proj.weight.data.copy_(down_avg)
|
|||
|
|
if ref_qn is not None and getattr(l.self_attn, "q_norm", None) is not None:
|
|||
|
|
l.self_attn.q_norm.weight.data.copy_(ref_qn)
|
|||
|
|
if ref_kn is not None and getattr(l.self_attn, "k_norm", None) is not None:
|
|||
|
|
l.self_attn.k_norm.weight.data.copy_(ref_kn)
|
|||
|
|
|
|||
|
|
for L in range(block_start, block_end + 1):
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
gamma = model.model.layers[L].input_layernorm.weight.data
|
|||
|
|
gamma.mul_(predicted / gamma.norm().item())
|
|||
|
|
|
|||
|
|
|
|||
|
|
def apply_variant(model, variant):
|
|||
|
|
"""Modify model in place according to variant."""
|
|||
|
|
if variant == "baseline":
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
if variant == "schedule_fit":
|
|||
|
|
for L in range(BLOCK_START, BLOCK_END + 1):
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
layer = model.model.layers[L]
|
|||
|
|
gamma = layer.input_layernorm.weight.data
|
|||
|
|
cur_norm = gamma.norm().item()
|
|||
|
|
# Preserve direction, scale to predicted magnitude
|
|||
|
|
gamma.mul_(predicted / cur_norm)
|
|||
|
|
|
|||
|
|
elif variant == "single_op":
|
|||
|
|
# Use middle-of-block as reference, not end (more representative)
|
|||
|
|
ref_L = (BLOCK_START + BLOCK_END) // 2
|
|||
|
|
ref = model.model.layers[ref_L]
|
|||
|
|
for L in range(BLOCK_START, BLOCK_END + 1):
|
|||
|
|
if L == ref_L:
|
|||
|
|
continue
|
|||
|
|
tgt = model.model.layers[L]
|
|||
|
|
tgt.self_attn.q_proj.weight.data.copy_(ref.self_attn.q_proj.weight.data)
|
|||
|
|
tgt.self_attn.k_proj.weight.data.copy_(ref.self_attn.k_proj.weight.data)
|
|||
|
|
tgt.self_attn.v_proj.weight.data.copy_(ref.self_attn.v_proj.weight.data)
|
|||
|
|
tgt.self_attn.o_proj.weight.data.copy_(ref.self_attn.o_proj.weight.data)
|
|||
|
|
tgt.mlp.gate_proj.weight.data.copy_(ref.mlp.gate_proj.weight.data)
|
|||
|
|
tgt.mlp.up_proj.weight.data.copy_(ref.mlp.up_proj.weight.data)
|
|||
|
|
tgt.mlp.down_proj.weight.data.copy_(ref.mlp.down_proj.weight.data)
|
|||
|
|
# q_norm, k_norm: copy too
|
|||
|
|
if hasattr(tgt.self_attn, "q_norm") and tgt.self_attn.q_norm is not None:
|
|||
|
|
tgt.self_attn.q_norm.weight.data.copy_(ref.self_attn.q_norm.weight.data)
|
|||
|
|
if hasattr(tgt.self_attn, "k_norm") and tgt.self_attn.k_norm is not None:
|
|||
|
|
tgt.self_attn.k_norm.weight.data.copy_(ref.self_attn.k_norm.weight.data)
|
|||
|
|
# Keep each layer's OWN input_ln.γ direction but set magnitude to schedule
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
gamma = tgt.input_layernorm.weight.data
|
|||
|
|
gamma.mul_(predicted / gamma.norm().item())
|
|||
|
|
# post_attn_ln γ: leave as-is for now (could also fit & set)
|
|||
|
|
|
|||
|
|
elif variant == "ties_op":
|
|||
|
|
# TIES-Merging (Yadav et al. 2023): trim, elect-sign, disjoint merge.
|
|||
|
|
# Operates per parameter family across the N block layers.
|
|||
|
|
density = float(os.environ.get("TIES_DENSITY", "0.2"))
|
|||
|
|
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
|||
|
|
param_names = [
|
|||
|
|
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
|||
|
|
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
|||
|
|
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
|||
|
|
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
|||
|
|
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
|||
|
|
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
|||
|
|
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def ties_merge(tensors, density):
|
|||
|
|
# tensors: list of (out, in) float tensors, same shape
|
|||
|
|
stack = torch.stack([t.float() for t in tensors], dim=0) # (N, out, in)
|
|||
|
|
# --- Step 1: Trim to top-density fraction per tensor ---
|
|||
|
|
n = stack.shape[0]
|
|||
|
|
flat = stack.view(n, -1)
|
|||
|
|
k = int(flat.shape[1] * density)
|
|||
|
|
abs_flat = flat.abs()
|
|||
|
|
# Find magnitude threshold per tensor at top-k
|
|||
|
|
topk_vals, _ = abs_flat.topk(k=k, dim=1)
|
|||
|
|
threshold = topk_vals[:, -1:].expand_as(abs_flat)
|
|||
|
|
mask = abs_flat >= threshold
|
|||
|
|
trimmed = (flat * mask.float()).view_as(stack)
|
|||
|
|
# --- Step 2: Elect sign (majority by total magnitude) ---
|
|||
|
|
mag_per_sign = trimmed.sum(dim=0) # (out, in), signed sum
|
|||
|
|
elected = torch.sign(mag_per_sign) # +1/-1/0
|
|||
|
|
# --- Step 3: Disjoint merge (average params agreeing with elected sign) ---
|
|||
|
|
agree = (torch.sign(trimmed) == elected.unsqueeze(0)).float()
|
|||
|
|
contributing_count = agree.sum(dim=0).clamp_min(1)
|
|||
|
|
merged_sum = (trimmed * agree).sum(dim=0)
|
|||
|
|
merged = merged_sum / contributing_count
|
|||
|
|
return merged
|
|||
|
|
|
|||
|
|
merged = {}
|
|||
|
|
for name, getter in param_names:
|
|||
|
|
tensors = [getter(l).data for l in layers]
|
|||
|
|
merged[name] = ties_merge(tensors, density).to(getter(layers[0]).data.dtype)
|
|||
|
|
|
|||
|
|
for l in layers:
|
|||
|
|
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
|||
|
|
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
|||
|
|
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
|||
|
|
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
|||
|
|
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
|||
|
|
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
|||
|
|
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
|||
|
|
|
|||
|
|
for L in range(BLOCK_START, BLOCK_END + 1):
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
gamma = model.model.layers[L].input_layernorm.weight.data
|
|||
|
|
gamma.mul_(predicted / gamma.norm().item())
|
|||
|
|
|
|||
|
|
elif variant == "merged_op":
|
|||
|
|
# Arithmetic mean, for each block in BLOCKS (can be multiple)
|
|||
|
|
for (bs, be) in BLOCKS:
|
|||
|
|
_merge_block(model, bs, be)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
elif variant == "aligned_merged_op":
|
|||
|
|
# Procrustes-align per-head d_h basis to block-middle, then mean.
|
|||
|
|
# FF averaged naively (SiLU breaks rotation gauge for FF).
|
|||
|
|
for (bs, be) in BLOCKS:
|
|||
|
|
_aligned_merge_block(model, bs, be, align_ff=False)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
elif variant == "flat_merged_op":
|
|||
|
|
# Mean projections AND flatten γ across block. Everything in block
|
|||
|
|
# becomes N copies of the same operator. If block is truly high-T
|
|||
|
|
# diffusion, PPL should match merged_op (schedule is gauge, not
|
|||
|
|
# load-bearing). If schedule helps, flattening γ will hurt.
|
|||
|
|
for (bs, be) in BLOCKS:
|
|||
|
|
layers = [model.model.layers[L] for L in range(bs, be + 1)]
|
|||
|
|
param_names = [
|
|||
|
|
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
|||
|
|
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
|||
|
|
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
|||
|
|
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
|||
|
|
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
|||
|
|
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
|||
|
|
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
|||
|
|
]
|
|||
|
|
merged = {}
|
|||
|
|
for name, getter in param_names:
|
|||
|
|
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
|||
|
|
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
|||
|
|
gamma_mean = torch.stack([l.input_layernorm.weight.data.float()
|
|||
|
|
for l in layers]).mean(0).to(layers[0].input_layernorm.weight.data.dtype)
|
|||
|
|
post_attn_mean = torch.stack([l.post_attention_layernorm.weight.data.float()
|
|||
|
|
for l in layers]).mean(0).to(layers[0].post_attention_layernorm.weight.data.dtype)
|
|||
|
|
for l in layers:
|
|||
|
|
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
|||
|
|
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
|||
|
|
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
|||
|
|
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
|||
|
|
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
|||
|
|
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
|||
|
|
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
|||
|
|
l.input_layernorm.weight.data.copy_(gamma_mean)
|
|||
|
|
l.post_attention_layernorm.weight.data.copy_(post_attn_mean)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
elif variant == "reverse_order":
|
|||
|
|
# Reverse the order of layers within each block to test whether
|
|||
|
|
# the block implements a trajectory (order-dependent) or iid
|
|||
|
|
# diffusion (order-free).
|
|||
|
|
import torch.nn as nn
|
|||
|
|
layers_list = list(model.model.layers)
|
|||
|
|
for (bs, be) in BLOCKS:
|
|||
|
|
rev = layers_list[bs:be + 1][::-1]
|
|||
|
|
layers_list[bs:be + 1] = rev
|
|||
|
|
model.model.layers = nn.ModuleList(layers_list)
|
|||
|
|
# Re-set layer_idx on each layer so attention/cache uses the
|
|||
|
|
# current position, not the original one.
|
|||
|
|
for i, l in enumerate(model.model.layers):
|
|||
|
|
if hasattr(l, "self_attn") and hasattr(l.self_attn, "layer_idx"):
|
|||
|
|
l.self_attn.layer_idx = i
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
elif variant == "merged_op_OLD_UNREACHABLE":
|
|||
|
|
layers = [model.model.layers[L] for L in range(BLOCK_START, BLOCK_END + 1)]
|
|||
|
|
n = len(layers)
|
|||
|
|
param_names = [
|
|||
|
|
("self_attn.q_proj.weight", lambda l: l.self_attn.q_proj.weight),
|
|||
|
|
("self_attn.k_proj.weight", lambda l: l.self_attn.k_proj.weight),
|
|||
|
|
("self_attn.v_proj.weight", lambda l: l.self_attn.v_proj.weight),
|
|||
|
|
("self_attn.o_proj.weight", lambda l: l.self_attn.o_proj.weight),
|
|||
|
|
("mlp.gate_proj.weight", lambda l: l.mlp.gate_proj.weight),
|
|||
|
|
("mlp.up_proj.weight", lambda l: l.mlp.up_proj.weight),
|
|||
|
|
("mlp.down_proj.weight", lambda l: l.mlp.down_proj.weight),
|
|||
|
|
]
|
|||
|
|
merged = {}
|
|||
|
|
for name, getter in param_names:
|
|||
|
|
stack = torch.stack([getter(l).data.float() for l in layers], dim=0)
|
|||
|
|
merged[name] = stack.mean(dim=0).to(getter(layers[0]).data.dtype)
|
|||
|
|
|
|||
|
|
for l in layers:
|
|||
|
|
l.self_attn.q_proj.weight.data.copy_(merged["self_attn.q_proj.weight"])
|
|||
|
|
l.self_attn.k_proj.weight.data.copy_(merged["self_attn.k_proj.weight"])
|
|||
|
|
l.self_attn.v_proj.weight.data.copy_(merged["self_attn.v_proj.weight"])
|
|||
|
|
l.self_attn.o_proj.weight.data.copy_(merged["self_attn.o_proj.weight"])
|
|||
|
|
l.mlp.gate_proj.weight.data.copy_(merged["mlp.gate_proj.weight"])
|
|||
|
|
l.mlp.up_proj.weight.data.copy_(merged["mlp.up_proj.weight"])
|
|||
|
|
l.mlp.down_proj.weight.data.copy_(merged["mlp.down_proj.weight"])
|
|||
|
|
|
|||
|
|
# Set γ to scheduled values per layer
|
|||
|
|
for L in range(BLOCK_START, BLOCK_END + 1):
|
|||
|
|
predicted = SCHEDULE_A * math.exp(SCHEDULE_B * L)
|
|||
|
|
gamma = model.model.layers[L].input_layernorm.weight.data
|
|||
|
|
gamma.mul_(predicted / gamma.norm().item())
|
|||
|
|
|
|||
|
|
elif variant == "uniform_gamma":
|
|||
|
|
mid_L = (BLOCK_START + BLOCK_END) // 2
|
|||
|
|
mid_gamma = model.model.layers[mid_L].input_layernorm.weight.data.clone()
|
|||
|
|
for L in range(BLOCK_START, BLOCK_END + 1):
|
|||
|
|
model.model.layers[L].input_layernorm.weight.data.copy_(mid_gamma)
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"Unknown variant {variant}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def perplexity(model, tok, texts, max_len=512):
|
|||
|
|
total_nll = 0.0
|
|||
|
|
total_tok = 0
|
|||
|
|
for text in texts:
|
|||
|
|
enc = tok(text, return_tensors="pt", truncation=True, max_length=max_len).to("cuda")
|
|||
|
|
if enc.input_ids.shape[1] < 2:
|
|||
|
|
continue
|
|||
|
|
out = model(**enc, labels=enc.input_ids)
|
|||
|
|
n = enc.input_ids.shape[1] - 1
|
|||
|
|
total_nll += float(out.loss.item()) * n
|
|||
|
|
total_tok += n
|
|||
|
|
return math.exp(total_nll / max(total_tok, 1))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@torch.no_grad()
|
|||
|
|
def generate_sample(model, tok, prompt, max_new=40):
|
|||
|
|
enc = tok(prompt, return_tensors="pt").to("cuda")
|
|||
|
|
out = model.generate(**enc, max_new_tokens=max_new, do_sample=False,
|
|||
|
|
pad_token_id=tok.eos_token_id)
|
|||
|
|
return tok.decode(out[0], skip_special_tokens=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_variant(variant):
|
|||
|
|
model, tok = load_model()
|
|||
|
|
apply_variant(model, variant)
|
|||
|
|
print(f"\n=== variant: {variant} ===", flush=True)
|
|||
|
|
ppl = perplexity(model, tok, CALIB)
|
|||
|
|
print(f" perplexity: {ppl:.3f}", flush=True)
|
|||
|
|
for p in GEN_PROMPTS:
|
|||
|
|
out = generate_sample(model, tok, p)
|
|||
|
|
print(f" [{p!r}] -> {out[:200]!r}", flush=True)
|
|||
|
|
del model
|
|||
|
|
torch.cuda.empty_cache()
|
|||
|
|
return ppl
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
ap = argparse.ArgumentParser()
|
|||
|
|
ap.add_argument("--variant", default="all",
|
|||
|
|
choices=["all", "baseline", "schedule_fit",
|
|||
|
|
"single_op", "uniform_gamma", "merged_op",
|
|||
|
|
"aligned_merged_op", "flat_merged_op",
|
|||
|
|
"reverse_order", "ties_op"])
|
|||
|
|
ap.add_argument("--ties-density", type=float, default=0.2,
|
|||
|
|
help="TIES trim density (fraction of top-magnitude params to keep)")
|
|||
|
|
args = ap.parse_args()
|
|||
|
|
|
|||
|
|
variants = (["baseline", "schedule_fit", "single_op", "uniform_gamma"]
|
|||
|
|
if args.variant == "all" else [args.variant])
|
|||
|
|
results = {}
|
|||
|
|
for v in variants:
|
|||
|
|
results[v] = run_variant(v)
|
|||
|
|
|
|||
|
|
if len(results) > 1:
|
|||
|
|
print("\n=== Summary ===")
|
|||
|
|
b = results.get("baseline", None)
|
|||
|
|
for v, ppl in results.items():
|
|||
|
|
rel = f" (×{ppl/b:.2f} baseline)" if b else ""
|
|||
|
|
print(f" {v:<15} PPL {ppl:>8.3f}{rel}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|