Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,5 @@ iOSInjectionProject/
.idea
.vscode

.claude/
default.profraw
10 changes: 0 additions & 10 deletions Libraries/MLXLLM/Models/Gemma3Text.swift
Original file line number Diff line number Diff line change
Expand Up @@ -416,19 +416,9 @@ public class Gemma3TextModel: Module, LLMModel {
return caches
}

/// Handles prompt processing for sequences
public func prepare(
_ input: LMInput, cache: [KVCache], windowSize: Int? = nil
) throws -> PrepareResult {
let promptTokens = input.text.tokens
let promptCount = promptTokens.dim(0)

guard promptCount > 0 else {
print("Warning: Preparing with empty prompt tokens.")
let emptyToken = MLXArray(Int32(0))[0 ..< 0]
return .tokens(.init(tokens: emptyToken))
}

return .tokens(input.text)
}
}
Expand Down
9 changes: 0 additions & 9 deletions Libraries/MLXLLM/Models/Gemma3nText.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1029,15 +1029,6 @@ public class Gemma3nTextModel: Module, LLMModel {
public func prepare(
_ input: LMInput, cache: [KVCache], windowSize: Int? = nil
) throws -> PrepareResult {
let promptTokens = input.text.tokens
let promptCount = promptTokens.dim(0)

guard promptCount > 0 else {
print("Warning: Preparing with empty prompt tokens.")
let emptyToken = MLXArray(Int32(0))[0 ..< 0]
return .tokens(.init(tokens: emptyToken))
}

return .tokens(input.text)
}
}
Expand Down
7 changes: 6 additions & 1 deletion Libraries/MLXLMCommon/AttentionUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ public func attentionWithCacheUpdate(
mask: mask
)
}
if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
if let turboCache = cache as? TurboQuantKVCache {
return turboCache.compressedAttention(
queries: queries, keys: keys, values: values,
scale: scale, mask: mask
)
} else if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
return quantizedScaledDotProductAttention(
Expand Down
13 changes: 12 additions & 1 deletion Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ public struct GenerateParameters: Sendable {
/// Step to begin using a quantized KV cache when kvBits is non-nil (default: 0)
public var quantizedKVStart: Int

/// KV cache compression scheme. Overrides kvBits when set.
/// Built-in: "affine4", "affine8", "turbo2", "turbo3", "turbo4", "turbo4v2", etc.
public var kvScheme: String?

/// Sampling temperature
public var temperature: Float

Expand Down Expand Up @@ -108,6 +112,7 @@ public struct GenerateParameters: Sendable {
kvBits: Int? = nil,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0,
kvScheme: String? = nil,
temperature: Float = 0.6,
topP: Float = 1.0,
topK: Int = 0,
Expand All @@ -125,6 +130,7 @@ public struct GenerateParameters: Sendable {
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
self.quantizedKVStart = quantizedKVStart
self.kvScheme = kvScheme
self.temperature = temperature
self.topP = topP
self.topK = topK
Expand Down Expand Up @@ -536,6 +542,7 @@ public struct TokenIterator: TokenIteratorProtocol {
let kvBits: Int?
let kvGroupSize: Int
let quantizedKVStart: Int
let kvScheme: String?

// Internal metrics
var promptPrefillTime: TimeInterval = 0.0
Expand Down Expand Up @@ -564,6 +571,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = parameters.kvBits
self.kvGroupSize = parameters.kvGroupSize
self.quantizedKVStart = parameters.quantizedKVStart
self.kvScheme = parameters.kvScheme

self.promptPrefillTime = try measure {
try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize)
Expand Down Expand Up @@ -597,6 +605,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = parameters.kvBits
self.kvGroupSize = parameters.kvGroupSize
self.quantizedKVStart = parameters.quantizedKVStart
self.kvScheme = parameters.kvScheme

self.promptPrefillTime = try measure {
try prepare(input: input, windowSize: parameters.prefillStepSize)
Expand Down Expand Up @@ -630,6 +639,7 @@ public struct TokenIterator: TokenIteratorProtocol {
self.kvBits = nil
self.kvGroupSize = 64
self.quantizedKVStart = 0
self.kvScheme = nil

self.promptPrefillTime = try measure {
try prepare(input: input, windowSize: prefillStepSize)
Expand Down Expand Up @@ -680,7 +690,8 @@ public struct TokenIterator: TokenIteratorProtocol {
cache: &cache,
kvBits: kvBits,
kvGroupSize: kvGroupSize,
quantizedKVStart: quantizedKVStart
quantizedKVStart: quantizedKVStart,
kvScheme: kvScheme
)

return convertToToken(logits: result.logits)
Expand Down
71 changes: 63 additions & 8 deletions Libraries/MLXLMCommon/KVCache.swift
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TurboQuant cache cannot be restored correctly from prompt cache. TurboQuantKVCache is not represented in cacheClassName, so it is serialized as KVCache. On load it is restored as KVCacheSimple, but TurboQuant compressed state carries 3/4 arrays rather than KVCacheSimple’s required 2 arrays, leading to invalid restoration behavior. Add explicit TurboQuant class name mapping and restore path.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in nex push. Added TurboQuantKVCache to cacheClassName and restoreCacheFromMetaState. metaState now carries bits, keyBits, valueBits, and seed so the cache can be reconstructed correctly on load.

Original file line number Diff line number Diff line change
Expand Up @@ -1776,27 +1776,82 @@ public func quantizedScaledDotProductAttention(
/// - kvBits: Number of bits for quantization (nil = no quantization)
/// - kvGroupSize: Group size for quantization
/// - quantizedKVStart: Token count threshold to begin quantizing
/// Resolve a kvScheme string to cache configuration.
/// Returns nil for unrecognized schemes.
public func resolveKVScheme(_ scheme: String?) -> (type: String, keyBits: Int, valueBits: Int)? {
switch scheme {
case "affine4": return ("affine", 4, 4)
case "affine8": return ("affine", 8, 8)
case "turbo2": return ("turbo", 2, 2)
case "turbo3": return ("turbo", 3, 3)
case "turbo4": return ("turbo", 4, 4)
case "turbo8": return ("turbo", 8, 8)
case "turbo4v2": return ("turbo", 4, 2) // asymmetric: 4-bit K, 2-bit V
case "turbo4v3": return ("turbo", 4, 3)
case "turbo0v2": return ("turbo", 0, 2) // raw K (FP16), 2-bit V
case "turbo0v4": return ("turbo", 0, 4) // raw K (FP16), 4-bit V
default: return nil
}
}

public func maybeQuantizeKVCache(
cache: inout [KVCache],
kvBits: Int?,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0
quantizedKVStart: Int = 0,
kvScheme: String? = nil
) {
guard let kvBits = kvBits,
!cache.isEmpty,
// TurboQuant schemes: compress KVCacheSimple layers, skip MambaCache/CacheList
if let scheme = kvScheme, let resolved = resolveKVScheme(scheme), resolved.type == "turbo" {
guard !cache.isEmpty else { return }
// Check if any KVCacheSimple layer is ready for compression
let hasCompressible = cache.contains { $0 is KVCacheSimple && $0.offset > quantizedKVStart }
guard hasCompressible else { return }

for i in 0 ..< cache.count {
if cache[i] is KVCacheSimple, cache[i].offset > quantizedKVStart {
let turbo = TurboQuantKVCache(
bits: resolved.keyBits,
keyBits: resolved.keyBits,
valueBits: resolved.valueBits
)
// Transfer existing KV data, trimmed to actual offset
let actualOffset = cache[i].offset
let state = cache[i].innerState()
if state.count >= 2, actualOffset > 0 {
let keys = state[0][.ellipsis, ..<actualOffset, 0...]
let values = state[1][.ellipsis, ..<actualOffset, 0...]
let _ = turbo.update(keys: keys, values: values)
}
cache[i] = turbo
}
}
return
}

// Affine quantization (existing behavior)
let effectiveBits: Int
let effectiveGroupSize: Int
if let scheme = kvScheme, let resolved = resolveKVScheme(scheme) {
effectiveBits = resolved.keyBits
effectiveGroupSize = kvGroupSize
} else if let kvBits {
effectiveBits = kvBits
effectiveGroupSize = kvGroupSize
} else {
return
}

guard !cache.isEmpty,
!(cache[0] is QuantizedKVCache),
cache[0].offset > quantizedKVStart
else {
return
}

for i in 0 ..< cache.count {
// Handle cache types that support quantization
if let simpleCache = cache[i] as? KVCacheSimple {
cache[i] = simpleCache.toQuantized(groupSize: kvGroupSize, bits: kvBits)
cache[i] = simpleCache.toQuantized(groupSize: effectiveGroupSize, bits: effectiveBits)
}
// TODO: RotatingKVCache.toQuantized() is not implemented yet, like in Python.
// When implemented, add: else if let rotatingCache = cache[i] as? RotatingKVCache { ... }
// MambaCache and CacheList don't use traditional KV quantization
}
}
Loading