diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index b84c9d650..eabbc0ff9 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -1312,7 +1312,7 @@ def meta_state(self, v): int, v[:3], ) - self.rotated = bool(v[3]) + self.rotated = v[3] == "True" def is_trimmable(self): return self._offset < self.max_size diff --git a/tests/test_prompt_cache.py b/tests/test_prompt_cache.py index bd1bc75ba..940a71177 100644 --- a/tests/test_prompt_cache.py +++ b/tests/test_prompt_cache.py @@ -98,6 +98,35 @@ def test_save_load_rotating_cache(self): self.assertTrue(mx.array_equal(k, lk)) self.assertTrue(mx.array_equal(v, lv)) + def test_save_load_batch_rotating_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + # rotated=False round-trip (regression: bool("False") was True) + cache = [BatchRotatingKVCache(max_size=10, left_padding=[0])] + cache[0].update_and_fetch( + mx.random.uniform(shape=(1, 4, 3, 4)), + mx.random.uniform(shape=(1, 4, 3, 4)), + ) + self.assertFalse(cache[0].rotated) + + save_prompt_cache(cache_file, cache) + loaded = load_prompt_cache(cache_file) + self.assertEqual(cache[0].rotated, loaded[0].rotated) + self.assertEqual(cache[0].max_size, loaded[0].max_size) + + # rotated=True round-trip — fill past max_size to trigger rotation + cache = [BatchRotatingKVCache(max_size=4, left_padding=[0])] + for _ in range(6): + cache[0].update_and_fetch( + mx.random.uniform(shape=(1, 4, 1, 4)), + mx.random.uniform(shape=(1, 4, 1, 4)), + ) + self.assertTrue(cache[0].rotated) + + save_prompt_cache(cache_file, cache) + loaded = load_prompt_cache(cache_file) + self.assertEqual(cache[0].rotated, loaded[0].rotated) + def test_save_load_mixed_cache(self): cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")