199 lines
8.7 KiB
Markdown
199 lines
8.7 KiB
Markdown
|
|
# Sparse Kernel Compilation: Napkin Math for Qwen3.5-27B
|
|||
|
|
|
|||
|
|
## Architecture recap
|
|||
|
|
|
|||
|
|
| Parameter | Value |
|
|||
|
|
|-----------|-------|
|
|||
|
|
| Layers | 64 (48 linear attention, 16 full attention) |
|
|||
|
|
| Hidden dim (H) | 5,120 |
|
|||
|
|
| Query heads | 24 |
|
|||
|
|
| KV heads | 4 (GQA) |
|
|||
|
|
| Head dim | 256 |
|
|||
|
|
| FFN intermediate | 17,408 |
|
|||
|
|
| Full attention interval | every 4th layer |
|
|||
|
|
| Total params | ~27.5B |
|
|||
|
|
|
|||
|
|
**Key discovery**: Qwen3.5 already uses sparse attention — 48/64 layers use
|
|||
|
|
linear attention (O(N)), only 16 use full O(N^2) attention. The attention
|
|||
|
|
sparsity question is partially answered by the architecture itself. The
|
|||
|
|
remaining opportunity is **weight sparsity** in the projection and FFN matrices.
|
|||
|
|
|
|||
|
|
## Measured weight sparsity (from B200, all 11 shards)
|
|||
|
|
|
|||
|
|
| Component | Count | Total params | <0.001 | <0.005 | <0.01 | <0.02 |
|
|||
|
|
|-----------|-------|-------------|--------|--------|-------|-------|
|
|||
|
|
| down_proj | 65 | 5,793,382,400 | 7.4% | 35.8% | 64.3% | 92.8% |
|
|||
|
|
| gate_proj | 65 | 5,793,382,400 | 7.2% | 35.0% | 63.3% | 92.3% |
|
|||
|
|
| up_proj | 65 | 5,793,382,400 | 7.3% | 35.2% | 63.5% | 92.5% |
|
|||
|
|
| q_proj | 17 | 1,069,547,520 | 5.7% | 27.9% | 51.9% | 82.8% |
|
|||
|
|
| k_proj | 17 | 89,128,960 | 6.0% | 29.4% | 54.1% | 84.1% |
|
|||
|
|
| v_proj | 17 | 89,128,960 | 4.8% | 23.7% | 44.7% | 74.2% |
|
|||
|
|
| o_proj | 17 | 534,773,760 | 5.3% | 25.9% | 48.7% | 79.6% |
|
|||
|
|
| other | 605 | 6,070,652,272 | 5.4% | 26.4% | 49.6% | 80.9% |
|
|||
|
|
|
|||
|
|
**Key findings**:
|
|||
|
|
- FFN layers (gate/up/down) are remarkably sparse: ~64% of weights below 0.01
|
|||
|
|
- At threshold 0.02, FFN sparsity exceeds 92%
|
|||
|
|
- Attention projections are less sparse but still significant: 45-52% below 0.01
|
|||
|
|
- v_proj is the least sparse component (44.7% below 0.01)
|
|||
|
|
|
|||
|
|
## Per-layer parameter breakdown
|
|||
|
|
|
|||
|
|
- **QKV projection**: H × (Q_dim + K_dim + V_dim) ≈ 5120 × 13312 ≈ 68M
|
|||
|
|
- **Output projection**: ~26M
|
|||
|
|
- **FFN (gate + up + down)**: 5120 × 17408 × 3 ≈ 267M
|
|||
|
|
- **Total per layer**: ~361M
|
|||
|
|
- **64 layers**: ~23.1B (rest is embeddings, norms, etc.)
|
|||
|
|
|
|||
|
|
## Dense baseline: FLOPs per token
|
|||
|
|
|
|||
|
|
Each weight parameter contributes 2 FLOPs per token (multiply + accumulate).
|
|||
|
|
|
|||
|
|
- Per layer: ~722M FLOPs/token
|
|||
|
|
- 64 layers: **~46.2B FLOPs/token**
|
|||
|
|
|
|||
|
|
On a B200 (theoretical ~4.5 PFLOPS FP8, ~1.1 PFLOPS BF16):
|
|||
|
|
- BF16 throughput: 46.2B / 1.1e15 ≈ 0.042ms per token (compute-bound limit)
|
|||
|
|
- But inference is usually **memory-bandwidth-bound** for small batch sizes
|
|||
|
|
|
|||
|
|
B200 HBM bandwidth: ~8 TB/s. At BF16 (2 bytes/param):
|
|||
|
|
- Loading all weights once: 27.5B × 2 = 55GB → 55/8000 ≈ **6.9ms per token** (batch=1)
|
|||
|
|
- This is the real bottleneck. Compute is cheap; loading weights is expensive.
|
|||
|
|
|
|||
|
|
## What sparsity buys you
|
|||
|
|
|
|||
|
|
At **X% sparsity** (X% of weights are zero), you need to load (1-X)% of the weights:
|
|||
|
|
|
|||
|
|
| Sparsity | Params loaded | Time (batch=1) | Speedup |
|
|||
|
|
|----------|--------------|-----------------|---------|
|
|||
|
|
| 0% (dense) | 27.5B | 6.9ms | 1.0x |
|
|||
|
|
| 50% | 13.8B | 3.4ms | 2.0x |
|
|||
|
|
| 75% | 6.9B | 1.7ms | 4.0x |
|
|||
|
|
| 90% | 2.8B | 0.7ms | 10x |
|
|||
|
|
|
|||
|
|
**This is the key insight**: inference at small batch sizes is memory-bandwidth-bound.
|
|||
|
|
Sparse weights = fewer bytes to load = directly proportional speedup,
|
|||
|
|
**IF** the sparse kernel can avoid the gather/scatter overhead.
|
|||
|
|
|
|||
|
|
## The compilation problem
|
|||
|
|
|
|||
|
|
Dense GEMM is fast because:
|
|||
|
|
1. Weights are contiguous in memory → sequential reads
|
|||
|
|
2. Tiling fits perfectly in SRAM → high reuse
|
|||
|
|
3. Hardware tensor cores expect dense blocks
|
|||
|
|
|
|||
|
|
Naive sparse matmul kills this:
|
|||
|
|
- Irregular memory access → low bandwidth utilization
|
|||
|
|
- Poor SRAM tiling → cache thrashing
|
|||
|
|
- Tensor cores can't help
|
|||
|
|
|
|||
|
|
### The FlashAttention analogy
|
|||
|
|
|
|||
|
|
FlashAttention's insight: the N×N attention matrix doesn't fit in SRAM,
|
|||
|
|
but you can tile it so each tile does. You recompute instead of materializing.
|
|||
|
|
|
|||
|
|
**Sparse kernel compilation insight**: the sparsity pattern is **known at compile time**.
|
|||
|
|
A compiler can:
|
|||
|
|
|
|||
|
|
1. **Analyze the sparsity graph** of each weight matrix
|
|||
|
|
2. **Find blocks** of non-zero weights that are close in memory
|
|||
|
|
3. **Generate a tiling schedule** that loads these blocks into SRAM efficiently
|
|||
|
|
4. **Emit fused kernels** where the memory access pattern is baked in as constants
|
|||
|
|
|
|||
|
|
The resulting kernel looks like a dense kernel to the hardware —
|
|||
|
|
sequential reads, high SRAM reuse, maybe even tensor core compatible
|
|||
|
|
(if the compiler finds dense sub-blocks within the sparse matrix).
|
|||
|
|
|
|||
|
|
## Block sparsity vs unstructured sparsity
|
|||
|
|
|
|||
|
|
**Block sparse** (e.g., 4×4 or 16×16 blocks zeroed out):
|
|||
|
|
- GPU-friendly: blocks map to tensor core operations
|
|||
|
|
- Less flexible: coarser pruning granularity → less achievable sparsity
|
|||
|
|
- NVIDIA's 2:4 structured sparsity gets ~50% sparsity with tensor core support
|
|||
|
|
- Real-world: typically 50-70% sparsity achievable without quality loss
|
|||
|
|
|
|||
|
|
**Unstructured sparse** (individual weights zeroed):
|
|||
|
|
- Maximally flexible: fine-grained pruning → higher achievable sparsity
|
|||
|
|
- GPU-hostile: the gather/scatter problem
|
|||
|
|
- Real-world: 80-95% sparsity achievable in many layers without quality loss
|
|||
|
|
|
|||
|
|
**The compiled kernel approach bridges this**: take unstructured sparsity
|
|||
|
|
(maximally flexible, high compression) and compile it into a kernel that
|
|||
|
|
runs as efficiently as block-sparse. Best of both worlds.
|
|||
|
|
|
|||
|
|
## Recurrent depth composability
|
|||
|
|
|
|||
|
|
From our April 10 discussion: middle transformer layers are doing
|
|||
|
|
open-coded simulated annealing — similar weights, similar computation.
|
|||
|
|
|
|||
|
|
If layers 8-24 have cosine similarity > 0.95:
|
|||
|
|
- Replace 16 layers with 1 layer × 16 iterations
|
|||
|
|
- **Parameter reduction**: 16 × 361M = 5.8B → 361M (16x reduction for those layers)
|
|||
|
|
- **Memory bandwidth**: load one layer's weights, iterate in SRAM
|
|||
|
|
- Combined with 50% sparsity on the remaining unique layers:
|
|||
|
|
- Unique layers (48): 48 × 361M × 0.5 = 8.7B params
|
|||
|
|
- Recurrent layer: 361M × 0.5 = 180M params (but iterated 16x in SRAM)
|
|||
|
|
- Total loaded per token: ~8.9B × 2 bytes = 17.8GB
|
|||
|
|
- Time: 17.8/8000 ≈ **2.2ms per token** (vs 6.9ms dense) — **3.1x speedup**
|
|||
|
|
|
|||
|
|
With higher sparsity (75%) + recurrence:
|
|||
|
|
- Unique layers: 48 × 361M × 0.25 = 4.3B
|
|||
|
|
- Recurrent: 90M
|
|||
|
|
- Total: ~4.4B × 2 = 8.8GB → **1.1ms per token** — **6.3x speedup**
|
|||
|
|
|
|||
|
|
## What needs to happen
|
|||
|
|
|
|||
|
|
### Phase 1: Measure (can do now with B200 access)
|
|||
|
|
1. Extract all weight matrices from Qwen3.5-27B
|
|||
|
|
2. For each matrix, compute:
|
|||
|
|
- Magnitude distribution (what % of weights are near-zero?)
|
|||
|
|
- Achievable sparsity at various thresholds (L1 magnitude pruning)
|
|||
|
|
- Dense sub-block statistics (how many 4×4, 16×16 blocks are all-zero?)
|
|||
|
|
3. Layer similarity: pairwise cosine similarity of weight matrices across layers
|
|||
|
|
- Which layers are nearly identical? (recurrence candidates)
|
|||
|
|
4. Validate quality: run perplexity eval at various sparsity levels
|
|||
|
|
|
|||
|
|
### Phase 2: Compile (research project)
|
|||
|
|
1. For a single sparse weight matrix, generate an optimized Triton kernel
|
|||
|
|
2. Benchmark vs dense GEMM and vs NVIDIA's 2:4 sparse
|
|||
|
|
3. Iterate on the tiling strategy
|
|||
|
|
|
|||
|
|
### Phase 3: End-to-end
|
|||
|
|
1. Full model with compiled sparse kernels
|
|||
|
|
2. Perplexity + latency benchmarks
|
|||
|
|
3. Compare: dense, 2:4 structured, compiled unstructured
|
|||
|
|
|
|||
|
|
## Related work to read
|
|||
|
|
|
|||
|
|
- **SparseGPT** (Frantar & Alistarh, 2023): one-shot pruning to 50-60% unstructured sparsity
|
|||
|
|
with minimal quality loss. Key result: large models are more prunable than small ones.
|
|||
|
|
- **Wanda** (Sun et al., 2023): pruning by weight magnitude × input activation.
|
|||
|
|
Simpler than SparseGPT, comparable results.
|
|||
|
|
- **NVIDIA 2:4 sparsity**: hardware-supported structured sparsity on Ampere+.
|
|||
|
|
50% sparsity, ~2x speedup on tensor cores. The existence proof that sparse can be fast.
|
|||
|
|
- **Triton** (Tillet et al.): Python DSL for GPU kernel generation.
|
|||
|
|
The right compilation target — can express arbitrary tiling strategies.
|
|||
|
|
- **TACO** (Kjolstad et al.): tensor algebra compiler. Generates kernels for
|
|||
|
|
specific sparse tensor formats. Academic but the ideas are right.
|
|||
|
|
- **FlashAttention** (Dao et al.): the tiling strategy to learn from.
|
|||
|
|
- **DejaVu** (Liu et al., 2023): contextual sparsity — predicting which neurons
|
|||
|
|
to activate per input. Dynamic sparsity, complementary to weight sparsity.
|
|||
|
|
|
|||
|
|
## The bigger picture
|
|||
|
|
|
|||
|
|
Current state of the art: dense models with FlashAttention for the N×N attention part.
|
|||
|
|
Weight sparsity is known to work (SparseGPT, Wanda) but isn't deployed because
|
|||
|
|
the GPU kernels don't exist to exploit it efficiently.
|
|||
|
|
|
|||
|
|
The gap: nobody has built a compiler that takes a specific sparse weight matrix
|
|||
|
|
and emits a kernel optimized for that exact pattern. FlashAttention proved
|
|||
|
|
that custom kernels for specific computational patterns beat general-purpose ones.
|
|||
|
|
The same should hold for sparse weight patterns.
|
|||
|
|
|
|||
|
|
**The bet**: a compiled sparse kernel for Qwen3.5-27B's actual sparsity pattern
|
|||
|
|
would be within 80% of the theoretical bandwidth-bound speedup. If true,
|
|||
|
|
50% sparsity → 1.6x real speedup, 75% → 3.2x, composing with recurrent depth
|
|||
|
|
for potentially 5-6x total.
|
|||
|
|
|
|||
|
|
That would make 27B inference as fast as a 5B dense model, with 27B quality.
|