Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/openpi/models/pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.models import prefix_cache as _prefix_cache
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
Expand Down Expand Up @@ -277,3 +278,86 @@ def cond(carry):

x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0

def sample_actions_with_cache(
self,
rng: at.KeyArrayLike,
observation: _model.Observation,
*,
num_steps: int | at.Int[at.Array, ""] = 10,
noise: at.Float[at.Array, "b ah ad"] | None = None,
prefix_cache: _prefix_cache.PrefixCache | None = None,
) -> tuple[_model.Actions, _prefix_cache.PrefixCache]:
"""Sample actions with optional prefix cache reuse.

This method extends sample_actions() to support caching of the prefix
(image + language) embeddings and KV cache. When the observation hasn't
changed between calls, this can provide significant speedup by avoiding
redundant computation.

Args:
rng: Random key for noise generation.
observation: The observation containing images, state, and prompt.
num_steps: Number of denoising steps.
noise: Optional pre-generated noise.
prefix_cache: Optional cached prefix from a previous call.

Returns:
A tuple of (actions, prefix_cache) where prefix_cache can be passed
to subsequent calls for reuse.
"""
observation = _model.preprocess_observation(None, observation, train=False)
dt = -1.0 / num_steps
batch_size = observation.state.shape[0]
if noise is None:
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))

if prefix_cache is not None and prefix_cache.is_valid_for(observation):
prefix_tokens = prefix_cache.prefix_tokens
prefix_mask = prefix_cache.prefix_mask
kv_cache = prefix_cache.kv_cache
logger.debug("Prefix cache hit, reusing cached KV cache")
else:
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
positions = jnp.cumsum(prefix_mask, axis=1) - 1
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)

cache_key = _prefix_cache.PrefixCacheKey.from_observation(observation)
prefix_cache = _prefix_cache.PrefixCache(
key=cache_key,
prefix_tokens=prefix_tokens,
prefix_mask=prefix_mask,
prefix_ar_mask=prefix_ar_mask,
kv_cache=kv_cache,
)
logger.debug("Prefix cache miss, computed new KV cache")

def step(carry):
x_t, time = carry
suffix_tokens, suffix_mask, suffix_ar_mask, adarms_cond = self.embed_suffix(
observation, x_t, jnp.broadcast_to(time, batch_size)
)
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
prefix_attn_mask_step = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
full_attn_mask = jnp.concatenate([prefix_attn_mask_step, suffix_attn_mask], axis=-1)
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1

(prefix_out, suffix_out), _ = self.PaliGemma.llm(
[None, suffix_tokens],
mask=full_attn_mask,
positions=positions,
kv_cache=kv_cache,
adarms_cond=[None, adarms_cond],
)
assert prefix_out is None
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])

return x_t + dt * v_t, time + dt

def cond(carry):
x_t, time = carry
return time >= -dt / 2

x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
return x_0, prefix_cache
119 changes: 119 additions & 0 deletions src/openpi/models/prefix_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Prefix cache for KV cache reuse during continuous inference.

This module provides caching infrastructure to avoid redundant computation
of prefix embeddings (images + language) when the observation hasn't changed
between consecutive inference calls.
"""

from __future__ import annotations

import dataclasses
import hashlib
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import jax

from openpi.models import model as _model

import numpy as np


def fast_array_hash(arr: np.ndarray | jax.Array, sample_size: int = 1000) -> str:
"""Compute a fast hash of an array using sampling for large arrays.

For small arrays, computes the hash of the entire content plus metadata.
For large arrays, samples uniformly distributed points to balance
speed and collision resistance.

Args:
arr: The array to hash.
sample_size: Number of samples for large arrays.

Returns:
MD5 hex digest of the array content.
"""
arr = np.asarray(arr)
flat = arr.ravel()

if flat.size <= sample_size:
data = flat.tobytes() + str(arr.shape).encode() + str(arr.dtype).encode()
else:
indices = np.linspace(0, flat.size - 1, sample_size, dtype=np.int64)
sampled = flat[indices]
data = sampled.tobytes() + str(arr.shape).encode() + str(arr.dtype).encode()

return hashlib.md5(data).hexdigest()


@dataclasses.dataclass(frozen=True)
class PrefixCacheKey:
"""Cache key based on image and prompt content hashes.

This is used to determine whether the cached prefix can be reused.
The key is immutable and hashable.
"""

image_hashes: tuple[tuple[str, str], ...]
prompt_hash: str | None
batch_size: int

@classmethod
def from_observation(cls, observation: _model.Observation) -> PrefixCacheKey:
"""Compute cache key from an Observation.

Args:
observation: The observation to compute the key from.

Returns:
A PrefixCacheKey that uniquely identifies this observation's prefix.
"""
image_hashes = []
for name in sorted(observation.images.keys()):
image = observation.images[name]
h = fast_array_hash(image)
image_hashes.append((name, h))

prompt_hash = None
if observation.tokenized_prompt is not None:
prompt_hash = fast_array_hash(observation.tokenized_prompt)

batch_size = int(np.asarray(observation.state).shape[0])

return cls(
image_hashes=tuple(image_hashes),
prompt_hash=prompt_hash,
batch_size=batch_size,
)


@dataclasses.dataclass
class PrefixCache:
"""Stores computed prefix results for reuse.

Contains the cache key and all intermediate results needed to skip
prefix computation on subsequent inference calls.
"""

key: PrefixCacheKey

prefix_tokens: jax.Array
prefix_mask: jax.Array
prefix_ar_mask: jax.Array
kv_cache: tuple[jax.Array, jax.Array]

def is_valid_for(self, observation: _model.Observation) -> bool:
"""Check if this cache is valid for the given observation.

Args:
observation: The observation to check against.

Returns:
True if the cache can be reused, False otherwise.
"""
new_key = PrefixCacheKey.from_observation(observation)
return self.key == new_key

def get_prefix_len(self) -> int:
"""Get the length of the cached prefix sequence."""
return int(self.prefix_tokens.shape[1])
Loading