Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,25 @@ def merge(cls, caches):
def empty(self):
return self.cache[0] is None

def is_trimmable(self):
return True

def trim(self, n):
"""Trim the cache by n tokens.

ArraysCache holds compressed recurrent state that cannot be partially
rolled back. When trimming is required, the cache is fully reset so
the recurrent layers recompute from scratch. KVCache layers in the
same hybrid model still benefit from their own trim.

For exact-match reuse (n=0), this method is not called and the full
recurrent state is preserved — the common server case.
"""
if n <= 0:
return 0
self.cache = [None] * len(self.cache)
return n

@property
def nbytes(self):
return sum(c.nbytes for c in self.cache if c is not None)
Expand Down
13 changes: 9 additions & 4 deletions tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,22 +216,27 @@ def test_trim_cache(self):
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)

# Can't trim arrays cache
# Trimming arrays cache resets recurrent state
cache = [ArraysCache(size=2) for _ in range(2)]
for c in cache:
c[0] = mx.zeros((5, 5))
c[1] = mx.zeros((5, 5))
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 0)
self.assertEqual(num_trimmed, 7)
for c in cache:
self.assertTrue(c.empty())

# All cache's have to be trimmable
# Hybrid cache (ArraysCache + KVCache) is trimmable
cache = [ArraysCache(size=2), KVCache()]
cache[0][0] = mx.zeros((5, 5))
cache[0][1] = mx.zeros((5, 5))
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[1].update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 1)
self.assertEqual(num_trimmed, 0)
self.assertEqual(num_trimmed, 1)
# ArraysCache resets, KVCache trims normally
self.assertTrue(cache[0].empty())
self.assertEqual(cache[1].offset, 9)

cache = [RotatingKVCache(max_size=6) for _ in range(2)]
for c in cache:
Expand Down