GGUF: optional Metal dequant fast path via kernels-community#45975
GGUF: optional Metal dequant fast path via kernels-community#45975ArthurZucker wants to merge 3 commits into
Conversation
…uant When the `kernels` package is installed and `kernels-community/gguf-dequant` has a build for the current torch/Apple-Silicon combo, route Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K / IQ4_NL / IQ4_XS dequant on MPS through one Metal compute kernel per tensor — 2.3-7.5× faster than the chained PyTorch path on M3 Max at real-world tensor sizes (apples-to-apples with input already on device). Lazy import on first dequant; opt-out with `TRANSFORMERS_GGUF_USE_METAL_KERNELS=0`. Falls back to the existing pure-torch path for CPU/CUDA, unsupported quant types, or when the Hub doesn't have a build variant for the current env. No behavioural change for any path that doesn't have the package installed.
When `transformers serve` runs on Apple Silicon (`--device auto` or `mps`) with `kernels` installed and no explicit `--attn-implementation` flag, default the attention to `kernels-community/metal-flash-sdpa` instead of plain SDPA. On the 100-sample gsm8k benchmark (Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch) it's a 1.66x throughput improvement (158 -> 256 tok/s) with token-for-token parity for greedy decoding. Users who don't want it can opt out with `--attn-implementation sdpa`. Help text on the `--attn-implementation` flag also now lists the kernels-hub syntax explicitly.
Test plan results (M3 Max, torch 2.11, kernel loaded via
|
| Model | baseline (s) | Metal kernel (s) | speedup |
|---|---|---|---|
| TinyLlama-1.1B Q4_K_M | 3.96 | 3.44 | 1.15× |
| Qwen2.5-0.5B Q4_0 | 9.99 | 9.96 | 1.00× |
| Llama-3.2-3B Q4_K_M | 13.22 | 12.79 | 1.03× |
Modest at the load-wall level (1.0–1.15×) — dequant is one phase of model load; the I/O, converter chain, MPS allocator, and module init dominate at these sizes. The kernel's microbench-level 2.3–7.5× speedup compresses through the rest of the pipeline. Expect a bigger relative win on larger models where dequant is a larger share of total time.
Fallback (TRANSFORMERS_GGUF_USE_METAL_KERNELS=0)
The "baseline" column above is exactly this case — the existing pure-torch path. No regression introduced by the patch.
Caveats reproducing locally
kernel-builderemits Metal builds for torch 2.8 / 2.9 / 2.10 only. On torch 2.11 nightly the Hub-fetch path 404s on the variant lookup, so the test usedget_local_kernelto load the torch 2.10 build directly. Hub-side rebuild (or stage the 2.10 build as 2.11) needed before merge for the default install path to work on the current nightly.- Kernel repo still at
ArthurZ/gguf-dequant; transfer tokernels-communityin flight. Override viaTRANSFORMERS_GGUF_METAL_KERNELS_REPO=….
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Wall-time bench, full set (median of 2,
|
| Model | baseline (s) | Metal kernel (s) | speedup |
|---|---|---|---|
| TinyLlama-1.1B Q4_K_M | 3.27 | 3.15 | 1.04× |
| Qwen2.5-0.5B Q4_0 | 8.62 | 8.80 | 0.98× |
| Llama-3.2-3B Q4_K_M | 11.16 | 11.55 | 0.97× |
| Gemma-2-2B Q4_K_M | 13.31 | 13.18 | 1.01× |
| Gemma-3-4B Q4_0 | 12.04 | 12.23 | 0.98× |
| Qwen1.5-MoE-A2.7B Q4_K_M | 14.15 | 10.63 | 1.33× 🚀 |
The MoE case is the headline — 60 experts × multiple Q4_K projections per layer, so dequant is a meaningful chunk of total load wall time. 3.5 seconds saved on a 14 s load.
Dense models sit at 0.97–1.04× — within run-to-run noise. At these sizes file I/O, the converter chain, and module init dominate over per-tensor dequant. Where the kernel pays off in absolute terms is models with lots of expert tensors (Mixtral, Qwen-MoE, DeepSeek-MoE).
Summary
Adds an optional Metal dequant fast path in
GGUFDequantize. When thekernelspackage is installed andkernels-community/gguf-dequantis reachable, MPS GGUF loads (Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K / IQ4_NL / IQ4_XS) route through one Metal compute kernel per tensor — 2.3–7.5× faster than the chained PyTorch path on M3 Max at real-world tensor sizes. Falls back to the existing pure-torch path on CPU/CUDA, for unsupported quant types, or when the Hub doesn't have a build variant for the current env. Lazy import + opt-out viaTRANSFORMERS_GGUF_USE_METAL_KERNELS=0. Stacked on #44794.Kernel currently published at
ArthurZ/gguf-dequantwhile the transfer tokernels-communityis in flight; override withTRANSFORMERS_GGUF_METAL_KERNELS_REPO=<repo-id>.Bench (M3 Max, 8 M elems per case, input already on GPU)
Test plan
RUN_SLOW=1 pytest tests/quantization/ggml/test_gguf_load_completeness.py -k Q4_Kwithpip install kernels+TRANSFORMERS_GGUF_METAL_KERNELS_REPO=ArthurZ/gguf-dequant; all 7 cells pass with no missing/unexpected keys.TRANSFORMERS_GGUF_USE_METAL_KERNELS=0— confirms fallback path still works.AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", gguf_file="…Q4_K_M.gguf", device_map="mps")end-to-end load wall time before/after.