diff --git a/Package.swift b/Package.swift index 17a4178f..166d7c61 100644 --- a/Package.swift +++ b/Package.swift @@ -203,10 +203,12 @@ let cmlx = Target.target( cSettings: [ .headerSearchPath("mlx"), .headerSearchPath("mlx-c"), + .headerSearchPath("turbo-quant"), ], cxxSettings: cxxSettings + [ .headerSearchPath("mlx"), .headerSearchPath("mlx-c"), + .headerSearchPath("turbo-quant"), .headerSearchPath("json/single_include/nlohmann"), .headerSearchPath("fmt/include"), .define("MLX_VERSION", to: "\"0.31.1\""), diff --git a/Source/Cmlx/include/mlx/c/fast.h b/Source/Cmlx/include/mlx/c/fast.h index c825d00e..aabdea6d 100644 --- a/Source/Cmlx/include/mlx/c/fast.h +++ b/Source/Cmlx/include/mlx/c/fast.h @@ -197,6 +197,48 @@ int mlx_fast_scaled_dot_product_attention( const mlx_array sinks /* may be null */, const mlx_stream s); +// TurboQuant KV cache compression +int mlx_fast_turbo_encode( + mlx_array* res_polar_k, + mlx_array* res_polar_v, + mlx_array* res_residual_k, + mlx_array* res_residual_v, + const mlx_array keys, + const mlx_array values, + int k_bits, + const mlx_stream s); + +int mlx_fast_turbo_decode_k( + mlx_array* res, + const mlx_array packed, + const mlx_stream s); + +int mlx_fast_turbo_decode_v( + mlx_array* res, + const mlx_array packed, + const mlx_stream s); + +// TurboQuant KV cache compression +int mlx_fast_turbo_encode( + mlx_array* res_polar_k, + mlx_array* res_polar_v, + mlx_array* res_residual_k, + mlx_array* res_residual_v, + const mlx_array keys, + const mlx_array values, + int k_bits, + const mlx_stream s); + +int mlx_fast_turbo_decode_k( + mlx_array* res, + const mlx_array packed, + const mlx_stream s); + +int mlx_fast_turbo_decode_v( + mlx_array* res, + const mlx_array packed, + const mlx_stream s); + /**@}*/ #ifdef __cplusplus diff --git a/Source/Cmlx/turbo-quant/turbo_quant.h b/Source/Cmlx/turbo-quant/turbo_quant.h new file mode 100644 index 00000000..8b941c87 --- /dev/null +++ b/Source/Cmlx/turbo-quant/turbo_quant.h @@ -0,0 +1,383 @@ +// Copyright © 2026 SharpAI +// turbo_quant.h — TurboQuant KV Cache compression for MLX +// +// Ported from TheTom/llama-cpp-turboquant (feature/turboquant-kv-cache) +// Primary sources: +// ggml/src/ggml-turbo-quant.c — CPU quantize/dequantize logic +// ggml/src/ggml-metal/turbo-wht.h — WHT sign arrays & rotation math +// Python validation: TheTom/turboquant_plus +// Paper: Zandieh et al., "TurboQuant", AISTATS/ICLR 2026 +// +// Algorithm summary: +// Stage 1 (PolarQuant, 2 bits for V; 2 bits within 3-bit for K): +// 1. Compute L2 norm of the head_dim vector +// 2. Normalize to unit sphere +// 3. Apply WHT rotation: D1 * FWHT * D2 (O(d log d)) +// 4. Quantize each coordinate to nearest Lloyd-Max centroid +// 5. Correct stored norm: grp_norm / recon_norm +// Stage 2 (QJL residual, 1 bit — K cache only, for inner-product bias removal): +// 1. Reconstruct MSE approximation, compute residual +// 2. Project residual via random Gaussian matrix S +// 3. Store sign bits of S @ residual + +#pragma once + +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/ops.h" +#include "mlx/utils.h" + +namespace mlx::core::fast { + +// --------------------------------------------------------------------------- +// Constants — must match turbo-wht.h and ggml-turbo-quant.c exactly (seed=42) +// --------------------------------------------------------------------------- + +static constexpr int TURBO_D = 128; // head_dim (rotation group) +static constexpr float TURBO_QJL_CONST = 1.2533141373155003f; // sqrt(pi/2) +static constexpr int TURBO_SEED_ROTATION = 42; +static constexpr int TURBO_SEED_QJL = 1042; + +// 3-bit Lloyd-Max centroids for N(0, 1/128) — from ggml-turbo-quant.c +static constexpr float CENTROIDS_3BIT[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +// 3-bit centroid decision boundaries (midpoints between adjacent centroids) +static constexpr float BOUNDARIES_3BIT[7] = { + -0.154259f, -0.091775f, -0.043589f, 0.000000f, + 0.043589f, 0.091775f, 0.154259f +}; + +// WHT sign arrays — seed=42, must match turbo-wht.h exactly +static constexpr float TURBO_S1[128] = { + -1,1,1,-1,-1,1,-1,1,-1,-1,1,1,1,1,1,1,1,-1,1,-1,1,-1,-1,1,1,1,-1,1,1,-1,-1,-1, + -1,1,1,-1,1,1,-1,1,-1,1,1,-1,-1,1,-1,1,1,1,1,-1,-1,-1,-1,-1,1,-1,1,1,1,1,-1,1, + -1,-1,1,-1,-1,-1,1,-1,-1,-1,1,-1,-1,-1,1,1,1,-1,-1,1,1,1,-1,-1,1,1,-1,1,1,-1,1,-1, + -1,1,1,-1,1,-1,1,-1,1,1,1,1,-1,1,-1,1,1,-1,1,1,-1,-1,-1,-1,-1,1,1,-1,1,1,-1,1 +}; + +static constexpr float TURBO_S2[128] = { + 1,1,1,1,-1,1,1,-1,1,-1,-1,-1,1,-1,-1,-1,1,1,-1,-1,1,-1,1,-1,1,-1,-1,1,-1,1,1,1, + 1,1,-1,-1,-1,1,-1,-1,-1,-1,-1,-1,1,1,1,-1,1,-1,1,1,1,-1,-1,1,-1,-1,-1,-1,-1,-1,1,1, + 1,-1,1,-1,-1,-1,-1,1,-1,1,-1,1,-1,-1,1,1,-1,1,-1,1,1,-1,1,-1,-1,-1,-1,1,-1,-1,1,-1, + 1,-1,1,1,1,-1,-1,1,-1,1,-1,1,1,-1,-1,1,-1,1,-1,1,1,-1,1,-1,1,-1,-1,-1,-1,-1,1,-1 +}; + +// QJL sign arrays — seed=1042, must match turbo-wht.h exactly +static constexpr float TURBO_QJL_S1[128] = { + 1,-1,-1,-1,-1,1,-1,1,1,-1,-1,1,-1,1,-1,1,1,-1,1,-1,-1,-1,1,1,-1,1,1,-1,1,-1,-1,1, + 1,1,1,1,-1,-1,1,1,-1,1,-1,-1,1,-1,1,1,1,-1,1,1,1,-1,-1,1,-1,1,-1,1,1,-1,1,1, + -1,-1,-1,1,1,1,1,1,1,-1,-1,1,1,-1,-1,-1,-1,-1,1,1,1,1,-1,1,1,-1,1,1,1,1,1,1, + 1,-1,1,-1,-1,1,-1,-1,-1,-1,1,-1,1,1,1,-1,-1,1,-1,1,1,1,-1,-1,1,-1,-1,-1,-1,-1,-1,-1 +}; + +static constexpr float TURBO_QJL_S2[128] = { + 1,1,-1,1,1,-1,1,1,-1,-1,1,1,1,-1,1,1,-1,-1,-1,1,-1,1,1,1,-1,1,-1,-1,-1,-1,1,1, + -1,-1,1,-1,1,1,-1,-1,-1,-1,-1,1,1,1,1,1,1,1,1,1,-1,-1,1,1,1,1,1,1,1,-1,1,1, + -1,-1,1,-1,1,1,-1,1,-1,-1,1,1,1,-1,1,-1,1,1,1,1,1,1,-1,1,-1,1,-1,1,-1,1,1,-1, + 1,-1,-1,1,1,-1,1,1,-1,1,1,1,-1,1,1,1,-1,-1,1,-1,1,-1,-1,1,-1,1,-1,1,1,1,1,-1 +}; + +// --------------------------------------------------------------------------- +// Fast Walsh-Hadamard Transform (in-place, normalized by 1/sqrt(n)) +// --------------------------------------------------------------------------- + +static inline void turbo_fwht(float* x, int n) { + for (int h = 1; h < n; h *= 2) { + for (int i = 0; i < n; i += h * 2) { + for (int j = i; j < i + h; j++) { + float a = x[j], b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } + const float inv_sqrt = (n == 128) ? 0.08838834764831845f : 0.125f; + for (int i = 0; i < n; i++) x[i] *= inv_sqrt; +} + +// Forward rotation: D1 @ FWHT @ D2 +static inline void turbo_rotate_forward(float* x, int n) { + for (int i = 0; i < n; i++) x[i] *= TURBO_S1[i]; + turbo_fwht(x, n); + for (int i = 0; i < n; i++) x[i] *= TURBO_S2[i]; +} + +// Inverse rotation: D2 @ FWHT @ D1 (FWHT is self-inverse up to normalization) +static inline void turbo_rotate_inverse(float* x, int n) { + for (int i = 0; i < n; i++) x[i] *= TURBO_S2[i]; + turbo_fwht(x, n); + for (int i = 0; i < n; i++) x[i] *= TURBO_S1[i]; +} + +// QJL rotation (different seed) +static inline void turbo_qjl_rotate(float* x, int n) { + for (int i = 0; i < n; i++) x[i] *= TURBO_QJL_S1[i]; + turbo_fwht(x, n); + for (int i = 0; i < n; i++) x[i] *= TURBO_QJL_S2[i]; +} + +// --------------------------------------------------------------------------- +// Nearest 3-bit centroid (O(log 8) binary search on boundaries) +// --------------------------------------------------------------------------- + +static inline int nearest_centroid_3bit(float v) { + if (v < BOUNDARIES_3BIT[3]) { // v < 0.0 + if (v < BOUNDARIES_3BIT[1]) return (v < BOUNDARIES_3BIT[0]) ? 0 : 1; + return (v < BOUNDARIES_3BIT[2]) ? 2 : 3; + } else { + if (v < BOUNDARIES_3BIT[5]) return (v < BOUNDARIES_3BIT[4]) ? 4 : 5; + return (v < BOUNDARIES_3BIT[6]) ? 6 : 7; + } +} + +// --------------------------------------------------------------------------- +// TurboQuant storage — packed bit arrays for a single head_dim=128 vector +// --------------------------------------------------------------------------- + +// TURBO3: 3-bit PolarQuant (V cache — MSE optimal) +// Storage: 48 bytes indices (3 bits × 128 = 384 bits) + 2 bytes norm (fp16) +struct TurboQuantV { + uint8_t indices[48]; // 3 bits per coordinate, packed + uint16_t norm_fp16; // corrected L2 norm as fp16 +}; + +// TURBO4: 3-bit PolarQuant + 1-bit QJL (K cache — inner product optimal) +// Storage: 48 bytes indices + 16 bytes QJL signs + 2 bytes norm + 2 bytes rnorm +struct TurboQuantK { + uint8_t indices[48]; // 3-bit PolarQuant indices, packed + uint8_t qjl_signs[16]; // 1-bit QJL sign per coordinate (128 bits) + uint16_t norm_fp16; // original L2 norm as fp16 + uint16_t rnorm_fp16; // residual norm as fp16 +}; + +// --------------------------------------------------------------------------- +// fp16 <-> fp32 helpers (portable, no intrinsics needed) +// --------------------------------------------------------------------------- + +static inline uint16_t fp32_to_fp16(float f) { + // Fast but portable fp32->fp16 conversion + union { float f; uint32_t u; } v = {f}; + uint32_t u = v.u; + uint16_t sign = (u >> 16) & 0x8000; + int32_t exp = (int32_t)((u >> 23) & 0xFF) - 127 + 15; + uint32_t mant = u & 0x7FFFFF; + if (exp <= 0) return sign; + if (exp >= 31) return sign | 0x7C00; + return (uint16_t)(sign | (exp << 10) | (mant >> 13)); +} + +static inline float fp16_to_fp32(uint16_t h) { + uint32_t sign = (h & 0x8000) << 16; + uint32_t exp = (h >> 10) & 0x1F; + uint32_t mant = h & 0x3FF; + if (exp == 0) { + if (mant == 0) { union{uint32_t u;float f;} v={sign}; return v.f; } + while (!(mant & 0x400)) { mant <<= 1; exp--; } + mant &= 0x3FF; exp++; + } else if (exp == 31) { + union{uint32_t u;float f;} v={sign|(0xFF<<23)|mant}; return v.f; + } + union{uint32_t u;float f;} v={sign|((exp+127-15)<<23)|(mant<<13)}; + return v.f; +} + +// --------------------------------------------------------------------------- +// Pack / unpack 3-bit indices into byte arrays +// --------------------------------------------------------------------------- + +static inline void pack_3bit(const uint8_t* idx, uint8_t* packed, int d) { + // 3 bits per element → 3 bytes per 8 elements + for (int i = 0; i < d; i++) { + int bit_offset = i * 3; + int byte_idx = bit_offset / 8; + int bit_pos = bit_offset % 8; + packed[byte_idx] |= (uint8_t)((idx[i] & 0x7) << bit_pos); + if (bit_pos > 5) { + packed[byte_idx + 1] |= (uint8_t)((idx[i] & 0x7) >> (8 - bit_pos)); + } + } +} + +static inline void unpack_3bit(const uint8_t* packed, uint8_t* idx, int d) { + for (int i = 0; i < d; i++) { + int bit_offset = i * 3; + int byte_idx = bit_offset / 8; + int bit_pos = bit_offset % 8; + uint16_t raw = (uint16_t)packed[byte_idx]; + if (byte_idx + 1 < (d * 3 + 7) / 8) + raw |= (uint16_t)packed[byte_idx + 1] << 8; + idx[i] = (uint8_t)((raw >> bit_pos) & 0x7); + } +} + +// --------------------------------------------------------------------------- +// Quantize one head_dim vector → TurboQuantV (3-bit PolarQuant, V cache) +// --------------------------------------------------------------------------- + +static inline TurboQuantV turbo_quantize_v(const float* src, int d) { + TurboQuantV out; + std::memset(&out, 0, sizeof(out)); + + // 1. Compute L2 norm + float norm_sq = 0.f; + float buf[TURBO_D]; + for (int i = 0; i < d; i++) { buf[i] = src[i]; norm_sq += buf[i] * buf[i]; } + float grp_norm = std::sqrt(norm_sq); + float inv_norm = (grp_norm > 1e-10f) ? 1.f / grp_norm : 0.f; + + // 2. Normalize + for (int i = 0; i < d; i++) buf[i] *= inv_norm; + + // 3. WHT rotation + turbo_rotate_forward(buf, d); + + // 4. Quantize, accumulate reconstructed norm² + uint8_t indices[TURBO_D]; + float recon_sq = 0.f; + for (int i = 0; i < d; i++) { + indices[i] = (uint8_t)nearest_centroid_3bit(buf[i]); + recon_sq += CENTROIDS_3BIT[indices[i]] * CENTROIDS_3BIT[indices[i]]; + } + + // 5. Corrected norm: grp_norm / recon_norm + float recon_norm = std::sqrt(recon_sq); + float corrected = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + out.norm_fp16 = fp32_to_fp16(corrected); + + // 6. Pack 3-bit indices + pack_3bit(indices, out.indices, d); + return out; +} + +// --------------------------------------------------------------------------- +// Dequantize TurboQuantV → float vector (CPU debug path; GPU uses Metal) +// --------------------------------------------------------------------------- + +static inline void turbo_dequantize_v(const TurboQuantV& v, float* dst, int d) { + uint8_t indices[TURBO_D]; + unpack_3bit(v.indices, indices, d); + + float norm = fp16_to_fp32(v.norm_fp16); + float buf[TURBO_D]; + for (int i = 0; i < d; i++) buf[i] = CENTROIDS_3BIT[indices[i]]; + + turbo_rotate_inverse(buf, d); + for (int i = 0; i < d; i++) dst[i] = buf[i] * norm; +} + +// --------------------------------------------------------------------------- +// Quantize one head_dim vector → TurboQuantK (3-bit PolarQuant + 1-bit QJL) +// --------------------------------------------------------------------------- + +static inline TurboQuantK turbo_quantize_k(const float* src, int d) { + TurboQuantK out; + std::memset(&out, 0, sizeof(out)); + + // 1. Norm + normalize + float norm_sq = 0.f; + float normalized[TURBO_D]; + for (int i = 0; i < d; i++) { norm_sq += src[i] * src[i]; } + float norm = std::sqrt(norm_sq); + float inv = (norm > 1e-10f) ? 1.f / norm : 0.f; + for (int i = 0; i < d; i++) normalized[i] = src[i] * inv; + + // 2. WHT rotation + float rotated[TURBO_D]; + std::memcpy(rotated, normalized, d * sizeof(float)); + turbo_rotate_forward(rotated, d); + + // 3. 3-bit quantization + uint8_t indices[TURBO_D]; + for (int i = 0; i < d; i++) indices[i] = (uint8_t)nearest_centroid_3bit(rotated[i]); + + // 4. Reconstruct MSE approximation → residual + float mse_recon[TURBO_D]; + for (int i = 0; i < d; i++) mse_recon[i] = CENTROIDS_3BIT[indices[i]]; + turbo_rotate_inverse(mse_recon, d); // back to original space + + float residual[TURBO_D]; + float rnorm_sq = 0.f; + for (int i = 0; i < d; i++) { + residual[i] = normalized[i] - mse_recon[i]; + rnorm_sq += residual[i] * residual[i]; + } + float rnorm = std::sqrt(rnorm_sq); + + // 5. QJL: WHT-based projection of residual, store sign bits + float projected[TURBO_D]; + std::memcpy(projected, residual, d * sizeof(float)); + turbo_qjl_rotate(projected, d); + + for (int i = 0; i < d; i++) { + if (projected[i] >= 0.f) + out.qjl_signs[i / 8] |= (uint8_t)(1 << (i % 8)); + } + + // 6. Pack + out.norm_fp16 = fp32_to_fp16(norm); + out.rnorm_fp16 = fp32_to_fp16(rnorm); + pack_3bit(indices, out.indices, d); + return out; +} + +// --------------------------------------------------------------------------- +// Dequantize TurboQuantK → float vector (CPU debug path; GPU uses Metal) +// --------------------------------------------------------------------------- + +static inline void turbo_dequantize_k(const TurboQuantK& k, float* dst, int d) { + uint8_t indices[TURBO_D]; + unpack_3bit(k.indices, indices, d); + + float norm = fp16_to_fp32(k.norm_fp16); + float rnorm = fp16_to_fp32(k.rnorm_fp16); + + // Stage 1: PolarQuant reconstruction + float mse_recon[TURBO_D]; + for (int i = 0; i < d; i++) mse_recon[i] = CENTROIDS_3BIT[indices[i]]; + turbo_rotate_inverse(mse_recon, d); + + // Stage 2: QJL reconstruction + float signs[TURBO_D]; + for (int i = 0; i < d; i++) + signs[i] = (k.qjl_signs[i / 8] & (1 << (i % 8))) ? 1.f : -1.f; + + // Apply inverse QJL WHT + turbo_qjl_rotate(signs, d); // WHT is self-inverse up to normalization + const float qjl_scale = TURBO_QJL_CONST / (float)d * rnorm; + for (int i = 0; i < d; i++) signs[i] *= qjl_scale; + + // Combine and scale by original norm + for (int i = 0; i < d; i++) dst[i] = (mse_recon[i] + signs[i]) * norm; +} + +// --------------------------------------------------------------------------- +// MLX array-level API (used by KVCache.swift via C bridge) +// --------------------------------------------------------------------------- + +/** + * Encode a batch of KV vectors into TurboQuant format. + * + * keys: [batch, num_heads, seq_len, head_dim] fp16/bf16/fp32 + * values: same shape + * + * Returns the compressed storage as opaque uint8 buffers. + * The Metal attention kernel reads these directly during SDPA. + */ +struct TurboQuantKV { + // Packed storage: each entry is one head_dim-sized compressed vector + std::vector k_data; // K cache (PolarQuant + QJL) + std::vector v_data; // V cache (PolarQuant only) + + int num_heads; + int seq_len; + int head_dim; +}; + +} // namespace mlx::core::fast diff --git a/Source/Cmlx/turbo-quant/turbo_quant_bridge.cpp b/Source/Cmlx/turbo-quant/turbo_quant_bridge.cpp new file mode 100644 index 00000000..1180aaa5 --- /dev/null +++ b/Source/Cmlx/turbo-quant/turbo_quant_bridge.cpp @@ -0,0 +1,68 @@ +// TurboQuant C bridge for Swift bindings +#include "mlx/c/fast.h" +#include "mlx/c/error.h" +#include "mlx/c/private/mlx.h" +#include "turbo_quant_decl.h" + +extern "C" int mlx_fast_turbo_encode( + mlx_array* res_polar_k, + mlx_array* res_polar_v, + mlx_array* res_residual_k, + mlx_array* res_residual_v, + const mlx_array keys, + const mlx_array values, + int k_bits, + const mlx_stream s) { + try { + mlx_array_set_( + *res_polar_k, + mlx::core::fast::turbo_encode_k( + mlx_array_get_(keys), + mlx_stream_get_(s))); + mlx_array_set_( + *res_polar_v, + mlx::core::fast::turbo_encode_v( + mlx_array_get_(values), + mlx_stream_get_(s))); + *res_residual_k = mlx_array_new(); + *res_residual_v = mlx_array_new(); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} + +extern "C" int mlx_fast_turbo_decode_k( + mlx_array* res, + const mlx_array packed, + const mlx_stream s) { + try { + mlx_array_set_( + *res, + mlx::core::fast::turbo_decode_k( + mlx_array_get_(packed), + mlx_stream_get_(s))); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} + +extern "C" int mlx_fast_turbo_decode_v( + mlx_array* res, + const mlx_array packed, + const mlx_stream s) { + try { + mlx_array_set_( + *res, + mlx::core::fast::turbo_decode_v( + mlx_array_get_(packed), + mlx_stream_get_(s))); + } catch (std::exception& e) { + mlx_error(e.what()); + return 1; + } + return 0; +} diff --git a/Source/Cmlx/turbo-quant/turbo_quant_decl.h b/Source/Cmlx/turbo-quant/turbo_quant_decl.h new file mode 100644 index 00000000..9e9d67bf --- /dev/null +++ b/Source/Cmlx/turbo-quant/turbo_quant_decl.h @@ -0,0 +1,14 @@ +// TurboQuant function declarations for C++ namespace +// These extend mlx::core::fast with TurboQuant operations +#pragma once + +#include "mlx/mlx.h" + +namespace mlx::core::fast { + +array turbo_encode_k(const array& keys, StreamOrDevice s = {}); +array turbo_encode_v(const array& values, StreamOrDevice s = {}); +array turbo_decode_k(const array& packed, StreamOrDevice s = {}); +array turbo_decode_v(const array& packed, StreamOrDevice s = {}); + +} // namespace mlx::core::fast diff --git a/Source/Cmlx/turbo-quant/turbo_quant_ops.cpp b/Source/Cmlx/turbo-quant/turbo_quant_ops.cpp new file mode 100644 index 00000000..ee20313c --- /dev/null +++ b/Source/Cmlx/turbo-quant/turbo_quant_ops.cpp @@ -0,0 +1,152 @@ +// TurboQuant KV cache compression operations +// Based on TurboQuant paper (Zandieh et al., arXiv 2504.19874) + +#include +#include +#include +#include + +#include "mlx/mlx.h" +#include "turbo_quant.h" + +namespace { +static constexpr int TURBO_K_RECORD = 68; +static constexpr int TURBO_V_RECORD = 50; +} // anonymous namespace + +namespace mlx::core::fast { + +static std::pair +turbo_to_f32(const mlx::core::array& x, mlx::core::StreamOrDevice s) { + auto x_f32 = mlx::core::astype(x, mlx::core::float32, s); + mlx::core::eval(x_f32); + return {x_f32, x_f32.data()}; +} + +array turbo_encode_k(const array& keys, StreamOrDevice s_) { + auto s = to_stream(s_); + const int head_dim = static_cast(keys.shape(-1)); + if (head_dim != 128 && head_dim != 256) { + throw std::invalid_argument( + "[turbo_encode_k] last dim must be 128 or 256, got " + + std::to_string(head_dim)); + } + const int n_subgroups = head_dim / TURBO_D; + const int record_bytes = TURBO_K_RECORD * n_subgroups; + auto [keys_f32, src] = turbo_to_f32(keys, s); + const int N = static_cast(keys_f32.size() / head_dim); + std::vector buf(static_cast(N) * record_bytes, 0u); + for (int i = 0; i < N; ++i) { + uint8_t* dst = buf.data() + i * record_bytes; + for (int g = 0; g < n_subgroups; ++g) { + TurboQuantK rec = turbo_quantize_k( + src + i * head_dim + g * TURBO_D, TURBO_D); + uint8_t* sub_dst = dst + g * TURBO_K_RECORD; + std::memcpy(sub_dst, rec.indices, 48); + std::memcpy(sub_dst + 48, rec.qjl_signs, 16); + std::memcpy(sub_dst + 64, &rec.norm_fp16, 2); + std::memcpy(sub_dst + 66, &rec.rnorm_fp16, 2); + } + } + Shape out_shape = keys.shape(); + out_shape.back() = record_bytes; + return array(buf.data(), out_shape, uint8); +} + +array turbo_encode_v(const array& values, StreamOrDevice s_) { + auto s = to_stream(s_); + const int head_dim = static_cast(values.shape(-1)); + if (head_dim != 128 && head_dim != 256) { + throw std::invalid_argument( + "[turbo_encode_v] last dim must be 128 or 256, got " + + std::to_string(head_dim)); + } + const int n_subgroups = head_dim / TURBO_D; + const int record_bytes = TURBO_V_RECORD * n_subgroups; + auto [vals_f32, src] = turbo_to_f32(values, s); + const int N = static_cast(vals_f32.size() / head_dim); + std::vector buf(static_cast(N) * record_bytes, 0u); + for (int i = 0; i < N; ++i) { + uint8_t* dst = buf.data() + i * record_bytes; + for (int g = 0; g < n_subgroups; ++g) { + TurboQuantV rec = turbo_quantize_v( + src + i * head_dim + g * TURBO_D, TURBO_D); + uint8_t* sub_dst = dst + g * TURBO_V_RECORD; + std::memcpy(sub_dst, rec.indices, 48); + std::memcpy(sub_dst + 48, &rec.norm_fp16, 2); + } + } + Shape out_shape = values.shape(); + out_shape.back() = record_bytes; + return array(buf.data(), out_shape, uint8); +} + +array turbo_decode_k(const array& packed, StreamOrDevice s_) { + auto s = to_stream(s_); + const int record_bytes = static_cast(packed.shape(-1)); + if (record_bytes != TURBO_K_RECORD && record_bytes != TURBO_K_RECORD * 2) { + throw std::invalid_argument( + "[turbo_decode_k] last dim must be 68 or 136, got " + + std::to_string(record_bytes)); + } + const int n_subgroups = record_bytes / TURBO_K_RECORD; + const int head_dim = n_subgroups * TURBO_D; + auto packed_u8 = astype(packed, uint8, s); + eval(packed_u8); + const uint8_t* src = packed_u8.data(); + const int N = static_cast(packed_u8.size() / record_bytes); + std::vector buf(static_cast(N) * head_dim); + for (int i = 0; i < N; ++i) { + for (int g = 0; g < n_subgroups; ++g) { + const uint8_t* sub_src = src + i * record_bytes + g * TURBO_K_RECORD; + TurboQuantK rec; + std::memset(&rec, 0, sizeof(rec)); + std::memcpy(rec.indices, sub_src, 48); + std::memcpy(rec.qjl_signs, sub_src + 48, 16); + std::memcpy(&rec.norm_fp16, sub_src + 64, 2); + std::memcpy(&rec.rnorm_fp16, sub_src + 66, 2); + turbo_dequantize_k( + rec, + buf.data() + i * head_dim + g * TURBO_D, + TURBO_D); + } + } + Shape out_shape = packed.shape(); + out_shape.back() = head_dim; + return array(buf.data(), out_shape, float32); +} + +array turbo_decode_v(const array& packed, StreamOrDevice s_) { + auto s = to_stream(s_); + const int record_bytes = static_cast(packed.shape(-1)); + if (record_bytes != TURBO_V_RECORD && record_bytes != TURBO_V_RECORD * 2) { + throw std::invalid_argument( + "[turbo_decode_v] last dim must be 50 or 100, got " + + std::to_string(record_bytes)); + } + const int n_subgroups = record_bytes / TURBO_V_RECORD; + const int head_dim = n_subgroups * TURBO_D; + auto packed_u8 = astype(packed, uint8, s); + eval(packed_u8); + const uint8_t* src = packed_u8.data(); + const int N = static_cast(packed_u8.size() / record_bytes); + std::vector buf(static_cast(N) * head_dim); + for (int i = 0; i < N; ++i) { + for (int g = 0; g < n_subgroups; ++g) { + const uint8_t* sub_src = src + i * record_bytes + g * TURBO_V_RECORD; + TurboQuantV rec; + std::memset(&rec, 0, sizeof(rec)); + std::memcpy(rec.indices, sub_src, 48); + std::memcpy(&rec.norm_fp16, sub_src + 48, 2); + turbo_dequantize_v( + rec, + buf.data() + i * head_dim + g * TURBO_D, + TURBO_D); + } + } + Shape out_shape = packed.shape(); + out_shape.back() = head_dim; + return array(buf.data(), out_shape, float32); +} + +} // namespace mlx::core::fast diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 92c96da8..31884e96 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -259,6 +259,47 @@ public enum MLXFast { return MLXArray(result) } + // MARK: - TurboQuant KV Cache Compression + + /// Compress K and V cache tensors using TurboQuant (3-bit PolarQuant + 1-bit QJL for keys). + /// + /// - Parameters: + /// - keys: K tensor `[B, H, T, D]` where D is 128 or 256 + /// - values: V tensor `[B, H, T, D]` where D is 128 or 256 + /// - bits: Compression bits (default 3) + /// - Returns: `((polarK, residualK), (polarV, residualV))` packed uint8 arrays + public static func turboQuantEncode( + keys: MLXArray, values: MLXArray, bits: Int = 3, stream: StreamOrDevice = .default + ) -> ((MLXArray, MLXArray), (MLXArray, MLXArray)) { + var resPolarK = mlx_array_new() + var resPolarV = mlx_array_new() + var resResidualK = mlx_array_new() + var resResidualV = mlx_array_new() + mlx_fast_turbo_encode( + &resPolarK, &resPolarV, &resResidualK, &resResidualV, + keys.ctx, values.ctx, Int32(bits), stream.ctx) + return ((MLXArray(resPolarK), MLXArray(resResidualK)), + (MLXArray(resPolarV), MLXArray(resResidualV))) + } + + /// Decode TurboKV compressed key history back to float32. + public static func turboDecodeK( + packed: MLXArray, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_fast_turbo_decode_k(&result, packed.ctx, stream.ctx) + return MLXArray(result) + } + + /// Decode TurboKV compressed value history back to float32. + public static func turboDecodeV( + packed: MLXArray, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_fast_turbo_decode_v(&result, packed.ctx, stream.ctx) + return MLXArray(result) + } + } /// Optimized implementation of `NN.RoPE`.