diff --git a/Libraries/MLXLMCommon/AttentionUtils.swift b/Libraries/MLXLMCommon/AttentionUtils.swift index 5b4e6c76b..e57b9ad2c 100644 --- a/Libraries/MLXLMCommon/AttentionUtils.swift +++ b/Libraries/MLXLMCommon/AttentionUtils.swift @@ -1,39 +1,43 @@ import Foundation import MLX -/// Attention utilities that match Python mlx-lm's interface -/// -/// This provides a single function that automatically routes to quantized or regular -/// attention based on cache type, matching Python's `scaled_dot_product_attention` +// MARK: - TurboKV Telemetry + +enum TurboKVTelemetry { + nonisolated(unsafe) private static var totalTokens: UInt64 = 0 + nonisolated(unsafe) private static var totalOrigBytes: UInt64 = 0 + nonisolated(unsafe) private static var totalPackedBytes: UInt64 = 0 + nonisolated(unsafe) private static var hasLogged = false + + static func logOnce(compressedOffset: Int, keys: MLXArray, values: MLXArray, headDim: Int) { + let B = keys.dim(0) + let nKVH = keys.dim(1) + let newTokens = keys.dim(2) + let packedBytes = UInt64(keys.nbytes + values.nbytes) + let origBytes = UInt64(B * nKVH * newTokens * headDim * 2 * 2) + record(tokens: newTokens, origBytes: Int(origBytes), packedBytes: Int(packedBytes)) + } + + static func record(tokens: Int, origBytes: Int, packedBytes: Int) { + totalTokens += UInt64(tokens) + totalOrigBytes += UInt64(origBytes) + totalPackedBytes += UInt64(packedBytes) + if !hasLogged && totalTokens > 0 { + hasLogged = true + let ratio = totalPackedBytes > 0 ? Double(totalOrigBytes) / Double(totalPackedBytes) : 0 + print("[TurboKV] \(totalTokens)t compressed, \(String(format: "%.1f", ratio))x ratio") + } + } +} + +typealias TurboKVCacheTelemetry = TurboKVTelemetry + +// MARK: - Attention Utilities /// Automatic attention with cache update /// -/// This function matches Python's `scaled_dot_product_attention` in base.py: -/// - Detects if cache is `QuantizedKVCache` using `isinstance` pattern -/// - Routes to `quantizedScaledDotProductAttention` or `MLXFast.scaledDotProductAttention` -/// - Handles cache updating automatically -/// - Transparent to models - they just call this function -/// -/// **Usage in models:** -/// ```swift -/// let output = attentionWithCacheUpdate( -/// queries: queries, -/// keys: keys, -/// values: values, -/// cache: cache, -/// scale: scale, -/// mask: mask -/// ) -/// ``` -/// -/// - Parameters: -/// - queries: Query tensor [B, nHeads, L, D] -/// - keys: Raw key tensor to be cached [B, nKVHeads, L, D] -/// - values: Raw value tensor to be cached [B, nKVHeads, L, D] -/// - cache: Cache instance (any type) -/// - scale: Attention scale factor -/// - mask: Attention mask -/// - Returns: Attention output [B, nHeads, L, D] +/// Routes to quantized, TurboQuant, or standard attention based on cache type. +/// Handles cache updating, TurboQuant decode, and mask slicing transparently. public func attentionWithCacheUpdate( queries: MLXArray, keys: MLXArray, @@ -51,6 +55,7 @@ public func attentionWithCacheUpdate( mask: mask ) } + if let quantizedKVCache = cache as? QuantizedKVCacheProtocol { let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized( keys: keys, values: values) @@ -66,12 +71,48 @@ public func attentionWithCacheUpdate( ) } else { let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values) + + var fullKeys = cachedKeys + var fullValues = cachedValues + if let kvCache = cache as? KVCacheSimple, + let pk = kvCache.polarKeys, + let pv = kvCache.polarValues, + kvCache.compressedOffset > 0 { + let historyK = MLXFast.turboDecodeK(packed: pk).asType(cachedKeys.dtype) + let historyV = MLXFast.turboDecodeV(packed: pv).asType(cachedValues.dtype) + var mergedK = historyK + var mergedV = historyV + if kvCache.turboSplitHeads { + let B = historyK.dim(0), H2 = historyK.dim(1), T = historyK.dim(2) + mergedK = historyK.reshaped(B, H2 / 2, T, 512) + mergedV = historyV.reshaped(B, H2 / 2, T, 512) + } + fullKeys = concatenated([mergedK, cachedKeys], axis: 2) + fullValues = concatenated([mergedV, cachedValues], axis: 2) + } + + let targetS = fullKeys.dim(2) + var safeMask = mask + if case .array(let customMask) = mask { + if customMask.dim(-1) != targetS && customMask.dim(-1) > targetS { + let sliced: MLXArray + if customMask.ndim == 2 { + sliced = customMask[0..., .. MLXArray? { public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible { internal var keys: MLXArray? internal var values: MLXArray? + + // MARK: - TurboQuant State + public var turboQuantEnabled: Bool = false + nonisolated(unsafe) private static var turboWarnedHeadDims: Set = [] + public var turboSplitHeads: Bool = false + public var polarKeys: MLXArray? + public var polarValues: MLXArray? + public var residualKeys: MLXArray? + public var residualValues: MLXArray? + public var compressedOffset: Int = 0 + public var turboMinActivationTokens: Int = 2048 + public var turboHotWindowSize: Int = 256 + public var step = 256 public override init() { @@ -381,6 +394,64 @@ public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible { self.keys?[.ellipsis, previous ..< self.offset, 0...] = keys self.values?[.ellipsis, previous ..< self.offset, 0...] = values + // MARK: TurboKV hot-window eviction + if turboQuantEnabled { + let headDim = keys.dim(-1) + let supportedDim = (headDim == 128 || headDim == 256) + let splittableDim = (headDim == 512) + if !supportedDim && !splittableDim { + if !Self.turboWarnedHeadDims.contains(headDim) { + Self.turboWarnedHeadDims.insert(headDim) + print("[TurboKV] head_dim \(headDim) unsupported (requires 128 or 256). Falling back to fp16.") + } + turboQuantEnabled = false + } else if self.offset > turboMinActivationTokens { + let coldEnd = self.offset - turboHotWindowSize + let newColdCount = coldEnd - self.compressedOffset + + if newColdCount >= step { + if let fullK = self.keys, let fullV = self.values { + var coldK = fullK[.ellipsis, self.compressedOffset..