diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..708203782 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -723,6 +723,25 @@ def merge(cls, caches): def empty(self): return self.cache[0] is None + def is_trimmable(self): + return True + + def trim(self, n): + """Trim the cache by n tokens. + + ArraysCache holds compressed recurrent state that cannot be partially + rolled back. When trimming is required, the cache is fully reset so + the recurrent layers recompute from scratch. KVCache layers in the + same hybrid model still benefit from their own trim. + + For exact-match reuse (n=0), this method is not called and the full + recurrent state is preserved — the common server case. + """ + if n <= 0: + return 0 + self.cache = [None] * len(self.cache) + return n + @property def nbytes(self): return sum(c.nbytes for c in self.cache if c is not None) diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index bd1bc75ba..8ebf4e89d 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -216,22 +216,27 @@ def test_trim_cache(self): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 3) - # Can't trim arrays cache + # Trimming arrays cache resets recurrent state cache = [ArraysCache(size=2) for _ in range(2)] for c in cache: c[0] = mx.zeros((5, 5)) c[1] = mx.zeros((5, 5)) num_trimmed = trim_prompt_cache(cache, 7) - self.assertEqual(num_trimmed, 0) + self.assertEqual(num_trimmed, 7) + for c in cache: + self.assertTrue(c.empty()) - # All cache's have to be trimmable + # Hybrid cache (ArraysCache + KVCache) is trimmable cache = [ArraysCache(size=2), KVCache()] cache[0][0] = mx.zeros((5, 5)) cache[0][1] = mx.zeros((5, 5)) x = mx.random.uniform(shape=(1, 8, 10, 4)) cache[1].update_and_fetch(x, x) num_trimmed = trim_prompt_cache(cache, 1) - self.assertEqual(num_trimmed, 0) + self.assertEqual(num_trimmed, 1) + # ArraysCache resets, KVCache trims normally + self.assertTrue(cache[0].empty()) + self.assertEqual(cache[1].offset, 9) cache = [RotatingKVCache(max_size=6) for _ in range(2)] for c in cache: