add TurboQuant KV cache compression#232
Conversation
Implements TurboQuant (arXiv:2504.19874) for KV cache compression: WHT rotation + Lloyd-Max codebook quantization with asymmetric K/V bit-width support. Compresses KV cache 3-7x with minimal quality loss. New files: - TurboQuantKVCache.swift: two-phase cache (raw prefill, compressed decode) - TurboQuantKernels.swift: Metal kernels for fused encode and attention Modified: - KVCache.swift: kvScheme routing for turbo schemes - AttentionUtils.swift: TurboQuantKVCache dispatch - Evaluate.swift: kvScheme parameter threading
| ], | ||
| dependencies: [ | ||
| .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), | ||
| .package(url: "https://github.com/ekryski/mlx-swift", branch: "alpha"), |
There was a problem hiding this comment.
You cannot change this library, I mean this needs to reference the official ml-explore mlx-swift, if you need to add there, you need to first add a pull request there and get it merged, this is no sense.
There was a problem hiding this comment.
This PR is still in draft, pointing to the wrong library. I would hold off on review. Defiant pointing at the wrong library
There was a problem hiding this comment.
This will be fixed in the next push. No intention of switching the dependencies.
There was a problem hiding this comment.
Out-of-bounds threadgroup memory in fused encode kernels (can corrupt results/crash for head_dim > 128).
In both dense and WHT encode kernels, shared_norm is hardcoded to 4 entries, but indexed by sg_id = d/32 and read up to num_groups = (Dim+31)/32. For Dim=256, this accesses indices 0...7.
There was a problem hiding this comment.
Good catch, fixed in next push. shared_norm is now sized [(Dim + 31) / 32] instead of hardcoded 4. Current models all use head_dim=128 so this wasn't hit in practice but would break on 256.
| } | ||
|
|
||
| // KV head mapping (use first query's head — same assumption as non-causal NR0) | ||
| uint q_head_idx_0 = (query_group * NR0) / L; |
There was a problem hiding this comment.
NR0 path derives kv_idx from only the first query row in the group. If a group spans multiple query heads (possible when grouping is not aligned to L), some rows will read/write against the wrong KV head. Guarding only on totalQ % nr0 == 0 is insufficient for correctness.
There was a problem hiding this comment.
Added a queryChunkLength % nr0 == 0 in, next push, guard so the NR0 path only activates when L aligns with the group size. Falls back to per-row dispatch otherwise.
There was a problem hiding this comment.
TurboQuant cache cannot be restored correctly from prompt cache. TurboQuantKVCache is not represented in cacheClassName, so it is serialized as KVCache. On load it is restored as KVCacheSimple, but TurboQuant compressed state carries 3/4 arrays rather than KVCacheSimple’s required 2 arrays, leading to invalid restoration behavior. Add explicit TurboQuant class name mapping and restore path.
There was a problem hiding this comment.
Fixed in nex push. Added TurboQuantKVCache to cacheClassName and restoreCacheFromMetaState. metaState now carries bits, keyBits, valueBits, and seed so the cache can be reconstructed correctly on load.
|
@aleroot I appreciate the early reviews but this is still in draft. |
| self.maxTokens = parameters.maxTokens | ||
| self.numDraftTokens = numDraftTokens | ||
|
|
||
| self.quantizeKVCache = { cache in |
There was a problem hiding this comment.
kvScheme dropped in speculative decoding quantization.TurboQuant scheme selection is ignored for speculative generation despite being part of GenerateParameters...
There was a problem hiding this comment.
Yep, missed that. Threaded kvScheme through the speculative quantization closure now.
aleroot
left a comment
There was a problem hiding this comment.
The implementation is ambitious and has strong ideas, and honestly is something I wanted to work on as well... At the moment this pull request is just a prototype as it does not look like a full inner-product TurboQuant formulation ...
There are no tests as well that are really proving is working.
Appreciate your early feedback. Still a work in progress and know that this is an ambitious PR as is. Thank you for the draft comments! |
Sorry, I did not notice it was a draft . |
All good let's work together on this. It's a big lift and it's hard to separate into individual PRs. Lots of work ahead. |
Follow up as i was on the road for my last comment: Still polishing this, the PPL data in the description shows it working across Llama, Qwen, Mistral, and Qwen3.5 MoE. Happy to collaborate on getting this to a state that works for the project. If you have ideas on the formulation or testing approach I'm all ears. (we can also do email or discord or whatever works) |
…kvScheme - size shared_norm dynamically in Metal encode kernels for head_dim > 128 - guard NR0 flash path on query chunk alignment to prevent KV head mismatch - add TurboQuantKVCache to prompt cache save/restore with metaState - thread kvScheme through speculative decoding quantization closure - point Package.swift at ml-explore/mlx-swift 0.31.3
4921784 to
2864b83
Compare
Summary
Implements TurboQuant KV cache compression -- WHT rotation + Lloyd-Max codebook quantization with support for asymmetric K/V bit-widths. Compresses KV cache 3-7x with minimal quality loss.
TurboQuantKVCache-- two-phase cache (raw prefill, compressed decode)TurboQuantKernels-- Metal kernels for fused encode and compressed-domain attentionKVCache.swift,AttentionUtils.swift,Evaluate.swift-- scheme routing and parameter threadingDependencies
Requires ml-explore/mlx-c#113 -- guards
mlx_array_dimagainst 0-dim arrays. Without this, Swift metadata init on 0-dim MLXArray crashes during TurboQuant codec setup.Schemes
turbo4turbo3turbo2turbo4v2turbo4v3turbo0v4turbo0v2PPL data (Genesis text, 512 tokens)
Multi-model (turbo4 -- recommended default)
All schemes (Llama-3.2-3B)
Context length stability (Llama-3.2-3B)
Design
turbo0v*schemes keep keys at FP16 while compressing only values -- best quality/compression tradeoff when memory allowsKnown limitations
turbo8(8-bit) encode is prohibitively slow due to 256-centroid codebook search -- practical limit is 4-bitFeedback welcome
This is a draft -- would love feedback on:
metalKernelvs framework-level dispatch)🤖 Generated with Claude Code