forked from kent/consciousness
145 lines
5.9 KiB
Python
145 lines
5.9 KiB
Python
|
|
"""Fit a functional form to the LN γ trajectory across layers; derive the
|
|||
|
|
effective attention temperature T(L) from known coupling formulas.
|
|||
|
|
|
|||
|
|
Rules of what scales with depth (from literature):
|
|||
|
|
DeepNorm: α_dec = (2M)^(1/4), β_dec = (8M)^(-1/4). Same per layer — does
|
|||
|
|
NOT depend on layer index l. The free variation across layers has to
|
|||
|
|
live in LN γ.
|
|||
|
|
Depth-μP: block multiplier a/√L, LR η/√L. Same per layer.
|
|||
|
|
So γ(L) is the family carrying the per-layer schedule.
|
|||
|
|
|
|||
|
|
Try fitting forms:
|
|||
|
|
γ(L) = a · L^b (power law in layer index)
|
|||
|
|
γ(L) = a · exp(b·L) (exponential)
|
|||
|
|
γ(L) = a + b·L (linear)
|
|||
|
|
γ(L) = a + b·L^c (free c) (power law with free exponent)
|
|||
|
|
|
|||
|
|
Report fit quality (R², residual statistics), and for the best fit, compute
|
|||
|
|
the derived T(L) curve.
|
|||
|
|
"""
|
|||
|
|
import json
|
|||
|
|
import numpy as np
|
|||
|
|
from math import log, exp
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fit_power(L, y):
|
|||
|
|
"""y ≈ a · L^b → log y ≈ log a + b log L."""
|
|||
|
|
mask = (L > 0) & (y > 0)
|
|||
|
|
lx, ly = np.log(L[mask]), np.log(y[mask])
|
|||
|
|
b, loga = np.polyfit(lx, ly, 1)
|
|||
|
|
yhat = np.exp(loga) * (L**b)
|
|||
|
|
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
|||
|
|
return {"form": "a*L^b", "a": float(np.exp(loga)), "b": float(b), "r2": float(r2), "yhat": yhat}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fit_exponential(L, y):
|
|||
|
|
"""y ≈ a · exp(b·L) → log y ≈ log a + b·L."""
|
|||
|
|
mask = y > 0
|
|||
|
|
b, loga = np.polyfit(L[mask], np.log(y[mask]), 1)
|
|||
|
|
yhat = np.exp(loga) * np.exp(b * L)
|
|||
|
|
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
|||
|
|
return {"form": "a*exp(b*L)", "a": float(np.exp(loga)), "b": float(b), "r2": float(r2), "yhat": yhat}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fit_linear(L, y):
|
|||
|
|
b, a = np.polyfit(L, y, 1)
|
|||
|
|
yhat = a + b * L
|
|||
|
|
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
|||
|
|
return {"form": "a+b*L", "a": float(a), "b": float(b), "r2": float(r2), "yhat": yhat}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def fit_piecewise_two(L, y):
|
|||
|
|
"""Best split point L* and linear fits on each half (log-space)."""
|
|||
|
|
best = None
|
|||
|
|
for Ls in range(3, len(L) - 3):
|
|||
|
|
mA, mB = L < Ls, L >= Ls
|
|||
|
|
if (y[mA] <= 0).any() or (y[mB] <= 0).any():
|
|||
|
|
continue
|
|||
|
|
bA, aA = np.polyfit(L[mA], np.log(y[mA]), 1)
|
|||
|
|
bB, aB = np.polyfit(L[mB], np.log(y[mB]), 1)
|
|||
|
|
yhat = np.where(mA, np.exp(aA + bA * L), np.exp(aB + bB * L))
|
|||
|
|
r2 = 1 - ((y - yhat)**2).sum() / ((y - y.mean())**2).sum()
|
|||
|
|
if best is None or r2 > best["r2"]:
|
|||
|
|
best = {"form": f"piecewise-exp-split@L={Ls}", "split": int(Ls),
|
|||
|
|
"a1": float(np.exp(aA)), "b1": float(bA),
|
|||
|
|
"a2": float(np.exp(aB)), "b2": float(bB),
|
|||
|
|
"r2": float(r2), "yhat": yhat}
|
|||
|
|
return best
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
d = json.load(open("/tmp/qwen3-4b-null.json"))
|
|||
|
|
scales = d["scales"]
|
|||
|
|
num_layers = len(scales["input_ln"])
|
|||
|
|
L = np.arange(num_layers, dtype=float)
|
|||
|
|
|
|||
|
|
families_of_interest = ["input_ln", "post_attn_ln", "q_norm", "k_norm",
|
|||
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|||
|
|
"gate_proj", "up_proj", "down_proj"]
|
|||
|
|
|
|||
|
|
print("=" * 72)
|
|||
|
|
print("γ-trajectory fits per family (Qwen3-4B, 36 layers)")
|
|||
|
|
print("=" * 72)
|
|||
|
|
|
|||
|
|
for fam in families_of_interest:
|
|||
|
|
y = np.array(scales[fam], dtype=float)
|
|||
|
|
print(f"\n--- {fam} ---")
|
|||
|
|
print(f" L=0: {y[0]:.3f} L=35: {y[-1]:.3f} ratio: {y[-1]/y[0]:+.2f}×")
|
|||
|
|
fits = [
|
|||
|
|
fit_linear(L, y),
|
|||
|
|
fit_power(L + 1, y), # L+1 so L=0 doesn't explode log
|
|||
|
|
fit_exponential(L, y),
|
|||
|
|
fit_piecewise_two(L + 1, y),
|
|||
|
|
]
|
|||
|
|
for f in fits:
|
|||
|
|
if f is None:
|
|||
|
|
continue
|
|||
|
|
extras = ""
|
|||
|
|
if "b" in f:
|
|||
|
|
extras = f" (a={f['a']:.3g}, b={f['b']:+.4f})"
|
|||
|
|
elif "split" in f:
|
|||
|
|
extras = f" (split={f['split']}, b1={f['b1']:+.4f}, b2={f['b2']:+.4f})"
|
|||
|
|
print(f" {f['form']:<32} R²={f['r2']:+.4f}{extras}")
|
|||
|
|
|
|||
|
|
# For input_ln specifically: plot the curve (text) and derive T(L)
|
|||
|
|
y = np.array(scales["input_ln"], dtype=float)
|
|||
|
|
print("\n" + "=" * 72)
|
|||
|
|
print("input_ln γ magnitude across layers (the schedule signal)")
|
|||
|
|
print("=" * 72)
|
|||
|
|
print(f" {'L':>3} {'γ_L':>12} {'γ_L / γ_0':>10} {'log γ_L':>10}")
|
|||
|
|
for l_idx in range(num_layers):
|
|||
|
|
print(f" {l_idx:>3} {y[l_idx]:>12.3f} {y[l_idx]/y[0]:>10.3f} {log(y[l_idx]):>+10.4f}")
|
|||
|
|
|
|||
|
|
# Classical SA schedules for comparison
|
|||
|
|
# - Linear: T(k) = T0 - k * (T0 - Tf)/N
|
|||
|
|
# - Exponential / Kirkpatrick: T(k) = T0 * α^k
|
|||
|
|
# - Logarithmic / Hajek: T(k) = c / log(k+2)
|
|||
|
|
# For γ (which grows = temperature drops, since larger γ → sharper attention):
|
|||
|
|
# γ growing corresponds to T cooling
|
|||
|
|
print("\n" + "=" * 72)
|
|||
|
|
print("Derived attention-temperature T(L) interpretation")
|
|||
|
|
print("=" * 72)
|
|||
|
|
print(" Attention logit ∝ (γ * W_Q * W_K * ||residual||²) / √d_head.")
|
|||
|
|
print(" With γ_L the schedule dial and other factors ~constant across layers,")
|
|||
|
|
print(" effective attention temperature T(L) ∝ 1/γ(L).")
|
|||
|
|
print(f"\n T(L)/T(0) = γ(0)/γ(L):")
|
|||
|
|
print(f" {'L':>3} {'T(L)/T(0)':>10} (smaller = cooler = sharper attention)")
|
|||
|
|
for l_idx in range(num_layers):
|
|||
|
|
print(f" {l_idx:>3} {y[0]/y[l_idx]:>10.4f}")
|
|||
|
|
|
|||
|
|
# Comparison with classical SA cooling laws:
|
|||
|
|
# Kirkpatrick: T(L) = T0 · α^L → log T(L) = log T0 + L log α
|
|||
|
|
logT = -np.log(y / y[0]) # because T ∝ 1/γ
|
|||
|
|
b_kirk, a_kirk = np.polyfit(L, logT, 1)
|
|||
|
|
# Hajek (log-cooling): T(L) = c/log(L+2)
|
|||
|
|
# Predicts: log T = log c - log(log(L+2))
|
|||
|
|
# Fit T(L) to c / log(L+c2)
|
|||
|
|
print(f"\n Kirkpatrick-law fit (exponential cooling):")
|
|||
|
|
print(f" log T(L) = {a_kirk:+.3f} + {b_kirk:+.4f} * L → T(L) = exp({a_kirk:+.3f}) · exp({b_kirk:+.4f}·L)")
|
|||
|
|
logT_hat = a_kirk + b_kirk * L
|
|||
|
|
r2_kirk = 1 - ((logT - logT_hat)**2).sum() / ((logT - logT.mean())**2).sum()
|
|||
|
|
print(f" R² (in log space) = {r2_kirk:+.4f} — ideally ≈ 1 if cooling is pure exponential")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|