forked from kent/consciousness
327 lines
13 KiB
Python
327 lines
13 KiB
Python
|
|
"""Quantize Qwen3.6-27B (multimodal) to FP8 for vLLM serving.
|
||
|
|
|
||
|
|
Why this exists
|
||
|
|
---------------
|
||
|
|
The earlier `quantize_qwen3_6.py` (in shell history, never committed)
|
||
|
|
loaded the model with `AutoModelForCausalLM`, which silently strips
|
||
|
|
the multimodal arch. Result: an FP8 checkpoint with no vision tower
|
||
|
|
weights at all. vLLM happily instantiated the vision tower from the
|
||
|
|
config and ran it with default/uninitialized weights, producing
|
||
|
|
gibberish image features and `!!!!!!`-style output. We chased that
|
||
|
|
through the protocol layer for a long time before tracing it back
|
||
|
|
to the quant. This script avoids that trap by loading via the
|
||
|
|
config-declared class explicitly.
|
||
|
|
|
||
|
|
Recipe
|
||
|
|
------
|
||
|
|
FP8_DYNAMIC (per-channel weight scales, per-token dynamic activation
|
||
|
|
scales, both E4M3) for Linear weights, with an `ignore` list derived
|
||
|
|
from Unsloth's UD-Q8_K_XL (`unsloth/Qwen3.6-27B-GGUF`). Their
|
||
|
|
sensitivity sweep flagged specific layers as quantization-fragile;
|
||
|
|
we honor those layer indices even though their algorithm is
|
||
|
|
GGUF-native Q8_K and ours is FP8 — sensitivity is a layer property,
|
||
|
|
not an algorithm property.
|
||
|
|
|
||
|
|
vLLM fusion constraint
|
||
|
|
~~~~~~~~~~~~~~~~~~~~~~
|
||
|
|
vLLM's Qwen3.5/3.6 model code fuses sub-modules at load time:
|
||
|
|
qkv_proj ← q_proj, k_proj, v_proj
|
||
|
|
gate_up_proj ← gate_proj, up_proj
|
||
|
|
in_proj_qkvz ← in_proj_qkv, in_proj_z
|
||
|
|
in_proj_ba ← in_proj_b, in_proj_a
|
||
|
|
compressed_tensors rejects checkpoints where sub-modules of a fused
|
||
|
|
layer have different quantization schemes. Our ignore list is shaped
|
||
|
|
around this — within any fused layer, all components share a scheme.
|
||
|
|
That's the reason `in_proj_qkv` is ignored even though Unsloth's
|
||
|
|
sweep doesn't single it out, and the reason late-stack attn override
|
||
|
|
covers q/k/v rather than just q/k.
|
||
|
|
|
||
|
|
MTP merge
|
||
|
|
---------
|
||
|
|
`Qwen3_5ForConditionalGeneration` doesn't expose the MTP submodule,
|
||
|
|
so `oneshot()` produces a checkpoint with the 15 `mtp.*` tensors
|
||
|
|
silently dropped. After quantization we read the MTP weights back
|
||
|
|
out of the upstream cached snapshot and splice them into the saved
|
||
|
|
safetensors at BF16. They're small (~850 MB) so quantizing them
|
||
|
|
isn't worth the calibration risk; speculative-decoding code paths
|
||
|
|
in vLLM expect the MTP head present.
|
||
|
|
|
||
|
|
Output
|
||
|
|
------
|
||
|
|
`OUTPUT_DIR` gets the FP8 model.safetensors + config + processor +
|
||
|
|
recipe.yaml. Vision tower stays BF16 (in `ignore`); LM Linears go
|
||
|
|
to FP8; norms, SSM internals (not Linear), and MTP tensors stay
|
||
|
|
BF16 untouched.
|
||
|
|
|
||
|
|
Verification at end: re-opens the saved safetensors and asserts
|
||
|
|
- vision .weight tensors present (>= 150; full count is 167)
|
||
|
|
- lm_head + embed_tokens at fp16/bf16 (NOT FP8)
|
||
|
|
- a sampled FP8'd Linear actually has float8 dtype
|
||
|
|
- 15 mtp.* tensors present
|
||
|
|
|
||
|
|
Run
|
||
|
|
---
|
||
|
|
~/vllm-venv/bin/python quantize_qwen3_6_mm.py
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import glob
|
||
|
|
import json
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from huggingface_hub import snapshot_download
|
||
|
|
from llmcompressor import oneshot
|
||
|
|
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||
|
|
from safetensors import safe_open
|
||
|
|
from safetensors.torch import save_file
|
||
|
|
from transformers import AutoProcessor
|
||
|
|
from transformers.models.qwen3_5.modeling_qwen3_5 import (
|
||
|
|
Qwen3_5ForConditionalGeneration,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
MODEL = "Qwen/Qwen3.6-27B"
|
||
|
|
OUTPUT_DIR = "/home/ubuntu/amygdala-training/Qwen3.6-27B-FP8-mm"
|
||
|
|
|
||
|
|
|
||
|
|
# Layers Unsloth's UD-Q8_K_XL keeps at F16 (perplexity-sensitive
|
||
|
|
# in their sweep). Late-stack clustering is consistent with the
|
||
|
|
# general finding that errors near the output propagate directly
|
||
|
|
# to logits.
|
||
|
|
LATE_FFN_LAYERS = (50, 51, 59, 62, 63)
|
||
|
|
LATE_ATTN_LAYERS = (51, 59, 63)
|
||
|
|
|
||
|
|
|
||
|
|
# Build the ignore regex list. Note: llmcompressor matches these
|
||
|
|
# patterns against MODULE names (no `.weight` suffix) when walking
|
||
|
|
# `named_modules()` for `targets=["Linear"]`. The first pass of
|
||
|
|
# this script used `\.weight$` patterns and silently quantized
|
||
|
|
# lm_head + every linear_attn projection — verified post-hoc by
|
||
|
|
# inspecting the saved safetensors. Patterns now anchor on `$`
|
||
|
|
# at the module name.
|
||
|
|
IGNORE_PATTERNS: list[str] = [
|
||
|
|
# Original recipe: lm_head and embeddings always full-precision.
|
||
|
|
# (embed_tokens is an Embedding, not a Linear, so it's already
|
||
|
|
# ignored by `targets=["Linear"]`. Pattern kept as belt-and-
|
||
|
|
# suspenders in case future llmcompressor versions widen the
|
||
|
|
# target set.)
|
||
|
|
"re:lm_head$",
|
||
|
|
"re:.*embed_tokens$",
|
||
|
|
|
||
|
|
# Vision tower — entire `model.visual.*` subtree (vision
|
||
|
|
# transformer blocks + merger + patch_embed + pos_embed).
|
||
|
|
# Unsloth ships the vision tower as a separate `mmproj-BF16.gguf`
|
||
|
|
# for GGUF consumers; in our single-file FP8 setup we just leave
|
||
|
|
# them at BF16.
|
||
|
|
"re:model\\.visual\\..*",
|
||
|
|
|
||
|
|
# MTP (multi-token prediction) module — Unsloth's GGUF doesn't
|
||
|
|
# carry MTP weights so we have no precision signal from them;
|
||
|
|
# safest to keep BF16.
|
||
|
|
"re:mtp\\..*",
|
||
|
|
|
||
|
|
# Linear-attention block — keep ENTIRELY at BF16. vLLM fuses
|
||
|
|
# `in_proj_qkv` and `in_proj_z` into a single `in_proj_qkvz`
|
||
|
|
# layer, and compressed_tensors rejects mixed schemes within a
|
||
|
|
# fused layer. Unsloth's recipe keeps z, a, b, out at F16/F32
|
||
|
|
# (gate/SSM internals are quantization-fragile in the GatedDeltaNet
|
||
|
|
# update), so the principled choice is to also keep `in_proj_qkv`
|
||
|
|
# at BF16 rather than FP8'ing the gate to match. We give up ~1 GB
|
||
|
|
# of FP8 coverage; in exchange we follow Unsloth's quality intent
|
||
|
|
# and load cleanly under vLLM. (`in_proj_a` + `in_proj_b` are
|
||
|
|
# likewise fused as `in_proj_ba` — both ignored, consistent.)
|
||
|
|
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_qkv$",
|
||
|
|
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_z$",
|
||
|
|
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_a$",
|
||
|
|
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.in_proj_b$",
|
||
|
|
"re:model\\.language_model\\.layers\\.\\d+\\.linear_attn\\.out_proj$",
|
||
|
|
|
||
|
|
# Per-layer high-precision MLP (Unsloth flagged exactly these
|
||
|
|
# late-stack indices in their UD-Q8_K_XL sensitivity sweep, all
|
||
|
|
# three of {gate, up, down} per layer). vLLM fuses gate+up into
|
||
|
|
# `gate_up_proj`; ignoring both keeps the fused layer consistent.
|
||
|
|
# `down_proj` is its own (non-fused) layer.
|
||
|
|
"re:model\\.language_model\\.layers\\.("
|
||
|
|
+ "|".join(str(n) for n in LATE_FFN_LAYERS)
|
||
|
|
+ ")\\.mlp\\.(down|gate|up)_proj$",
|
||
|
|
|
||
|
|
# Per-layer high-precision attention q/k/v (Unsloth's sweep upgrades
|
||
|
|
# only q and k; we extend to v because vLLM fuses q/k/v into
|
||
|
|
# `qkv_proj` and rejects mixed schemes. `o_proj` is its own
|
||
|
|
# non-fused layer and stays at FP8.
|
||
|
|
"re:model\\.language_model\\.layers\\.("
|
||
|
|
+ "|".join(str(n) for n in LATE_ATTN_LAYERS)
|
||
|
|
+ ")\\.self_attn\\.(q|k|v)_proj$",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
print(f"Loading {MODEL} as multimodal "
|
||
|
|
f"(Qwen3_5ForConditionalGeneration)...", flush=True)
|
||
|
|
model = Qwen3_5ForConditionalGeneration.from_pretrained(
|
||
|
|
MODEL,
|
||
|
|
dtype=torch.bfloat16,
|
||
|
|
device_map="auto",
|
||
|
|
trust_remote_code=True,
|
||
|
|
)
|
||
|
|
print(f" loaded: {model.__class__.__name__}", flush=True)
|
||
|
|
|
||
|
|
print(f"Loading processor (text + image preprocessing)...", flush=True)
|
||
|
|
processor = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True)
|
||
|
|
|
||
|
|
print("Running FP8_DYNAMIC oneshot quantization...", flush=True)
|
||
|
|
print(f" ignore list: {len(IGNORE_PATTERNS)} patterns",
|
||
|
|
flush=True)
|
||
|
|
recipe = QuantizationModifier(
|
||
|
|
targets=["Linear"],
|
||
|
|
scheme="FP8_DYNAMIC",
|
||
|
|
ignore=IGNORE_PATTERNS,
|
||
|
|
)
|
||
|
|
oneshot(model=model, recipe=recipe, output_dir=OUTPUT_DIR)
|
||
|
|
processor.save_pretrained(OUTPUT_DIR)
|
||
|
|
print(f" wrote model + processor to {OUTPUT_DIR}", flush=True)
|
||
|
|
|
||
|
|
merge_mtp(OUTPUT_DIR)
|
||
|
|
verify_output(OUTPUT_DIR)
|
||
|
|
|
||
|
|
|
||
|
|
def merge_mtp(out_dir: str) -> None:
|
||
|
|
"""Splice upstream MTP tensors into the saved FP8 safetensors.
|
||
|
|
|
||
|
|
`Qwen3_5ForConditionalGeneration` skips the MTP submodule on load,
|
||
|
|
so oneshot's output is missing the 15 `mtp.*` tensors. We resolve
|
||
|
|
the upstream snapshot via the HF cache (already populated by
|
||
|
|
from_pretrained), pull just the MTP tensors out at BF16, and
|
||
|
|
rewrite the safetensors with them merged in. The compressed_tensors
|
||
|
|
metadata header (which carries the FP8 format identifier vLLM
|
||
|
|
needs to dequantize) is preserved verbatim.
|
||
|
|
|
||
|
|
Atomic-rename is used so a crash mid-write doesn't corrupt the
|
||
|
|
33+ GB checkpoint we just spent minutes producing.
|
||
|
|
"""
|
||
|
|
print("\nMerging upstream MTP tensors...", flush=True)
|
||
|
|
upstream_dir = Path(snapshot_download(
|
||
|
|
MODEL,
|
||
|
|
allow_patterns=["model.safetensors.index.json",
|
||
|
|
"model-*-of-*.safetensors"],
|
||
|
|
))
|
||
|
|
|
||
|
|
with open(upstream_dir / "model.safetensors.index.json") as f:
|
||
|
|
idx = json.load(f)
|
||
|
|
mtp_shards = sorted({v for k, v in idx["weight_map"].items()
|
||
|
|
if k.startswith("mtp.")})
|
||
|
|
print(f" MTP tensors live in shards: {mtp_shards}", flush=True)
|
||
|
|
|
||
|
|
mtp_tensors: dict[str, torch.Tensor] = {}
|
||
|
|
for shard in mtp_shards:
|
||
|
|
with safe_open(upstream_dir / shard, framework="pt") as f:
|
||
|
|
for k in f.keys():
|
||
|
|
if k.startswith("mtp."):
|
||
|
|
mtp_tensors[k] = f.get_tensor(k).contiguous()
|
||
|
|
mtp_bytes = sum(t.numel() * t.element_size()
|
||
|
|
for t in mtp_tensors.values())
|
||
|
|
print(f" loaded {len(mtp_tensors)} mtp tensors "
|
||
|
|
f"({mtp_bytes/1e6:.1f} MB)", flush=True)
|
||
|
|
|
||
|
|
fp8_files = sorted(Path(out_dir).glob("*.safetensors"))
|
||
|
|
if len(fp8_files) != 1:
|
||
|
|
sys.exit(f"FAIL: expected single safetensors shard, "
|
||
|
|
f"got {fp8_files}")
|
||
|
|
existing_path = fp8_files[0]
|
||
|
|
|
||
|
|
with safe_open(existing_path, framework="pt") as f:
|
||
|
|
metadata = f.metadata() or {}
|
||
|
|
all_tensors = {k: f.get_tensor(k) for k in f.keys()}
|
||
|
|
|
||
|
|
overlap = set(all_tensors) & set(mtp_tensors)
|
||
|
|
if overlap:
|
||
|
|
sys.exit(f"FAIL: MTP key collision with FP8 output: "
|
||
|
|
f"{sorted(overlap)[:5]}")
|
||
|
|
all_tensors.update(mtp_tensors)
|
||
|
|
|
||
|
|
tmp_path = existing_path.with_name(existing_path.name + ".new")
|
||
|
|
print(f" rewriting {existing_path.name} "
|
||
|
|
f"({len(all_tensors)} tensors)...", flush=True)
|
||
|
|
save_file(all_tensors, str(tmp_path), metadata=metadata)
|
||
|
|
tmp_path.replace(existing_path)
|
||
|
|
print(" done", flush=True)
|
||
|
|
|
||
|
|
|
||
|
|
def verify_output(out_dir: str) -> None:
|
||
|
|
"""Open the saved safetensors and assert the recipe actually
|
||
|
|
landed: vision tower present at BF16, FP8 dtype on at least one
|
||
|
|
quantized Linear, lm_head not FP8."""
|
||
|
|
print(f"\nVerifying {out_dir}...", flush=True)
|
||
|
|
|
||
|
|
files = sorted(glob.glob(f"{out_dir}/*.safetensors"))
|
||
|
|
if not files:
|
||
|
|
sys.exit(f"FAIL: no safetensors in {out_dir}")
|
||
|
|
|
||
|
|
vision_keys: list[tuple[str, str]] = []
|
||
|
|
fp8_sample: tuple[str, str] | None = None
|
||
|
|
lm_head_dtype: str | None = None
|
||
|
|
mtp_keys: list[str] = []
|
||
|
|
|
||
|
|
for fp in files:
|
||
|
|
with safe_open(fp, framework="pt") as f:
|
||
|
|
for k in f.keys():
|
||
|
|
if k.startswith("mtp."):
|
||
|
|
mtp_keys.append(k)
|
||
|
|
# Some FP8 quants write a sibling `_scale` / `_zero_point`;
|
||
|
|
# we just care about the .weight tensors.
|
||
|
|
if not k.endswith(".weight"):
|
||
|
|
continue
|
||
|
|
t = f.get_tensor(k)
|
||
|
|
dtype = str(t.dtype).replace("torch.", "")
|
||
|
|
if "model.visual." in k:
|
||
|
|
vision_keys.append((k, dtype))
|
||
|
|
if k == "lm_head.weight":
|
||
|
|
lm_head_dtype = dtype
|
||
|
|
if (fp8_sample is None
|
||
|
|
and "float8" in dtype
|
||
|
|
and "language_model.layers" in k):
|
||
|
|
fp8_sample = (k, dtype)
|
||
|
|
|
||
|
|
# Qwen3.6-27B has 167 vision `.weight` tensors (333 vision tensors
|
||
|
|
# total, the rest are `.bias` and per-block norms). 150 is a
|
||
|
|
# sanity floor that catches "vision tower didn't make it through"
|
||
|
|
# without being brittle to minor arch revisions.
|
||
|
|
if len(vision_keys) < 150:
|
||
|
|
sys.exit(f"FAIL: only {len(vision_keys)} vision tensors found "
|
||
|
|
f"(expected >= 150). Vision tower didn't make it "
|
||
|
|
f"through the quant.")
|
||
|
|
|
||
|
|
bad_vision = [(k, d) for k, d in vision_keys if "float8" in d]
|
||
|
|
if bad_vision:
|
||
|
|
sys.exit(f"FAIL: vision weights got quantized to FP8: "
|
||
|
|
f"{bad_vision[:3]}...")
|
||
|
|
|
||
|
|
if lm_head_dtype is None:
|
||
|
|
sys.exit("FAIL: lm_head.weight not found in output.")
|
||
|
|
if "float8" in lm_head_dtype:
|
||
|
|
sys.exit(f"FAIL: lm_head.weight is FP8 ({lm_head_dtype}); "
|
||
|
|
f"should be BF16/FP16.")
|
||
|
|
|
||
|
|
if fp8_sample is None:
|
||
|
|
sys.exit("FAIL: no FP8 weights found in language_model.layers — "
|
||
|
|
"the recipe didn't quantize anything.")
|
||
|
|
|
||
|
|
# Upstream Qwen3.6-27B has exactly 15 mtp.* tensors (1 fused
|
||
|
|
# transformer block + projection + norms). merge_mtp() should
|
||
|
|
# have spliced all of them in.
|
||
|
|
if len(mtp_keys) != 15:
|
||
|
|
sys.exit(f"FAIL: expected 15 mtp.* tensors, found "
|
||
|
|
f"{len(mtp_keys)}. merge_mtp() missed some.")
|
||
|
|
|
||
|
|
print(f" ✓ {len(vision_keys)} vision tensors at "
|
||
|
|
f"{vision_keys[0][1]} (not FP8)")
|
||
|
|
print(f" ✓ lm_head.weight at {lm_head_dtype} (not FP8)")
|
||
|
|
print(f" ✓ FP8 sample: {fp8_sample[0]} = {fp8_sample[1]}")
|
||
|
|
print(f" ✓ {len(mtp_keys)} mtp.* tensors present")
|
||
|
|
print("DONE")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|