GgufLinear: inference-time GGUF matmul on Apple Silicon β llama.cpp parity#45977
Open
ArthurZucker wants to merge 3 commits into
Open
GgufLinear: inference-time GGUF matmul on Apple Silicon β llama.cpp parity#45977ArthurZucker wants to merge 3 commits into
ArthurZucker wants to merge 3 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Opt-in path that keeps GGUF weights at their native quantization after load and runs the forward pass through the kernels-community Metal kernels (ArthurZ/gguf-kernels). Same MSL kernels as llama.cpp β at decode (batch-1, memory-bound) we hit llama.cpp parity on M3 Max (266 tok/s vs 261, Qwen2.5-0.5B Q4_0). What lands: - `integrations/gguf_linear.GgufLinear` β drop-in nn.Linear replacement storing raw GGUF block bytes in `qweight`. Forward picks `mul_mat_vec_<fmt>_f32` for batch-1 (memory-bound) and `mul_mat_<fmt>_f32` for batch>1 (compute-bound). Supported quant types: Q4_0, Q8_0, Q4_K, Q5_K, Q6_K, IQ4_NL, IQ4_XS. CPU/CUDA fall back to dequant + torch.nn.functional.linear. - `integrations/gguf_linear.replace_with_gguf_linear(model, qmap)` β walks `model.named_modules()`, swaps each `nn.Linear` whose param name is in `qmap` with a `GgufLinear`. Re-quantizes the loaded fp32 weight via gguf-py so the swap is self-contained. - `quantizers.quantizer_gguf.GGUFQuantizer` β new `linear_mode`, `gguf_tensors` kwargs. When `linear_mode=True`, the post-load hook walks the GGUF tensor dict, applies the same rename rules the loader used, and calls `replace_with_gguf_linear` with the resulting `hf_name β quant_type` map. - `modeling_utils.from_pretrained` β picks up the new `gguf_linear` kwarg (or `TRANSFORMERS_GGUF_LINEAR=1` env var) and threads it into the quantizer. Default off so existing behaviour is unchanged. - `tests/quantization/ggml/test_gguf_linear.py` β unit tests for the module (forward matches dequant + nn.linear) and the swap helper. Memory positioning: weights stay at 4.5 bpw (Q4_K) instead of 16 bpw bf16 β 3.5Γ less RAM at inference time. This is the load-side dual to PR #44794's loader refactor: that PR makes load fast, this PR makes inference fit in memory.
When swapping nn.Linear β GgufLinear after load, the old path called gguf.quantize(fp32_weight) which (a) is unimplemented for K-quants and IQ4 in gguf-py and (b) gives non-bit-exact results vs llama.cpp. New path: pull the original GGUFQuantizedTensor's raw block bytes off the gguf_tensors map the quantizer captured at load, and copy them verbatim into GgufLinear.qweight. Byte-identical to llama.cpp on disk, works for every quant type, no precision round-trip. For attn_q / attn_k the GGUF bytes are stored in llama.cpp's permuted layout, so we still round-trip via fp32 + gguf.quantize (works for Q4_0/Q8_0; permuted K-quant Linears stay as nn.Linear with the dequantized weight, since K-quants have no Python re-quantizer). Validated on Q4_0 (TinyLlama, Qwen2.5-0.5B): 100% Linear β GgufLinear swap, max|Ξ| = 0.0 vs the dequant baseline. On Q4_K_M (Llama-3.2-3B): ~71% swap (140/197), the remaining attn_q/k stay as nn.Linear; logits still bit-identical because the un-swapped Linears use the same dequantized fp32 weight either way.
Qwen2-MoE-style models hold all expert weights in a single fused module
(``Qwen2MoeExperts``) with two large fp32 parameters per layer β these
parameters are 90%+ of the model's total memory footprint. ``GgufLinear``
swap doesn't touch them because they aren't ``nn.Linear``.
New ``GgufQwen2MoeExperts`` mirrors the original forward pass but keeps
the gate / up / down expert weights as flat uint8 quantized buffers
(one per projection β three buffers total). Forward iterates over
activated experts and per expert dispatches the matching
``mul_mat_vec_<fmt>_f32`` / ``mul_mat_<fmt>_f32`` Metal kernel against
the right byte slice.
Q4_K_M GGUFs mix quant types per tensor (typically gate/up = Q4_K,
down = Q8_0), so each projection carries its own quant type + format.
Wired into ``GGUFQuantizer._process_model_after_weight_loading``: after
the Linear swap, it groups ``ffn_{gate,up,down}_exps.weight`` tensors
by their HF parent path and hands them to ``replace_qwen2_moe_experts``.
Validated on Qwen1.5-MoE-A2.7B Q4_K_M: 12/24 MoE layers swapped (the
other 12 use Q5_0 down_proj which we don't ship a mat/matvec kernel for
yet), forward agrees with the fp32 baseline to within 8e-6 (fp32
accumulator noise).
56d3847 to
cb6ba16
Compare
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: ggml |
Contributor
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45977&sha=cb6ba1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Opt-in inference path for GGUF models that keeps weights at their native quantization (Q4_0 / Q8_0 / Q4_K / Q5_K / Q6_K / IQ4_NL / IQ4_XS) and runs the forward matmul through the same Metal kernels llama.cpp ships (lifted from candle, packaged as
ArthurZ/gguf-kernelsviakernel-builder).Two new pieces:
integrations/gguf_linear.GgufLinearβ drop-innn.Linearreplacement. Forward picksmul_mat_vec_<fmt>_f32for batch-1 (decode) andmul_mat_<fmt>_f32for batch>1 (prefill). CPU/CUDA fall back to dequant +nn.functional.linear.GGUFQuantizer.linear_modeβ when on, the post-load hook walks the GGUF tensor map, re-applies the same renames the loader used, and swaps each matchingnn.Linearfor aGgufLinearwith raw block bytes.Enabled via
from_pretrained(..., gguf_file=..., gguf_linear=True)orTRANSFORMERS_GGUF_LINEAR=1. Default off, no change to existing behaviour.Builds on top of #44794 (this PR targets
update-gguf). #44794 makes load fast; this PR makes inference fit in memory β weights stay at 4.5 bpw (Q4_K) instead of 16 bpw bf16 = 3.5Γ less RAM at inference time, the qualitative win on Apple Silicon.Performance vs llama.cpp
Measured on M3 Max (96 GB), Qwen2.5-0.5B-Instruct Q4_0:
The kernels are byte-identical to llama.cpp's (via candle). When dispatch is batched into a single command buffer per token (which PyTorch's MPS stream does naturally for ops on the same forward pass), throughput matches.
Per-kernel ceilings vs PyTorch's MPS bf16 path (same M3 Max, 8M-element shapes, input already on GPU):
Compute-bound prefill is a small throughput loss vs bf16 GEMM; bandwidth-bound batch-1 decode is a clear win. Memory savings are 3.5Γ regardless of workload.
Verification
gguf.dequantizeto 0 ULP (Q4_0 / Q8_0 / IQ4_NL) or fp32 reduction noise (~1e-4 for K-quants).tests/quantization/ggml/test_gguf_linear.pychecksGgufLinearforward againstdequant_gguf_tensor+nn.functional.linearfor Q4_0 and Q4_K, batch sizes 0/1/8.ArthurZ/gguf-kernels(tests/test_real_model_e2e.py) loads real Qwen2.5-0.5B Q4_0 with all 169 linears swapped toGgufLinear; top-5 logits match the dequant baseline to fp32 ULP.Kernels package
ArthurZ/gguf-kernelsships pre-compiled metallibs for torch 2.8 / 2.9 / 2.10 / 2.11 on Apple Silicon. Built viakernel-builder(Nix + Xcode 26 + Metal Toolchain). Source: same MSL kernels asllama.cpp/candle.Override the repo with
TRANSFORMERS_GGUF_METAL_KERNELS_REPO=...(so the personal-namespace location works ahead of any kernels-community transfer).What's out of scope
Test plan
python -m pytest tests/quantization/ggml/test_gguf_linear.pyβ 4 passedfrom_pretrained(..., gguf_linear=True)end-to-end (blocked on a separate shape bug inupdate-gguf's dequant for some configs β surfaces asreplace_with_gguf_linearwarnings and the layer stays asnn.Linear, so behaviour is safe but the GgufLinear path isn't exercised end-to-end on that config yet)