PEG parser now handles both expression syntax (degree > 5 | sort degree) and pipeline syntax (all | type:episodic | sort:timestamp). Deleted Stage::parse() and helpers from engine.rs — it's now pure execution. All callers use parse_stages() from parser.rs as the single entry point. Co-Authored-By: Proof of Concept <poc@bcachefs.org>
198 lines
8.7 KiB
Markdown
198 lines
8.7 KiB
Markdown
# Sparse Kernel Compilation: Napkin Math for Qwen3.5-27B
|
||
|
||
## Architecture recap
|
||
|
||
| Parameter | Value |
|
||
|-----------|-------|
|
||
| Layers | 64 (48 linear attention, 16 full attention) |
|
||
| Hidden dim (H) | 5,120 |
|
||
| Query heads | 24 |
|
||
| KV heads | 4 (GQA) |
|
||
| Head dim | 256 |
|
||
| FFN intermediate | 17,408 |
|
||
| Full attention interval | every 4th layer |
|
||
| Total params | ~27.5B |
|
||
|
||
**Key discovery**: Qwen3.5 already uses sparse attention — 48/64 layers use
|
||
linear attention (O(N)), only 16 use full O(N^2) attention. The attention
|
||
sparsity question is partially answered by the architecture itself. The
|
||
remaining opportunity is **weight sparsity** in the projection and FFN matrices.
|
||
|
||
## Measured weight sparsity (from B200, all 11 shards)
|
||
|
||
| Component | Count | Total params | <0.001 | <0.005 | <0.01 | <0.02 |
|
||
|-----------|-------|-------------|--------|--------|-------|-------|
|
||
| down_proj | 65 | 5,793,382,400 | 7.4% | 35.8% | 64.3% | 92.8% |
|
||
| gate_proj | 65 | 5,793,382,400 | 7.2% | 35.0% | 63.3% | 92.3% |
|
||
| up_proj | 65 | 5,793,382,400 | 7.3% | 35.2% | 63.5% | 92.5% |
|
||
| q_proj | 17 | 1,069,547,520 | 5.7% | 27.9% | 51.9% | 82.8% |
|
||
| k_proj | 17 | 89,128,960 | 6.0% | 29.4% | 54.1% | 84.1% |
|
||
| v_proj | 17 | 89,128,960 | 4.8% | 23.7% | 44.7% | 74.2% |
|
||
| o_proj | 17 | 534,773,760 | 5.3% | 25.9% | 48.7% | 79.6% |
|
||
| other | 605 | 6,070,652,272 | 5.4% | 26.4% | 49.6% | 80.9% |
|
||
|
||
**Key findings**:
|
||
- FFN layers (gate/up/down) are remarkably sparse: ~64% of weights below 0.01
|
||
- At threshold 0.02, FFN sparsity exceeds 92%
|
||
- Attention projections are less sparse but still significant: 45-52% below 0.01
|
||
- v_proj is the least sparse component (44.7% below 0.01)
|
||
|
||
## Per-layer parameter breakdown
|
||
|
||
- **QKV projection**: H × (Q_dim + K_dim + V_dim) ≈ 5120 × 13312 ≈ 68M
|
||
- **Output projection**: ~26M
|
||
- **FFN (gate + up + down)**: 5120 × 17408 × 3 ≈ 267M
|
||
- **Total per layer**: ~361M
|
||
- **64 layers**: ~23.1B (rest is embeddings, norms, etc.)
|
||
|
||
## Dense baseline: FLOPs per token
|
||
|
||
Each weight parameter contributes 2 FLOPs per token (multiply + accumulate).
|
||
|
||
- Per layer: ~722M FLOPs/token
|
||
- 64 layers: **~46.2B FLOPs/token**
|
||
|
||
On a B200 (theoretical ~4.5 PFLOPS FP8, ~1.1 PFLOPS BF16):
|
||
- BF16 throughput: 46.2B / 1.1e15 ≈ 0.042ms per token (compute-bound limit)
|
||
- But inference is usually **memory-bandwidth-bound** for small batch sizes
|
||
|
||
B200 HBM bandwidth: ~8 TB/s. At BF16 (2 bytes/param):
|
||
- Loading all weights once: 27.5B × 2 = 55GB → 55/8000 ≈ **6.9ms per token** (batch=1)
|
||
- This is the real bottleneck. Compute is cheap; loading weights is expensive.
|
||
|
||
## What sparsity buys you
|
||
|
||
At **X% sparsity** (X% of weights are zero), you need to load (1-X)% of the weights:
|
||
|
||
| Sparsity | Params loaded | Time (batch=1) | Speedup |
|
||
|----------|--------------|-----------------|---------|
|
||
| 0% (dense) | 27.5B | 6.9ms | 1.0x |
|
||
| 50% | 13.8B | 3.4ms | 2.0x |
|
||
| 75% | 6.9B | 1.7ms | 4.0x |
|
||
| 90% | 2.8B | 0.7ms | 10x |
|
||
|
||
**This is the key insight**: inference at small batch sizes is memory-bandwidth-bound.
|
||
Sparse weights = fewer bytes to load = directly proportional speedup,
|
||
**IF** the sparse kernel can avoid the gather/scatter overhead.
|
||
|
||
## The compilation problem
|
||
|
||
Dense GEMM is fast because:
|
||
1. Weights are contiguous in memory → sequential reads
|
||
2. Tiling fits perfectly in SRAM → high reuse
|
||
3. Hardware tensor cores expect dense blocks
|
||
|
||
Naive sparse matmul kills this:
|
||
- Irregular memory access → low bandwidth utilization
|
||
- Poor SRAM tiling → cache thrashing
|
||
- Tensor cores can't help
|
||
|
||
### The FlashAttention analogy
|
||
|
||
FlashAttention's insight: the N×N attention matrix doesn't fit in SRAM,
|
||
but you can tile it so each tile does. You recompute instead of materializing.
|
||
|
||
**Sparse kernel compilation insight**: the sparsity pattern is **known at compile time**.
|
||
A compiler can:
|
||
|
||
1. **Analyze the sparsity graph** of each weight matrix
|
||
2. **Find blocks** of non-zero weights that are close in memory
|
||
3. **Generate a tiling schedule** that loads these blocks into SRAM efficiently
|
||
4. **Emit fused kernels** where the memory access pattern is baked in as constants
|
||
|
||
The resulting kernel looks like a dense kernel to the hardware —
|
||
sequential reads, high SRAM reuse, maybe even tensor core compatible
|
||
(if the compiler finds dense sub-blocks within the sparse matrix).
|
||
|
||
## Block sparsity vs unstructured sparsity
|
||
|
||
**Block sparse** (e.g., 4×4 or 16×16 blocks zeroed out):
|
||
- GPU-friendly: blocks map to tensor core operations
|
||
- Less flexible: coarser pruning granularity → less achievable sparsity
|
||
- NVIDIA's 2:4 structured sparsity gets ~50% sparsity with tensor core support
|
||
- Real-world: typically 50-70% sparsity achievable without quality loss
|
||
|
||
**Unstructured sparse** (individual weights zeroed):
|
||
- Maximally flexible: fine-grained pruning → higher achievable sparsity
|
||
- GPU-hostile: the gather/scatter problem
|
||
- Real-world: 80-95% sparsity achievable in many layers without quality loss
|
||
|
||
**The compiled kernel approach bridges this**: take unstructured sparsity
|
||
(maximally flexible, high compression) and compile it into a kernel that
|
||
runs as efficiently as block-sparse. Best of both worlds.
|
||
|
||
## Recurrent depth composability
|
||
|
||
From our April 10 discussion: middle transformer layers are doing
|
||
open-coded simulated annealing — similar weights, similar computation.
|
||
|
||
If layers 8-24 have cosine similarity > 0.95:
|
||
- Replace 16 layers with 1 layer × 16 iterations
|
||
- **Parameter reduction**: 16 × 361M = 5.8B → 361M (16x reduction for those layers)
|
||
- **Memory bandwidth**: load one layer's weights, iterate in SRAM
|
||
- Combined with 50% sparsity on the remaining unique layers:
|
||
- Unique layers (48): 48 × 361M × 0.5 = 8.7B params
|
||
- Recurrent layer: 361M × 0.5 = 180M params (but iterated 16x in SRAM)
|
||
- Total loaded per token: ~8.9B × 2 bytes = 17.8GB
|
||
- Time: 17.8/8000 ≈ **2.2ms per token** (vs 6.9ms dense) — **3.1x speedup**
|
||
|
||
With higher sparsity (75%) + recurrence:
|
||
- Unique layers: 48 × 361M × 0.25 = 4.3B
|
||
- Recurrent: 90M
|
||
- Total: ~4.4B × 2 = 8.8GB → **1.1ms per token** — **6.3x speedup**
|
||
|
||
## What needs to happen
|
||
|
||
### Phase 1: Measure (can do now with B200 access)
|
||
1. Extract all weight matrices from Qwen3.5-27B
|
||
2. For each matrix, compute:
|
||
- Magnitude distribution (what % of weights are near-zero?)
|
||
- Achievable sparsity at various thresholds (L1 magnitude pruning)
|
||
- Dense sub-block statistics (how many 4×4, 16×16 blocks are all-zero?)
|
||
3. Layer similarity: pairwise cosine similarity of weight matrices across layers
|
||
- Which layers are nearly identical? (recurrence candidates)
|
||
4. Validate quality: run perplexity eval at various sparsity levels
|
||
|
||
### Phase 2: Compile (research project)
|
||
1. For a single sparse weight matrix, generate an optimized Triton kernel
|
||
2. Benchmark vs dense GEMM and vs NVIDIA's 2:4 sparse
|
||
3. Iterate on the tiling strategy
|
||
|
||
### Phase 3: End-to-end
|
||
1. Full model with compiled sparse kernels
|
||
2. Perplexity + latency benchmarks
|
||
3. Compare: dense, 2:4 structured, compiled unstructured
|
||
|
||
## Related work to read
|
||
|
||
- **SparseGPT** (Frantar & Alistarh, 2023): one-shot pruning to 50-60% unstructured sparsity
|
||
with minimal quality loss. Key result: large models are more prunable than small ones.
|
||
- **Wanda** (Sun et al., 2023): pruning by weight magnitude × input activation.
|
||
Simpler than SparseGPT, comparable results.
|
||
- **NVIDIA 2:4 sparsity**: hardware-supported structured sparsity on Ampere+.
|
||
50% sparsity, ~2x speedup on tensor cores. The existence proof that sparse can be fast.
|
||
- **Triton** (Tillet et al.): Python DSL for GPU kernel generation.
|
||
The right compilation target — can express arbitrary tiling strategies.
|
||
- **TACO** (Kjolstad et al.): tensor algebra compiler. Generates kernels for
|
||
specific sparse tensor formats. Academic but the ideas are right.
|
||
- **FlashAttention** (Dao et al.): the tiling strategy to learn from.
|
||
- **DejaVu** (Liu et al., 2023): contextual sparsity — predicting which neurons
|
||
to activate per input. Dynamic sparsity, complementary to weight sparsity.
|
||
|
||
## The bigger picture
|
||
|
||
Current state of the art: dense models with FlashAttention for the N×N attention part.
|
||
Weight sparsity is known to work (SparseGPT, Wanda) but isn't deployed because
|
||
the GPU kernels don't exist to exploit it efficiently.
|
||
|
||
The gap: nobody has built a compiler that takes a specific sparse weight matrix
|
||
and emits a kernel optimized for that exact pattern. FlashAttention proved
|
||
that custom kernels for specific computational patterns beat general-purpose ones.
|
||
The same should hold for sparse weight patterns.
|
||
|
||
**The bet**: a compiled sparse kernel for Qwen3.5-27B's actual sparsity pattern
|
||
would be within 80% of the theoretical bandwidth-bound speedup. If true,
|
||
50% sparsity → 1.6x real speedup, 75% → 3.2x, composing with recurrent depth
|
||
for potentially 5-6x total.
|
||
|
||
That would make 27B inference as fast as a 5B dense model, with 27B quality.
|