Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 74 additions & 33 deletions Libraries/MLXLMCommon/AttentionUtils.swift
Original file line number Diff line number Diff line change
@@ -1,39 +1,43 @@
import Foundation
import MLX

/// Attention utilities that match Python mlx-lm's interface
///
/// This provides a single function that automatically routes to quantized or regular
/// attention based on cache type, matching Python's `scaled_dot_product_attention`
// MARK: - TurboKV Telemetry

enum TurboKVTelemetry {
nonisolated(unsafe) private static var totalTokens: UInt64 = 0
nonisolated(unsafe) private static var totalOrigBytes: UInt64 = 0
nonisolated(unsafe) private static var totalPackedBytes: UInt64 = 0
nonisolated(unsafe) private static var hasLogged = false

static func logOnce(compressedOffset: Int, keys: MLXArray, values: MLXArray, headDim: Int) {
let B = keys.dim(0)
let nKVH = keys.dim(1)
let newTokens = keys.dim(2)
let packedBytes = UInt64(keys.nbytes + values.nbytes)
let origBytes = UInt64(B * nKVH * newTokens * headDim * 2 * 2)
record(tokens: newTokens, origBytes: Int(origBytes), packedBytes: Int(packedBytes))
}

static func record(tokens: Int, origBytes: Int, packedBytes: Int) {
totalTokens += UInt64(tokens)
totalOrigBytes += UInt64(origBytes)
totalPackedBytes += UInt64(packedBytes)
if !hasLogged && totalTokens > 0 {
hasLogged = true
let ratio = totalPackedBytes > 0 ? Double(totalOrigBytes) / Double(totalPackedBytes) : 0
print("[TurboKV] \(totalTokens)t compressed, \(String(format: "%.1f", ratio))x ratio")
}
}
}

typealias TurboKVCacheTelemetry = TurboKVTelemetry

// MARK: - Attention Utilities

/// Automatic attention with cache update
///
/// This function matches Python's `scaled_dot_product_attention` in base.py:
/// - Detects if cache is `QuantizedKVCache` using `isinstance` pattern
/// - Routes to `quantizedScaledDotProductAttention` or `MLXFast.scaledDotProductAttention`
/// - Handles cache updating automatically
/// - Transparent to models - they just call this function
///
/// **Usage in models:**
/// ```swift
/// let output = attentionWithCacheUpdate(
/// queries: queries,
/// keys: keys,
/// values: values,
/// cache: cache,
/// scale: scale,
/// mask: mask
/// )
/// ```
///
/// - Parameters:
/// - queries: Query tensor [B, nHeads, L, D]
/// - keys: Raw key tensor to be cached [B, nKVHeads, L, D]
/// - values: Raw value tensor to be cached [B, nKVHeads, L, D]
/// - cache: Cache instance (any type)
/// - scale: Attention scale factor
/// - mask: Attention mask
/// - Returns: Attention output [B, nHeads, L, D]
/// Routes to quantized, TurboQuant, or standard attention based on cache type.
/// Handles cache updating, TurboQuant decode, and mask slicing transparently.
public func attentionWithCacheUpdate(
queries: MLXArray,
keys: MLXArray,
Expand All @@ -51,6 +55,7 @@ public func attentionWithCacheUpdate(
mask: mask
)
}

if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
Expand All @@ -66,12 +71,48 @@ public func attentionWithCacheUpdate(
)
} else {
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)

var fullKeys = cachedKeys
var fullValues = cachedValues
if let kvCache = cache as? KVCacheSimple,
let pk = kvCache.polarKeys,
let pv = kvCache.polarValues,
kvCache.compressedOffset > 0 {
let historyK = MLXFast.turboDecodeK(packed: pk).asType(cachedKeys.dtype)
let historyV = MLXFast.turboDecodeV(packed: pv).asType(cachedValues.dtype)
var mergedK = historyK
var mergedV = historyV
if kvCache.turboSplitHeads {
let B = historyK.dim(0), H2 = historyK.dim(1), T = historyK.dim(2)
mergedK = historyK.reshaped(B, H2 / 2, T, 512)
mergedV = historyV.reshaped(B, H2 / 2, T, 512)
}
fullKeys = concatenated([mergedK, cachedKeys], axis: 2)
fullValues = concatenated([mergedV, cachedValues], axis: 2)
}

let targetS = fullKeys.dim(2)
var safeMask = mask
if case .array(let customMask) = mask {
if customMask.dim(-1) != targetS && customMask.dim(-1) > targetS {
let sliced: MLXArray
if customMask.ndim == 2 {
sliced = customMask[0..., ..<targetS]
} else if customMask.ndim == 4 {
sliced = customMask[0..., 0..., 0..., ..<targetS]
} else {
fatalError("Unsupported mask dimensionality: \(customMask.ndim)")
}
safeMask = .array(sliced)
}
}

return MLXFast.scaledDotProductAttention(
queries: queries,
keys: cachedKeys,
values: cachedValues,
keys: fullKeys,
values: fullValues,
scale: scale,
mask: mask
mask: safeMask
)
}
}
22 changes: 21 additions & 1 deletion Libraries/MLXLMCommon/ChatSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ public final class ChatSession {
/// Speculative decoding configuration, nil if disabled.
public let speculativeDecoding: SpeculativeDecodingConfig?

/// When true, enables TurboQuant KV cache compression on each KVCacheSimple layer.
public var turboQuantEnabled: Bool = false

/// Minimum token count before TurboQuant compression activates (default 2048).
public var turboMinActivationTokens: Int = 2048

/// Number of recent tokens kept in full fp16 precision (default 256).
public var turboHotWindowSize: Int = 256

/// Initialize the `ChatSession`.
///
/// - Parameters:
Expand Down Expand Up @@ -417,7 +426,8 @@ public final class ChatSession {
[
model,
instructions, processing, tools, toolDispatch,
additionalContext, cache, generateParameters, speculativeDecoding
additionalContext, cache, generateParameters, speculativeDecoding,
turboQuantEnabled, turboMinActivationTokens, turboHotWindowSize
] in
do {
try await cache.update { cache in
Expand Down Expand Up @@ -467,6 +477,16 @@ public final class ChatSession {
messages.append(contentsOf: history)
}

if turboQuantEnabled {
for layer in kvCache {
if let simple = layer as? KVCacheSimple {
simple.turboQuantEnabled = true
simple.turboMinActivationTokens = turboMinActivationTokens
simple.turboHotWindowSize = turboHotWindowSize
}
}
}

// prepare the input
messages.append(message.consume())

Expand Down
71 changes: 71 additions & 0 deletions Libraries/MLXLMCommon/KVCache.swift
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,19 @@ public func createSSMMask(h: MLXArray, cache: MambaCache?) -> MLXArray? {
public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible {
internal var keys: MLXArray?
internal var values: MLXArray?

// MARK: - TurboQuant State
public var turboQuantEnabled: Bool = false
nonisolated(unsafe) private static var turboWarnedHeadDims: Set<Int> = []
public var turboSplitHeads: Bool = false
public var polarKeys: MLXArray?
public var polarValues: MLXArray?
public var residualKeys: MLXArray?
public var residualValues: MLXArray?
public var compressedOffset: Int = 0
public var turboMinActivationTokens: Int = 2048
public var turboHotWindowSize: Int = 256

public var step = 256

public override init() {
Expand Down Expand Up @@ -381,6 +394,64 @@ public class KVCacheSimple: BaseKVCache, CustomDebugStringConvertible {
self.keys?[.ellipsis, previous ..< self.offset, 0...] = keys
self.values?[.ellipsis, previous ..< self.offset, 0...] = values

// MARK: TurboKV hot-window eviction
if turboQuantEnabled {
let headDim = keys.dim(-1)
let supportedDim = (headDim == 128 || headDim == 256)
let splittableDim = (headDim == 512)
if !supportedDim && !splittableDim {
if !Self.turboWarnedHeadDims.contains(headDim) {
Self.turboWarnedHeadDims.insert(headDim)
print("[TurboKV] head_dim \(headDim) unsupported (requires 128 or 256). Falling back to fp16.")
}
turboQuantEnabled = false
} else if self.offset > turboMinActivationTokens {
let coldEnd = self.offset - turboHotWindowSize
let newColdCount = coldEnd - self.compressedOffset

if newColdCount >= step {
if let fullK = self.keys, let fullV = self.values {
var coldK = fullK[.ellipsis, self.compressedOffset..<coldEnd, 0...]
var coldV = fullV[.ellipsis, self.compressedOffset..<coldEnd, 0...]

if headDim == 512 {
turboSplitHeads = true
let B = coldK.dim(0), H = coldK.dim(1), T = coldK.dim(2)
coldK = coldK.reshaped(B, H * 2, T, 256)
coldV = coldV.reshaped(B, H * 2, T, 256)
}

let (qK, qV) = MLXFast.turboQuantEncode(keys: coldK, values: coldV, bits: 3)

if let existingPK = self.polarKeys, let existingPV = self.polarValues {
self.polarKeys = concatenated([existingPK, qK.0], axis: 2)
self.polarValues = concatenated([existingPV, qV.0], axis: 2)
} else {
self.polarKeys = qK.0
self.polarValues = qV.0
}
self.residualKeys = qK.1
self.residualValues = qV.1
self.compressedOffset += newColdCount

let hotK = fullK[.ellipsis, coldEnd..<self.offset, 0...]
let hotV = fullV[.ellipsis, coldEnd..<self.offset, 0...]
let sparK = MLXArray.zeros(
[keys.dim(0), keys.dim(1), step, keys.dim(3)], dtype: keys.dtype)
let sparV = MLXArray.zeros(
[values.dim(0), values.dim(1), step, values.dim(3)], dtype: values.dtype)
self.keys = concatenated([hotK, sparK], axis: 2)
self.values = concatenated([hotV, sparV], axis: 2)
self.offset = turboHotWindowSize

TurboKVCacheTelemetry.logOnce(
compressedOffset: newColdCount, keys: qK.0, values: qV.0,
headDim: turboSplitHeads ? 256 : keys.dim(-1))
}
}
}
}

let returnedKeys = self.keys![.ellipsis, ..<self.offset, 0...]
let returnedValues = self.values![.ellipsis, ..<self.offset, 0...]

Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let package = Package(
targets: ["IntegrationTestHelpers"]),
],
dependencies: [
.package(url: "https://github.com/ml-explore/mlx-swift", .upToNextMinor(from: "0.31.3")),
.package(url: "https://github.com/joelnishanth/mlx-swift", branch: "feature/turboquant"),
.package(url: "https://github.com/swiftlang/swift-syntax.git", "600.0.0" ..< "604.0.0"),
],
targets: [
Expand Down