scripts: FP8 quantize Qwen3.6-27B for vLLM (multimodal + MTP)

Quantization recipe targeting the multimodal Qwen3.6-27B for vLLM
serving. Three pitfalls the script avoids, each documented inline:

1. Loader strip: `AutoModelForCausalLM` silently drops the vision
   tower; we load via the config-declared
   `Qwen3_5ForConditionalGeneration` instead.

2. Pattern anchor: llmcompressor matches the `ignore` list against
   module names (no `.weight` suffix) when walking `named_modules()`,
   not against full tensor names. Patterns now anchor on `$` at the
   module name; the earlier `\.weight$` form silently quantized
   lm_head and every linear_attn projection.

3. vLLM fusion: vLLM fuses {q,k,v}_proj into qkv_proj, gate+up into
   gate_up_proj, and in_proj_qkv+in_proj_z into in_proj_qkvz. The
   compressed_tensors loader rejects mixed schemes within a fused
   layer, so the `ignore` list is shaped to keep all sub-components
   of a fused layer consistent.

After `oneshot()` writes the FP8 output, MTP tensors (which the HF
class doesn't expose) are spliced in at BF16 from the upstream cached
snapshot, with the compressed_tensors metadata header preserved.

Recipe follows Unsloth's UD-Q8_K_XL late-stack overrides (FFN: 50,
51, 59, 62, 63; ATTN: 51, 59, 63), extended to include `v_proj` for
fusion compat. Final checkpoint is ~35 GB (matches Unsloth's GGUF
size to within ~1%) with vision tower BF16, MTP head BF16, and most
mlp/self_attn Linears at FP8_DYNAMIC.

Co-Authored-By: Proof of Concept <poc@bcachefs.org>
This commit is contained in:
Kent Overstreet 2026-04-24 22:15:31 -04:00
commit 11a7e4043e

View file

@ -0,0 +1,327 @@
"""Quantize Qwen3.6-27B (multimodal) to FP8 for vLLM serving.
Why this exists
---------------
The earlier `quantize_qwen3_6.py` (in shell history, never committed)
loaded the model with `AutoModelForCausalLM`, which silently strips
the multimodal arch. Result: an FP8 checkpoint with no vision tower
weights at all. vLLM happily instantiated the vision tower from the
config and ran it with default/uninitialized weights, producing
gibberish image features and `!!!!!!`-style output. We chased that
through the protocol layer for a long time before tracing it back
to the quant. This script avoids that trap by loading via the
config-declared class explicitly.
Recipe
------
FP8_DYNAMIC (per-channel weight scales, per-token dynamic activation
scales, both E4M3) for Linear weights, with an `ignore` list derived
from Unsloth's UD-Q8_K_XL (`unsloth/Qwen3.6-27B-GGUF`). Their
sensitivity sweep flagged specific layers as quantization-fragile;
we honor those layer indices even though their algorithm is
GGUF-native Q8_K and ours is FP8 sensitivity is a layer property,
not an algorithm property.
vLLM fusion constraint
~~~~~~~~~~~~~~~~~~~~~~
vLLM's Qwen3.5/3.6 model code fuses sub-modules at load time:
qkv_proj q_proj, k_proj, v_proj
gate_up_proj gate_proj, up_proj
in_proj_qkvz in_proj_qkv, in_proj_z
in_proj_ba in_proj_b, in_proj_a
compressed_tensors rejects checkpoints where sub-modules of a fused
layer have different quantization schemes. Our ignore list is shaped
around this within any fused layer, all components share a scheme.
That's the reason `in_proj_qkv` is ignored even though Unsloth's
sweep doesn't single it out, and the reason late-stack attn override
covers q/k/v rather than just q/k.
MTP merge
---------
`Qwen3_5ForConditionalGeneration` doesn't expose the MTP submodule,
so `oneshot()` produces a checkpoint with the 15 `mtp.*` tensors
silently dropped. After quantization we read the MTP weights back
out of the upstream cached snapshot and splice them into the saved
safetensors at BF16. They're small (~850 MB) so quantizing them
isn't worth the calibration risk; speculative-decoding code paths
in vLLM expect the MTP head present.
Output
------
`OUTPUT_DIR` gets the FP8 model.safetensors + config + processor +
recipe.yaml. Vision tower stays BF16 (in `ignore`); LM Linears go
to FP8; norms, SSM internals (not Linear), and MTP tensors stay
BF16 untouched.
Verification at end: re-opens the saved safetensors and asserts
- vision .weight tensors present (>= 150; full count is 167)
- lm_head + embed_tokens at fp16/bf16 (NOT FP8)
- a sampled FP8'd Linear actually has float8 dtype
- 15 mtp.* tensors present
Run
---
~/vllm-venv/bin/python quantize_qwen3_6_mm.py
"""
from __future__ import annotations
import glob
import json
import sys
from pathlib import Path
import torch
from huggingface_hub import snapshot_download
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import AutoProcessor
from transformers.models.qwen3_5.modeling_qwen3_5 import (
Qwen3_5ForConditionalGeneration,
)
MODEL = "Qwen/Qwen3.6-27B"
OUTPUT_DIR = "/home/ubuntu/amygdala-training/Qwen3.6-27B-FP8-mm"
# Layers Unsloth's UD-Q8_K_XL keeps at F16 (perplexity-sensitive
# in their sweep). Late-stack clustering is consistent with the
# general finding that errors near the output propagate directly
# to logits.
LATE_FFN_LAYERS = (50, 51, 59, 62, 63)
LATE_ATTN_LAYERS = (51, 59, 63)
# Build the ignore regex list. Note: llmcompressor matches these
# patterns against MODULE names (no `.weight` suffix) when walking
# `named_modules()` for `targets=["Linear"]`. The first pass of
# this script used `\.weight$` patterns and silently quantized
# lm_head + every linear_attn projection — verified post-hoc by
# inspecting the saved safetensors. Patterns now anchor on `$`
# at the module name.
IGNORE_PATTERNS: list[str] = [
# Original recipe: lm_head and embeddings always full-precision.
# (embed_tokens is an Embedding, not a Linear, so it's already
# ignored by `targets=["Linear"]`. Pattern kept as belt-and-
# suspenders in case future llmcompressor versions widen the
# target set.)
"re:lm_head$",
"re:.*embed_tokens$",
# Vision tower — entire `model.visual.*` subtree (vision
# transformer blocks + merger + patch_embed + pos_embed).
# Unsloth ships the vision tower as a separate `mmproj-BF16.gguf`
# for GGUF consumers; in our single-file FP8 setup we just leave
# them at BF16.
"re:model\\.visual\\..*",
# MTP (multi-token prediction) module — Unsloth's GGUF doesn't
# carry MTP weights so we have no precision signal from them;
# safest to keep BF16.
"re:mtp\\..*",
# Linear-attention block — keep ENTIRELY at BF16. vLLM fuses
# `in_proj_qkv` and `in_proj_z` into a single `in_proj_qkvz`
# layer, and compressed_tensors rejects mixed schemes within a
# fused layer. Unsloth's recipe keeps z, a, b, out at F16/F32
# (gate/SSM internals are quantization-fragile in the GatedDeltaNet
# update), so the principled choice is to also keep `in_proj_qkv`
# at BF16 rather than FP8'ing the gate to match. We give up ~1 GB
# of FP8 coverage; in exchange we follow Unsloth's quality intent
# and load cleanly under vLLM. (`in_proj_a` + `in_proj_b` are
# likewise fused as `in_proj_ba` — both ignored, consistent.)
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_qkv$",
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_z$",
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_a$",
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_b$",
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.out_proj$",
# Per-layer high-precision MLP (Unsloth flagged exactly these
# late-stack indices in their UD-Q8_K_XL sensitivity sweep, all
# three of {gate, up, down} per layer). vLLM fuses gate+up into
# `gate_up_proj`; ignoring both keeps the fused layer consistent.
# `down_proj` is its own (non-fused) layer.
"re:model\\.language_model\\.layers\\.("
+ "|".join(str(n) for n in LATE_FFN_LAYERS)
+ ")\\.mlp\\.(down|gate|up)_proj$",
# Per-layer high-precision attention q/k/v (Unsloth's sweep upgrades
# only q and k; we extend to v because vLLM fuses q/k/v into
# `qkv_proj` and rejects mixed schemes. `o_proj` is its own
# non-fused layer and stays at FP8.
"re:model\\.language_model\\.layers\\.("
+ "|".join(str(n) for n in LATE_ATTN_LAYERS)
+ ")\\.self_attn\\.(q|k|v)_proj$",
]
def main() -> None:
print(f"Loading {MODEL} as multimodal "
f"(Qwen3_5ForConditionalGeneration)...", flush=True)
model = Qwen3_5ForConditionalGeneration.from_pretrained(
MODEL,
dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(f" loaded: {model.__class__.__name__}", flush=True)
print(f"Loading processor (text + image preprocessing)...", flush=True)
processor = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True)
print("Running FP8_DYNAMIC oneshot quantization...", flush=True)
print(f" ignore list: {len(IGNORE_PATTERNS)} patterns",
flush=True)
recipe = QuantizationModifier(
targets=["Linear"],
scheme="FP8_DYNAMIC",
ignore=IGNORE_PATTERNS,
)
oneshot(model=model, recipe=recipe, output_dir=OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print(f" wrote model + processor to {OUTPUT_DIR}", flush=True)
merge_mtp(OUTPUT_DIR)
verify_output(OUTPUT_DIR)
def merge_mtp(out_dir: str) -> None:
"""Splice upstream MTP tensors into the saved FP8 safetensors.
`Qwen3_5ForConditionalGeneration` skips the MTP submodule on load,
so oneshot's output is missing the 15 `mtp.*` tensors. We resolve
the upstream snapshot via the HF cache (already populated by
from_pretrained), pull just the MTP tensors out at BF16, and
rewrite the safetensors with them merged in. The compressed_tensors
metadata header (which carries the FP8 format identifier vLLM
needs to dequantize) is preserved verbatim.
Atomic-rename is used so a crash mid-write doesn't corrupt the
33+ GB checkpoint we just spent minutes producing.
"""
print("\nMerging upstream MTP tensors...", flush=True)
upstream_dir = Path(snapshot_download(
MODEL,
allow_patterns=["model.safetensors.index.json",
"model-*-of-*.safetensors"],
))
with open(upstream_dir / "model.safetensors.index.json") as f:
idx = json.load(f)
mtp_shards = sorted({v for k, v in idx["weight_map"].items()
if k.startswith("mtp.")})
print(f" MTP tensors live in shards: {mtp_shards}", flush=True)
mtp_tensors: dict[str, torch.Tensor] = {}
for shard in mtp_shards:
with safe_open(upstream_dir / shard, framework="pt") as f:
for k in f.keys():
if k.startswith("mtp."):
mtp_tensors[k] = f.get_tensor(k).contiguous()
mtp_bytes = sum(t.numel() * t.element_size()
for t in mtp_tensors.values())
print(f" loaded {len(mtp_tensors)} mtp tensors "
f"({mtp_bytes/1e6:.1f} MB)", flush=True)
fp8_files = sorted(Path(out_dir).glob("*.safetensors"))
if len(fp8_files) != 1:
sys.exit(f"FAIL: expected single safetensors shard, "
f"got {fp8_files}")
existing_path = fp8_files[0]
with safe_open(existing_path, framework="pt") as f:
metadata = f.metadata() or {}
all_tensors = {k: f.get_tensor(k) for k in f.keys()}
overlap = set(all_tensors) & set(mtp_tensors)
if overlap:
sys.exit(f"FAIL: MTP key collision with FP8 output: "
f"{sorted(overlap)[:5]}")
all_tensors.update(mtp_tensors)
tmp_path = existing_path.with_name(existing_path.name + ".new")
print(f" rewriting {existing_path.name} "
f"({len(all_tensors)} tensors)...", flush=True)
save_file(all_tensors, str(tmp_path), metadata=metadata)
tmp_path.replace(existing_path)
print(" done", flush=True)
def verify_output(out_dir: str) -> None:
"""Open the saved safetensors and assert the recipe actually
landed: vision tower present at BF16, FP8 dtype on at least one
quantized Linear, lm_head not FP8."""
print(f"\nVerifying {out_dir}...", flush=True)
files = sorted(glob.glob(f"{out_dir}/*.safetensors"))
if not files:
sys.exit(f"FAIL: no safetensors in {out_dir}")
vision_keys: list[tuple[str, str]] = []
fp8_sample: tuple[str, str] | None = None
lm_head_dtype: str | None = None
mtp_keys: list[str] = []
for fp in files:
with safe_open(fp, framework="pt") as f:
for k in f.keys():
if k.startswith("mtp."):
mtp_keys.append(k)
# Some FP8 quants write a sibling `_scale` / `_zero_point`;
# we just care about the .weight tensors.
if not k.endswith(".weight"):
continue
t = f.get_tensor(k)
dtype = str(t.dtype).replace("torch.", "")
if "model.visual." in k:
vision_keys.append((k, dtype))
if k == "lm_head.weight":
lm_head_dtype = dtype
if (fp8_sample is None
and "float8" in dtype
and "language_model.layers" in k):
fp8_sample = (k, dtype)
# Qwen3.6-27B has 167 vision `.weight` tensors (333 vision tensors
# total, the rest are `.bias` and per-block norms). 150 is a
# sanity floor that catches "vision tower didn't make it through"
# without being brittle to minor arch revisions.
if len(vision_keys) < 150:
sys.exit(f"FAIL: only {len(vision_keys)} vision tensors found "
f"(expected >= 150). Vision tower didn't make it "
f"through the quant.")
bad_vision = [(k, d) for k, d in vision_keys if "float8" in d]
if bad_vision:
sys.exit(f"FAIL: vision weights got quantized to FP8: "
f"{bad_vision[:3]}...")
if lm_head_dtype is None:
sys.exit("FAIL: lm_head.weight not found in output.")
if "float8" in lm_head_dtype:
sys.exit(f"FAIL: lm_head.weight is FP8 ({lm_head_dtype}); "
f"should be BF16/FP16.")
if fp8_sample is None:
sys.exit("FAIL: no FP8 weights found in language_model.layers — "
"the recipe didn't quantize anything.")
# Upstream Qwen3.6-27B has exactly 15 mtp.* tensors (1 fused
# transformer block + projection + norms). merge_mtp() should
# have spliced all of them in.
if len(mtp_keys) != 15:
sys.exit(f"FAIL: expected 15 mtp.* tensors, found "
f"{len(mtp_keys)}. merge_mtp() missed some.")
print(f"{len(vision_keys)} vision tensors at "
f"{vision_keys[0][1]} (not FP8)")
print(f" ✓ lm_head.weight at {lm_head_dtype} (not FP8)")
print(f" ✓ FP8 sample: {fp8_sample[0]} = {fp8_sample[1]}")
print(f"{len(mtp_keys)} mtp.* tensors present")
print("DONE")
if __name__ == "__main__":
main()