From b64ff9d9e8ae4d5226a600574f1d79c6693251de Mon Sep 17 00:00:00 2001 From: enjoysport2022 <946691288@qq.com> Date: Fri, 17 Apr 2026 15:10:28 +0000 Subject: [PATCH] feat: add prefix cache for KV cache reuse during inference Add caching infrastructure to avoid redundant computation of prefix embeddings (images + language) when observation hasn't changed between consecutive inference calls. This can provide significant speedup in continuous robot control scenarios. Changes: - Add prefix_cache.py with PrefixCacheKey and PrefixCache classes - Add sample_actions_with_cache() method to Pi0 model - Add enable_prefix_cache parameter to Policy and create_trained_policy - Add clear_cache() and reset() methods to Policy Usage: policy = create_trained_policy(config, checkpoint, enable_prefix_cache=True) result = policy.infer(obs) # Auto-caches prefix policy.reset() # Clear cache for new episode --- src/openpi/models/pi0.py | 84 ++++++++ src/openpi/models/prefix_cache.py | 119 +++++++++++ src/openpi/models/prefix_cache_test.py | 278 +++++++++++++++++++++++++ src/openpi/policies/policy.py | 39 +++- src/openpi/policies/policy_config.py | 7 + 5 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 src/openpi/models/prefix_cache.py create mode 100644 src/openpi/models/prefix_cache_test.py diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index ae7c4590f3..d1b54da25d 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -9,6 +9,7 @@ from openpi.models import model as _model from openpi.models import pi0_config +from openpi.models import prefix_cache as _prefix_cache import openpi.models.gemma as _gemma import openpi.models.siglip as _siglip from openpi.shared import array_typing as at @@ -277,3 +278,86 @@ def cond(carry): x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) return x_0 + + def sample_actions_with_cache( + self, + rng: at.KeyArrayLike, + observation: _model.Observation, + *, + num_steps: int | at.Int[at.Array, ""] = 10, + noise: at.Float[at.Array, "b ah ad"] | None = None, + prefix_cache: _prefix_cache.PrefixCache | None = None, + ) -> tuple[_model.Actions, _prefix_cache.PrefixCache]: + """Sample actions with optional prefix cache reuse. + + This method extends sample_actions() to support caching of the prefix + (image + language) embeddings and KV cache. When the observation hasn't + changed between calls, this can provide significant speedup by avoiding + redundant computation. + + Args: + rng: Random key for noise generation. + observation: The observation containing images, state, and prompt. + num_steps: Number of denoising steps. + noise: Optional pre-generated noise. + prefix_cache: Optional cached prefix from a previous call. + + Returns: + A tuple of (actions, prefix_cache) where prefix_cache can be passed + to subsequent calls for reuse. + """ + observation = _model.preprocess_observation(None, observation, train=False) + dt = -1.0 / num_steps + batch_size = observation.state.shape[0] + if noise is None: + noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim)) + + if prefix_cache is not None and prefix_cache.is_valid_for(observation): + prefix_tokens = prefix_cache.prefix_tokens + prefix_mask = prefix_cache.prefix_mask + kv_cache = prefix_cache.kv_cache + logger.debug("Prefix cache hit, reusing cached KV cache") + else: + prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation) + prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask) + positions = jnp.cumsum(prefix_mask, axis=1) - 1 + _, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions) + + cache_key = _prefix_cache.PrefixCacheKey.from_observation(observation) + prefix_cache = _prefix_cache.PrefixCache( + key=cache_key, + prefix_tokens=prefix_tokens, + prefix_mask=prefix_mask, + prefix_ar_mask=prefix_ar_mask, + kv_cache=kv_cache, + ) + logger.debug("Prefix cache miss, computed new KV cache") + + def step(carry): + x_t, time = carry + suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix( + observation, x_t, jnp.broadcast_to(time, batch_size) + ) + suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask) + prefix_attn_mask_step = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1]) + full_attn_mask = jnp.concatenate([prefix_attn_mask_step, suffix_attn_mask], axis=-1) + positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1 + + (prefix_out, suffix_out), _ = self.PaliGemma.llm( + [None, suffix_tokens], + mask=full_attn_mask, + positions=positions, + kv_cache=kv_cache, + adarms_cond=[None, adarms_cond], + ) + assert prefix_out is None + v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :]) + + return x_t + dt * v_t, time + dt + + def cond(carry): + x_t, time = carry + return time >= -dt / 2 + + x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) + return x_0, prefix_cache diff --git a/src/openpi/models/prefix_cache.py b/src/openpi/models/prefix_cache.py new file mode 100644 index 0000000000..71c508f99d --- /dev/null +++ b/src/openpi/models/prefix_cache.py @@ -0,0 +1,119 @@ +"""Prefix cache for KV cache reuse during continuous inference. + +This module provides caching infrastructure to avoid redundant computation +of prefix embeddings (images + language) when the observation hasn't changed +between consecutive inference calls. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import jax + + from openpi.models import model as _model + +import numpy as np + + +def fast_array_hash(arr: np.ndarray | jax.Array, sample_size: int = 1000) -> str: + """Compute a fast hash of an array using sampling for large arrays. + + For small arrays, computes the hash of the entire content plus metadata. + For large arrays, samples uniformly distributed points to balance + speed and collision resistance. + + Args: + arr: The array to hash. + sample_size: Number of samples for large arrays. + + Returns: + MD5 hex digest of the array content. + """ + arr = np.asarray(arr) + flat = arr.ravel() + + if flat.size <= sample_size: + data = flat.tobytes() + str(arr.shape).encode() + str(arr.dtype).encode() + else: + indices = np.linspace(0, flat.size - 1, sample_size, dtype=np.int64) + sampled = flat[indices] + data = sampled.tobytes() + str(arr.shape).encode() + str(arr.dtype).encode() + + return hashlib.md5(data).hexdigest() + + +@dataclasses.dataclass(frozen=True) +class PrefixCacheKey: + """Cache key based on image and prompt content hashes. + + This is used to determine whether the cached prefix can be reused. + The key is immutable and hashable. + """ + + image_hashes: tuple[tuple[str, str], ...] + prompt_hash: str | None + batch_size: int + + @classmethod + def from_observation(cls, observation: _model.Observation) -> PrefixCacheKey: + """Compute cache key from an Observation. + + Args: + observation: The observation to compute the key from. + + Returns: + A PrefixCacheKey that uniquely identifies this observation's prefix. + """ + image_hashes = [] + for name in sorted(observation.images.keys()): + image = observation.images[name] + h = fast_array_hash(image) + image_hashes.append((name, h)) + + prompt_hash = None + if observation.tokenized_prompt is not None: + prompt_hash = fast_array_hash(observation.tokenized_prompt) + + batch_size = int(np.asarray(observation.state).shape[0]) + + return cls( + image_hashes=tuple(image_hashes), + prompt_hash=prompt_hash, + batch_size=batch_size, + ) + + +@dataclasses.dataclass +class PrefixCache: + """Stores computed prefix results for reuse. + + Contains the cache key and all intermediate results needed to skip + prefix computation on subsequent inference calls. + """ + + key: PrefixCacheKey + + prefix_tokens: jax.Array + prefix_mask: jax.Array + prefix_ar_mask: jax.Array + kv_cache: tuple[jax.Array, jax.Array] + + def is_valid_for(self, observation: _model.Observation) -> bool: + """Check if this cache is valid for the given observation. + + Args: + observation: The observation to check against. + + Returns: + True if the cache can be reused, False otherwise. + """ + new_key = PrefixCacheKey.from_observation(observation) + return self.key == new_key + + def get_prefix_len(self) -> int: + """Get the length of the cached prefix sequence.""" + return int(self.prefix_tokens.shape[1]) diff --git a/src/openpi/models/prefix_cache_test.py b/src/openpi/models/prefix_cache_test.py new file mode 100644 index 0000000000..9740f3d6a8 --- /dev/null +++ b/src/openpi/models/prefix_cache_test.py @@ -0,0 +1,278 @@ +"""Tests for prefix_cache module.""" + +import numpy as np +import pytest + +from openpi.models import prefix_cache + + +class TestFastArrayHash: + """Tests for _fast_array_hash function.""" + + def test_same_array_same_hash(self): + """Same array content produces same hash.""" + arr = np.random.rand(100, 100) + hash1 = prefix_cache.fast_array_hash(arr) + hash2 = prefix_cache.fast_array_hash(arr) + assert hash1 == hash2 + + def test_different_array_different_hash(self): + """Different array content produces different hash.""" + arr1 = np.zeros((100, 100)) + arr2 = np.ones((100, 100)) + hash1 = prefix_cache.fast_array_hash(arr1) + hash2 = prefix_cache.fast_array_hash(arr2) + assert hash1 != hash2 + + def test_small_array_full_hash(self): + """Small arrays use full content for hash.""" + arr = np.array([1, 2, 3]) + hash1 = prefix_cache.fast_array_hash(arr, sample_size=1000) + hash2 = prefix_cache.fast_array_hash(arr, sample_size=1000) + assert hash1 == hash2 + + def test_large_array_sampled_hash(self): + """Large arrays use sampling for hash.""" + arr = np.random.rand(10000) + hash1 = prefix_cache.fast_array_hash(arr, sample_size=100) + hash2 = prefix_cache.fast_array_hash(arr, sample_size=100) + assert hash1 == hash2 + + def test_different_shapes_different_hash(self): + """Arrays with same values but different shapes have different hashes.""" + arr1 = np.ones((2, 3)) + arr2 = np.ones((3, 2)) + hash1 = prefix_cache.fast_array_hash(arr1) + hash2 = prefix_cache.fast_array_hash(arr2) + assert hash1 != hash2 + + +class TestPrefixCacheKey: + """Tests for PrefixCacheKey.""" + + def test_equality(self): + """Same content produces equal keys.""" + key1 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + key2 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + assert key1 == key2 + + def test_different_image_hash(self): + """Different image hashes produce different keys.""" + key1 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + key2 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "def456"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + assert key1 != key2 + + def test_different_prompt_hash(self): + """Different prompt hashes produce different keys.""" + key1 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash_1", + batch_size=1, + ) + key2 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash_2", + batch_size=1, + ) + assert key1 != key2 + + def test_different_batch_size(self): + """Different batch sizes produce different keys.""" + key1 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + key2 = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=2, + ) + assert key1 != key2 + + def test_hashable(self): + """Keys are hashable and can be used in sets/dicts.""" + key = prefix_cache.PrefixCacheKey( + image_hashes=(("img1", "abc123"),), + prompt_hash="prompt_hash", + batch_size=1, + ) + key_set = {key} + assert key in key_set + + +class TestPrefixCacheKeyFromObservation: + """Tests for PrefixCacheKey.from_observation.""" + + @pytest.fixture + def mock_observation(self): + """Create a mock observation for testing.""" + from dataclasses import dataclass + + @dataclass + class MockObservation: + images: dict + image_masks: dict + state: np.ndarray + tokenized_prompt: np.ndarray | None = None + tokenized_prompt_mask: np.ndarray | None = None + + return MockObservation + + def test_same_observation_same_key(self, mock_observation): + """Same observation produces same key.""" + obs = mock_observation( + images={"img1": np.zeros((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + tokenized_prompt=np.array([[1, 2, 3]]), + tokenized_prompt_mask=np.array([[True, True, True]]), + ) + key1 = prefix_cache.PrefixCacheKey.from_observation(obs) + key2 = prefix_cache.PrefixCacheKey.from_observation(obs) + assert key1 == key2 + + def test_different_image_different_key(self, mock_observation): + """Different image produces different key.""" + obs1 = mock_observation( + images={"img1": np.zeros((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + ) + obs2 = mock_observation( + images={"img1": np.ones((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + ) + key1 = prefix_cache.PrefixCacheKey.from_observation(obs1) + key2 = prefix_cache.PrefixCacheKey.from_observation(obs2) + assert key1 != key2 + + def test_no_prompt_key(self, mock_observation): + """Observation without prompt has None prompt_hash.""" + obs = mock_observation( + images={"img1": np.zeros((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + tokenized_prompt=None, + ) + key = prefix_cache.PrefixCacheKey.from_observation(obs) + assert key.prompt_hash is None + + def test_multiple_images_sorted(self, mock_observation): + """Multiple images are sorted by name for consistent ordering.""" + obs1 = mock_observation( + images={ + "img_b": np.zeros((1, 224, 224, 3)), + "img_a": np.ones((1, 224, 224, 3)), + }, + image_masks={ + "img_b": np.array([True]), + "img_a": np.array([True]), + }, + state=np.zeros((1, 14)), + ) + obs2 = mock_observation( + images={ + "img_a": np.ones((1, 224, 224, 3)), + "img_b": np.zeros((1, 224, 224, 3)), + }, + image_masks={ + "img_a": np.array([True]), + "img_b": np.array([True]), + }, + state=np.zeros((1, 14)), + ) + key1 = prefix_cache.PrefixCacheKey.from_observation(obs1) + key2 = prefix_cache.PrefixCacheKey.from_observation(obs2) + assert key1 == key2 + + +class TestPrefixCache: + """Tests for PrefixCache.""" + + @pytest.fixture + def mock_observation(self): + """Create a mock observation for testing.""" + from dataclasses import dataclass + + @dataclass + class MockObservation: + images: dict + image_masks: dict + state: np.ndarray + tokenized_prompt: np.ndarray | None = None + tokenized_prompt_mask: np.ndarray | None = None + + return MockObservation + + def test_is_valid_for_same_observation(self, mock_observation): + """Cache is valid for same observation.""" + obs = mock_observation( + images={"img1": np.zeros((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + ) + key = prefix_cache.PrefixCacheKey.from_observation(obs) + cache = prefix_cache.PrefixCache( + key=key, + prefix_tokens=np.zeros((1, 100, 256)), + prefix_mask=np.ones((1, 100), dtype=bool), + prefix_ar_mask=np.zeros((100,), dtype=bool), + kv_cache=(np.zeros((18, 1, 100, 1, 256)), np.zeros((18, 1, 100, 1, 256))), + ) + assert cache.is_valid_for(obs) + + def test_is_valid_for_different_observation(self, mock_observation): + """Cache is invalid for different observation.""" + obs1 = mock_observation( + images={"img1": np.zeros((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + ) + obs2 = mock_observation( + images={"img1": np.ones((1, 224, 224, 3))}, + image_masks={"img1": np.array([True])}, + state=np.zeros((1, 14)), + ) + key = prefix_cache.PrefixCacheKey.from_observation(obs1) + cache = prefix_cache.PrefixCache( + key=key, + prefix_tokens=np.zeros((1, 100, 256)), + prefix_mask=np.ones((1, 100), dtype=bool), + prefix_ar_mask=np.zeros((100,), dtype=bool), + kv_cache=(np.zeros((18, 1, 100, 1, 256)), np.zeros((18, 1, 100, 1, 256))), + ) + assert not cache.is_valid_for(obs2) + + def test_get_prefix_len(self): + """get_prefix_len returns correct length.""" + key = prefix_cache.PrefixCacheKey( + image_hashes=(), + prompt_hash=None, + batch_size=1, + ) + cache = prefix_cache.PrefixCache( + key=key, + prefix_tokens=np.zeros((1, 150, 256)), + prefix_mask=np.ones((1, 150), dtype=bool), + prefix_ar_mask=np.zeros((150,), dtype=bool), + kv_cache=(np.zeros((18, 1, 150, 1, 256)), np.zeros((18, 1, 150, 1, 256))), + ) + assert cache.get_prefix_len() == 150 diff --git a/src/openpi/policies/policy.py b/src/openpi/policies/policy.py index b9b708bdca..71450deb22 100644 --- a/src/openpi/policies/policy.py +++ b/src/openpi/policies/policy.py @@ -15,6 +15,7 @@ from openpi import transforms as _transforms from openpi.models import model as _model +from openpi.models import prefix_cache as _prefix_cache from openpi.shared import array_typing as at from openpi.shared import nnx_utils @@ -33,6 +34,7 @@ def __init__( metadata: dict[str, Any] | None = None, pytorch_device: str = "cpu", is_pytorch: bool = False, + enable_prefix_cache: bool = False, ): """Initialize the Policy. @@ -46,6 +48,10 @@ def __init__( pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda:0"). Only relevant when is_pytorch=True. is_pytorch: Whether the model is a PyTorch model. If False, assumes JAX model. + enable_prefix_cache: Whether to enable prefix caching for JAX models. + When enabled, the policy caches prefix embeddings and KV cache + to avoid redundant computation when observation hasn't changed. + Only effective for JAX models (ignored for PyTorch). """ self._model = model self._input_transform = _transforms.compose(transforms) @@ -54,6 +60,8 @@ def __init__( self._metadata = metadata or {} self._is_pytorch_model = is_pytorch self._pytorch_device = pytorch_device + self._enable_prefix_cache = enable_prefix_cache and not is_pytorch + self._prefix_cache: _prefix_cache.PrefixCache | None = None if self._is_pytorch_model: self._model = self._model.to(pytorch_device) @@ -62,6 +70,8 @@ def __init__( else: # JAX model setup self._sample_actions = nnx_utils.module_jit(model.sample_actions) + if self._enable_prefix_cache: + self._sample_actions_with_cache = nnx_utils.module_jit(model.sample_actions_with_cache) self._rng = rng or jax.random.key(0) @override @@ -89,9 +99,18 @@ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: observation = _model.Observation.from_dict(inputs) start_time = time.monotonic() + + if self._enable_prefix_cache: + sample_kwargs["prefix_cache"] = self._prefix_cache + actions, self._prefix_cache = self._sample_actions_with_cache( + sample_rng_or_pytorch_device, observation, **sample_kwargs + ) + else: + actions = self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs) + outputs = { "state": inputs["state"], - "actions": self._sample_actions(sample_rng_or_pytorch_device, observation, **sample_kwargs), + "actions": actions, } model_time = time.monotonic() - start_time if self._is_pytorch_model: @@ -103,12 +122,30 @@ def infer(self, obs: dict, *, noise: np.ndarray | None = None) -> dict: # type: outputs["policy_timing"] = { "infer_ms": model_time * 1000, } + if self._enable_prefix_cache: + outputs["policy_timing"]["cache_enabled"] = True return outputs @property def metadata(self) -> dict[str, Any]: return self._metadata + def clear_cache(self) -> None: + """Clear the prefix cache. + + Call this when starting a new episode or when you want to force + recomputation of the prefix embeddings. + """ + self._prefix_cache = None + + def reset(self) -> None: + """Reset the policy state. + + This clears the prefix cache and can be extended by subclasses + to reset additional state. + """ + self.clear_cache() + class PolicyRecorder(_base_policy.BasePolicy): """Records the policy's behavior to disk.""" diff --git a/src/openpi/policies/policy_config.py b/src/openpi/policies/policy_config.py index 6570df05ed..9ca7ded854 100644 --- a/src/openpi/policies/policy_config.py +++ b/src/openpi/policies/policy_config.py @@ -22,6 +22,7 @@ def create_trained_policy( default_prompt: str | None = None, norm_stats: dict[str, transforms.NormStats] | None = None, pytorch_device: str | None = None, + enable_prefix_cache: bool = False, ) -> _policy.Policy: """Create a policy from a trained checkpoint. @@ -37,6 +38,11 @@ def create_trained_policy( from the checkpoint directory. pytorch_device: Device to use for PyTorch models (e.g., "cpu", "cuda", "cuda:0"). If None and is_pytorch=True, will use "cuda" if available, otherwise "cpu". + enable_prefix_cache: Whether to enable prefix caching for JAX models. When enabled, + the policy caches prefix embeddings and KV cache to avoid redundant + computation when observation hasn't changed between inference calls. + This can provide significant speedup in continuous control scenarios. + Only effective for JAX models (ignored for PyTorch). Note: The function automatically detects whether the model is PyTorch-based by checking for the @@ -91,4 +97,5 @@ def create_trained_policy( metadata=train_config.policy_metadata, is_pytorch=is_pytorch, pytorch_device=pytorch_device if is_pytorch else None, + enable_prefix_cache=enable_prefix_cache, )