From 369ba3a38e0c34b8a974315af87e67d9dde505c3 Mon Sep 17 00:00:00 2001 From: TheTom Date: Tue, 21 Apr 2026 13:16:02 -0500 Subject: [PATCH 1/2] add TurboQuant KV cache compression Implements TurboQuant (arXiv:2504.19874) for KV cache compression: WHT rotation + Lloyd-Max codebook quantization with asymmetric K/V bit-width support. Compresses KV cache 3-7x with minimal quality loss. New files: - TurboQuantKVCache.swift: two-phase cache (raw prefill, compressed decode) - TurboQuantKernels.swift: Metal kernels for fused encode and attention Modified: - KVCache.swift: kvScheme routing for turbo schemes - AttentionUtils.swift: TurboQuantKVCache dispatch - Evaluate.swift: kvScheme parameter threading --- .gitignore | 2 + Libraries/MLXLLM/Models/Gemma3Text.swift | 10 - Libraries/MLXLLM/Models/Gemma3nText.swift | 9 - Libraries/MLXLMCommon/AttentionUtils.swift | 7 +- Libraries/MLXLMCommon/Evaluate.swift | 13 +- Libraries/MLXLMCommon/KVCache.swift | 71 +- Libraries/MLXLMCommon/TurboQuantKVCache.swift | 1185 +++++++++++++++++ Libraries/MLXLMCommon/TurboQuantKernels.swift | 1125 ++++++++++++++++ Package.swift | 10 +- Tools/Gemma4MoETest/main.swift | 58 + 10 files changed, 2460 insertions(+), 30 deletions(-) create mode 100644 Libraries/MLXLMCommon/TurboQuantKVCache.swift create mode 100644 Libraries/MLXLMCommon/TurboQuantKernels.swift create mode 100644 Tools/Gemma4MoETest/main.swift diff --git a/.gitignore b/.gitignore index 1f32eac74..2238332f1 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,5 @@ iOSInjectionProject/ .idea .vscode +.claude/ +default.profraw diff --git a/Libraries/MLXLLM/Models/Gemma3Text.swift b/Libraries/MLXLLM/Models/Gemma3Text.swift index c5d4a060b..140d4f478 100644 --- a/Libraries/MLXLLM/Models/Gemma3Text.swift +++ b/Libraries/MLXLLM/Models/Gemma3Text.swift @@ -416,19 +416,9 @@ public class Gemma3TextModel: Module, LLMModel { return caches } - /// Handles prompt processing for sequences public func prepare( _ input: LMInput, cache: [KVCache], windowSize: Int? = nil ) throws -> PrepareResult { - let promptTokens = input.text.tokens - let promptCount = promptTokens.dim(0) - - guard promptCount > 0 else { - print("Warning: Preparing with empty prompt tokens.") - let emptyToken = MLXArray(Int32(0))[0 ..< 0] - return .tokens(.init(tokens: emptyToken)) - } - return .tokens(input.text) } } diff --git a/Libraries/MLXLLM/Models/Gemma3nText.swift b/Libraries/MLXLLM/Models/Gemma3nText.swift index 1aaa1d75a..06d51a3df 100644 --- a/Libraries/MLXLLM/Models/Gemma3nText.swift +++ b/Libraries/MLXLLM/Models/Gemma3nText.swift @@ -1029,15 +1029,6 @@ public class Gemma3nTextModel: Module, LLMModel { public func prepare( _ input: LMInput, cache: [KVCache], windowSize: Int? = nil ) throws -> PrepareResult { - let promptTokens = input.text.tokens - let promptCount = promptTokens.dim(0) - - guard promptCount > 0 else { - print("Warning: Preparing with empty prompt tokens.") - let emptyToken = MLXArray(Int32(0))[0 ..< 0] - return .tokens(.init(tokens: emptyToken)) - } - return .tokens(input.text) } } diff --git a/Libraries/MLXLMCommon/AttentionUtils.swift b/Libraries/MLXLMCommon/AttentionUtils.swift index 5b4e6c76b..8a6f43d9b 100644 --- a/Libraries/MLXLMCommon/AttentionUtils.swift +++ b/Libraries/MLXLMCommon/AttentionUtils.swift @@ -51,7 +51,12 @@ public func attentionWithCacheUpdate( mask: mask ) } - if let quantizedKVCache = cache as? QuantizedKVCacheProtocol { + if let turboCache = cache as? TurboQuantKVCache { + return turboCache.compressedAttention( + queries: queries, keys: keys, values: values, + scale: scale, mask: mask + ) + } else if let quantizedKVCache = cache as? QuantizedKVCacheProtocol { let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized( keys: keys, values: values) return quantizedScaledDotProductAttention( diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index c59b03e71..3b113e176 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -72,6 +72,10 @@ public struct GenerateParameters: Sendable { /// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0) public var quantizedKVStart: Int + /// KV cache compression scheme. Overrides kvBits when set. + /// Built-in: "affine4", "affine8", "turbo2", "turbo3", "turbo4", "turbo4v2", etc. + public var kvScheme: String? + /// Sampling temperature public var temperature: Float @@ -108,6 +112,7 @@ public struct GenerateParameters: Sendable { kvBits: Int? = nil, kvGroupSize: Int = 64, quantizedKVStart: Int = 0, + kvScheme: String? = nil, temperature: Float = 0.6, topP: Float = 1.0, topK: Int = 0, @@ -125,6 +130,7 @@ public struct GenerateParameters: Sendable { self.kvBits = kvBits self.kvGroupSize = kvGroupSize self.quantizedKVStart = quantizedKVStart + self.kvScheme = kvScheme self.temperature = temperature self.topP = topP self.topK = topK @@ -536,6 +542,7 @@ public struct TokenIterator: TokenIteratorProtocol { let kvBits: Int? let kvGroupSize: Int let quantizedKVStart: Int + let kvScheme: String? // Internal metrics var promptPrefillTime: TimeInterval = 0.0 @@ -564,6 +571,7 @@ public struct TokenIterator: TokenIteratorProtocol { self.kvBits = parameters.kvBits self.kvGroupSize = parameters.kvGroupSize self.quantizedKVStart = parameters.quantizedKVStart + self.kvScheme = parameters.kvScheme self.promptPrefillTime = try measure { try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize) @@ -597,6 +605,7 @@ public struct TokenIterator: TokenIteratorProtocol { self.kvBits = parameters.kvBits self.kvGroupSize = parameters.kvGroupSize self.quantizedKVStart = parameters.quantizedKVStart + self.kvScheme = parameters.kvScheme self.promptPrefillTime = try measure { try prepare(input: input, windowSize: parameters.prefillStepSize) @@ -630,6 +639,7 @@ public struct TokenIterator: TokenIteratorProtocol { self.kvBits = nil self.kvGroupSize = 64 self.quantizedKVStart = 0 + self.kvScheme = nil self.promptPrefillTime = try measure { try prepare(input: input, windowSize: prefillStepSize) @@ -680,7 +690,8 @@ public struct TokenIterator: TokenIteratorProtocol { cache: &cache, kvBits: kvBits, kvGroupSize: kvGroupSize, - quantizedKVStart: quantizedKVStart + quantizedKVStart: quantizedKVStart, + kvScheme: kvScheme ) return convertToToken(logits: result.logits) diff --git a/Libraries/MLXLMCommon/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift index 7561d3017..285f77ac9 100644 --- a/Libraries/MLXLMCommon/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -1776,14 +1776,73 @@ public func quantizedScaledDotProductAttention( /// - kvBits: Number of bits for quantization (nil = no quantization) /// - kvGroupSize: Group size for quantization /// - quantizedKVStart: Token count threshold to begin quantizing +/// Resolve a kvScheme string to cache configuration. +/// Returns nil for unrecognized schemes. +public func resolveKVScheme(_ scheme: String?) -> (type: String, keyBits: Int, valueBits: Int)? { + switch scheme { + case "affine4": return ("affine", 4, 4) + case "affine8": return ("affine", 8, 8) + case "turbo2": return ("turbo", 2, 2) + case "turbo3": return ("turbo", 3, 3) + case "turbo4": return ("turbo", 4, 4) + case "turbo8": return ("turbo", 8, 8) + case "turbo4v2": return ("turbo", 4, 2) // asymmetric: 4-bit K, 2-bit V + case "turbo4v3": return ("turbo", 4, 3) + case "turbo0v2": return ("turbo", 0, 2) // raw K (FP16), 2-bit V + case "turbo0v4": return ("turbo", 0, 4) // raw K (FP16), 4-bit V + default: return nil + } +} + public func maybeQuantizeKVCache( cache: inout [KVCache], kvBits: Int?, kvGroupSize: Int = 64, - quantizedKVStart: Int = 0 + quantizedKVStart: Int = 0, + kvScheme: String? = nil ) { - guard let kvBits = kvBits, - !cache.isEmpty, + // TurboQuant schemes: compress KVCacheSimple layers, skip MambaCache/CacheList + if let scheme = kvScheme, let resolved = resolveKVScheme(scheme), resolved.type == "turbo" { + guard !cache.isEmpty else { return } + // Check if any KVCacheSimple layer is ready for compression + let hasCompressible = cache.contains { $0 is KVCacheSimple && $0.offset > quantizedKVStart } + guard hasCompressible else { return } + + for i in 0 ..< cache.count { + if cache[i] is KVCacheSimple, cache[i].offset > quantizedKVStart { + let turbo = TurboQuantKVCache( + bits: resolved.keyBits, + keyBits: resolved.keyBits, + valueBits: resolved.valueBits + ) + // Transfer existing KV data, trimmed to actual offset + let actualOffset = cache[i].offset + let state = cache[i].innerState() + if state.count >= 2, actualOffset > 0 { + let keys = state[0][.ellipsis, .. quantizedKVStart else { @@ -1791,12 +1850,8 @@ public func maybeQuantizeKVCache( } for i in 0 ..< cache.count { - // Handle cache types that support quantization if let simpleCache = cache[i] as? KVCacheSimple { - cache[i] = simpleCache.toQuantized(groupSize: kvGroupSize, bits: kvBits) + cache[i] = simpleCache.toQuantized(groupSize: effectiveGroupSize, bits: effectiveBits) } - // TODO: RotatingKVCache.toQuantized() is not implemented yet, like in Python. - // When implemented, add: else if let rotatingCache = cache[i] as? RotatingKVCache { ... } - // MambaCache and CacheList don't use traditional KV quantization } } diff --git a/Libraries/MLXLMCommon/TurboQuantKVCache.swift b/Libraries/MLXLMCommon/TurboQuantKVCache.swift new file mode 100644 index 000000000..52d274389 --- /dev/null +++ b/Libraries/MLXLMCommon/TurboQuantKVCache.swift @@ -0,0 +1,1185 @@ +// Copyright © 2026 Apple Inc. + +import Foundation +import MLX +import MLXNN + +extension DType { + fileprivate var bytesPerScalar: Int { + switch self { + case .bfloat16, .float16: return 2 + case .float32: return 4 + case .int32, .uint32: return 4 + case .int16, .uint16: return 2 + case .int8, .uint8: return 1 + default: return 4 + } + } +} + +/// Lloyd-Max codebook centroids for Beta-distributed coordinates. +public enum TurboQuantCodebook { + + /// Pre-computed centroids for common (dim, bits) pairs. + nonisolated(unsafe) private static var precomputed: [Int: [Int: [Float]]] = [ + 64: [ + 2: [-0.18745463, -0.05649366, 0.05649367, 0.18745449], + 3: [ + -0.26375133, -0.16599470, -0.09368263, -0.03040462, 0.03040464, 0.09368261, + 0.16599482, 0.26375186, + ], + 4: [ + -0.32913971, -0.25096416, -0.19681059, -0.15295772, -0.11478586, -0.08000945, + -0.04726735, -0.01563822, 0.01563822, 0.04723797, 0.07994876, 0.11472529, + 0.15289739, 0.19675052, 0.25090477, 0.32908401, + ], + ], + 128: [ + 2: [-0.13302007, -0.03998107, 0.03998102, 0.13302033], + 3: [ + -0.18828832, -0.11801215, -0.06648001, -0.02156330, 0.02156329, 0.06648005, + 0.11801218, 0.18828897, + ], + 4: [ + -0.23639172, -0.17934021, -0.14023653, -0.10881814, -0.08157559, -0.05678632, + -0.03350975, -0.01108178, 0.01108178, 0.03350975, 0.05678631, 0.08157560, + 0.10881804, 0.14023650, 0.17934017, 0.23639278, + ], + ], + 256: [ + 2: [-0.09420358, -0.02827190, 0.02827190, 0.09420330], + 3: [ + -0.13371243, -0.08361249, -0.04704370, -0.01524900, 0.01524901, 0.04704368, + 0.08361248, 0.13371260, + ], + 4: [ + -0.16852295, -0.12754069, -0.09961203, -0.07719406, -0.05781249, -0.04021866, + -0.02370371, -0.00783269, 0.00783269, 0.02370371, 0.04021868, 0.05781246, + 0.07719407, 0.09961203, 0.12754090, 0.16852276, + ], + ], + ] + + private static let centroidLock = NSLock() + + /// Non-power-of-2 dims lazily populated on first access (e.g. 80 for Qwen3-4B). + private static let lazyDims: [Int] = [80, 96] + private static let lazyBits: [Int] = [2, 3, 4, 8] + + /// Ensure centroids for a given dim are populated. + private static func ensureCentroidsPopulated(dim: Int) { + centroidLock.lock() + let exists = precomputed[dim] != nil + centroidLock.unlock() + guard !exists else { return } + + // Generate all bit-widths for this dim + var dimTable: [Int: [Float]] = [:] + for bits in lazyBits { + dimTable[bits] = generateCentroids(dim: dim, bits: bits) + } + + centroidLock.lock() + // Double-check after lock (another thread may have populated) + if precomputed[dim] == nil { + precomputed[dim] = dimTable + } + centroidLock.unlock() + } + + /// Codebook centroids for the given (dim, bits) pair. + public static func codebook(dim: Int, bits: Int) -> MLXArray { + if let dimTable = precomputed[dim], let centroids = dimTable[bits] { + return MLXArray(centroids) + } + // Lazy populate for known model dims + if lazyDims.contains(dim) { + ensureCentroidsPopulated(dim: dim) + if let dimTable = precomputed[dim], let centroids = dimTable[bits] { + return MLXArray(centroids) + } + } + let centroids = generateCentroids(dim: dim, bits: bits) + return MLXArray(centroids) + } + + /// Midpoint boundaries between adjacent centroids. + public static func boundaries(dim: Int, bits: Int) -> MLXArray { + let centroids: [Float] + if let dimTable = precomputed[dim], let cached = dimTable[bits] { + centroids = cached + } else if lazyDims.contains(dim) { + ensureCentroidsPopulated(dim: dim) + if let dimTable = precomputed[dim], let cached = dimTable[bits] { + centroids = cached + } else { + centroids = generateCentroids(dim: dim, bits: bits) + } + } else { + centroids = generateCentroids(dim: dim, bits: bits) + } + var bounds = [Float]() + for i in 0 ..< centroids.count - 1 { + bounds.append((centroids[i] + centroids[i + 1]) / 2.0) + } + return MLXArray(bounds) + } + + /// Generate codebook centroids via weighted k-means on Beta distribution. + static func generateCentroids(dim: Int, bits: Int) -> [Float] { + let levels = 1 << bits + let gridSize = 32768 + let sigma = 1.0 / sqrt(Float(dim)) + + var grid = [Float](repeating: 0, count: gridSize) + var weights = [Float](repeating: 0, count: gridSize) + for i in 0 ..< gridSize { + let x = -1.0 + 2.0 * Float(i) / Float(gridSize - 1) + grid[i] = x + let exponent = Float(dim - 3) / 2.0 + let w = pow(max(1.0 - x * x, 1e-30), exponent) + weights[i] = w + } + + let totalW = weights.reduce(0, +) + var centroids = [Float](repeating: 0, count: levels) + var cumW: Float = 0 + var ci = 0 + for i in 0 ..< gridSize { + cumW += weights[i] + let target = (Float(ci) + 0.5) / Float(levels) * totalW + if cumW >= target && ci < levels { + centroids[ci] = grid[i] + ci += 1 + } + } + // Fill remaining + while ci < levels { + centroids[ci] = centroids[ci - 1] + sigma + ci += 1 + } + + for _ in 0 ..< 100 { + var sums = [Float](repeating: 0, count: levels) + var counts = [Float](repeating: 0, count: levels) + for i in 0 ..< gridSize { + var bestJ = 0 + var bestDist = Float.infinity + for j in 0 ..< levels { + let d = abs(grid[i] - centroids[j]) + if d < bestDist { + bestDist = d + bestJ = j + } + } + sums[bestJ] += grid[i] * weights[i] + counts[bestJ] += weights[i] + } + for j in 0 ..< levels { + if counts[j] > 0 { centroids[j] = sums[j] / counts[j] } + } + } + + return centroids.sorted() + } +} + +/// Random orthogonal rotation matrix generation. +public enum TurboQuantRotation { + + /// Generate a deterministic random orthogonal rotation matrix via QR decomposition. + public static func rotationMatrix(dim: Int, seed: UInt64) -> MLXArray { + let key = MLXRandom.key(seed) + let gaussian = MLXRandom.normal([dim, dim], key: key) + + let (q, r) = MLXLinalg.qr(gaussian, stream: .cpu) + let diagR = r.diagonal(stream: .cpu) + let signs = sign(diagR, stream: .cpu) + let result = q * expandedDimensions(signs, axis: 0) + return result + } + + /// Generate a Hadamard matrix of size dim x dim. Requires dim to be a power of 2. + public static func hadamardMatrix(dim: Int) -> MLXArray { + precondition(dim > 0 && (dim & (dim - 1)) == 0, "dim must be power of 2") + var h: [[Float]] = [[1.0]] + var size = 1 + while size < dim { + var newH = [[Float]](repeating: [Float](repeating: 0, count: size * 2), count: size * 2) + for i in 0 ..< size { + for j in 0 ..< size { + newH[i][j] = h[i][j] + newH[i][j + size] = h[i][j] + newH[i + size][j] = h[i][j] + newH[i + size][j + size] = -h[i][j] + } + } + h = newH + size *= 2 + } + let flat = h.flatMap { $0 } + let result = MLXArray(flat, [dim, dim]) + return result + } + + /// Generate random ±1 sign vector for WHT rotation. + public static func whtSigns(dim: Int, seed: UInt64) -> MLXArray { + let key = MLXRandom.key(seed) + let uniform = MLXRandom.uniform(low: 0, high: 1, [dim], key: key) + let signs = MLX.where(uniform .> Float(0.5), Float(1.0), Float(-1.0)) + return signs + } + + /// Apply WHT butterfly on the last dimension of x. + private static func whtButterfly(_ x: MLXArray) -> MLXArray { + let dim = x.dim(-1) + let logDim = Int(log2(Double(dim))) + let origShape = x.shape + let N = origShape.dropLast().reduce(1, *) + var y = x.reshaped([N, dim]) + + for s in 0 ..< logDim { + let halfBlock = 1 << s + let blockSize = halfBlock << 1 + let numBlocks = dim / blockSize + y = y.reshaped([N, numBlocks, blockSize]) + let a = y[0..., 0..., .. MLXArray { + let dim = x.dim(-1) + precondition(dim > 0 && (dim & (dim - 1)) == 0, "dim must be power of 2") + let signed = x * signs + let transformed = whtButterfly(signed) + let invSqrtDim = MLXArray(1.0 / sqrt(Float(dim)), dtype: x.dtype) + return transformed * invSqrtDim + } + + /// Apply SRHT inverse rotation on the last dimension. + public static func fwhtInverse(_ y: MLXArray, signs: MLXArray) -> MLXArray { + let dim = y.dim(-1) + precondition(dim > 0 && (dim & (dim - 1)) == 0, "dim must be power of 2") + let transformed = whtButterfly(y) + let invSqrtDim = MLXArray(1.0 / sqrt(Float(dim)), dtype: y.dtype) + return transformed * invSqrtDim * signs + } +} + +/// Bit packing/unpacking for codebook indices. +public enum TurboQuantPacking { + + /// Number of uint32 words needed to pack `count` values at `bits` each. + public static func packedWidth(count: Int, bits: Int) -> Int { + (count * bits + 31) / 32 + } + + /// Pack b-bit indices into uint32 words. + public static func packLowBit(_ indices: MLXArray, bits: Int) -> MLXArray { + let count = indices.dim(-1) + let batchShape = Array(indices.shape.dropLast()) + let rows = batchShape.reduce(1, *) + let flat = indices.reshaped([rows, count]) + let pw = packedWidth(count: count, bits: bits) + let mask = UInt32((1 << bits) - 1) + + var wordArrays = [MLXArray]() + for w in 0 ..< pw { + var word = MLXArray.zeros([rows], dtype: .uint32) + for d in 0 ..< count { + let bitOffset = d * bits + let wordIdx = bitOffset / 32 + let offset = bitOffset % 32 + let spill = offset + bits - 32 + + if wordIdx == w { + let shifted = + (flat[0..., d].asType(.uint32) & MLXArray(mask)) << MLXArray(UInt32(offset)) + word = word | shifted + } + if spill > 0 && wordIdx + 1 == w { + let shifted = + (flat[0..., d].asType(.uint32) & MLXArray(mask)) + >> MLXArray(UInt32(bits - spill)) + word = word | shifted + } + } + wordArrays.append(expandedDimensions(word, axis: -1)) + } + let packed = concatenated(wordArrays, axis: -1) + return packed.reshaped(batchShape + [pw]) + } + + /// Unpack b-bit indices from uint32 words. + public static func unpackLowBit(_ packed: MLXArray, bits: Int, count: Int) -> MLXArray { + let shape = packed.shape + let batchShape = Array(shape.dropLast()) + let rows = batchShape.reduce(1, *) + let flat = packed.reshaped([rows, -1]) + let mask = UInt32((1 << bits) - 1) + + var dimArrays = [MLXArray]() + for d in 0 ..< count { + let bitOffset = d * bits + let wordIdx = bitOffset / 32 + let offset = bitOffset % 32 + let spill = offset + bits - 32 + + var value = (flat[0..., wordIdx] >> MLXArray(UInt32(offset))) & MLXArray(mask) + if spill > 0 { + let high = + (flat[0..., wordIdx + 1] << MLXArray(UInt32(bits - spill))) & MLXArray(mask) + value = value | high + } + dimArrays.append(expandedDimensions(value, axis: -1)) + } + let unpacked = concatenated(dimArrays, axis: -1) + return unpacked.reshaped(batchShape + [count]) + } +} + +/// State for MSE-quantized vectors. +public struct MSECodecState { + public var norms: MLXArray + public var packedIndices: MLXArray + public var tokenCount: Int + public let dim: Int + public let bits: Int + + public init(norms: MLXArray, packedIndices: MLXArray, tokenCount: Int, dim: Int, bits: Int) { + self.norms = norms + self.packedIndices = packedIndices + self.tokenCount = tokenCount + self.dim = dim + self.bits = bits + } +} + +/// MSE-optimal codec: rotate, quantize to codebook indices, pack bits. +public class MSECodec { + public let dim: Int + public let bits: Int + public let seed: UInt64 + + public let codebook: MLXArray + public let boundaries: MLXArray + public let useWHT: Bool + public let whtSigns: MLXArray? + public let rotation: MLXArray + public let rotationT: MLXArray + + public init(dim: Int, bits: Int, seed: UInt64 = 42) { + self.dim = dim + self.bits = bits + self.seed = seed + self.codebook = TurboQuantCodebook.codebook(dim: dim, bits: bits) + self.boundaries = TurboQuantCodebook.boundaries(dim: dim, bits: bits) + + let isPowerOf2 = dim > 0 && (dim & (dim - 1)) == 0 + self.useWHT = isPowerOf2 && dim <= 1024 + if useWHT { + let signs = TurboQuantRotation.whtSigns(dim: dim, seed: seed) + self.whtSigns = signs + let hadamard = TurboQuantRotation.hadamardMatrix(dim: dim) + let signsDiag = expandedDimensions(signs, axis: 0) + let whtRot = hadamard * signsDiag / Float(sqrt(Float(dim))) + // bf16 to match model dtype and avoid promoting inputs through matmul + self.rotation = whtRot.asType(.bfloat16) + self.rotationT = self.rotation.transposed() + } else { + self.whtSigns = nil + let rot = TurboQuantRotation.rotationMatrix(dim: dim, seed: seed) + self.rotation = rot.asType(.bfloat16) + self.rotationT = self.rotation.transposed() + } + } + + /// Encode vectors to packed codebook indices with norm extraction. + public func encode(_ vectors: MLXArray) -> MSECodecState { + let norms = sqrt((vectors * vectors).sum(axis: -1)) + let safeNorms = maximum(norms, Float(1e-8)) + let unit = vectors / expandedDimensions(safeNorms, axis: -1) + + let rotated = matmul(unit, rotationT) + + let indices = boundaryQuantize(rotated) + + let storedNorms: MLXArray + if useWHT { + // WHT is orthogonal — norms preserved, no correction needed + storedNorms = norms + } else { + // Norm correction compensates for quantization error in dense rotation + let reconstructed = codebook[indices] + let reconNormSq = (reconstructed * reconstructed).sum(axis: -1) + let reconNorms = sqrt(maximum(reconNormSq, Float(1e-16))) + storedNorms = norms / reconNorms + } + + let packed = TurboQuantPacking.packLowBit(indices, bits: bits) + + return MSECodecState( + norms: storedNorms, + packedIndices: packed, + tokenCount: vectors.dim(2), + dim: dim, + bits: bits + ) + } + + /// Decode from compressed state back to dense vectors. + public func decode(_ state: MSECodecState) -> MLXArray { + let indices = TurboQuantPacking.unpackLowBit(state.packedIndices, bits: bits, count: dim) + + let approx = codebook[indices] + + let unrotated = matmul(approx, rotation) + + return expandedDimensions(state.norms, axis: -1) * unrotated + } + + /// Decode in rotated space without inverse rotation. + public func decodeRotated(_ state: MSECodecState) -> MLXArray { + let indices = TurboQuantPacking.unpackLowBit(state.packedIndices, bits: bits, count: dim) + let approx = codebook[indices] + return expandedDimensions(state.norms, axis: -1) * approx + } + + /// Pre-rotate queries for compressed-domain scoring. + public func prepareQueries(_ queries: MLXArray) -> MLXArray { + return matmul(queries, rotationT) + } + + /// Quantize via boundary comparison. Returns uint32 codebook indices. + func boundaryQuantize(_ rotated: MLXArray) -> MLXArray { + let ndim = rotated.ndim + let expanded = expandedDimensions(rotated, axis: -1) + var bShape = [Int](repeating: 1, count: ndim + 1) + bShape[ndim] = boundaries.count + let b = boundaries.reshaped(bShape) + let greater = (expanded .> b).asType(.uint32) + let indices = greater.sum(axis: -1) + return indices.asType(.uint32) + } +} + +/// KV cache with TurboQuant compression. Stores raw K/V during prefill, +/// compresses on first decode call, then encodes new tokens incrementally. +public class TurboQuantKVCache: BaseKVCache { + + public let bits: Int + public let keyBits: Int + public let valueBits: Int + private let seed: UInt64 + + /// When true, keys stay FP16 and only values are compressed (keyBits == 0). + public let rawKeyMode: Bool + + private var keyMSECodec: MSECodec? + private var valueMSECodec: MSECodec? + + private var rawKeys: MLXArray? + private var rawValues: MLXArray? + private var rawAllocSteps = 0 + + private var keyPackedMSE: MLXArray? + private var keyNorms: MLXArray? + private var valPackedMSE: MLXArray? + private var valNorms: MLXArray? + private var compressedAllocSteps = 0 + + public private(set) var isCompressed = false + + private var pendingRawKeys: [MLXArray] = [] + private var pendingRawValues: [MLXArray] = [] + private var uncompressedCount = 0 + private let recompressInterval: Int + private let step: Int + + public init( + bits: Int = 4, keyBits: Int? = nil, valueBits: Int? = nil, step: Int = 1024, + recompressInterval: Int = 64, seed: UInt64 = 42 + ) { + self.bits = bits + self.keyBits = keyBits ?? bits + self.valueBits = valueBits ?? bits + self.rawKeyMode = (keyBits ?? bits) == 0 + self.seed = seed + self.step = step + self.recompressInterval = recompressInterval + super.init() + } + + override public var isTrimmable: Bool { true } + + private static let codecLock = NSLock() + nonisolated(unsafe) private static var sharedCodecs: [String: MSECodec] = [:] + + private static func getOrCreateCodec(dim: Int, bits: Int, seed: UInt64) -> MSECodec { + let key = "\(dim)_\(bits)_\(seed)" + codecLock.lock() + if let cached = sharedCodecs[key] { + codecLock.unlock() + return cached + } + codecLock.unlock() + let codec = MSECodec(dim: dim, bits: bits, seed: seed) + codecLock.lock() + sharedCodecs[key] = codec + codecLock.unlock() + return codec + } + + /// Initialize codecs if needed, using the shared cache. + private func ensureCodecs(headDim: Int) { + guard valueMSECodec == nil else { return } + if !rawKeyMode { + keyMSECodec = Self.getOrCreateCodec(dim: headDim, bits: keyBits, seed: seed) + } + valueMSECodec = Self.getOrCreateCodec(dim: headDim, bits: valueBits, seed: seed + 1) + } + + /// Dispatch to the appropriate fused encode kernel (WHT or dense). + private func fusedEncodeDispatch( + input: MLXArray, codec: MSECodec, headDim: Int + ) -> (packed: MLXArray, norms: MLXArray) { + if codec.useWHT, let signs = codec.whtSigns { + return TurboQuantKernelOps.fusedEncodeWHT( + input: input, whtSigns: signs, + boundaries: codec.boundaries, codebook: codec.codebook, + bits: codec.bits, dim: headDim + ) + } else { + return TurboQuantKernelOps.fusedEncode( + input: input, rotation: codec.rotation, + boundaries: codec.boundaries, codebook: codec.codebook, + bits: codec.bits, dim: headDim + ) + } + } + + /// Store raw K/V during prefill. Use ``updateAndDequant`` for decode. + override public func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { + let previous = self.offset + + let reset = + if let currentKeys = self.rawKeys, (previous + keys.dim(2)) > currentKeys.dim(2) { + true + } else { + self.rawKeys == nil + } + if reset { + let B = keys.dim(0) + let kvHeads = keys.dim(1) + let kHeadDim = keys.dim(3) + let vHeadDim = values.dim(3) + + let nSteps = (step + keys.dim(2) - 1) / step + let kShape = [B, kvHeads, nSteps * step, kHeadDim] + let vShape = [B, kvHeads, nSteps * step, vHeadDim] + let newK = MLXArray.zeros(kShape, dtype: keys.dtype) + let newV = MLXArray.zeros(vShape, dtype: values.dtype) + + if var currentKeys = self.rawKeys, var currentValues = self.rawValues { + if previous % step != 0 { + currentKeys = currentKeys[.ellipsis, .. 0 else { return } + let allKeys = rk[.ellipsis, .. rawAllocSteps { + let newAlloc = ((prev + numSteps + step - 1) / step) * step + let newRK = MLXArray.zeros([B, H, newAlloc, headDim], dtype: keys.dtype) + if prev > 0, let rk = rawKeys { + newRK[.ellipsis, .. compressedAllocSteps { + let newAlloc = ((prev + numSteps + step - 1) / step) * step + let newVP = MLXArray.zeros([B, H, newAlloc, vpw], dtype: .uint32) + let newVN = MLXArray.zeros([B, H, newAlloc]) + if prev > 0 { + newVP[.ellipsis, .. compressedAllocSteps { + let newAlloc = ((prev + numSteps + step - 1) / step) * step + let newKP = MLXArray.zeros([B, H, newAlloc, kpw], dtype: .uint32) + let newKN = MLXArray.zeros([B, H, newAlloc]) + let newVP = MLXArray.zeros([B, H, newAlloc, vpw], dtype: .uint32) + let newVN = MLXArray.zeros([B, H, newAlloc]) + if prev > 0 { + newKP[.ellipsis, .. ( + MLXArray, MLXArray + ) { + let headDim = newKeys.dim(-1) + ensureCodecs(headDim: headDim) + + guard let valueMSECodec else { + return (newKeys, newValues) + } + + if !isCompressed { + isCompressed = true + let tokenCount = offset + if tokenCount > 0, let rk = rawKeys, let rv = rawValues { + let rawK = rk[.ellipsis, ..= adaptiveInterval { + flushPendingEncode(headDim: newKeys.dim(-1)) + } + + let dequantNewKeys: MLXArray + if rawKeyMode { + dequantNewKeys = newKeys + } else { + guard let keyMSECodec else { return (newKeys, newValues) } + dequantNewKeys = keyMSECodec.prepareQueries(newKeys) + } + let rotNewValues = valueMSECodec.prepareQueries(newValues) + + let reset = + if let dk = self.dequantKeys, prevOffset + newKeys.dim(2) > dk.dim(2) { + true + } else { + self.dequantKeys == nil + } + if reset { + let B = newKeys.dim(0) + let H = newKeys.dim(1) + let nSteps = (step + newKeys.dim(2) - 1) / step + let kShape = [B, H, nSteps * step, headDim] + let newDK = MLXArray.zeros(kShape, dtype: newKeys.dtype) + let newDV = MLXArray.zeros(kShape, dtype: newKeys.dtype) + + if var currentKeys = self.dequantKeys, var currentValues = self.dequantValues { + if prevOffset % step != 0 { + currentKeys = currentKeys[.ellipsis, .. MLXArray { + if rawKeyMode { return queries } + guard let keyMSECodec else { return queries } + return keyMSECodec.prepareQueries(queries) + } + + /// Inverse-rotate SDPA output back to original space. + public func inverseRotateOutput(_ rotatedOutput: MLXArray) -> MLXArray { + guard let valueMSECodec else { return rotatedOutput } + return matmul(rotatedOutput, valueMSECodec.rotation) + } + + /// Use compressed-domain Metal kernels instead of dequant + SDPA. + public var useCompressedAttention: Bool = false + + /// Compressed-domain attention via Metal kernels. + public func compressedAttention( + queries: MLXArray, + keys newKeys: MLXArray, + values newValues: MLXArray, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none + ) -> MLXArray { + let headDim = newKeys.dim(-1) + let B = queries.dim(0) + let nQHeads = queries.dim(1) + let nKVHeads = newKeys.dim(1) + let L = queries.dim(2) + let nRepeats = nQHeads / nKVHeads + + if !isCompressed { + compressRawCache() + } + + pendingRawKeys.append(newKeys) + pendingRawValues.append(newValues) + uncompressedCount += newKeys.dim(2) + offset += newKeys.dim(2) + flushPendingEncode(headDim: newKeys.dim(-1)) + + guard let valueMSECodec else { + return queries + } + + let tokenCount = offset + + let flatValPacked = valPackedMSE![0..., 0..., .. 1 { + let expanded = expandedDimensions(allKeys, axis: 2) + let tiledKeys = MLX.tiled(expanded, repetitions: [1, 1, nRepeats, 1, 1]) + expandedKeys = tiledKeys.reshaped([B, nQHeads, tokenCount, headDim]) + } else { + expandedKeys = allKeys + } + + var scores = matmul(queries, expandedKeys.transposed(0, 1, 3, 2)) * scale + switch mask { + case .array(let maskArray): + if maskArray.dtype == .bool { + scores = MLX.where( + maskArray, scores, MLXArray(Float.leastNormalMagnitude, dtype: scores.dtype) + ) + } else { + scores = scores + maskArray + } + case .causal: + let queryOffset = tokenCount - L + let causalMask = MLXArray.tri(L, m: tokenCount, k: queryOffset, type: Bool.self) + let expandedMask = expandedDimensions( + expandedDimensions(causalMask, axis: 0), axis: 0) + scores = MLX.where( + expandedMask, scores, MLXArray(Float.leastNormalMagnitude, dtype: scores.dtype)) + case .none: break + default: break + } + + let attnWeights = softmax(scores, axis: -1) + // Materialize before Metal kernel — prevents lazy graph overflow in rawKeyMode + eval(attnWeights) + + let flatWeights = attnWeights.reshaped([B * nQHeads * L, tokenCount]) + let rotatedOutput = TurboQuantKernelOps.mseWeightedSum( + weights: flatWeights, packed: flatValPacked, norms: flatValNorms, + codebook: valueMSECodec.codebook, tokenCount: tokenCount, + repeatCount: nRepeats, bits: self.valueBits, dim: headDim + ) + + output = matmul( + rotatedOutput.reshaped([B, nQHeads, L, headDim]), + valueMSECodec.rotation + ) + } else { + guard let keyMSECodec else { return queries } + + let qRot = keyMSECodec.prepareQueries(queries) * scale + let flatQ = qRot.reshaped([B * nQHeads * L, headDim]) + + let flatKeyPacked = keyPackedMSE![0..., 0..., .. 1 { + expandedK = MLX.repeated( + dequantKRot.reshaped([B * nKVHeads, 1, tokenCount, headDim]), + count: nRepeats, axis: 1 + ).reshaped([B * nQHeads * L, tokenCount, headDim]) + } else { + expandedK = dequantKRot + } + + var scores = matmul( + flatQ.reshaped([B * nQHeads * L, 1, headDim]), + expandedK.transposed(0, 2, 1) + ).squeezed(axis: 1).reshaped([B, nQHeads, L, tokenCount]) + + switch mask { + case .array(let maskArray): + if maskArray.dtype == .bool { + scores = MLX.where( + maskArray, scores, + MLXArray(Float.leastNormalMagnitude, dtype: scores.dtype)) + } else { + scores = scores + maskArray + } + case .none: break + default: break + } + + let attnWeights = softmax(scores, axis: -1) + // Materialize before dequant+matmul chain — prevents lazy graph overflow + eval(attnWeights) + + let flatWeights = attnWeights.reshaped([B * nQHeads * L, tokenCount]).contiguous() + + let dequantV = valueMSECodec.decodeRotated(MSECodecState( + norms: flatValNorms.reshaped([B * nKVHeads * tokenCount]), + packedIndices: flatValPacked.reshaped([B * nKVHeads * tokenCount, -1]), + tokenCount: tokenCount, dim: headDim, bits: self.valueBits + )).reshaped([B * nKVHeads, tokenCount, headDim]) + + let expandedV: MLXArray + if nRepeats > 1 { + expandedV = MLX.repeated( + dequantV.reshaped([B * nKVHeads, 1, tokenCount, headDim]), + count: nRepeats, axis: 1 + ).reshaped([B * nQHeads * L, tokenCount, headDim]) + } else { + expandedV = dequantV + } + + let rotatedOutput = matmul( + flatWeights.expandedDimensions(axis: 1), + expandedV + ).squeezed(axis: 1) + + output = matmul( + rotatedOutput.reshaped([B, nQHeads, L, headDim]), + valueMSECodec.rotation + ) + } + } + + return output + } + + /// Approximate memory footprint in bytes (excludes shared codec overhead). + public var memoryBytes: Int { + var total = 0 + if let rk = rawKeys { total += rk.shape.reduce(1, *) * rk.dtype.bytesPerScalar } + if let rv = rawValues { total += rv.shape.reduce(1, *) * rv.dtype.bytesPerScalar } + if let kp = keyPackedMSE { total += kp.shape.reduce(1, *) * kp.dtype.bytesPerScalar } + if let kn = keyNorms { total += kn.shape.reduce(1, *) * kn.dtype.bytesPerScalar } + if let vp = valPackedMSE { total += vp.shape.reduce(1, *) * vp.dtype.bytesPerScalar } + if let vn = valNorms { total += vn.shape.reduce(1, *) * vn.dtype.bytesPerScalar } + if let dk = dequantKeys { total += dk.shape.reduce(1, *) * dk.dtype.bytesPerScalar } + if let dv = dequantValues { total += dv.shape.reduce(1, *) * dv.dtype.bytesPerScalar } + return total + } + + override public var state: [MLXArray] { + get { + if isCompressed { + if rawKeyMode { + guard let rk = rawKeys, + let vpm = valPackedMSE, let vn = valNorms, + offset > 0 + else { return [] } + return [ + rk[0..., 0..., .. 0 + else { return [] } + return [ + kpm[0..., 0..., .. 0 else { return [] } + return [rk[0..., 0..., .. Int { + guard n > 0, offset > 0 else { return 0 } + + if !pendingRawKeys.isEmpty { + flushPendingEncode(headDim: pendingRawKeys[0].dim(-1)) + } + + let trimCount = min(n, offset) + offset -= trimCount + if offset == 0 { + rawKeys = nil + rawValues = nil + rawAllocSteps = 0 + keyPackedMSE = nil + keyNorms = nil + valPackedMSE = nil + valNorms = nil + dequantKeys = nil + dequantValues = nil + compressedAllocSteps = 0 + isCompressed = false + pendingRawKeys.removeAll() + pendingRawValues.removeAll() + uncompressedCount = 0 + } + return trimCount + } +} diff --git a/Libraries/MLXLMCommon/TurboQuantKernels.swift b/Libraries/MLXLMCommon/TurboQuantKernels.swift new file mode 100644 index 000000000..f914af372 --- /dev/null +++ b/Libraries/MLXLMCommon/TurboQuantKernels.swift @@ -0,0 +1,1125 @@ +// Copyright © 2026 Apple Inc. + +import Foundation +import MLX + +enum TurboQuantMetalKernels { + + /// Fused encode: norm, rotate, quantize, pack, and norm-correct in one dispatch. + static let fusedEncodeSource = """ + constexpr uint LEVELS = 1u << Bits; + + uint d = thread_position_in_threadgroup.x; // dimension index (0..Dim-1) + uint row = thread_position_in_grid.y; // vector index (B*H*T) + + // --- Step 1: Load input value --- + float val = input[row * Dim + d]; + + // --- Step 2: Compute L2 norm (SIMD reduction) --- + float sq = val * val; + float norm_sq = simd_sum(sq); + // For Dim > 32, need threadgroup reduction + threadgroup float shared_norm[4]; // up to 4 SIMD groups + uint sg_id = d / 32; + if (d % 32 == 0) { + shared_norm[sg_id] = norm_sq; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float total_norm_sq = 0; + uint num_groups = (Dim + 31) / 32; + for (uint i = 0; i < num_groups; i++) { + total_norm_sq += shared_norm[i]; + } + float norm_val = sqrt(total_norm_sq); + float inv_norm = (norm_val > 1e-8f) ? (1.0f / norm_val) : 0.0f; + + // --- Step 3: Normalize --- + float unit_val = val * inv_norm; + + // --- Step 4: Rotate (y = Π · x_unit) via shared memory matmul --- + // Each thread d computes: y[d] = Σ_j rotation[d * Dim + j] * x_unit[j] + threadgroup float shared_unit[1024]; // max Dim = 1024 + shared_unit[d] = unit_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float rotated = 0.0f; + for (uint j = 0; j < Dim; j++) { + rotated += rotation[d * Dim + j] * shared_unit[j]; + } + + // --- Step 5: Quantize via branchless boundary comparison --- + // V2.1 optimization: use arithmetic sum of comparisons instead of branching. + // Metal compiles (rotated > boundaries[b]) to a predicated 0/1 — summing these + // is branchless and avoids SIMD lane divergence. + uint idx = 0; + for (uint b = 0; b < LEVELS - 1; b++) { + idx += (uint)(rotated > boundaries[b]); + } + + // --- Step 6: Pack bits into uint32 word (atomic OR) --- + uint bit_offset = d * Bits; + uint word_idx = bit_offset / 32; + uint shift = bit_offset % 32; + uint masked = idx & ((1u << Bits) - 1u); + + // Pack bits — use threadgroup shared memory to avoid atomic contention + // Each thread writes its index bits to shared, then thread 0 per word writes output + threadgroup uint shared_packed[64]; // max PackedWidth = 64 words + if (d < PackedWidth) shared_packed[d] = 0; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Each dimension contributes its bits via atomic OR on threadgroup memory + uint primary_val = masked << shift; + atomic_fetch_or_explicit((threadgroup atomic_uint*)&shared_packed[word_idx], primary_val, memory_order_relaxed); + + int spill = (int)shift + (int)Bits - 32; + if (spill > 0) { + uint spill_val = masked >> ((uint)Bits - (uint)spill); + atomic_fetch_or_explicit((threadgroup atomic_uint*)&shared_packed[word_idx + 1], spill_val, memory_order_relaxed); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write packed words to output (one thread per word) + if (d < PackedWidth) { + packed_out[row * PackedWidth + d] = shared_packed[d]; + } + + // --- Step 7: Norm correction --- + // Compute reconstruction norm: ||codebook[idx]||₂ for the quantized unit vector. + // Store corrected_norm = original_norm / recon_norm so that + // decode(centroid[idx] * corrected_norm) better approximates the original vector. + float centroid_val = codebook[idx]; + float recon_sq = centroid_val * centroid_val; + float recon_norm_sq = simd_sum(recon_sq); + // Threadgroup reduction for Dim > 32 + if (d % 32 == 0) { + shared_norm[sg_id] = recon_norm_sq; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float total_recon_sq = 0; + for (uint i = 0; i < num_groups; i++) { + total_recon_sq += shared_norm[i]; + } + float recon_norm = sqrt(total_recon_sq); + float corrected_norm = (recon_norm > 1e-8f) ? (norm_val / recon_norm) : norm_val; + + if (d == 0) { + norms_out[row] = corrected_norm; + } + """ + + /// Fused WHT encode: norm, WHT rotation, quantize, and pack (no norm correction). + static let fusedEncodeWHTSource = """ + constexpr uint LEVELS = 1u << Bits; + + uint d = thread_position_in_threadgroup.x; // dimension index (0..Dim-1) + uint row = thread_position_in_grid.y; // vector index (B*H*T) + + // --- Step 1: Load input value --- + float val = input[row * Dim + d]; + + // --- Step 2: Compute L2 norm (SIMD reduction) --- + float sq = val * val; + float norm_sq = simd_sum(sq); + threadgroup float shared_norm[4]; + uint sg_id = d / 32; + if (d % 32 == 0) { + shared_norm[sg_id] = norm_sq; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float total_norm_sq = 0; + uint num_groups = (Dim + 31) / 32; + for (uint i = 0; i < num_groups; i++) { + total_norm_sq += shared_norm[i]; + } + float norm_val = sqrt(total_norm_sq); + float inv_norm = (norm_val > 1e-8f) ? (1.0f / norm_val) : 0.0f; + + // --- Step 3: Normalize + sign flip (fused) --- + // V2.1 optimization: pre-compute inv_norm * sign to eliminate one multiply per element. + // Instead of: unit_val = val * inv_norm; wht_val = sign * unit_val (2 muls) + // We do: wht_val = val * (inv_norm * sign) (1 mul + 1 FMA-friendly product) + float inv_norm_sign = inv_norm * wht_signs[d]; + float wht_val = val * inv_norm_sign; + + // --- Step 4: WHT rotation via cooperative SIMD shuffle --- + // V2.1 optimization: use simd_shuffle_xor for intra-SIMD butterfly stages + // (register-to-register, no shared memory or barriers needed for first 5 stages) + + // Phase 1: Intra-SIMD butterfly via simd_shuffle_xor (stages 0..min(LogDim,5)-1) + // Each stage s XORs lane indices at distance 2^s — effectively free on Apple GPU + // Use metal::min: MLX-injected headers add overloads named `min` (bf16_math.h), so + // unqualified min(LogDim, 5u) is ambiguous vs metal::min on newer toolchains. + uint log_dim_u = static_cast(LogDim); + uint simd_stages = metal::min(log_dim_u, 5u); // 5 stages covers 32 lanes (2^5 = 32) + uint lane_in_simd = d % 32; + for (uint s = 0; s < simd_stages; s++) { + uint step = 1u << s; + float other = simd_shuffle_xor(wht_val, step); + wht_val = (lane_in_simd & step) ? (other - wht_val) : (other + wht_val); + } + + // Phase 2: Cross-SIMD-group butterfly via shared memory (stages 5..LogDim-1) + // Only needed when Dim > 32 — these stages cross SIMD group boundaries + threadgroup float shared_buf[1024]; // max Dim = 1024 + if (log_dim_u > 5u) { + shared_buf[d] = wht_val; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = simd_stages; s < log_dim_u; s++) { + uint half_block = 1u << s; + uint block_size = half_block << 1; + uint block_id = d / block_size; + uint pos_in_block = d % block_size; + + float a, b; + if (pos_in_block < half_block) { + a = shared_buf[block_id * block_size + pos_in_block]; + b = shared_buf[block_id * block_size + pos_in_block + half_block]; + shared_buf[d] = a + b; + } else { + a = shared_buf[block_id * block_size + pos_in_block - half_block]; + b = shared_buf[block_id * block_size + pos_in_block]; + shared_buf[d] = a - b; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + wht_val = shared_buf[d]; + } + + // Normalize: WHT has scale factor sqrt(Dim) + float inv_sqrt_dim = 1.0f / sqrt((float)Dim); + float rotated = wht_val * inv_sqrt_dim; + + // --- Step 5: Quantize via branchless boundary comparison --- + // V2.1 optimization: arithmetic sum avoids SIMD lane divergence + uint idx = 0; + for (uint b = 0; b < LEVELS - 1; b++) { + idx += (uint)(rotated > boundaries[b]); + } + + // --- Step 6: Pack bits into uint32 word (atomic OR) --- + uint bit_offset = d * Bits; + uint word_idx = bit_offset / 32; + uint shift = bit_offset % 32; + uint masked = idx & ((1u << Bits) - 1u); + + threadgroup uint shared_packed[64]; + if (d < PackedWidth) shared_packed[d] = 0; + threadgroup_barrier(mem_flags::mem_threadgroup); + + uint primary_val = masked << shift; + atomic_fetch_or_explicit((threadgroup atomic_uint*)&shared_packed[word_idx], primary_val, memory_order_relaxed); + + int spill = (int)shift + (int)Bits - 32; + if (spill > 0) { + uint spill_val = masked >> ((uint)Bits - (uint)spill); + atomic_fetch_or_explicit((threadgroup atomic_uint*)&shared_packed[word_idx + 1], spill_val, memory_order_relaxed); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (d < PackedWidth) { + packed_out[row * PackedWidth + d] = shared_packed[d]; + } + + // --- Step 7: Store raw norm (WHT is orthogonal — no norm correction needed) --- + // WHT preserves norms: ||WHT(x)||₂ = ||x||₂. Reconstruction norm ≈ original norm, + // so the correction ratio ≈ 1.0. Skipping saves codebook lookup + norm + division. + if (d == 0) { + norms_out[row] = norm_val; + } + """ + + /// Flash attention pass 1 (causal): per-block partial attention with causal masking. + static let turboFlashPass1CausalSource = """ + constexpr uint KEY_MASK = (1u << KeyBits) - 1u; + constexpr uint KEY_LEVELS = 1u << KeyBits; + constexpr uint VAL_MASK = (1u << ValueBits) - 1u; + constexpr uint VAL_LEVELS = 1u << ValueBits; + constexpr uint DIMS_PER_LANE = (Dim + 31) / 32; + + // Runtime params from input buffers + uint token_count = uint(tc_buf[0]); + uint repeat_count = uint(rc_buf[0]); + uint num_blocks = uint(nb_buf[0]); + uint BlockSize = uint(bs_buf[0]); + uint L = uint(L_buf[0]); + uint q_offset = uint(qo_buf[0]); + + uint lane = thread_position_in_grid.x; // SIMD lane (0-31) + uint q_idx = thread_position_in_grid.y; // query index (B*nQHeads*L) + uint block_idx = thread_position_in_grid.z; // token block index + + uint q_within_L = q_idx % L; + uint q_head_idx = q_idx / L; + uint kv_idx = q_head_idx / repeat_count; + + uint q_abs = q_offset + q_within_L; + + uint t_start = block_idx * BlockSize; + uint t_end = t_start + BlockSize; + if (t_end > token_count) t_end = token_count; + + // Early exit: entire block is future-masked + if (t_start > q_abs) { + uint partial_base = (q_idx * num_blocks + block_idx) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) o_partials[partial_base + d] = 0.0f; + } + if (lane == 0) { + uint ml_idx = q_idx * num_blocks + block_idx; + m_partials[ml_idx] = -INFINITY; + l_partials[ml_idx] = 0.0f; + } + return; + } + + // Clamp t_end to causal boundary + if (t_end > q_abs + 1) t_end = q_abs + 1; + + // Load key codebook into registers + float key_cb[KEY_LEVELS]; + for (uint i = 0; i < KEY_LEVELS; i++) { + key_cb[i] = key_codebook[i]; + } + + // Load value codebook into registers + float val_cb[VAL_LEVELS]; + for (uint i = 0; i < VAL_LEVELS; i++) { + val_cb[i] = val_codebook[i]; + } + + // Load query values for this lane's dimensions + float q_vals[DIMS_PER_LANE]; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + q_vals[i] = (d < Dim) ? q_rot[q_idx * Dim + d] : 0.0f; + } + + // Online softmax state for this block + float m = -INFINITY; + float l = 0.0f; + float o[DIMS_PER_LANE]; + for (uint i = 0; i < DIMS_PER_LANE; i++) o[i] = 0.0f; + + // Process tokens in this block (up to causal boundary) + for (uint t = t_start; t < t_end; t++) { + // --- Score: Q×K dot product --- + const device uint32_t* k_packed_ptr = key_packed + kv_idx * token_count * KeyPackedWidth + t * KeyPackedWidth; + float k_norm = key_norms[kv_idx * token_count + t]; + + float dot_partial = 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d >= Dim) break; + + uint k_bit_offset = d * KeyBits; + uint k_word_idx = k_bit_offset / 32; + uint k_shift = k_bit_offset % 32; + uint k_value = (k_packed_ptr[k_word_idx] >> k_shift); + int k_spill = (int)k_shift + (int)KeyBits - 32; + if (k_spill > 0) { + k_value |= (k_packed_ptr[k_word_idx + 1] << ((uint)KeyBits - (uint)k_spill)); + } + k_value &= KEY_MASK; + + dot_partial += q_vals[i] * key_cb[k_value]; + } + + float score = simd_sum(dot_partial) * k_norm; + + // --- Online softmax update + V accumulation --- + float new_m = max(m, score); + float exp_diff = exp(m - new_m); + float exp_score = exp(score - new_m); + + const device uint32_t* v_packed_ptr = val_packed + kv_idx * token_count * ValuePackedWidth + t * ValuePackedWidth; + float v_norm = val_norms[kv_idx * token_count + t]; + + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d >= Dim) break; + + uint v_bit_offset = d * ValueBits; + uint v_word_idx = v_bit_offset / 32; + uint v_shift = v_bit_offset % 32; + uint v_value = (v_packed_ptr[v_word_idx] >> v_shift); + int v_spill = (int)v_shift + (int)ValueBits - 32; + if (v_spill > 0) { + v_value |= (v_packed_ptr[v_word_idx + 1] << ((uint)ValueBits - (uint)v_spill)); + } + v_value &= VAL_MASK; + + o[i] = o[i] * exp_diff + exp_score * (val_cb[v_value] * v_norm); + } + + l = l * exp_diff + exp_score; + m = new_m; + } + + // Write partial results: o[D], m, l + uint partial_base = (q_idx * num_blocks + block_idx) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + o_partials[partial_base + d] = o[i]; + } + } + if (lane == 0) { + uint ml_idx = q_idx * num_blocks + block_idx; + m_partials[ml_idx] = m; + l_partials[ml_idx] = l; + } + """ + + /// Flash attention pass 1 NR0 (causal): multi-row amortized KV dequant with per-row masking. + static let turboFlashPass1NR0CausalSource = """ + constexpr uint KEY_MASK = (1u << KeyBits) - 1u; + constexpr uint KEY_LEVELS = 1u << KeyBits; + constexpr uint VAL_MASK = (1u << ValueBits) - 1u; + constexpr uint VAL_LEVELS = 1u << ValueBits; + constexpr uint DIMS_PER_LANE = (Dim + 31) / 32; + + // Runtime params from input buffers (avoids per-token pipeline recompilation) + uint token_count = uint(tc_buf[0]); + uint repeat_count = uint(rc_buf[0]); + uint num_blocks = uint(nb_buf[0]); + uint BlockSize = uint(bs_buf[0]); + uint L = uint(L_buf[0]); + uint q_offset = uint(qo_buf[0]); + + uint lane = thread_position_in_grid.x; + uint query_group = thread_position_in_grid.y; + uint block_idx = thread_position_in_grid.z; + + // Token range for this block + uint t_start = block_idx * BlockSize; + uint t_end = t_start + BlockSize; + if (t_end > token_count) t_end = token_count; + + // Compute per-row causal boundaries and find the maximum (most permissive) + // for the shared token loop. Per-row masking happens inside the score loop. + uint q_abs[NR0]; + uint max_q_abs = 0; + for (uint r = 0; r < NR0; r++) { + uint q_idx = query_group * NR0 + r; + uint q_within_L = q_idx % L; + q_abs[r] = q_offset + q_within_L; + if (q_abs[r] > max_q_abs) max_q_abs = q_abs[r]; + } + + // Early exit: entire block is future-masked for ALL NR0 queries + if (t_start > max_q_abs) { + for (uint r = 0; r < NR0; r++) { + uint q_idx = query_group * NR0 + r; + uint partial_base = (q_idx * num_blocks + block_idx) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) o_partials[partial_base + d] = 0.0f; + } + if (lane == 0) { + uint ml_idx = q_idx * num_blocks + block_idx; + m_partials[ml_idx] = -INFINITY; + l_partials[ml_idx] = 0.0f; + } + } + return; + } + + // Clamp t_end to the most permissive causal boundary + if (t_end > max_q_abs + 1) t_end = max_q_abs + 1; + + // Load codebooks (shared across all NR0 queries) + float key_cb[KEY_LEVELS]; + for (uint i = 0; i < KEY_LEVELS; i++) key_cb[i] = key_codebook[i]; + float val_cb[VAL_LEVELS]; + for (uint i = 0; i < VAL_LEVELS; i++) val_cb[i] = val_codebook[i]; + + // Load query values for all NR0 rows + float q_vals[NR0 * DIMS_PER_LANE]; + for (uint r = 0; r < NR0; r++) { + uint q_idx = query_group * NR0 + r; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + q_vals[r * DIMS_PER_LANE + i] = (d < Dim) ? q_rot[q_idx * Dim + d] : 0.0f; + } + } + + // KV head mapping (use first query's head — same assumption as non-causal NR0) + uint q_head_idx_0 = (query_group * NR0) / L; + uint kv_idx = q_head_idx_0 / repeat_count; + + // Online softmax state — NR0 independent streams + float m_state[NR0]; + float l_state[NR0]; + float o_state[NR0 * DIMS_PER_LANE]; + for (uint r = 0; r < NR0; r++) { + m_state[r] = -INFINITY; + l_state[r] = 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; i++) o_state[r * DIMS_PER_LANE + i] = 0.0f; + } + + // Process tokens — KV dequant once, score per-row with causal mask + for (uint t = t_start; t < t_end; t++) { + // Dequant K once + float k_decoded[DIMS_PER_LANE]; + const device uint32_t* k_packed_ptr = key_packed + kv_idx * token_count * KeyPackedWidth + t * KeyPackedWidth; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d >= Dim) { k_decoded[i] = 0.0f; continue; } + uint k_bit_offset = d * KeyBits; + uint k_word_idx = k_bit_offset / 32; + uint k_shift = k_bit_offset % 32; + uint k_value = (k_packed_ptr[k_word_idx] >> k_shift); + int k_spill = (int)k_shift + (int)KeyBits - 32; + if (k_spill > 0) { + k_value |= (k_packed_ptr[k_word_idx + 1] << ((uint)KeyBits - (uint)k_spill)); + } + k_value &= KEY_MASK; + k_decoded[i] = key_cb[k_value]; + } + float k_norm = key_norms[kv_idx * token_count + t]; + + // Dequant V once + float v_decoded[DIMS_PER_LANE]; + const device uint32_t* v_packed_ptr = val_packed + kv_idx * token_count * ValuePackedWidth + t * ValuePackedWidth; + float v_norm = val_norms[kv_idx * token_count + t]; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d >= Dim) { v_decoded[i] = 0.0f; continue; } + uint v_bit_offset = d * ValueBits; + uint v_word_idx = v_bit_offset / 32; + uint v_shift = v_bit_offset % 32; + uint v_value = (v_packed_ptr[v_word_idx] >> v_shift); + int v_spill = (int)v_shift + (int)ValueBits - 32; + if (v_spill > 0) { + v_value |= (v_packed_ptr[v_word_idx + 1] << ((uint)ValueBits - (uint)v_spill)); + } + v_value &= VAL_MASK; + v_decoded[i] = val_cb[v_value] * v_norm; + } + + // Score + softmax + V for each query row (with per-row causal mask) + for (uint r = 0; r < NR0; r++) { + // Per-row causal: skip if this token is future for this specific query + if (t > q_abs[r]) continue; + + float dot_partial = 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + dot_partial += q_vals[r * DIMS_PER_LANE + i] * k_decoded[i]; + } + float score = simd_sum(dot_partial) * k_norm; + + float new_m = max(m_state[r], score); + float exp_diff = exp(m_state[r] - new_m); + float exp_score = exp(score - new_m); + + for (uint i = 0; i < DIMS_PER_LANE; i++) { + o_state[r * DIMS_PER_LANE + i] = o_state[r * DIMS_PER_LANE + i] * exp_diff + exp_score * v_decoded[i]; + } + l_state[r] = l_state[r] * exp_diff + exp_score; + m_state[r] = new_m; + } + } + + // Write partial results for all NR0 queries + for (uint r = 0; r < NR0; r++) { + uint q_idx = query_group * NR0 + r; + uint partial_base = (q_idx * num_blocks + block_idx) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) o_partials[partial_base + d] = o_state[r * DIMS_PER_LANE + i]; + } + if (lane == 0) { + uint ml_idx = q_idx * num_blocks + block_idx; + m_partials[ml_idx] = m_state[r]; + l_partials[ml_idx] = l_state[r]; + } + } + """ + + /// Flash attention pass 2: cross-block reduction of partial softmax states. + static let turboFlashPass2Source = """ + constexpr uint DIMS_PER_LANE = (Dim + 31) / 32; + + // Runtime params from input buffers (avoids per-token pipeline recompilation) + uint num_blocks = uint(nb_buf[0]); + + uint lane = thread_position_in_grid.x; + uint q_idx = thread_position_in_grid.y; + + float m = -INFINITY; + float l = 0.0f; + float o[DIMS_PER_LANE]; + for (uint i = 0; i < DIMS_PER_LANE; i++) o[i] = 0.0f; + + for (uint b = 0; b < num_blocks; b++) { + uint ml_idx = q_idx * num_blocks + b; + + // All lanes read the same m/l (broadcast read from device memory) + float block_m = m_partials[ml_idx]; + float block_l = l_partials[ml_idx]; + + // Skip empty blocks + if (block_l == 0.0f) continue; + + float new_m = max(m, block_m); + float exp_old = exp(m - new_m); + float exp_block = exp(block_m - new_m); + + uint partial_base = (q_idx * num_blocks + b) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + o[i] = o[i] * exp_old + o_partials[partial_base + d] * exp_block; + } + } + + l = l * exp_old + block_l * exp_block; + m = new_m; + } + + // Write normalized output + float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + output[q_idx * Dim + d] = o[i] * inv_l; + } + } + """ + + /// Flash attention pass 2 with fused inverse value rotation. + static let turboFlashPass2FusedRotSource = """ + constexpr uint DIMS_PER_LANE = (Dim + 31) / 32; + + // Runtime params from input buffers (avoids per-token pipeline recompilation) + uint num_blocks = uint(nb_buf[0]); + + uint lane = thread_position_in_grid.x; + uint q_idx = thread_position_in_grid.y; + + float m = -INFINITY; + float l = 0.0f; + float o[DIMS_PER_LANE]; + for (uint i = 0; i < DIMS_PER_LANE; i++) o[i] = 0.0f; + + for (uint b = 0; b < num_blocks; b++) { + uint ml_idx = q_idx * num_blocks + b; + + float block_m = m_partials[ml_idx]; + float block_l = l_partials[ml_idx]; + + if (block_l == 0.0f) continue; + + float new_m = max(m, block_m); + float exp_old = exp(m - new_m); + float exp_block = exp(block_m - new_m); + + uint partial_base = (q_idx * num_blocks + b) * Dim; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + o[i] = o[i] * exp_old + o_partials[partial_base + d] * exp_block; + } + } + + l = l * exp_old + block_l * exp_block; + m = new_m; + } + + // Normalize + float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f; + + // Gather normalized output into threadgroup shared memory for rotation + threadgroup float shared_out[Dim]; + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + shared_out[d] = o[i] * inv_l; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Apply inverse value rotation: output[d] = Σ_j shared_out[j] * Π_val[j][d] + // matmul(x, Π_val) reads column d of Π_val for output dimension d. + // Π_val is stored row-major [Dim, Dim], so column d = val_rotation[j * Dim + d] + for (uint i = 0; i < DIMS_PER_LANE; i++) { + uint d = lane + i * 32; + if (d < Dim) { + float acc = 0.0f; + for (uint j = 0; j < Dim; j++) { + acc += shared_out[j] * val_rotation[j * Dim + d]; + } + output[q_idx * Dim + d] = acc; + } + } + """ + + /// Sparse V skip threshold. Override via `TURBO_SPARSE_V_THRESHOLD` env var. + static let sparseVThreshold: Float = { + if let envValue = ProcessInfo.processInfo.environment["TURBO_SPARSE_V_THRESHOLD"], + let parsed = Float(envValue) + { + return parsed + } + return 1e-6 + }() + + /// Value aggregation: weighted sum of codebook-quantized values in rotated space. + static var valueKernelSource: String { + let threshold = String(format: "%e", sparseVThreshold) + return """ + constexpr uint MASK = (1u << Bits) - 1u; + constexpr uint LEVELS = 1u << Bits; + + // Runtime params from input buffers (avoids per-token pipeline recompilation) + uint token_count = uint(tc_buf[0]); + uint repeat_count = uint(rc_buf[0]); + + uint lane = thread_position_in_grid.x; + uint head_idx = thread_position_in_grid.y; + uint dim_block = thread_position_in_grid.z; + + uint d = dim_block * 32 + lane; + if (d >= Dim) return; + + uint kv_head = head_idx / repeat_count; + + // Load codebook + float cb[LEVELS]; + for (uint i = 0; i < LEVELS; i++) { + cb[i] = codebook[i]; + } + + float acc = 0.0f; + for (uint t = 0; t < token_count; t++) { + float w = weights[head_idx * token_count + t]; + if (w < \(threshold)f) continue; // Sparse V: skip negligible attention weights + + float norm_val = norms[kv_head * token_count + t]; + const device uint32_t* packed_ptr = packed + kv_head * token_count * PackedWidth + t * PackedWidth; + + uint bit_offset = d * Bits; + uint word_idx = bit_offset / 32; + uint shift = bit_offset % 32; + uint value = (packed_ptr[word_idx] >> shift); + + int spill = (int)shift + (int)Bits - 32; + if (spill > 0) { + value |= (packed_ptr[word_idx + 1] << ((uint)Bits - (uint)spill)); + } + value &= MASK; + + acc += w * norm_val * cb[value]; + } + + output[head_idx * Dim + d] = acc; + """ + } +} + +public enum TurboQuantKernelOps { + nonisolated(unsafe) private static var valueKernels: [String: MLXFast.MLXFastKernel] = [:] + private static let lock = NSLock() + + /// Fused encode with dense rotation. + nonisolated(unsafe) private static var encodeKernelCache: [String: MLXFast.MLXFastKernel] = [:] + private static let encodeLock = NSLock() + + private static func getEncodeKernel(bits: Int, dim: Int) -> MLXFast.MLXFastKernel { + let key = "encode_\(bits)_\(dim)" + encodeLock.lock() + if let cached = encodeKernelCache[key] { + encodeLock.unlock() + return cached + } + let kernel = MLXFast.metalKernel( + name: "turbo_fused_encode_\(bits)_\(dim)", + inputNames: ["input", "rotation", "boundaries", "codebook"], + outputNames: ["packed_out", "norms_out"], + source: TurboQuantMetalKernels.fusedEncodeSource + ) + encodeKernelCache[key] = kernel + encodeLock.unlock() + return kernel + } + + public static func fusedEncode( + input: MLXArray, + rotation: MLXArray, + boundaries: MLXArray, + codebook: MLXArray, + bits: Int, + dim: Int + ) -> (packed: MLXArray, norms: MLXArray) { + let numRows = input.dim(0) + let packedWidth = TurboQuantPacking.packedWidth(count: dim, bits: bits) + let kernel = getEncodeKernel(bits: bits, dim: dim) + + let results = kernel( + [input, rotation, boundaries, codebook], + template: [ + ("Bits", bits), + ("Dim", dim), + ("PackedWidth", packedWidth), + ], + grid: (dim, numRows, 1), + threadGroup: (dim, 1, 1), + outputShapes: [[numRows, packedWidth], [numRows]], + outputDTypes: [.uint32, .float32] + ) + return (packed: results[0], norms: results[1]) + } + + /// Fused encode with WHT rotation (power-of-2 dims only). + nonisolated(unsafe) private static var encodeWHTKernelCache: [String: MLXFast.MLXFastKernel] = + [:] + private static let encodeWHTLock = NSLock() + + private static func getEncodeWHTKernel(bits: Int, dim: Int) -> MLXFast.MLXFastKernel { + let key = "encode_wht_\(bits)_\(dim)" + encodeWHTLock.lock() + if let cached = encodeWHTKernelCache[key] { + encodeWHTLock.unlock() + return cached + } + let kernel = MLXFast.metalKernel( + name: "turbo_fused_encode_wht_\(bits)_\(dim)", + inputNames: ["input", "wht_signs", "boundaries"], + outputNames: ["packed_out", "norms_out"], + source: TurboQuantMetalKernels.fusedEncodeWHTSource + ) + encodeWHTKernelCache[key] = kernel + encodeWHTLock.unlock() + return kernel + } + + public static func fusedEncodeWHT( + input: MLXArray, + whtSigns: MLXArray, + boundaries: MLXArray, + codebook: MLXArray, + bits: Int, + dim: Int + ) -> (packed: MLXArray, norms: MLXArray) { + let numRows = input.dim(0) + let packedWidth = TurboQuantPacking.packedWidth(count: dim, bits: bits) + let kernel = getEncodeWHTKernel(bits: bits, dim: dim) + + let results = kernel( + [input, whtSigns, boundaries], + template: [ + ("Bits", bits), + ("Dim", dim), + ("PackedWidth", packedWidth), + ("LogDim", Int(log2(Double(dim)))), + ], + grid: (dim, numRows, 1), + threadGroup: (dim, 1, 1), + outputShapes: [[numRows, packedWidth], [numRows]], + outputDTypes: [.uint32, .float32] + ) + return (packed: results[0], norms: results[1]) + } + + nonisolated(unsafe) private static var flashPass1Kernels: [String: MLXFast.MLXFastKernel] = [:] + nonisolated(unsafe) private static var flashPass2Kernels: [String: MLXFast.MLXFastKernel] = [:] + + /// Query rows per SIMD group. Override via `TURBO_FLASH_NR0` env var. + public static let flashNR0: Int = { + if let envValue = ProcessInfo.processInfo.environment["TURBO_FLASH_NR0"], + let parsed = Int(envValue), parsed > 0, (parsed & (parsed - 1)) == 0 + { + return parsed + } + return 2 + }() + + /// Tokens per block in two-pass flash attention. Override via `TURBO_FLASH_BLOCK_SIZE` env var. + public static let flashBlockSize: Int = { + if let envValue = ProcessInfo.processInfo.environment["TURBO_FLASH_BLOCK_SIZE"], + let parsed = Int(envValue), parsed > 0 + { + return parsed + } + return 64 + }() + + /// Current sparse V skip threshold. + public static var sparseVThreshold: Float { TurboQuantMetalKernels.sparseVThreshold } + + private static let flashPass1Lock = NSLock() + + private static func getFlashPass1Kernel( + source: String, cachePrefix: String, + keyBits: Int, valueBits: Int, dim: Int, + extraInputNames: [String] + ) -> MLXFast.MLXFastKernel { + let key = "\(cachePrefix)_\(keyBits)_\(valueBits)_\(dim)" + flashPass1Lock.lock() + if let cached = flashPass1Kernels[key] { + flashPass1Lock.unlock() + return cached + } + let baseInputs = [ + "q_rot", "key_packed", "key_norms", "key_codebook", + "val_packed", "val_norms", "val_codebook", + "tc_buf", "rc_buf", "nb_buf", "bs_buf", + ] + let kernel = MLXFast.metalKernel( + name: "turbo_flash_p1_\(cachePrefix)_\(keyBits)_\(valueBits)_\(dim)", + inputNames: baseInputs + extraInputNames, + outputNames: ["o_partials", "m_partials", "l_partials"], + source: source + ) + flashPass1Kernels[key] = kernel + flashPass1Lock.unlock() + return kernel + } + + private static func dispatchFlashPass1( + source: String, cachePrefix: String, + rotatedQueries: MLXArray, + keyPacked: MLXArray, keyNorms: MLXArray, keyCodebook: MLXArray, + valPacked: MLXArray, valNorms: MLXArray, valCodebook: MLXArray, + tokenCount: Int, repeatCount: Int, + keyBits: Int, valueBits: Int, dim: Int, + blockSize: Int, + extraInputNames: [String] = [], + extraInputBuffers: [MLXArray] = [] + ) -> (oPartials: MLXArray, mPartials: MLXArray, lPartials: MLXArray) { + let numBlocks = (tokenCount + blockSize - 1) / blockSize + let totalQ = rotatedQueries.dim(0) + let kernel = getFlashPass1Kernel( + source: source, cachePrefix: cachePrefix, + keyBits: keyBits, valueBits: valueBits, dim: dim, + extraInputNames: extraInputNames) + + let runtimeBufs: [MLXArray] = [ + MLXArray([Int32(tokenCount)]), MLXArray([Int32(repeatCount)]), + MLXArray([Int32(numBlocks)]), MLXArray([Int32(blockSize)]), + ] + let keyPW = TurboQuantPacking.packedWidth(count: dim, bits: keyBits) + let valPW = TurboQuantPacking.packedWidth(count: dim, bits: valueBits) + + let results = kernel( + [ + rotatedQueries, keyPacked, keyNorms, keyCodebook, + valPacked, valNorms, valCodebook, + ] + runtimeBufs + extraInputBuffers, + template: [ + ("KeyBits", keyBits), + ("ValueBits", valueBits), + ("Dim", dim), + ("KeyPackedWidth", keyPW), + ("ValuePackedWidth", valPW), + ], + grid: (32, totalQ, numBlocks), + threadGroup: (32, 1, 1), + outputShapes: [[totalQ, numBlocks, dim], [totalQ, numBlocks], [totalQ, numBlocks]], + outputDTypes: [.float32, .float32, .float32] + ) + return (oPartials: results[0], mPartials: results[1], lPartials: results[2]) + } + + /// NR0 multi-row causal pass 1 dispatch. + private static func dispatchFlashPass1NR0Causal( + rotatedQueries: MLXArray, + keyPacked: MLXArray, keyNorms: MLXArray, keyCodebook: MLXArray, + valPacked: MLXArray, valNorms: MLXArray, valCodebook: MLXArray, + tokenCount: Int, repeatCount: Int, + keyBits: Int, valueBits: Int, dim: Int, + blockSize: Int, nr0: Int, + queryChunkLength: Int, queryOffset: Int + ) -> (oPartials: MLXArray, mPartials: MLXArray, lPartials: MLXArray) { + let numBlocks = (tokenCount + blockSize - 1) / blockSize + let totalQ = rotatedQueries.dim(0) + let kernel = getFlashPass1Kernel( + source: TurboQuantMetalKernels.turboFlashPass1NR0CausalSource, + cachePrefix: "nr0_causal", keyBits: keyBits, valueBits: valueBits, dim: dim, + extraInputNames: ["L_buf", "qo_buf"]) + let runtimeBufs: [MLXArray] = [ + MLXArray([Int32(tokenCount)]), MLXArray([Int32(repeatCount)]), + MLXArray([Int32(numBlocks)]), MLXArray([Int32(blockSize)]), + ] + let extraBufs = [MLXArray([Int32(queryChunkLength)]), MLXArray([Int32(queryOffset)])] + let keyPW = TurboQuantPacking.packedWidth(count: dim, bits: keyBits) + let valPW = TurboQuantPacking.packedWidth(count: dim, bits: valueBits) + let results = kernel( + [ + rotatedQueries, keyPacked, keyNorms, keyCodebook, + valPacked, valNorms, valCodebook, + ] + runtimeBufs + extraBufs, + template: [ + ("KeyBits", keyBits), ("ValueBits", valueBits), ("Dim", dim), + ("KeyPackedWidth", keyPW), ("ValuePackedWidth", valPW), ("NR0", nr0), + ], + grid: (32, totalQ / nr0, numBlocks), + threadGroup: (32, 1, 1), + outputShapes: [[totalQ, numBlocks, dim], [totalQ, numBlocks], [totalQ, numBlocks]], + outputDTypes: [.float32, .float32, .float32] + ) + return (oPartials: results[0], mPartials: results[1], lPartials: results[2]) + } + + /// Pass 2 dispatch with optional fused output rotation. + private static let flashPass2Lock = NSLock() + + private static func getFlashPass2Kernel(fused: Bool, dim: Int) -> MLXFast.MLXFastKernel { + let key = "p2_\(fused ? "fused" : "plain")_\(dim)" + flashPass2Lock.lock() + if let cached = flashPass2Kernels[key] { + flashPass2Lock.unlock() + return cached + } + let source = + fused + ? TurboQuantMetalKernels.turboFlashPass2FusedRotSource + : TurboQuantMetalKernels.turboFlashPass2Source + let inputs = + fused + ? ["o_partials", "m_partials", "l_partials", "val_rotation", "nb_buf"] + : ["o_partials", "m_partials", "l_partials", "nb_buf"] + let kernel = MLXFast.metalKernel( + name: "turbo_flash_p2_\(key)", + inputNames: inputs, + outputNames: ["output"], + source: source + ) + flashPass2Kernels[key] = kernel + flashPass2Lock.unlock() + return kernel + } + + private static func dispatchFlashPass2( + oPartials: MLXArray, mPartials: MLXArray, lPartials: MLXArray, + dim: Int, numBlocks: Int, totalQ: Int, + valRotation: MLXArray? = nil + ) -> MLXArray { + let fused = valRotation != nil + let kernel = getFlashPass2Kernel(fused: fused, dim: dim) + var inputs: [MLXArray] = [oPartials, mPartials, lPartials] + if let valRotation { inputs.append(valRotation) } + inputs.append(MLXArray([Int32(numBlocks)])) + + let results = kernel( + inputs, + template: [("Dim", dim)], + grid: (dim, totalQ, 1), + threadGroup: (dim, 1, 1), + outputShapes: [[totalQ, dim]], + outputDTypes: [.float32] + ) + return results[0] + } + + /// Two-pass flash attention with causal masking. + public static func turboFlashAttentionCausal( + rotatedQueries: MLXArray, + keyPacked: MLXArray, + keyNorms: MLXArray, + keyCodebook: MLXArray, + valPacked: MLXArray, + valNorms: MLXArray, + valCodebook: MLXArray, + tokenCount: Int, + repeatCount: Int, + keyBits: Int, + valueBits: Int, + dim: Int, + queryChunkLength: Int, + queryOffset: Int, + valRotation: MLXArray? = nil, + blockSize: Int? = nil + ) -> MLXArray { + let blockSize = blockSize ?? flashBlockSize + let numBlocks = (tokenCount + blockSize - 1) / blockSize + let totalQ = rotatedQueries.dim(0) + let nr0 = flashNR0 + + let useNR0 = nr0 > 1 && totalQ % nr0 == 0 && totalQ >= nr0 + + let oPartials: MLXArray + let mPartials: MLXArray + let lPartials: MLXArray + + if useNR0 { + (oPartials, mPartials, lPartials) = dispatchFlashPass1NR0Causal( + rotatedQueries: rotatedQueries, + keyPacked: keyPacked, keyNorms: keyNorms, keyCodebook: keyCodebook, + valPacked: valPacked, valNorms: valNorms, valCodebook: valCodebook, + tokenCount: tokenCount, repeatCount: repeatCount, + keyBits: keyBits, valueBits: valueBits, dim: dim, + blockSize: blockSize, nr0: nr0, + queryChunkLength: queryChunkLength, queryOffset: queryOffset + ) + } else { + (oPartials, mPartials, lPartials) = dispatchFlashPass1( + source: TurboQuantMetalKernels.turboFlashPass1CausalSource, + cachePrefix: "flash_p1_causal", + rotatedQueries: rotatedQueries, + keyPacked: keyPacked, keyNorms: keyNorms, keyCodebook: keyCodebook, + valPacked: valPacked, valNorms: valNorms, valCodebook: valCodebook, + tokenCount: tokenCount, repeatCount: repeatCount, + keyBits: keyBits, valueBits: valueBits, dim: dim, + blockSize: blockSize, + extraInputNames: ["L_buf", "qo_buf"], + extraInputBuffers: [ + MLXArray([Int32(queryChunkLength)]), MLXArray([Int32(queryOffset)]), + ] + ) + } + + return dispatchFlashPass2( + oPartials: oPartials, mPartials: mPartials, lPartials: lPartials, + dim: dim, numBlocks: numBlocks, totalQ: totalQ, + valRotation: valRotation + ) + } + + /// Weighted sum of packed codebook values. Result is in rotated space. + private static let valueLock = NSLock() + + public static func mseWeightedSum( + weights: MLXArray, + packed: MLXArray, + norms: MLXArray, + codebook: MLXArray, + tokenCount: Int, + repeatCount: Int, + bits: Int, + dim: Int + ) -> MLXArray { + let key = "value_\(bits)_\(dim)" + valueLock.lock() + if valueKernels[key] == nil { + valueKernels[key] = MLXFast.metalKernel( + name: "turbo_value_\(bits)_\(dim)", + inputNames: ["weights", "packed", "norms", "codebook", "tc_buf", "rc_buf"], + outputNames: ["output"], + source: TurboQuantMetalKernels.valueKernelSource + ) + } + let kernel = valueKernels[key]! + valueLock.unlock() + + let totalHeads = weights.dim(0) + let packedWidth = TurboQuantPacking.packedWidth(count: dim, bits: bits) + let results = kernel( + [ + weights, packed, norms, codebook, + MLXArray([Int32(tokenCount)]), MLXArray([Int32(repeatCount)]), + ], + template: [("Bits", bits), ("Dim", dim), ("PackedWidth", packedWidth)], + grid: (32, totalHeads, (dim + 31) / 32), + threadGroup: (32, 1, 1), + outputShapes: [[totalHeads, dim]], + outputDTypes: [.float32] + ) + return results[0] + } +} diff --git a/Package.swift b/Package.swift index 3924f7df9..d48af5dbf 100644 --- a/Package.swift +++ b/Package.swift @@ -36,7 +36,7 @@ let package = Package( targets: ["IntegrationTestHelpers"]), ], dependencies: [ - .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), + .package(url: "https://github.com/ekryski/mlx-swift", branch: "alpha"), .package(url: "https://github.com/swiftlang/swift-syntax.git", "600.0.0" ..< "604.0.0"), ], targets: [ @@ -146,6 +146,14 @@ let package = Package( ], path: "Libraries/MLXHuggingFace" ), + .executableTarget( + name: "Gemma4MoETest", + dependencies: [ + "MLXLLM", "MLXLMCommon", "BenchmarkHelpers", + .product(name: "MLX", package: "mlx-swift"), + ], + path: "Tools/Gemma4MoETest" + ), ] ) diff --git a/Tools/Gemma4MoETest/main.swift b/Tools/Gemma4MoETest/main.swift new file mode 100644 index 000000000..1465400e7 --- /dev/null +++ b/Tools/Gemma4MoETest/main.swift @@ -0,0 +1,58 @@ +import Foundation +import MLX +import BenchmarkHelpers +import MLXLLM +import MLXLMCommon + +let args = CommandLine.arguments +guard args.count >= 3 else { print("Usage: TQBench [scheme]"); exit(1) } +let test = args[1]; let modelPath = args[2]; let scheme = args.count > 3 ? args[3] : "none" + +let container = try await loadModelContainer(from: URL(fileURLWithPath: modelPath), using: NoOpTokenizerLoader()) + +if test == "ppl" { + let data = try JSONSerialization.jsonObject(with: Data(contentsOf: URL(fileURLWithPath: "/tmp/ppl_test_data.json"))) as! [String: Any] + let ids = (data["ids"] as! [NSNumber]).map { Int32($0.intValue) } + try await container.perform { context in + let model = context.model; var cache = model.newCache(parameters: nil) + let halfLen = ids.count / 2 + let prefillOut = model(MLXArray(Array(ids[0.. 4 ? args[4] : "512")! + let tokens = MLXArray(Array(repeating: Int32(1), count: ctx))[.newAxis, .ellipsis] + let output = model(tokens, cache: cache); eval(output, cache) + if scheme != "none" { maybeQuantizeKVCache(cache: &cache, kvBits: nil, quantizedKVStart: 0, kvScheme: scheme); eval(cache) } + // Decode 5 tokens to trigger compression + var cur = MLXArray([Int32(1)])[.newAxis, .ellipsis] + for _ in 0..<5 { let out = model(cur, cache: cache); eval(out, cache); cur = MLXArray([Int32(1)])[.newAxis, .ellipsis] } + var turboMem = 0; var turboCount = 0 + for c in cache { if let tc = c as? TurboQuantKVCache { turboMem += tc.memoryBytes; turboCount += 1 } } + print("MEM:\(turboMem/1024)\t\(turboCount)\t\(cache.count)") + } +} From 2864b83116748af33b7fcbe2ea87f2b4ae758601 Mon Sep 17 00:00:00 2001 From: TheTom Date: Wed, 22 Apr 2026 15:06:22 -0500 Subject: [PATCH 2/2] fix TurboQuant encode overflow, cache serialization, and speculative kvScheme - size shared_norm dynamically in Metal encode kernels for head_dim > 128 - guard NR0 flash path on query chunk alignment to prevent KV head mismatch - add TurboQuantKVCache to prompt cache save/restore with metaState - thread kvScheme through speculative decoding quantization closure - point Package.swift at ml-explore/mlx-swift 0.31.3 --- Libraries/MLXLMCommon/Evaluate.swift | 3 ++- Libraries/MLXLMCommon/KVCache.swift | 17 +++++++++++++++++ Libraries/MLXLMCommon/TurboQuantKVCache.swift | 12 ++++++++++++ Libraries/MLXLMCommon/TurboQuantKernels.swift | 5 ++++- Package.swift | 2 +- 5 files changed, 36 insertions(+), 3 deletions(-) diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index 3b113e176..bd0a9525f 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -809,7 +809,8 @@ public struct SpeculativeTokenIterator: TokenIteratorProtocol { cache: &cache, kvBits: parameters.kvBits, kvGroupSize: parameters.kvGroupSize, - quantizedKVStart: parameters.quantizedKVStart + quantizedKVStart: parameters.quantizedKVStart, + kvScheme: parameters.kvScheme ) } diff --git a/Libraries/MLXLMCommon/KVCache.swift b/Libraries/MLXLMCommon/KVCache.swift index 285f77ac9..2065cd042 100644 --- a/Libraries/MLXLMCommon/KVCache.swift +++ b/Libraries/MLXLMCommon/KVCache.swift @@ -1368,6 +1368,7 @@ private func cacheClassName(_ cache: KVCache) -> String { case is ArraysCache: return "ArraysCache" case is RotatingKVCache: return "RotatingKVCache" case is QuantizedKVCache: return "QuantizedKVCache" + case is TurboQuantKVCache: return "TurboQuantKVCache" case is KVCacheSimple: return "KVCache" case is CacheList: return "CacheList" default: return "KVCache" @@ -1521,6 +1522,22 @@ private func restoreCacheFromMetaState( cache.restoreFromMetaState(state: state, savedMetaState: metaState) return cache + case "TurboQuantKVCache": + guard metaState.count >= 5, + let bits = Int(metaState[1]), + let keyBits = Int(metaState[2]), + let valueBits = Int(metaState[3]), + let seed = UInt64(metaState[4]) + else { + throw KVCacheError( + message: "Invalid TurboQuantKVCache metaState") + } + let cache = TurboQuantKVCache( + bits: bits, keyBits: keyBits, valueBits: valueBits, seed: seed) + cache.state = state + cache.metaState = metaState + return cache + case "CacheList": return try CacheList.fromState(state: state, metaState: metaState) diff --git a/Libraries/MLXLMCommon/TurboQuantKVCache.swift b/Libraries/MLXLMCommon/TurboQuantKVCache.swift index 52d274389..ef7d75697 100644 --- a/Libraries/MLXLMCommon/TurboQuantKVCache.swift +++ b/Libraries/MLXLMCommon/TurboQuantKVCache.swift @@ -1154,6 +1154,18 @@ public class TurboQuantKVCache: BaseKVCache { } } + override public var metaState: [String] { + get { + ["\(offset)", "\(bits)", "\(keyBits)", "\(valueBits)", "\(seed)"] + } + set { + guard newValue.count >= 5, + let o = Int(newValue[0]) + else { return } + offset = o + } + } + @discardableResult override public func trim(_ n: Int) -> Int { guard n > 0, offset > 0 else { return 0 } diff --git a/Libraries/MLXLMCommon/TurboQuantKernels.swift b/Libraries/MLXLMCommon/TurboQuantKernels.swift index f914af372..bf7bf9122 100644 --- a/Libraries/MLXLMCommon/TurboQuantKernels.swift +++ b/Libraries/MLXLMCommon/TurboQuantKernels.swift @@ -19,7 +19,7 @@ enum TurboQuantMetalKernels { float sq = val * val; float norm_sq = simd_sum(sq); // For Dim > 32, need threadgroup reduction - threadgroup float shared_norm[4]; // up to 4 SIMD groups + threadgroup float shared_norm[(Dim + 31) / 32]; uint sg_id = d / 32; if (d % 32 == 0) { shared_norm[sg_id] = norm_sq; @@ -1041,7 +1041,10 @@ public enum TurboQuantKernelOps { let totalQ = rotatedQueries.dim(0) let nr0 = flashNR0 + // NR0 multi-row: only safe when L is divisible by NR0, otherwise + // grouped queries can span KV head boundaries within a group. let useNR0 = nr0 > 1 && totalQ % nr0 == 0 && totalQ >= nr0 + && queryChunkLength % nr0 == 0 let oPartials: MLXArray let mPartials: MLXArray diff --git a/Package.swift b/Package.swift index d48af5dbf..15c9aa0b3 100644 --- a/Package.swift +++ b/Package.swift @@ -36,7 +36,7 @@ let package = Package( targets: ["IntegrationTestHelpers"]), ], dependencies: [ - .package(url: "https://github.com/ekryski/mlx-swift", branch: "alpha"), + .package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")), .package(url: "https://github.com/swiftlang/swift-syntax.git", "600.0.0" ..< "604.0.0"), ], targets: [