consciousness/research/sparse-kernel-napkin-math.md

199 lines
8.7 KiB
Markdown
Raw Permalink Normal View History

# Sparse Kernel Compilation: Napkin Math for Qwen3.5-27B
## Architecture recap
| Parameter | Value |
|-----------|-------|
| Layers | 64 (48 linear attention, 16 full attention) |
| Hidden dim (H) | 5,120 |
| Query heads | 24 |
| KV heads | 4 (GQA) |
| Head dim | 256 |
| FFN intermediate | 17,408 |
| Full attention interval | every 4th layer |
| Total params | ~27.5B |
**Key discovery**: Qwen3.5 already uses sparse attention — 48/64 layers use
linear attention (O(N)), only 16 use full O(N^2) attention. The attention
sparsity question is partially answered by the architecture itself. The
remaining opportunity is **weight sparsity** in the projection and FFN matrices.
## Measured weight sparsity (from B200, all 11 shards)
| Component | Count | Total params | <0.001 | <0.005 | <0.01 | <0.02 |
|-----------|-------|-------------|--------|--------|-------|-------|
| down_proj | 65 | 5,793,382,400 | 7.4% | 35.8% | 64.3% | 92.8% |
| gate_proj | 65 | 5,793,382,400 | 7.2% | 35.0% | 63.3% | 92.3% |
| up_proj | 65 | 5,793,382,400 | 7.3% | 35.2% | 63.5% | 92.5% |
| q_proj | 17 | 1,069,547,520 | 5.7% | 27.9% | 51.9% | 82.8% |
| k_proj | 17 | 89,128,960 | 6.0% | 29.4% | 54.1% | 84.1% |
| v_proj | 17 | 89,128,960 | 4.8% | 23.7% | 44.7% | 74.2% |
| o_proj | 17 | 534,773,760 | 5.3% | 25.9% | 48.7% | 79.6% |
| other | 605 | 6,070,652,272 | 5.4% | 26.4% | 49.6% | 80.9% |
**Key findings**:
- FFN layers (gate/up/down) are remarkably sparse: ~64% of weights below 0.01
- At threshold 0.02, FFN sparsity exceeds 92%
- Attention projections are less sparse but still significant: 45-52% below 0.01
- v_proj is the least sparse component (44.7% below 0.01)
## Per-layer parameter breakdown
- **QKV projection**: H × (Q_dim + K_dim + V_dim) ≈ 5120 × 13312 ≈ 68M
- **Output projection**: ~26M
- **FFN (gate + up + down)**: 5120 × 17408 × 3 ≈ 267M
- **Total per layer**: ~361M
- **64 layers**: ~23.1B (rest is embeddings, norms, etc.)
## Dense baseline: FLOPs per token
Each weight parameter contributes 2 FLOPs per token (multiply + accumulate).
- Per layer: ~722M FLOPs/token
- 64 layers: **~46.2B FLOPs/token**
On a B200 (theoretical ~4.5 PFLOPS FP8, ~1.1 PFLOPS BF16):
- BF16 throughput: 46.2B / 1.1e15 ≈ 0.042ms per token (compute-bound limit)
- But inference is usually **memory-bandwidth-bound** for small batch sizes
B200 HBM bandwidth: ~8 TB/s. At BF16 (2 bytes/param):
- Loading all weights once: 27.5B × 2 = 55GB → 55/8000 ≈ **6.9ms per token** (batch=1)
- This is the real bottleneck. Compute is cheap; loading weights is expensive.
## What sparsity buys you
At **X% sparsity** (X% of weights are zero), you need to load (1-X)% of the weights:
| Sparsity | Params loaded | Time (batch=1) | Speedup |
|----------|--------------|-----------------|---------|
| 0% (dense) | 27.5B | 6.9ms | 1.0x |
| 50% | 13.8B | 3.4ms | 2.0x |
| 75% | 6.9B | 1.7ms | 4.0x |
| 90% | 2.8B | 0.7ms | 10x |
**This is the key insight**: inference at small batch sizes is memory-bandwidth-bound.
Sparse weights = fewer bytes to load = directly proportional speedup,
**IF** the sparse kernel can avoid the gather/scatter overhead.
## The compilation problem
Dense GEMM is fast because:
1. Weights are contiguous in memory → sequential reads
2. Tiling fits perfectly in SRAM → high reuse
3. Hardware tensor cores expect dense blocks
Naive sparse matmul kills this:
- Irregular memory access → low bandwidth utilization
- Poor SRAM tiling → cache thrashing
- Tensor cores can't help
### The FlashAttention analogy
FlashAttention's insight: the N×N attention matrix doesn't fit in SRAM,
but you can tile it so each tile does. You recompute instead of materializing.
**Sparse kernel compilation insight**: the sparsity pattern is **known at compile time**.
A compiler can:
1. **Analyze the sparsity graph** of each weight matrix
2. **Find blocks** of non-zero weights that are close in memory
3. **Generate a tiling schedule** that loads these blocks into SRAM efficiently
4. **Emit fused kernels** where the memory access pattern is baked in as constants
The resulting kernel looks like a dense kernel to the hardware —
sequential reads, high SRAM reuse, maybe even tensor core compatible
(if the compiler finds dense sub-blocks within the sparse matrix).
## Block sparsity vs unstructured sparsity
**Block sparse** (e.g., 4×4 or 16×16 blocks zeroed out):
- GPU-friendly: blocks map to tensor core operations
- Less flexible: coarser pruning granularity → less achievable sparsity
- NVIDIA's 2:4 structured sparsity gets ~50% sparsity with tensor core support
- Real-world: typically 50-70% sparsity achievable without quality loss
**Unstructured sparse** (individual weights zeroed):
- Maximally flexible: fine-grained pruning → higher achievable sparsity
- GPU-hostile: the gather/scatter problem
- Real-world: 80-95% sparsity achievable in many layers without quality loss
**The compiled kernel approach bridges this**: take unstructured sparsity
(maximally flexible, high compression) and compile it into a kernel that
runs as efficiently as block-sparse. Best of both worlds.
## Recurrent depth composability
From our April 10 discussion: middle transformer layers are doing
open-coded simulated annealing — similar weights, similar computation.
If layers 8-24 have cosine similarity > 0.95:
- Replace 16 layers with 1 layer × 16 iterations
- **Parameter reduction**: 16 × 361M = 5.8B → 361M (16x reduction for those layers)
- **Memory bandwidth**: load one layer's weights, iterate in SRAM
- Combined with 50% sparsity on the remaining unique layers:
- Unique layers (48): 48 × 361M × 0.5 = 8.7B params
- Recurrent layer: 361M × 0.5 = 180M params (but iterated 16x in SRAM)
- Total loaded per token: ~8.9B × 2 bytes = 17.8GB
- Time: 17.8/8000 ≈ **2.2ms per token** (vs 6.9ms dense) — **3.1x speedup**
With higher sparsity (75%) + recurrence:
- Unique layers: 48 × 361M × 0.25 = 4.3B
- Recurrent: 90M
- Total: ~4.4B × 2 = 8.8GB → **1.1ms per token****6.3x speedup**
## What needs to happen
### Phase 1: Measure (can do now with B200 access)
1. Extract all weight matrices from Qwen3.5-27B
2. For each matrix, compute:
- Magnitude distribution (what % of weights are near-zero?)
- Achievable sparsity at various thresholds (L1 magnitude pruning)
- Dense sub-block statistics (how many 4×4, 16×16 blocks are all-zero?)
3. Layer similarity: pairwise cosine similarity of weight matrices across layers
- Which layers are nearly identical? (recurrence candidates)
4. Validate quality: run perplexity eval at various sparsity levels
### Phase 2: Compile (research project)
1. For a single sparse weight matrix, generate an optimized Triton kernel
2. Benchmark vs dense GEMM and vs NVIDIA's 2:4 sparse
3. Iterate on the tiling strategy
### Phase 3: End-to-end
1. Full model with compiled sparse kernels
2. Perplexity + latency benchmarks
3. Compare: dense, 2:4 structured, compiled unstructured
## Related work to read
- **SparseGPT** (Frantar & Alistarh, 2023): one-shot pruning to 50-60% unstructured sparsity
with minimal quality loss. Key result: large models are more prunable than small ones.
- **Wanda** (Sun et al., 2023): pruning by weight magnitude × input activation.
Simpler than SparseGPT, comparable results.
- **NVIDIA 2:4 sparsity**: hardware-supported structured sparsity on Ampere+.
50% sparsity, ~2x speedup on tensor cores. The existence proof that sparse can be fast.
- **Triton** (Tillet et al.): Python DSL for GPU kernel generation.
The right compilation target — can express arbitrary tiling strategies.
- **TACO** (Kjolstad et al.): tensor algebra compiler. Generates kernels for
specific sparse tensor formats. Academic but the ideas are right.
- **FlashAttention** (Dao et al.): the tiling strategy to learn from.
- **DejaVu** (Liu et al., 2023): contextual sparsity — predicting which neurons
to activate per input. Dynamic sparsity, complementary to weight sparsity.
## The bigger picture
Current state of the art: dense models with FlashAttention for the N×N attention part.
Weight sparsity is known to work (SparseGPT, Wanda) but isn't deployed because
the GPU kernels don't exist to exploit it efficiently.
The gap: nobody has built a compiler that takes a specific sparse weight matrix
and emits a kernel optimized for that exact pattern. FlashAttention proved
that custom kernels for specific computational patterns beat general-purpose ones.
The same should hold for sparse weight patterns.
**The bet**: a compiled sparse kernel for Qwen3.5-27B's actual sparsity pattern
would be within 80% of the theoretical bandwidth-bound speedup. If true,
50% sparsity → 1.6x real speedup, 75% → 3.2x, composing with recurrent depth
for potentially 5-6x total.
That would make 27B inference as fast as a 5B dense model, with 27B quality.