From 66e3d77a8d069abacc402c08b75f124116719e62 Mon Sep 17 00:00:00 2001 From: Gason Date: Tue, 17 Mar 2026 17:29:47 +0800 Subject: [PATCH 1/3] Fix Git user and push via SSH --- examples/download_dataset.py | 3 + .../libero/convert_libero_data_to_lerobot.py | 111 ++-- prune_distill/README.md | 43 ++ prune_distill/train_prefix_distill.py | 519 ++++++++++++++++++ pyproject.toml | 1 + scripts/compute_norm_stats.py | 9 +- scripts/train_pytorch.py | 7 + src/openpi/models/gemma_pruning.py | 440 +++++++++++++++ src/openpi/models_pytorch/gemma_pytorch.py | 2 +- .../models_pytorch/preprocessing_pytorch.py | 7 +- src/openpi/training/config.py | 59 +- 11 files changed, 1147 insertions(+), 54 deletions(-) create mode 100644 examples/download_dataset.py create mode 100644 prune_distill/README.md create mode 100644 prune_distill/train_prefix_distill.py create mode 100644 src/openpi/models/gemma_pruning.py diff --git a/examples/download_dataset.py b/examples/download_dataset.py new file mode 100644 index 0000000000..0138216a6b --- /dev/null +++ b/examples/download_dataset.py @@ -0,0 +1,3 @@ +from datasets import load_dataset + +dataset = load_dataset("openvla/modified_libero_rlds", cache_dir="./") diff --git a/examples/libero/convert_libero_data_to_lerobot.py b/examples/libero/convert_libero_data_to_lerobot.py index 51db6f138e..c83a348b9e 100644 --- a/examples/libero/convert_libero_data_to_lerobot.py +++ b/examples/libero/convert_libero_data_to_lerobot.py @@ -1,52 +1,60 @@ -""" -Minimal example script for converting a dataset to LeRobot format. - -We use the Libero dataset (stored in RLDS) for this example, but it can be easily -modified for any other data you have saved in a custom format. +"""Convert a local LIBERO RLDS mirror into a local LeRobot dataset. Usage: -uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data - -If you want to push your dataset to the Hugging Face Hub, you can use the following command: -uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub +uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /root/pi_train/modified_libero_rlds -Note: to run the script, you need to install tensorflow_datasets: -`uv pip install tensorflow tensorflow_datasets` - -You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds -The resulting dataset will get saved to the $HF_LEROBOT_HOME directory. -Running this conversion script will take approximately 30 minutes. +This writes a LeRobot dataset under HF_LEROBOT_HOME/. +By default it targets physical-intelligence/libero so the standard pi05_libero +training config can run fully offline from the local cache. """ +from collections.abc import Iterator, Sequence +from io import BytesIO +from pathlib import Path import shutil from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import tensorflow_datasets as tfds +import numpy as np +from PIL import Image +from tfrecord.reader import sequence_loader import tyro -REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub -RAW_DATASET_NAMES = [ +DEFAULT_REPO_NAME = "physical-intelligence/libero" +DEFAULT_RAW_DATASET_NAMES = ( "libero_10_no_noops", "libero_goal_no_noops", "libero_object_no_noops", "libero_spatial_no_noops", -] # For simplicity we will combine multiple Libero datasets into one training dataset +) + + +def _decode_image(value: bytes | np.bytes_) -> np.ndarray: + return np.asarray(Image.open(BytesIO(bytes(value))).convert("RGB")) -def main(data_dir: str, *, push_to_hub: bool = False): - # Clean up any existing dataset in the output directory - output_path = HF_LEROBOT_HOME / REPO_NAME +def _iter_episodes(dataset_root: Path) -> Iterator[dict[str, np.ndarray]]: + for shard_path in sorted((dataset_root / "1.0.0").glob("*.tfrecord-*")): + for context, _ in sequence_loader(str(shard_path), None): + yield context + + +def main( + data_dir: str, + *, + repo_name: str = DEFAULT_REPO_NAME, + raw_dataset_names: Sequence[str] = DEFAULT_RAW_DATASET_NAMES, + max_episodes_per_dataset: int | None = None, +): + output_path = HF_LEROBOT_HOME / repo_name if output_path.exists(): shutil.rmtree(output_path) - # Create LeRobot dataset, define features to store - # OpenPi assumes that proprio is stored in `state` and actions in `action` - # LeRobot assumes that dtype of image data is `image` dataset = LeRobotDataset.create( - repo_id=REPO_NAME, + repo_id=repo_name, robot_type="panda", fps=10, + use_videos=False, features={ "image": { "dtype": "image", @@ -69,35 +77,44 @@ def main(data_dir: str, *, push_to_hub: bool = False): "names": ["actions"], }, }, - image_writer_threads=10, - image_writer_processes=5, ) - # Loop over raw Libero datasets and write episodes to the LeRobot dataset - # You can modify this for your own data format - for raw_dataset_name in RAW_DATASET_NAMES: - raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train") - for episode in raw_dataset: - for step in episode["steps"].as_numpy_iterator(): + base_dir = Path(data_dir) + total_episodes = 0 + for raw_dataset_name in raw_dataset_names: + dataset_root = base_dir / raw_dataset_name + saved_for_dataset = 0 + for context in _iter_episodes(dataset_root): + num_steps = len(context["steps/is_last"]) + actions = np.asarray(context["steps/action"], dtype=np.float32).reshape(num_steps, 7) + states = np.asarray(context["steps/observation/state"], dtype=np.float32).reshape(num_steps, 8) + main_images = context["steps/observation/image"] + wrist_images = context["steps/observation/wrist_image"] + tasks = context["steps/language_instruction"] + + for step_idx in range(num_steps): dataset.add_frame( { - "image": step["observation"]["image"], - "wrist_image": step["observation"]["wrist_image"], - "state": step["observation"]["state"], - "actions": step["action"], - "task": step["language_instruction"].decode(), + "image": _decode_image(main_images[step_idx]), + "wrist_image": _decode_image(wrist_images[step_idx]), + "state": states[step_idx], + "actions": actions[step_idx], + "task": bytes(tasks[step_idx]).decode("utf-8"), } ) + dataset.save_episode() + saved_for_dataset += 1 + total_episodes += 1 + print( + f"saved {raw_dataset_name} episode {saved_for_dataset} -> total={total_episodes}", + flush=True, + ) + + if max_episodes_per_dataset is not None and saved_for_dataset >= max_episodes_per_dataset: + break - # Optionally push to the Hugging Face Hub - if push_to_hub: - dataset.push_to_hub( - tags=["libero", "panda", "rlds"], - private=False, - push_videos=True, - license="apache-2.0", - ) + print(f"Finished writing {total_episodes} episodes to {output_path}") if __name__ == "__main__": diff --git a/prune_distill/README.md b/prune_distill/README.md new file mode 100644 index 0000000000..62ab92af59 --- /dev/null +++ b/prune_distill/README.md @@ -0,0 +1,43 @@ +# Prefix Distillation + +This directory contains a low-memory distillation path that trains `gemma_prune` from the frozen `gemma_2b` prefix branch used by the LIBERO `pi05` checkpoint. + +The runner keeps memory down by: +- freezing SigLIP +- freezing the teacher Gemma-2B +- dropping the action branch from the distillation graph +- mixing real LIBERO batches with fake random batches + +The loss is a weighted sum of: +- hidden-state MSE on all valid prefix tokens +- cosine distance on the same tokens + +The student is warm-started from the teacher by: +- copying the full token embedder and final norm +- copying the first 14 attention and norm layers +- slicing the teacher MLP weights down to the pruned hidden size + +Run it from the repo root: + +```bash +.venv/bin/python prune_distill/train_prefix_distill.py --exp-name gemma_prune_prefix --overwrite +``` + +TensorBoard logs are written to `checkpoints/prune_distill//tensorboard`. + +```bash +.venv/bin/python -m tensorboard.main --logdir checkpoints/prune_distill +``` + +Useful flags: + +```bash +.venv/bin/python prune_distill/train_prefix_distill.py \ + --exp-name gemma_prune_prefix \ + --batch-size 8 \ + --num-train-steps 10000 \ + --real-batch-prob 0.8 \ + --log-interval 20 \ + --save-interval 500 \ + --overwrite +``` diff --git a/prune_distill/train_prefix_distill.py b/prune_distill/train_prefix_distill.py new file mode 100644 index 0000000000..82ae687a85 --- /dev/null +++ b/prune_distill/train_prefix_distill.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import dataclasses +import functools +import json +import logging +import pathlib +import random +import shutil +from typing import Any + +from flax import nnx +from flax import struct +import flax.nnx.bridge as nnx_bridge +import jax +import jax.numpy as jnp +import numpy as np +import optax +import orbax.checkpoint as ocp +from torch.utils.tensorboard import SummaryWriter +import tyro + +from openpi.models import gemma as gemma_teacher +from openpi.models import gemma_pruning +from openpi.models import model as _model +from openpi.models import pi0 as pi0_model +from openpi.models import pi0_config +from openpi.models import siglip +from openpi.shared import nnx_utils +import openpi.training.config as training_config +import openpi.training.data_loader as data_loader + + +LOGGER = logging.getLogger("prune_distill") + + +TRAINABLE_FILTER = nnx.All(nnx.Param, nnx_utils.PathRegex("student/.*")) +FROZEN_FILTER = nnx.All(nnx.Param, nnx_utils.PathRegex("(vision|teacher)/.*")) + + +def teacher_module_kwargs(embed_dtype: str) -> dict[str, Any]: + teacher_cfg = gemma_teacher.get_config("gemma_2b") + return { + "variant": "gemma_2b", + "width": teacher_cfg.width, + "depth": teacher_cfg.depth, + "mlp_dim": teacher_cfg.mlp_dim, + "num_heads": teacher_cfg.num_heads, + "num_kv_heads": teacher_cfg.num_kv_heads, + "head_dim": teacher_cfg.head_dim, + "norm_eps": 1e-6, + "vocab_size": gemma_teacher.PALIGEMMA_VOCAB_SIZE, + "embed_dtype": embed_dtype, + "scan": True, + "remat_policy": "nothing_saveable", + } + + +@dataclasses.dataclass(frozen=True) +class DistillConfig: + exp_name: str = "gemma_prefix_distill" + teacher_checkpoint: str = "/root/pi_train/pi05_libero/params" + output_dir: str = "/root/openpi-wr/checkpoints/prune_distill" + train_config_name: str = "pi05_libero" + batch_size: int = 8 + num_workers: int = 2 + num_train_steps: int = 10_000 + log_interval: int = 20 + save_interval: int = 500 + seed: int = 42 + learning_rate: float = 1e-4 + warmup_steps: int = 200 + weight_decay: float = 1e-2 + real_batch_prob: float = 0.8 + hidden_loss_weight: float = 1.0 + cosine_loss_weight: float = 0.1 + dtype: str = "bfloat16" + overwrite: bool = False + + +@struct.dataclass +class DistillState: + step: jax.Array + params: nnx.State + model_def: nnx.GraphDef[Any] + opt_state: optax.OptState + tx: optax.GradientTransformation = struct.field(pytree_node=False) + + +class PrefixDistillModel(nnx.Module): + def __init__(self, config: DistillConfig, rngs: nnx.Rngs): + student_cfg = gemma_pruning.get_config("gemma_prune") + + self.vision = nnx_bridge.ToNNX( + siglip.Module( + num_classes=teacher_module_kwargs(config.dtype)["width"], + variant="So400m/14", + pool_type="none", + scan=True, + dtype_mm=config.dtype, + ) + ) + self.teacher = nnx_bridge.ToNNX( + gemma_pruning.Module(**teacher_module_kwargs(config.dtype)) + ) + self.student = nnx_bridge.ToNNX( + gemma_pruning.Module(**student_cfg, embed_dtype=config.dtype) + ) + + fake_image = next(iter(pi0_config.Pi0Config(pi05=True).fake_obs().images.values())) + self.vision.lazy_init(fake_image, train=False, rngs=rngs) + self.teacher.lazy_init(rngs=rngs, method="init") + self.student.lazy_init(rngs=rngs, method="init") + + def _embed_images( + self, observation: _model.Observation + ) -> tuple[jax.Array, jax.Array]: + tokens = [] + masks = [] + for name in _model.IMAGE_KEYS: + image_tokens, _ = self.vision(observation.images[name], train=False) + image_tokens = jax.lax.stop_gradient(image_tokens) + tokens.append(image_tokens) + masks.append( + jnp.repeat( + observation.image_masks[name][:, None], + image_tokens.shape[1], + axis=1, + ) + ) + return jnp.concatenate(tokens, axis=1), jnp.concatenate(masks, axis=1) + + def compute_loss( + self, + rng: jax.Array, + observation: _model.Observation, + *, + hidden_loss_weight: float, + cosine_loss_weight: float, + train: bool, + ) -> tuple[jax.Array, dict[str, jax.Array]]: + observation = _model.preprocess_observation(rng, observation, train=train) + image_tokens, image_mask = self._embed_images(observation) + + if observation.tokenized_prompt is None or observation.tokenized_prompt_mask is None: + raise ValueError("Prefix distillation requires tokenized prompts.") + + input_mask = jnp.concatenate([image_mask, observation.tokenized_prompt_mask], axis=1) + ar_mask = jnp.zeros((input_mask.shape[1],), dtype=jnp.bool_) + attn_mask = pi0_model.make_attn_mask(input_mask, ar_mask) + positions = jnp.cumsum(input_mask, axis=1) - 1 + + teacher_hidden, _, _ = self.teacher( + tokens=observation.tokenized_prompt, + embedded_prefix=image_tokens, + positions=positions, + mask=attn_mask, + return_prelogits=True, + deterministic=True, + ) + teacher_hidden = jax.lax.stop_gradient(teacher_hidden.astype(jnp.float32)) + + student_hidden, _, _ = self.student( + tokens=observation.tokenized_prompt, + embedded_prefix=image_tokens, + positions=positions, + mask=attn_mask, + return_prelogits=True, + deterministic=not train, + ) + student_hidden = student_hidden.astype(jnp.float32) + + valid = input_mask.astype(jnp.float32) + denom = jnp.maximum(jnp.sum(valid), 1.0) + + sq_error = jnp.square(student_hidden - teacher_hidden) + hidden_mse = jnp.sum(sq_error * valid[..., None]) / denom + + teacher_norm = teacher_hidden / jnp.maximum(jnp.linalg.norm(teacher_hidden, axis=-1, keepdims=True), 1e-6) + student_norm = student_hidden / jnp.maximum(jnp.linalg.norm(student_hidden, axis=-1, keepdims=True), 1e-6) + cosine_dist = 1.0 - jnp.sum(student_norm * teacher_norm, axis=-1) + cosine_loss = jnp.sum(cosine_dist * valid) / denom + + loss = hidden_loss_weight * hidden_mse + cosine_loss_weight * cosine_loss + return loss, { + "hidden_mse": hidden_mse, + "cosine_loss": cosine_loss, + "valid_tokens": denom, + } + + +class MixedBatchIterator: + def __init__(self, real_iter, fake_iter, *, real_batch_prob: float, seed: int): + self._real_iter = real_iter + self._fake_iter = fake_iter + self._real_batch_prob = real_batch_prob + self._rng = random.Random(seed) + + def __iter__(self): + return self + + def __next__(self): + use_real = self._rng.random() < self._real_batch_prob + if use_real: + return next(self._real_iter), "libero" + return next(self._fake_iter), "random" + + +def init_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + +def _repo_root() -> pathlib.Path: + return pathlib.Path(__file__).resolve().parents[1] + + +def build_data_config(config: DistillConfig) -> training_config.TrainConfig: + repo_root = _repo_root() + base = training_config.get_config(config.train_config_name) + return dataclasses.replace( + base, + batch_size=config.batch_size, + num_workers=config.num_workers, + assets_base_dir=str(repo_root / "assets"), + checkpoint_base_dir=str(repo_root / "checkpoints"), + ) + + +def load_checkpoint_params(params_path: str) -> dict[str, Any]: + LOGGER.info("Loading teacher checkpoint from %s", params_path) + return _model.restore_params(params_path, restore_type=np.ndarray) + + +def extract_siglip_params(teacher_params: dict[str, Any]) -> dict[str, Any]: + return teacher_params["PaliGemma"]["img"] + + +def extract_teacher_prefix_params(teacher_params: dict[str, Any]) -> dict[str, Any]: + llm = teacher_params["PaliGemma"]["llm"] + return { + "embedder": llm["embedder"], + "final_norm": llm["final_norm"], + "layers": { + "attn": { + "attn_vec_einsum": llm["layers"]["attn"]["attn_vec_einsum"], + "kv_einsum": llm["layers"]["attn"]["kv_einsum"], + "q_einsum": llm["layers"]["attn"]["q_einsum"], + }, + "mlp": llm["layers"]["mlp"], + "pre_attention_norm": llm["layers"]["pre_attention_norm"], + "pre_ffw_norm": llm["layers"]["pre_ffw_norm"], + }, + } + + +def extract_student_init_params(teacher_params: dict[str, Any]) -> dict[str, Any]: + llm = teacher_params["PaliGemma"]["llm"] + student_cfg = gemma_pruning.get_config("gemma_prune") + depth = student_cfg.depth + mlp_dim = student_cfg.mlp_dim + return { + "embedder": llm["embedder"], + "final_norm": llm["final_norm"], + "layers": { + "attn": { + "attn_vec_einsum": {"w": llm["layers"]["attn"]["attn_vec_einsum"]["w"][:depth]}, + "kv_einsum": {"w": llm["layers"]["attn"]["kv_einsum"]["w"][:depth]}, + "q_einsum": {"w": llm["layers"]["attn"]["q_einsum"]["w"][:depth]}, + }, + "mlp": { + "gating_einsum": llm["layers"]["mlp"]["gating_einsum"][:depth, :, :, :mlp_dim], + "linear": llm["layers"]["mlp"]["linear"][:depth, :mlp_dim, :], + }, + "pre_attention_norm": {"scale": llm["layers"]["pre_attention_norm"]["scale"][:depth]}, + "pre_ffw_norm": {"scale": llm["layers"]["pre_ffw_norm"]["scale"][:depth]}, + }, + } + + +def init_state(config: DistillConfig) -> DistillState: + rng = jax.random.key(config.seed) + model = PrefixDistillModel(config, rngs=nnx.Rngs(rng)) + graphdef, state = nnx.split(model) + + teacher_params = load_checkpoint_params(config.teacher_checkpoint) + state.replace_by_pure_dict( + { + "vision": extract_siglip_params(teacher_params), + "teacher": extract_teacher_prefix_params(teacher_params), + "student": extract_student_init_params(teacher_params), + } + ) + del teacher_params + + state = nnx_utils.state_map(state, FROZEN_FILTER, lambda p: p.replace(p.value.astype(jnp.bfloat16))) + model = nnx.merge(graphdef, state) + params = nnx.state(model) + + schedule = optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=config.learning_rate, + warmup_steps=config.warmup_steps, + decay_steps=max(config.num_train_steps, config.warmup_steps + 1), + end_value=config.learning_rate * 0.1, + ) + tx = optax.adamw(schedule, weight_decay=config.weight_decay) + + return DistillState( + step=jnp.asarray(0, dtype=jnp.int32), + params=params, + model_def=nnx.graphdef(model), + opt_state=tx.init(params.filter(TRAINABLE_FILTER)), + tx=tx, + ) + + +def train_step( + state: DistillState, + rng: jax.Array, + batch: tuple[_model.Observation, jax.Array], + *, + hidden_loss_weight: float, + cosine_loss_weight: float, +) -> tuple[DistillState, dict[str, jax.Array]]: + model = nnx.merge(state.model_def, state.params) + observation, _ = batch + + def loss_fn(module: PrefixDistillModel): + return module.compute_loss( + rng, + observation, + hidden_loss_weight=hidden_loss_weight, + cosine_loss_weight=cosine_loss_weight, + train=True, + ) + + diff_state = nnx.DiffState(0, TRAINABLE_FILTER) + (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True, argnums=diff_state)(model) + + trainable_params = state.params.filter(TRAINABLE_FILTER) + updates, new_opt_state = state.tx.update(grads, state.opt_state, trainable_params) + new_trainable_params = optax.apply_updates(trainable_params, updates) + nnx.update(model, new_trainable_params) + new_params = nnx.state(model) + + info = { + "loss": loss, + "hidden_mse": metrics["hidden_mse"], + "cosine_loss": metrics["cosine_loss"], + "valid_tokens": metrics["valid_tokens"], + "grad_norm": optax.global_norm(grads), + "param_norm": optax.global_norm(new_trainable_params), + } + return dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state), info + + +def count_params(state: nnx.State, filter_: nnx.filterlib.Filter | None = None) -> int: + subset = state if filter_ is None else state.filter(filter_) + total = 0 + for value in subset.flat_state().values(): + total += int(np.prod(value.value.shape)) + return total + + +def key_param_stats(state: nnx.State) -> dict[str, float]: + wanted = { + "embed": "student/embedder/input_embedding", + "q": "student/layers/attn/q_einsum/w", + "mlp_gate": "student/layers/mlp/gating_einsum", + "mlp_out": "student/layers/mlp/linear", + "final_norm": "student/final_norm/scale", + } + flat = state.flat_state() + stats = {} + for label, path in wanted.items(): + for key, value in flat.items(): + joined = "/".join(str(part) for part in key) + if joined == path: + stats[label] = float(jnp.linalg.norm(value.value.astype(jnp.float32))) + break + return stats + + +def save_student_checkpoint(output_dir: pathlib.Path, state: DistillState) -> None: + model = nnx.merge(state.model_def, state.params) + student_params = {"params": nnx.state(model.student).to_pure_dict()} + ckpt = output_dir / "student" + if ckpt.exists(): + shutil.rmtree(ckpt) + with ocp.PyTreeCheckpointer() as checkpointer: + checkpointer.save(ckpt, student_params) + + +def prepare_output_dir(config: DistillConfig) -> pathlib.Path: + output_dir = pathlib.Path(config.output_dir) / config.exp_name + if output_dir.exists(): + if not config.overwrite: + raise FileExistsError(f"{output_dir} already exists. Pass --overwrite to replace it.") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(config), indent=2)) + return output_dir + + +def log_tensorboard_scalars( + writer: SummaryWriter, + step: int, + info: dict[str, float], + *, + source: str, + key_params: dict[str, float] | None = None, +) -> None: + writer.add_scalar("train/loss", info["loss"], step) + writer.add_scalar("train/hidden_mse", info["hidden_mse"], step) + writer.add_scalar("train/cosine_loss", info["cosine_loss"], step) + writer.add_scalar("train/grad_norm", info["grad_norm"], step) + writer.add_scalar("train/param_norm", info["param_norm"], step) + writer.add_scalar("train/valid_tokens", info["valid_tokens"], step) + writer.add_scalar("train/source_is_libero", 1.0 if source == "libero" else 0.0, step) + if key_params is not None: + for name, value in key_params.items(): + writer.add_scalar(f"key_params/{name}", value, step) + + +def main(config: DistillConfig) -> None: + init_logging() + LOGGER.info("Starting prefix distillation with config: %s", config) + + output_dir = prepare_output_dir(config) + tensorboard_dir = output_dir / "tensorboard" + base_data_config = build_data_config(config) + fake_data_config = dataclasses.replace(base_data_config, data=training_config.FakeDataConfig()) + + real_loader = data_loader.create_data_loader(base_data_config, shuffle=True) + fake_loader = data_loader.create_data_loader(fake_data_config, shuffle=True, skip_norm_stats=True) + batch_iter = MixedBatchIterator(iter(real_loader), iter(fake_loader), real_batch_prob=config.real_batch_prob, seed=config.seed) + + writer = SummaryWriter(log_dir=str(tensorboard_dir)) + try: + state = init_state(config) + total_params = count_params(state.params) + frozen_params = count_params(state.params, FROZEN_FILTER) + trainable_params = count_params(state.params, TRAINABLE_FILTER) + initial_key_params = key_param_stats(state.params) + + LOGGER.info( + "Params: total=%d frozen=%d trainable=%d", + total_params, + frozen_params, + trainable_params, + ) + LOGGER.info("Initial key params: %s", initial_key_params) + writer.add_text( + "run/config", + json.dumps(dataclasses.asdict(config), indent=2), + 0, + ) + writer.add_scalar("params/total", total_params, 0) + writer.add_scalar("params/frozen", frozen_params, 0) + writer.add_scalar("params/trainable", trainable_params, 0) + for name, value in initial_key_params.items(): + writer.add_scalar(f"key_params/{name}", value, 0) + + ptrain_step = jax.jit( + functools.partial( + train_step, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + ), + donate_argnums=(0,), + ) + + rng = jax.random.key(config.seed + 1) + for step in range(config.num_train_steps): + batch, source = next(batch_iter) + rng, step_rng = jax.random.split(rng) + state, info = ptrain_step(state, step_rng, batch) + + step_num = int(state.step) + host_info = jax.tree.map(lambda x: float(x), info) + log_tensorboard_scalars(writer, step_num, host_info, source=source) + + if step_num % config.log_interval == 0 or step_num == 1: + key_params = key_param_stats(state.params) + for name, value in key_params.items(): + writer.add_scalar(f"key_params/{name}", value, step_num) + LOGGER.info( + "step=%d source=%s loss=%.6f hidden_mse=%.6f cosine=%.6f grad_norm=%.6f param_norm=%.6f key_params=%s", + step_num, + source, + host_info["loss"], + host_info["hidden_mse"], + host_info["cosine_loss"], + host_info["grad_norm"], + host_info["param_norm"], + key_params, + ) + + if step_num % config.save_interval == 0 or step_num == config.num_train_steps: + step_dir = output_dir / f"step_{step_num:07d}" + step_dir.mkdir(parents=True, exist_ok=True) + save_student_checkpoint(step_dir, state) + writer.flush() + + final_key_params = key_param_stats(state.params) + LOGGER.info("Finished distillation. Final key params: %s", final_key_params) + for name, value in final_key_params.items(): + writer.add_scalar(f"key_params_final/{name}", value, int(state.step)) + writer.flush() + finally: + writer.close() + + +if __name__ == "__main__": + main(tyro.cli(DistillConfig)) diff --git a/pyproject.toml b/pyproject.toml index c4a06e5328..4b1e613f83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "orbax-checkpoint==0.11.13", "pillow>=11.0.0", "sentencepiece>=0.2.0", + "tensorboard>=2.15.2", "torch==2.7.1", "tqdm-loggable>=0.2", "typing-extensions>=4.12.2", diff --git a/scripts/compute_norm_stats.py b/scripts/compute_norm_stats.py index c8aef87222..0c1d728150 100644 --- a/scripts/compute_norm_stats.py +++ b/scripts/compute_norm_stats.py @@ -90,6 +90,14 @@ def main(config_name: str, max_frames: int | None = None): config = _config.get_config(config_name) data_config = config.data.create(config.assets_dirs, config.model) + if data_config.repo_id is None: + raise ValueError("Data config must have a repo_id") + + output_path = config.assets_dirs / data_config.repo_id + if (output_path / "norm_stats.json").exists(): + print(f"Norm stats already exist at: {output_path}") + return + if data_config.rlds_data_dir is not None: data_loader, num_batches = create_rlds_dataloader( data_config, config.model.action_horizon, config.batch_size, max_frames @@ -108,7 +116,6 @@ def main(config_name: str, max_frames: int | None = None): norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} - output_path = config.assets_dirs / data_config.repo_id print(f"Writing stats to: {output_path}") normalize.save(output_path, norm_stats) diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c7ddd2b595..86b2ff79b5 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -75,6 +75,13 @@ def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = T wandb.init(mode="disabled") return + has_api_key = bool(os.environ.get("WANDB_API_KEY")) + interactive = os.isatty(0) and os.isatty(1) + if not has_api_key and not interactive: + logging.warning("WANDB_API_KEY is not set in a non-interactive session; disabling wandb.") + wandb.init(mode="disabled") + return + ckpt_dir = config.checkpoint_dir if not ckpt_dir.exists(): raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") diff --git a/src/openpi/models/gemma_pruning.py b/src/openpi/models/gemma_pruning.py new file mode 100644 index 0000000000..2dae02198b --- /dev/null +++ b/src/openpi/models/gemma_pruning.py @@ -0,0 +1,440 @@ +# Copyright 2024 Big Vision Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility) +Used for FAST autoregressive policies. +""" + +import dataclasses +from typing import Literal, TypeAlias + +import einops +import flax.linen as nn +import jax +import jax.numpy as jnp +import ml_collections + +import openpi.models.lora as lora +import openpi.shared.array_typing as at + +Variant = Literal["gemma_prune", "gemma_prune_lora"] + +# Layers: 18 → 14 +# Heads: 8 +# MLP: 12_288 +# params: 2.1B → +def get_config(variant): + """Returns config for specified gemma variant.""" + if variant == "gemma_prune": + return ml_collections.ConfigDict( + { + "variant": variant, + "width": 2048, + "depth": 14, + "mlp_dim": 12_288, + "num_heads": 8, + "num_kv_heads": 1, + "head_dim": 256, + "norm_eps": 1e-6, + "vocab_size": 257_152, + "scan": True, + "remat_policy": "nothing_saveable", + } + ) + if variant == "gemma_prune_lora": + return ml_collections.ConfigDict( + { + "variant": variant, + "width": 2048, + "depth": 14, + "mlp_dim": 12_288, + "num_heads": 8, + "num_kv_heads": 1, + "head_dim": 256, + "norm_eps": 1e-6, + "vocab_size": 257_152, + "scan": True, + "remat_policy": "nothing_saveable", + "lora_configs": { + "attn": lora.LoRAConfig(rank=16, alpha=16.0), + "ffn": lora.LoRAConfig(rank=16, alpha=16.0), + }, + } + ) + raise ValueError(f"Unknown variant: {variant}") + + +@at.typecheck +class Einsum(nn.Module): + shape: tuple[int, ...] + + @nn.compact + def __call__(self, eqn, x): + dtype = x.dtype # original dtype, could be half-precision + w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype) + return jnp.einsum(eqn, x, w) + + +@at.typecheck +class RMSNorm(nn.Module): + @nn.compact + def __call__(self, x): + dtype = x.dtype # original dtype, could be half-precision + scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) + var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32 + normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32 + normed_inputs = normed_inputs * ( + 1 + scale + ) # scale by learned parameter in float32 (matches Flax implementation) + return normed_inputs.astype(dtype) # return in original dtype + + +@at.typecheck +class Embedder(nn.Module): + """Embedder module.""" + + vocab_size: int + embed_dim: int + + def setup(self): + self.input_embedding_table = self.param( + "input_embedding", + nn.initializers.zeros_init(), + (self.vocab_size, self.embed_dim), + ) + + def encode(self, x): + x = self.input_embedding_table[(x,)] + x *= jnp.sqrt(self.embed_dim).astype(x.dtype) + return x + + def decode(self, x): + return jnp.dot(x, self.input_embedding_table.T) + + +@at.typecheck +class Attention(nn.Module): + """Attention module.""" + + num_heads: int + num_kv_heads: int + features: int + head_dim: int + + cache_dtype: str | None = None + + lora_config: lora.LoRAConfig | None = None + + def setup(self): + if self.num_kv_heads == self.num_heads: + self.qkv_einsum = lora.Einsum( + shape=(3, self.num_heads, self.features, self.head_dim), + name="qkv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, + ) + else: + self.q_einsum = lora.Einsum( + shape=(self.num_heads, self.features, self.head_dim), + name="q_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, + ) + self.kv_einsum = lora.Einsum( + shape=(2, self.num_kv_heads, self.features, self.head_dim), + name="kv_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), + lora_config=self.lora_config, + ) + self.attn_vec_einsum = lora.Einsum( + shape=(self.num_heads, self.head_dim, self.features), + name="attn_vec_einsum", + init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), + lora_config=self.lora_config, + ) + + def _init_cache(self, k, v, cache_size): + """Initialize KV cache""" + prefill_len = k.shape[1] + pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0)) + cache_dtype = self.cache_dtype or k.dtype + k_cache = jnp.pad(k.astype(cache_dtype), pad_width) + v_cache = jnp.pad(v.astype(cache_dtype), pad_width) + idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len + return idx, k_cache, v_cache + + def _update_cache(self, k, v, idx, k_cache, v_cache): + """Update KV cache with new values""" + assert k.shape[1] == 1, "Only support kv-cache updates of length 1" + indices = (0, idx[0], 0, 0) + cache_dtype = self.cache_dtype or k.dtype + k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices) + v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices) + idx_new = idx + 1 + return idx_new, k_new, v_new + + @nn.compact + def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002 + dtype = x.dtype # original dtype, could be half-precision + if self.num_kv_heads == self.num_heads: + q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x) + else: + q = self.q_einsum("BTD,NDH->BTNH", x) + k, v = self.kv_einsum("BSD,2KDH->2BSKH", x) + + q = _apply_rope(q, positions=positions) # promotes to float32 + q *= self.head_dim**-0.5 + + k = _apply_rope(k, positions=positions) # promotes to float32 + + if kv_cache is None: + idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1]) + else: + idx, k_cache, v_cache = kv_cache + idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache) + + k, v = k_cache, v_cache + kv_cache = (idx, k_cache, v_cache) + + q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads) + logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) + + if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): + raise ValueError( + f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" + ) + + # big_neg = jnp.finfo(logits.dtype).min + big_neg = -2.3819763e38 # See gemma/modules.py + masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) + + probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) + + encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) + encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") + return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache + + +@at.typecheck +class Block(nn.Module): + """Transformer block.""" + + num_heads: int + num_kv_heads: int + embed_dim: int + head_dim: int + hidden_dim: int + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () + cache_dtype: str | None = None + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) + + def setup(self): + self.pre_attention_norm = RMSNorm() + self.attn = Attention( + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + features=self.embed_dim, + head_dim=self.head_dim, + cache_dtype=self.cache_dtype, + lora_config=self.lora_configs.get("attn"), + ) + self.pre_ffw_norm = RMSNorm() + self.mlp = lora.FeedForward( + features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn") + ) + if self.dropout: + self.drop = nn.Dropout(self.dropout, self.dropout_bdims) + else: + self.drop = lambda x, _: x + + def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002 + x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb")) + inputs_normalized = self.pre_attention_norm(x) + attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic) + attn_output = self.drop(attn_output, deterministic) + attn_output += x + residual = attn_output + attn_output = self.pre_ffw_norm(attn_output) + outputs = self.mlp(attn_output) + outputs = self.drop(outputs, deterministic) + outputs = residual + outputs + return outputs, kv_cache + + +KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]] + + +@at.typecheck +class Module(nn.Module): + """gemma model.""" + + variant: str + + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + norm_eps: float + vocab_size: int + embed_dtype: str + + dropout: float = 0.0 + dropout_bdims: tuple[int, ...] = () # Every float is dropped independently. + cache_dtype: str | None = None + + scan: bool = False + remat_policy: str = "none" + lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict) + + @nn.compact + def __call__( + self, + tokens=None, + embedded_prefix=None, + embed_only=False, # noqa: FBT002 + pre_logits=None, + positions=None, + mask=None, + decode=False, # noqa: FBT002 + kv_cache=None, + deterministic=True, # noqa: FBT002 + return_prelogits=False, # noqa: FBT002 + ): + """Embed only, or complete forward pass. + + Args: + tokens: Embedded, then and appended to `embedded_prefix`. Can be None. + embedded_prefix: Optional prefix that is already embedded. + embed_only: Whether to compute embeddings only. + pre_logits: If present computes logits from pre_logits and returns. + positions: Optional `[B, T]` allows to specify the absolute position of + the tokens. + mask: Optional attention mask `[B, T, S]`. + decode: Whether to use kv-cache. Caller must pass masks and positions. + deterministic: Forwarded to all dropout layers. + return_prelogits: Whether to return the pre-logits. + + Returns: + If `embed_only=False`, then `(logits, out)` will be returned. + If `embed_only=True`, then the embeddings will be returned. + If `return_prelogits=True`, then the pre-logits will be returned. + """ + out = {} + + embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder") + + if pre_logits is not None: + x = out["pre_logits"] = pre_logits + logits = out["logits"] = embedder.decode(x) + return logits, out + + x = [] + if embedded_prefix is not None: + x.append(embedded_prefix) + if tokens is not None: + x.append(embedder.encode(tokens)) + + x = jnp.concatenate(x, axis=-2) + x = x.astype(self.embed_dtype) + batch_size, seq_len, width = x.shape + + if embed_only: + return x + + if decode: + assert positions is not None and mask is not None, ( # noqa: PT018 + "Must explicitly pass positions and mask for decoding." + ) + + if positions is None: + positions = jnp.arange(seq_len).astype(jnp.int32)[None, :] + assert positions.shape[1] == x.shape[1], (positions.shape, x.shape) + + if mask is None: + mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len])) + if mask.ndim == 3: + mask = mask[:, None, :, :] + cache_size = max(seq_len, mask.shape[-1]) + assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape + + if self.remat_policy == "none": + block_cls = Block + else: + block_cls = nn.remat( + Block, + prevent_cse=not self.scan, + static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic + policy=getattr(jax.checkpoint_policies, self.remat_policy), + ) + + block_kw = { + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "num_kv_heads": self.num_kv_heads, + "embed_dim": width, + "hidden_dim": self.mlp_dim, + "dropout": self.dropout, + "dropout_bdims": self.dropout_bdims, + "cache_dtype": self.cache_dtype, + "lora_configs": self.lora_configs, + } + layers = self.scope.push("layers") + blocks = [ + nn.scan( + block_cls, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask + length=self.depth, + )(parent=layers, **block_kw) + ] + for block in blocks: + x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic) + + assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check. + out["encoded"] = x + + x = RMSNorm(name="final_norm")(x) + out["pre_logits"] = x + if return_prelogits: + return x, kv_cache, out + + x = embedder.decode(x) + out["logits"] = x + + return x, kv_cache, out + + def init(self): + """Convenience method for initializing all parameters, necessary due to the quirks of linen.""" + self(jnp.zeros((1, 1), dtype=jnp.int32)) + + +def _apply_rope(x, *, positions, max_wavelength=10_000): + """Applies RoPE positions [B, L] to x [B, L, H, D].""" + freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + assert radians.dtype == jnp.float32 + # radians.shape = [...,L,1,d=D/2] + sin, cos = jnp.sin(radians), jnp.cos(radians) + x1, x2 = jnp.split(x, 2, axis=-1) + res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + assert res.dtype == jnp.float32 + return res.astype(x.dtype) diff --git a/src/openpi/models_pytorch/gemma_pytorch.py b/src/openpi/models_pytorch/gemma_pytorch.py index 203b36be8a..3127956326 100644 --- a/src/openpi/models_pytorch/gemma_pytorch.py +++ b/src/openpi/models_pytorch/gemma_pytorch.py @@ -36,7 +36,7 @@ def __init__( vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.intermediate_size = 4304 - vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projection_dim = vlm_config.width vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" vlm_config_hf.vision_config.torch_dtype = "float32" diff --git a/src/openpi/models_pytorch/preprocessing_pytorch.py b/src/openpi/models_pytorch/preprocessing_pytorch.py index 33c94a59b1..d7427d6480 100644 --- a/src/openpi/models_pytorch/preprocessing_pytorch.py +++ b/src/openpi/models_pytorch/preprocessing_pytorch.py @@ -142,9 +142,10 @@ def preprocess_observation_pytorch( image = image * 2.0 - 1.0 # Convert back to [B, C, H, W] format if it was originally channels-first - if is_channels_first: - image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] - + #if is_channels_first: + # image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + if image.ndim == 4 and image.shape[-1] == 3: + image = image.permute(0, 3, 1, 2).contiguous() out_images[key] = image # obtain mask diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index 4ca47e1286..a4b55966d6 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -757,8 +757,63 @@ def __post_init__(self) -> None: ), optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), ema_decay=0.999, - weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi05_base/params"), - pytorch_weight_path="/path/to/your/pytorch_weight_path", + weight_loader=weight_loaders.CheckpointWeightLoader("/root/pi_train/pi05_libero/params"), + pytorch_weight_path="/root/pi_train/", + num_train_steps=30_000, + ), + TrainConfig( + name="pi05_libero_low_mem_local", + model=pi0_config.Pi0Config( + pi05=True, + action_horizon=10, + discrete_state_input=False, + paligemma_variant="gemma_2b_lora", + action_expert_variant="gemma_300m_lora", + ), + data=LeRobotLiberoDataConfig( + repo_id="physical-intelligence/libero", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=16, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=None, + freeze_filter=pi0_config.Pi0Config( + pi05=True, + action_horizon=10, + discrete_state_input=False, + paligemma_variant="gemma_2b_lora", + action_expert_variant="gemma_300m_lora", + ).get_freeze_filter(), + weight_loader=weight_loaders.CheckpointWeightLoader("/root/pi_train/pi05_libero/params"), + pytorch_weight_path="/root/pi_train/", + num_train_steps=30_000, + ), + TrainConfig( + name="pi05_libero_local", + model=pi0_config.Pi0Config(pi05=True, action_horizon=10, discrete_state_input=False), + data=LeRobotLiberoDataConfig( + repo_id="openvla/modified_libero_rlds", + base_config=DataConfig(prompt_from_task=True), + extra_delta_transform=False, + ), + batch_size=256, + lr_schedule=_optimizer.CosineDecaySchedule( + warmup_steps=10_000, + peak_lr=5e-5, + decay_steps=1_000_000, + decay_lr=5e-5, + ), + optimizer=_optimizer.AdamW(clip_gradient_norm=1.0), + ema_decay=0.999, + weight_loader=weight_loaders.CheckpointWeightLoader("/root/pi_train/pi05_libero/params"), + pytorch_weight_path="/root/pi_train/", num_train_steps=30_000, ), # From 448a843629a1f34b239663bef1cfc5c69674b5ca Mon Sep 17 00:00:00 2001 From: Gason Date: Tue, 17 Mar 2026 18:17:47 +0800 Subject: [PATCH 2/3] add resume --- prune_distill/README.md | 9 ++ prune_distill/train_prefix_distill.py | 169 +++++++++++++++++++++++--- src/openpi/models/gemma_pruning.py | 2 +- 3 files changed, 160 insertions(+), 20 deletions(-) diff --git a/prune_distill/README.md b/prune_distill/README.md index 62ab92af59..26cc8652d2 100644 --- a/prune_distill/README.md +++ b/prune_distill/README.md @@ -41,3 +41,12 @@ Useful flags: --save-interval 500 \ --overwrite ``` + +Resume from the latest saved `step_*` checkpoint: + +```bash +.venv/bin/python prune_distill/train_prefix_distill.py \ + --exp-name gemma_prune_prefix \ + --num-train-steps 10000 \ + --resume +``` diff --git a/prune_distill/train_prefix_distill.py b/prune_distill/train_prefix_distill.py index 82ae687a85..78d9641d65 100644 --- a/prune_distill/train_prefix_distill.py +++ b/prune_distill/train_prefix_distill.py @@ -75,6 +75,7 @@ class DistillConfig: hidden_loss_weight: float = 1.0 cosine_loss_weight: float = 0.1 dtype: str = "bfloat16" + resume: bool = False overwrite: bool = False @@ -190,11 +191,13 @@ def compute_loss( class MixedBatchIterator: - def __init__(self, real_iter, fake_iter, *, real_batch_prob: float, seed: int): + def __init__(self, real_iter, fake_iter, *, real_batch_prob: float, seed: int, skip_steps: int = 0): self._real_iter = real_iter self._fake_iter = fake_iter self._real_batch_prob = real_batch_prob self._rng = random.Random(seed) + for _ in range(skip_steps): + self._rng.random() def __iter__(self): return self @@ -385,24 +388,140 @@ def key_param_stats(state: nnx.State) -> dict[str, float]: return stats -def save_student_checkpoint(output_dir: pathlib.Path, state: DistillState) -> None: +def extract_student_params(state: DistillState) -> dict[str, Any]: model = nnx.merge(state.model_def, state.params) - student_params = {"params": nnx.state(model.student).to_pure_dict()} + return nnx.state(model.student).to_pure_dict() + + +def save_student_checkpoint(output_dir: pathlib.Path, student_params: dict[str, Any]) -> None: ckpt = output_dir / "student" if ckpt.exists(): shutil.rmtree(ckpt) with ocp.PyTreeCheckpointer() as checkpointer: - checkpointer.save(ckpt, student_params) + checkpointer.save(ckpt, {"params": student_params}) + + +def save_resume_checkpoint( + output_dir: pathlib.Path, + *, + student_params: dict[str, Any], + opt_state: optax.OptState, + step: int, + train_rng: jax.Array, +) -> None: + ckpt = output_dir / "resume_state" + if ckpt.exists(): + shutil.rmtree(ckpt) + with ocp.PyTreeCheckpointer() as checkpointer: + checkpointer.save( + ckpt, + { + "student_params": student_params, + "opt_state": opt_state, + "step": np.asarray(step, dtype=np.int32), + "train_rng": np.asarray(train_rng), + }, + ) + + +def save_step_checkpoint(output_dir: pathlib.Path, state: DistillState, train_rng: jax.Array) -> None: + student_params = extract_student_params(state) + step = int(state.step) + save_student_checkpoint(output_dir, student_params) + save_resume_checkpoint(output_dir, student_params=student_params, opt_state=state.opt_state, step=step, train_rng=train_rng) + + +def _parse_step_dir(step_dir: pathlib.Path) -> int: + return int(step_dir.name.split("_", 1)[1]) + + +def _latest_step_dir(output_dir: pathlib.Path) -> pathlib.Path | None: + step_dirs = sorted((p for p in output_dir.glob("step_*") if p.is_dir()), key=_parse_step_dir) + if not step_dirs: + return None + return step_dirs[-1] + + +def _has_count_field(node: Any) -> bool: + if hasattr(node, "_fields") and "count" in node._fields: + return True + if dataclasses.is_dataclass(node): + return any(field.name == "count" for field in dataclasses.fields(node)) + return False + + +def _set_opt_state_step(opt_state: optax.OptState, step: int) -> optax.OptState: + def replace_count(node: Any) -> Any: + if hasattr(node, "_fields") and "count" in node._fields: + return node._replace(count=jnp.asarray(step, dtype=jnp.asarray(node.count).dtype)) + if dataclasses.is_dataclass(node): + return dataclasses.replace(node, count=jnp.asarray(step, dtype=jnp.asarray(node.count).dtype)) + return node + + return jax.tree.map(replace_count, opt_state, is_leaf=_has_count_field) + + +def maybe_resume_state( + config: DistillConfig, + output_dir: pathlib.Path, + state: DistillState, +) -> tuple[DistillState, jax.Array]: + if not config.resume: + return state, jax.random.key(config.seed + 1) + + step_dir = _latest_step_dir(output_dir) + if step_dir is None: + raise FileNotFoundError(f"No step_* checkpoints found under {output_dir} to resume from.") + + resume_dir = step_dir / "resume_state" + if resume_dir.exists(): + with ocp.PyTreeCheckpointer() as checkpointer: + restored = checkpointer.restore(resume_dir) + state.params.replace_by_pure_dict({"student": restored["student_params"]}) + resumed_state = dataclasses.replace( + state, + step=jnp.asarray(restored["step"], dtype=jnp.int32), + opt_state=restored["opt_state"], + ) + train_rng = jnp.asarray(restored["train_rng"]) + LOGGER.info("Resumed exact distill state from %s at step=%d", step_dir, int(resumed_state.step)) + return resumed_state, train_rng + + student_params = _model.restore_params(step_dir / "student", restore_type=np.ndarray) + resumed_step = _parse_step_dir(step_dir) + state.params.replace_by_pure_dict({"student": student_params}) + resumed_state = dataclasses.replace( + state, + step=jnp.asarray(resumed_step, dtype=jnp.int32), + opt_state=_set_opt_state_step(state.opt_state, resumed_step), + ) + train_rng = jax.random.fold_in(jax.random.key(config.seed + 1), resumed_step) + LOGGER.warning( + "Resumed from legacy student-only checkpoint %s at step=%d. " + "Student weights were restored, but optimizer moments were reinitialized.", + step_dir, + resumed_step, + ) + return resumed_state, train_rng def prepare_output_dir(config: DistillConfig) -> pathlib.Path: output_dir = pathlib.Path(config.output_dir) / config.exp_name if output_dir.exists(): + if config.resume and config.overwrite: + raise ValueError("Cannot use resume and overwrite at the same time.") if not config.overwrite: - raise FileExistsError(f"{output_dir} already exists. Pass --overwrite to replace it.") - shutil.rmtree(output_dir) + if not config.resume: + raise FileExistsError(f"{output_dir} already exists. Pass --overwrite or --resume.") + else: + shutil.rmtree(output_dir) + elif config.resume: + raise FileNotFoundError(f"{output_dir} does not exist, so there is nothing to resume.") + output_dir.mkdir(parents=True, exist_ok=True) - (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(config), indent=2)) + config_path = output_dir / "config.json" + if not config.resume or not config_path.exists(): + config_path.write_text(json.dumps(dataclasses.asdict(config), indent=2)) return output_dir @@ -437,15 +556,23 @@ def main(config: DistillConfig) -> None: real_loader = data_loader.create_data_loader(base_data_config, shuffle=True) fake_loader = data_loader.create_data_loader(fake_data_config, shuffle=True, skip_norm_stats=True) - batch_iter = MixedBatchIterator(iter(real_loader), iter(fake_loader), real_batch_prob=config.real_batch_prob, seed=config.seed) writer = SummaryWriter(log_dir=str(tensorboard_dir)) try: state = init_state(config) + state, rng = maybe_resume_state(config, output_dir, state) + start_step = int(state.step) + batch_iter = MixedBatchIterator( + iter(real_loader), + iter(fake_loader), + real_batch_prob=config.real_batch_prob, + seed=config.seed, + skip_steps=start_step, + ) total_params = count_params(state.params) frozen_params = count_params(state.params, FROZEN_FILTER) trainable_params = count_params(state.params, TRAINABLE_FILTER) - initial_key_params = key_param_stats(state.params) + current_key_params = key_param_stats(state.params) LOGGER.info( "Params: total=%d frozen=%d trainable=%d", @@ -453,17 +580,22 @@ def main(config: DistillConfig) -> None: frozen_params, trainable_params, ) - LOGGER.info("Initial key params: %s", initial_key_params) + LOGGER.info("Current key params at step=%d: %s", start_step, current_key_params) writer.add_text( "run/config", json.dumps(dataclasses.asdict(config), indent=2), - 0, + start_step, ) - writer.add_scalar("params/total", total_params, 0) - writer.add_scalar("params/frozen", frozen_params, 0) - writer.add_scalar("params/trainable", trainable_params, 0) - for name, value in initial_key_params.items(): - writer.add_scalar(f"key_params/{name}", value, 0) + writer.add_scalar("params/total", total_params, start_step) + writer.add_scalar("params/frozen", frozen_params, start_step) + writer.add_scalar("params/trainable", trainable_params, start_step) + for name, value in current_key_params.items(): + writer.add_scalar(f"key_params/{name}", value, start_step) + + if start_step >= config.num_train_steps: + LOGGER.info("Current step %d is already at or beyond num_train_steps=%d.", start_step, config.num_train_steps) + writer.flush() + return ptrain_step = jax.jit( functools.partial( @@ -474,8 +606,7 @@ def main(config: DistillConfig) -> None: donate_argnums=(0,), ) - rng = jax.random.key(config.seed + 1) - for step in range(config.num_train_steps): + for _ in range(start_step, config.num_train_steps): batch, source = next(batch_iter) rng, step_rng = jax.random.split(rng) state, info = ptrain_step(state, step_rng, batch) @@ -503,7 +634,7 @@ def main(config: DistillConfig) -> None: if step_num % config.save_interval == 0 or step_num == config.num_train_steps: step_dir = output_dir / f"step_{step_num:07d}" step_dir.mkdir(parents=True, exist_ok=True) - save_student_checkpoint(step_dir, state) + save_step_checkpoint(step_dir, state, rng) writer.flush() final_key_params = key_param_stats(state.params) diff --git a/src/openpi/models/gemma_pruning.py b/src/openpi/models/gemma_pruning.py index 2dae02198b..86c796f1f7 100644 --- a/src/openpi/models/gemma_pruning.py +++ b/src/openpi/models/gemma_pruning.py @@ -34,7 +34,7 @@ # Layers: 18 → 14 # Heads: 8 # MLP: 12_288 -# params: 2.1B → +# params: 2.1B → 1.36B def get_config(variant): """Returns config for specified gemma variant.""" if variant == "gemma_prune": From 1609b66b9be9d2e510ea557cf80a8561be43ce11 Mon Sep 17 00:00:00 2001 From: Gason Date: Wed, 18 Mar 2026 13:54:09 +0800 Subject: [PATCH 3/3] Add sensitivity analysis and benchmarking scripts for pi05 models - Introduced `analyze_pi05_sensitivity.py` for conducting sensitivity analysis on model parameters, including functionality for quantization and pruning perturbations. - Added `benchmark_pi05_models.py` to benchmark origin and pruned models on datasets and LIBERO rollouts, with detailed logging and CSV output for metrics. - Implemented configuration management using dataclasses for both scripts to streamline parameter handling. - Enhanced logging for better traceability during benchmarking and sensitivity evaluations. --- prune_distill/README.md | 98 +++- prune_distill/analyze_pi05_sensitivity.py | 518 ++++++++++++++++++ prune_distill/benchmark_pi05_models.py | 617 ++++++++++++++++++++++ prune_distill/train_prefix_distill.py | 290 +++++++--- 4 files changed, 1441 insertions(+), 82 deletions(-) create mode 100644 prune_distill/analyze_pi05_sensitivity.py create mode 100644 prune_distill/benchmark_pi05_models.py diff --git a/prune_distill/README.md b/prune_distill/README.md index 26cc8652d2..9689be8f8c 100644 --- a/prune_distill/README.md +++ b/prune_distill/README.md @@ -6,7 +6,9 @@ The runner keeps memory down by: - freezing SigLIP - freezing the teacher Gemma-2B - dropping the action branch from the distillation graph -- mixing real LIBERO batches with fake random batches +- training only the pruned student on a local LeRobot-format dataset at `/root/flatten_fold_v2` +- reusing the LIBERO normalization stats from `/root/pi_train/pi05_libero/assets` +- limiting the loaded training subset by default to about `50_000` examples The loss is a weighted sum of: - hidden-state MSE on all valid prefix tokens @@ -20,7 +22,11 @@ The student is warm-started from the teacher by: Run it from the repo root: ```bash -.venv/bin/python prune_distill/train_prefix_distill.py --exp-name gemma_prune_prefix --overwrite +.venv/bin/python prune_distill/train_prefix_distill.py \ + --exp-name gemma_prune_prefix \ + --dataset-path /root/flatten_fold_v2 \ + --max-examples 50000 \ + --overwrite ``` TensorBoard logs are written to `checkpoints/prune_distill//tensorboard`. @@ -34,19 +40,105 @@ Useful flags: ```bash .venv/bin/python prune_distill/train_prefix_distill.py \ --exp-name gemma_prune_prefix \ + --dataset-path /root/flatten_fold_v2 \ --batch-size 8 \ + --max-examples 50000 \ --num-train-steps 10000 \ - --real-batch-prob 0.8 \ --log-interval 20 \ --save-interval 500 \ --overwrite ``` +If you still hit disk pressure during dataset materialization, lower the subset further: + +```bash +.venv/bin/python prune_distill/train_prefix_distill.py \ + --dataset-path /root/flatten_fold_v2 \ + --max-examples 10000 +``` + Resume from the latest saved `step_*` checkpoint: ```bash .venv/bin/python prune_distill/train_prefix_distill.py \ --exp-name gemma_prune_prefix \ + --dataset-path /root/flatten_fold_v2 \ --num-train-steps 10000 \ --resume ``` + +## PI0.5 Sensitivity Analysis + +Use the sensitivity analyzer to score which `pi05` student tensors are most fragile for quantization, pruning, and distillation drift. + +It will: +- rank student layers by distillation gradient / Taylor sensitivity +- run fake-quant perturbations on the top-ranked tensors +- run magnitude-prune perturbations on the same tensors +- save `summary.json`, `candidate_scores.csv`, and `family_summary.csv` + +Example: + +```bash +.venv/bin/python prune_distill/analyze_pi05_sensitivity.py \ + --exp-name pi05_sensitivity \ + --dataset-path /liujinxin/ZZF/openpi/datasets/piper/flatten_fold_v2 \ + --student-checkpoint /root/openpi-wr/checkpoints/prune_distill/gemma_prune_prefix/step_0072001/student \ + --max-examples 2048 \ + --max-batches 4 \ + --eval-top-k 24 \ + --quant-bits 8 4 \ + --prune-ratios 0.1 0.3 0.5 \ + --overwrite +``` + +Results are written to `checkpoints/prune_distill/sensitivity/`. + +## PI0.5 Benchmark + +Use the benchmark runner to evaluate: +- the original full `pi05` checkpoint on offline datasets +- the distilled pruned prefix checkpoint on offline teacher-agreement metrics +- the original full `pi05` checkpoint on real LIBERO rollout success + +The current distilled student checkpoint is only a pruned prefix model, so it does not support LIBERO action rollouts yet. + +Offline benchmark on a custom dataset: + +```bash +.venv/bin/python prune_distill/benchmark_pi05_models.py \ + --exp-name pi05_dataset_benchmark \ + --origin-checkpoint-dir /root/pi_train/pi05_libero \ + --pruned-student-checkpoint /root/openpi-wr/checkpoints/prune_distill/gemma_prune_prefix/step_0072001/student \ + --dataset-path /liujinxin/ZZF/openpi/datasets/piper/flatten_fold_v2 \ + --max-examples 50000 \ + --max-eval-examples 256 \ + --overwrite +``` + +LIBERO rollout success for the original full `pi05` checkpoint: + +```bash +.venv/bin/python prune_distill/benchmark_pi05_models.py \ + --exp-name pi05_libero_rollout \ + --origin-checkpoint-dir /root/pi_train/pi05_libero \ + --no-run-dataset-benchmark \ + --run-libero-rollout \ + --libero-task-suite-names libero_spatial libero_object libero_goal libero_10 \ + --libero-num-trials-per-task 10 \ + --overwrite +``` + +Benchmark outputs are written to `checkpoints/prune_distill/benchmark/`. + +If the pruned prefix benchmark hits JAX GPU OOM during model init, rerun it on CPU: + +```bash +JAX_PLATFORMS=cpu .venv/bin/python prune_distill/benchmark_pi05_models.py \ + --exp-name pi05_dataset_benchmark_cpu \ + --origin-checkpoint-dir None \ + --pruned-student-checkpoint /root/openpi-wr/checkpoints/prune_distill/gemma_prune_prefix/step_0072001/student \ + --dataset-path /liujinxin/ZZF/openpi/datasets/piper/flatten_fold_v2 \ + --max-eval-examples 64 \ + --overwrite +``` diff --git a/prune_distill/analyze_pi05_sensitivity.py b/prune_distill/analyze_pi05_sensitivity.py new file mode 100644 index 0000000000..6019195bb3 --- /dev/null +++ b/prune_distill/analyze_pi05_sensitivity.py @@ -0,0 +1,518 @@ +from __future__ import annotations + +import csv +import dataclasses +import json +import logging +import os +import pathlib +import re +import shutil +import sys +from collections.abc import Callable +from typing import Any + +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") + +from flax import nnx +import jax +import jax.numpy as jnp +import numpy as np +import tyro + +from openpi.models import model as _model +from openpi.shared import nnx_utils +from prune_distill import train_prefix_distill as distill + + +LOGGER = logging.getLogger("pi05_sensitivity") + +TARGET_SUFFIXES = ( + "student/embedder/input_embedding", + "student/layers/attn/q_einsum/w", + "student/layers/attn/kv_einsum/w", + "student/layers/attn/attn_vec_einsum/w", + "student/layers/mlp/gating_einsum", + "student/layers/mlp/linear", + "student/layers/pre_attention_norm/scale", + "student/layers/pre_ffw_norm/scale", + "student/final_norm/scale", +) + + +@dataclasses.dataclass(frozen=True) +class SensitivityConfig: + exp_name: str = "pi05_sensitivity" + teacher_checkpoint: str = "/root/pi_train/pi05_libero/params" + student_checkpoint: str | None = None + output_dir: str = "/root/openpi-wr/checkpoints/prune_distill/sensitivity" + train_config_name: str = "pi05_libero" + dataset_path: str = "/root/flatten_fold_v2" + norm_stats_assets_dir: str = "/root/pi_train/pi05_libero/assets" + norm_stats_asset_id: str = "physical-intelligence/libero" + max_examples: int | None = 2_048 + max_episodes: int | None = 8 + batch_size: int = 2 + num_workers: int = 2 + max_batches: int = 4 + eval_top_k: int = 24 + seed: int = 42 + dtype: str = "bfloat16" + hidden_loss_weight: float = 1.0 + cosine_loss_weight: float = 0.1 + quant_bits: tuple[int, ...] = (8, 4) + prune_ratios: tuple[float, ...] = (0.1, 0.3, 0.5) + overwrite: bool = False + + +@dataclasses.dataclass(frozen=True) +class Candidate: + path: str + key: tuple[Any, ...] + layer_idx: int | None + num_params: int + + @property + def name(self) -> str: + if self.layer_idx is None: + return self.path + return f"{self.path}:layer_{self.layer_idx:02d}" + + @property + def family(self) -> str: + return self.path.removeprefix("student/") + + +def init_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + +def to_distill_config(config: SensitivityConfig) -> distill.DistillConfig: + return distill.DistillConfig( + exp_name=config.exp_name, + teacher_checkpoint=config.teacher_checkpoint, + output_dir=str(pathlib.Path(config.output_dir).parent), + train_config_name=config.train_config_name, + dataset_path=config.dataset_path, + norm_stats_assets_dir=config.norm_stats_assets_dir, + norm_stats_asset_id=config.norm_stats_asset_id, + max_examples=config.max_examples, + max_episodes=config.max_episodes, + batch_size=config.batch_size, + num_workers=config.num_workers, + num_train_steps=1, + log_interval=1, + save_interval=1, + seed=config.seed, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + dtype=config.dtype, + ) + + +def prepare_output_dir(config: SensitivityConfig) -> pathlib.Path: + output_dir = pathlib.Path(config.output_dir) / config.exp_name + if output_dir.exists(): + if not config.overwrite: + raise FileExistsError(f"{output_dir} already exists. Pass --overwrite to replace it.") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + +def maybe_load_student_checkpoint(model: distill.PrefixDistillModel, config: SensitivityConfig) -> None: + if config.student_checkpoint is None: + return + student_params = _model.restore_params(config.student_checkpoint, restore_type=np.ndarray) + student_state = nnx.state(model.student) + student_state.replace_by_pure_dict(student_params) + nnx.update(model.student, student_state) + LOGGER.info("Loaded student checkpoint from %s", config.student_checkpoint) + + +def collect_batches(config: SensitivityConfig) -> list[tuple[_model.Observation, jax.Array]]: + distill_config = to_distill_config(config) + train_config = distill.build_data_config(distill_config) + loader = distill.create_distill_data_loader(train_config, distill_config) + batches = [] + batch_iter = iter(loader) + for _ in range(config.max_batches): + try: + batches.append(next(batch_iter)) + except StopIteration: + break + if not batches: + raise ValueError("No batches were loaded for sensitivity analysis.") + LOGGER.info("Collected %d batches for sensitivity scoring.", len(batches)) + return batches + + +def init_model(config: SensitivityConfig) -> distill.PrefixDistillModel: + state = distill.init_state(to_distill_config(config)) + model = nnx.merge(state.model_def, state.params) + maybe_load_student_checkpoint(model, config) + return model + + +def should_analyze_path(path: str) -> bool: + return any(path == suffix for suffix in TARGET_SUFFIXES) + + +def collect_candidates(model: distill.PrefixDistillModel) -> list[Candidate]: + trainable_state = nnx.state(model).filter(distill.TRAINABLE_FILTER) + candidates: list[Candidate] = [] + for key, value in trainable_state.flat_state().items(): + path = "/".join(str(part) for part in key) + if not should_analyze_path(path): + continue + shape = value.value.shape + if "/layers/" in path and shape: + for layer_idx in range(shape[0]): + candidates.append( + Candidate( + path=path, + key=key, + layer_idx=layer_idx, + num_params=int(np.prod(shape[1:])), + ) + ) + else: + candidates.append( + Candidate( + path=path, + key=key, + layer_idx=None, + num_params=int(np.prod(shape)), + ) + ) + LOGGER.info("Collected %d candidate tensors/slices.", len(candidates)) + return candidates + + +def batch_rng(config: SensitivityConfig, batch_idx: int) -> jax.Array: + return jax.random.fold_in(jax.random.key(config.seed + 17), batch_idx) + + +def evaluate_model( + model: distill.PrefixDistillModel, + batches: list[tuple[_model.Observation, jax.Array]], + config: SensitivityConfig, +) -> dict[str, float]: + totals = { + "loss": 0.0, + "hidden_mse": 0.0, + "cosine_loss": 0.0, + "valid_tokens": 0.0, + } + for batch_idx, (observation, _) in enumerate(batches): + loss, metrics = model.compute_loss( + batch_rng(config, batch_idx), + observation, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + train=False, + ) + totals["loss"] += float(loss) + totals["hidden_mse"] += float(metrics["hidden_mse"]) + totals["cosine_loss"] += float(metrics["cosine_loss"]) + totals["valid_tokens"] += float(metrics["valid_tokens"]) + count = float(len(batches)) + return {name: value / count for name, value in totals.items()} + + +def compute_candidate_scores( + model: distill.PrefixDistillModel, + candidates: list[Candidate], + batches: list[tuple[_model.Observation, jax.Array]], + config: SensitivityConfig, +) -> dict[str, dict[str, float]]: + scores = { + candidate.name: { + "grad_norm": 0.0, + "taylor_sum": 0.0, + "taylor_mean": 0.0, + "param_norm": 0.0, + } + for candidate in candidates + } + + for batch_idx, (observation, _) in enumerate(batches): + def loss_fn(module: distill.PrefixDistillModel): + return module.compute_loss( + batch_rng(config, batch_idx), + observation, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + train=False, + ) + + diff_state = nnx.DiffState(0, distill.TRAINABLE_FILTER) + (_, _), grads = nnx.value_and_grad(loss_fn, has_aux=True, argnums=diff_state)(model) + param_flat = nnx.state(model).filter(distill.TRAINABLE_FILTER).flat_state() + grad_flat = grads.flat_state() + + for candidate in candidates: + param = param_flat[candidate.key].value.astype(jnp.float32) + grad = grad_flat[candidate.key].value.astype(jnp.float32) + if candidate.layer_idx is not None: + param = param[candidate.layer_idx] + grad = grad[candidate.layer_idx] + grad_norm = float(jnp.linalg.norm(grad)) + taylor = float(jnp.sum(jnp.abs(param * grad))) + taylor_mean = taylor / max(candidate.num_params, 1) + param_norm = float(jnp.linalg.norm(param)) + item = scores[candidate.name] + item["grad_norm"] += grad_norm + item["taylor_sum"] += taylor + item["taylor_mean"] += taylor_mean + item["param_norm"] += param_norm + + count = float(len(batches)) + for item in scores.values(): + for name in ("grad_norm", "taylor_sum", "taylor_mean", "param_norm"): + item[name] /= count + return scores + + +def fake_quantize(x: jax.Array, bits: int) -> jax.Array: + if bits < 2: + raise ValueError(f"bits must be >= 2, got {bits}") + x32 = x.astype(jnp.float32) + max_abs = jnp.max(jnp.abs(x32)) + if float(max_abs) == 0.0: + return x + qmax = (1 << (bits - 1)) - 1 + scale = max_abs / qmax + quantized = jnp.clip(jnp.round(x32 / scale), -qmax, qmax) * scale + return quantized.astype(x.dtype) + + +def magnitude_prune(x: jax.Array, ratio: float) -> jax.Array: + if not 0.0 <= ratio <= 1.0: + raise ValueError(f"prune ratio must be in [0, 1], got {ratio}") + if ratio == 0.0: + return x + if ratio == 1.0: + return jnp.zeros_like(x) + + x32 = x.astype(jnp.float32) + flat_abs = jnp.abs(x32).reshape(-1) + num_pruned = int(round(ratio * flat_abs.size)) + if num_pruned <= 0: + return x + if num_pruned >= flat_abs.size: + return jnp.zeros_like(x) + + threshold = jnp.partition(flat_abs, num_pruned - 1)[num_pruned - 1] + mask = jnp.abs(x32) > threshold + return jnp.where(mask, x32, 0.0).astype(x.dtype) + + +def perturb_candidate( + model: distill.PrefixDistillModel, + candidate: Candidate, + transform: Callable[[jax.Array], jax.Array], +) -> tuple[nnx.State, jax.Array]: + candidate_filter = nnx.All(nnx.Param, nnx_utils.PathRegex(re.escape(candidate.path))) + candidate_state = nnx.state(model, candidate_filter) + flat = candidate_state.flat_state() + if len(flat) != 1: + raise ValueError(f"Expected exactly one tensor for {candidate.path}, found {list(flat)}") + variable = next(iter(flat.values())) + original = variable.value + updated = transform(original if candidate.layer_idx is None else original[candidate.layer_idx]) + if candidate.layer_idx is None: + variable.value = updated + else: + variable.value = original.at[candidate.layer_idx].set(updated) + nnx.update(model, candidate_state) + return candidate_state, original + + +def restore_candidate( + model: distill.PrefixDistillModel, + candidate_state: nnx.State, + original: jax.Array, +) -> None: + variable = next(iter(candidate_state.flat_state().values())) + variable.value = original + nnx.update(model, candidate_state) + + +def evaluate_perturbation( + model: distill.PrefixDistillModel, + candidate: Candidate, + batches: list[tuple[_model.Observation, jax.Array]], + config: SensitivityConfig, + transform: Callable[[jax.Array], jax.Array], +) -> dict[str, float]: + candidate_state, original = perturb_candidate(model, candidate, transform) + try: + return evaluate_model(model, batches, config) + finally: + restore_candidate(model, candidate_state, original) + + +def metric_key(prefix: str, value: int | float, name: str) -> str: + if isinstance(value, int): + return f"{prefix}_{value}_{name}" + return f"{prefix}_{int(round(value * 100)):02d}p_{name}" + + +def summarize_by_family(records: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[str, dict[str, float]] = {} + counts: dict[str, int] = {} + for record in records: + family = record["family"] + if family not in grouped: + grouped[family] = {"taylor_sum": 0.0, "grad_norm": 0.0} + counts[family] = 0 + grouped[family]["taylor_sum"] += float(record["taylor_sum"]) + grouped[family]["grad_norm"] += float(record["grad_norm"]) + counts[family] += 1 + + rows = [] + for family, values in grouped.items(): + count = counts[family] + rows.append( + { + "family": family, + "count": count, + "mean_taylor_sum": values["taylor_sum"] / count, + "mean_grad_norm": values["grad_norm"] / count, + } + ) + rows.sort(key=lambda row: row["mean_taylor_sum"], reverse=True) + return rows + + +def write_csv(path: pathlib.Path, rows: list[dict[str, Any]]) -> None: + if not rows: + return + fieldnames = list(rows[0].keys()) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def main(config: SensitivityConfig) -> None: + init_logging() + output_dir = prepare_output_dir(config) + (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(config), indent=2)) + + LOGGER.info("Building pi05 sensitivity model and data.") + model = init_model(config) + batches = collect_batches(config) + candidates = collect_candidates(model) + + LOGGER.info("Evaluating baseline distillation metrics.") + baseline = evaluate_model(model, batches, config) + LOGGER.info("Baseline metrics: %s", baseline) + + LOGGER.info("Computing distillation gradient/Taylor sensitivity.") + scores = compute_candidate_scores(model, candidates, batches, config) + + records: list[dict[str, Any]] = [] + for candidate in candidates: + item = scores[candidate.name] + records.append( + { + "name": candidate.name, + "path": candidate.path, + "family": candidate.family, + "layer_idx": candidate.layer_idx, + "num_params": candidate.num_params, + "param_norm": item["param_norm"], + "grad_norm": item["grad_norm"], + "taylor_sum": item["taylor_sum"], + "taylor_mean": item["taylor_mean"], + } + ) + + records.sort(key=lambda row: row["taylor_sum"], reverse=True) + eval_records = records[: config.eval_top_k] + LOGGER.info( + "Running fake-quant and magnitude-prune perturbations on top %d candidates.", + len(eval_records), + ) + + record_map = {record["name"]: record for record in records} + candidate_map = {candidate.name: candidate for candidate in candidates} + for rank, record in enumerate(eval_records, start=1): + candidate = candidate_map[record["name"]] + LOGGER.info( + "[%d/%d] %s taylor_sum=%.4f", + rank, + len(eval_records), + candidate.name, + record["taylor_sum"], + ) + + for bits in config.quant_bits: + perturbed = evaluate_perturbation( + model, + candidate, + batches, + config, + lambda x, bits=bits: fake_quantize(x, bits), + ) + record_map[candidate.name][metric_key("quant", bits, "loss")] = perturbed["loss"] + record_map[candidate.name][metric_key("quant", bits, "loss_delta")] = perturbed["loss"] - baseline["loss"] + record_map[candidate.name][metric_key("quant", bits, "hidden_mse_delta")] = ( + perturbed["hidden_mse"] - baseline["hidden_mse"] + ) + record_map[candidate.name][metric_key("quant", bits, "cosine_delta")] = ( + perturbed["cosine_loss"] - baseline["cosine_loss"] + ) + + for ratio in config.prune_ratios: + perturbed = evaluate_perturbation( + model, + candidate, + batches, + config, + lambda x, ratio=ratio: magnitude_prune(x, ratio), + ) + record_map[candidate.name][metric_key("prune", ratio, "loss")] = perturbed["loss"] + record_map[candidate.name][metric_key("prune", ratio, "loss_delta")] = perturbed["loss"] - baseline["loss"] + record_map[candidate.name][metric_key("prune", ratio, "hidden_mse_delta")] = ( + perturbed["hidden_mse"] - baseline["hidden_mse"] + ) + record_map[candidate.name][metric_key("prune", ratio, "cosine_delta")] = ( + perturbed["cosine_loss"] - baseline["cosine_loss"] + ) + + family_summary = summarize_by_family(records) + + summary = { + "baseline": baseline, + "top_by_taylor_sum": records[: min(20, len(records))], + "family_summary": family_summary[: min(20, len(family_summary))], + } + (output_dir / "summary.json").write_text(json.dumps(summary, indent=2)) + write_csv(output_dir / "candidate_scores.csv", records) + write_csv(output_dir / "family_summary.csv", family_summary) + + LOGGER.info("Top sensitivity candidates:") + for record in records[:10]: + LOGGER.info( + " %s | taylor_sum=%.4f grad_norm=%.4f", + record["name"], + record["taylor_sum"], + record["grad_norm"], + ) + LOGGER.info("Sensitivity analysis written to %s", output_dir) + + +if __name__ == "__main__": + main(tyro.cli(SensitivityConfig)) diff --git a/prune_distill/benchmark_pi05_models.py b/prune_distill/benchmark_pi05_models.py new file mode 100644 index 0000000000..e0ac9f147b --- /dev/null +++ b/prune_distill/benchmark_pi05_models.py @@ -0,0 +1,617 @@ +from __future__ import annotations + +import collections +import csv +import dataclasses +import json +import logging +import math +import os +import pathlib +import shutil +import sys +from typing import Any + +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") + +from flax import nnx +import jax +import numpy as np +from openpi_client import image_tools +import tyro + +from openpi.models import model as _model +from openpi.policies import policy_config +import openpi.training.config as training_config +import openpi.training.data_loader as data_loader +import openpi.transforms as _transforms +from prune_distill import train_prefix_distill as distill + +import lerobot.common.datasets.lerobot_dataset as lerobot_dataset + + +LOGGER = logging.getLogger("pi05_benchmark") +LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0] +LIBERO_ENV_RESOLUTION = 256 + + +@dataclasses.dataclass(frozen=True) +class BenchmarkConfig: + exp_name: str = "pi05_benchmark" + output_dir: str = "/root/openpi-wr/checkpoints/prune_distill/benchmark" + + train_config_name: str = "pi05_libero" + origin_checkpoint_dir: str | None = "/root/pi_train/pi05_libero" + teacher_checkpoint: str = "/root/pi_train/pi05_libero/params" + pruned_student_checkpoint: str | None = None + + dataset_path: str = "/root/flatten_fold_v2" + norm_stats_assets_dir: str = "/root/pi_train/pi05_libero/assets" + norm_stats_asset_id: str = "physical-intelligence/libero" + max_examples: int | None = 50_000 + max_episodes: int | None = None + max_eval_examples: int = 256 + batch_size: int = 4 + num_workers: int = 2 + + run_dataset_benchmark: bool = True + run_libero_rollout: bool = False + + sample_actions_num_steps: int = 10 + disable_policy_norm_stats: bool = False + action_tolerance: float = 0.05 + + libero_task_suite_names: tuple[str, ...] = ("libero_spatial",) + libero_num_trials_per_task: int = 10 + libero_max_tasks: int | None = None + libero_replan_steps: int = 5 + libero_num_steps_wait: int = 10 + libero_resize_size: int = 224 + libero_video_out_dir: str | None = None + + hidden_loss_weight: float = 1.0 + cosine_loss_weight: float = 0.1 + dtype: str = "bfloat16" + seed: int = 7 + overwrite: bool = False + + +def init_logging() -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%H:%M:%S", + ) + + +def write_csv(path: pathlib.Path, rows: list[dict[str, Any]]) -> None: + if not rows: + return + fieldnames = list(rows[0].keys()) + with path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +def _augment_oom_error(exc: Exception, context: str) -> RuntimeError: + err = RuntimeError( + f"{context} failed due to JAX/GPU OOM. " + "Retry with CPU fallback, for example: `JAX_PLATFORMS=cpu .venv/bin/python ...`" + ) + err.__cause__ = exc + return err + + +def prepare_output_dir(config: BenchmarkConfig) -> pathlib.Path: + output_dir = pathlib.Path(config.output_dir) / config.exp_name + if output_dir.exists(): + if not config.overwrite: + raise FileExistsError(f"{output_dir} already exists. Pass --overwrite to replace it.") + shutil.rmtree(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir + + +def resolve_policy_checkpoint_dir(checkpoint_dir: str) -> pathlib.Path: + path = pathlib.Path(checkpoint_dir).expanduser().resolve() + if path.name == "params" and path.parent.exists(): + return path.parent + if (path / "params").exists(): + return path + raise FileNotFoundError(f"Expected a checkpoint directory containing params/: {checkpoint_dir}") + + +def resolve_teacher_checkpoint(path: str) -> str: + resolved = pathlib.Path(path).expanduser().resolve() + if resolved.name == "params": + return str(resolved) + if (resolved / "params").exists(): + return str((resolved / "params").resolve()) + raise FileNotFoundError(f"Expected teacher checkpoint params at {path}") + + +def resolve_student_checkpoint(path: str) -> pathlib.Path: + resolved = pathlib.Path(path).expanduser().resolve() + if resolved.name == "student": + return resolved + if (resolved / "student").exists(): + return (resolved / "student").resolve() + raise FileNotFoundError(f"Expected pruned student checkpoint at {path}") + + +def to_distill_config(config: BenchmarkConfig) -> distill.DistillConfig: + return distill.DistillConfig( + exp_name=config.exp_name, + teacher_checkpoint=resolve_teacher_checkpoint(config.teacher_checkpoint), + output_dir=str(pathlib.Path(config.output_dir).parent), + train_config_name=config.train_config_name, + dataset_path=config.dataset_path, + norm_stats_assets_dir=config.norm_stats_assets_dir, + norm_stats_asset_id=config.norm_stats_asset_id, + max_examples=config.max_examples, + max_episodes=config.max_episodes, + batch_size=config.batch_size, + num_workers=config.num_workers, + num_train_steps=1, + log_interval=1, + save_interval=1, + seed=config.seed, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + dtype=config.dtype, + ) + + +@dataclasses.dataclass(frozen=True) +class DatasetResources: + train_config: training_config.TrainConfig + data_config: training_config.DataConfig + dataset: Any + dataset_meta: lerobot_dataset.LeRobotDatasetMetadata + + +def create_dataset_resources(config: BenchmarkConfig) -> DatasetResources: + distill_config = to_distill_config(config) + train_config = distill.build_data_config(distill_config) + data_config = train_config.data.create(train_config.assets_dirs, train_config.model) + + dataset_root = pathlib.Path(config.dataset_path).expanduser().resolve() + dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(dataset_root.name, root=dataset_root) + selected_episodes = distill.select_episode_subset(distill_config, dataset_meta) + + dataset = lerobot_dataset.LeRobotDataset( + dataset_root.name, + root=dataset_root, + episodes=selected_episodes, + delta_timestamps={ + key: [t / dataset_meta.fps for t in range(train_config.model.action_horizon)] + for key in data_config.action_sequence_keys + }, + ) + data_config = distill.adapt_data_config_to_dataset(data_config, dataset) + if data_config.prompt_from_task: + dataset = data_loader.TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) + + return DatasetResources( + train_config=train_config, + data_config=data_config, + dataset=dataset, + dataset_meta=dataset_meta, + ) + + +def choose_policy_norm_stats( + data_config: training_config.DataConfig, + sample: dict[str, Any], + *, + disable_policy_norm_stats: bool, +) -> dict[str, Any] | None: + if disable_policy_norm_stats: + LOGGER.warning("Policy norm stats disabled by config.") + return {} + + norm_stats = data_config.norm_stats + if norm_stats is None: + LOGGER.warning("No norm stats available. Using no-op normalization for policy inference.") + return {} + + state_stats = norm_stats.get("state") + if state_stats is None: + return norm_stats + + sample_state = np.asarray(sample["state"]) + sample_state_dim = int(sample_state.shape[-1]) + stats_state_dim = int(state_stats.mean.shape[-1]) + if sample_state_dim > stats_state_dim: + LOGGER.warning( + "Dataset state dim (%d) exceeds available norm stats dim (%d). Using no-op normalization for policy inference.", + sample_state_dim, + stats_state_dim, + ) + return {} + return norm_stats + + +def _safe_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if hasattr(value, "item"): + try: + return str(value.item()) + except Exception: + pass + return str(value) + + +def align_actions(pred: np.ndarray, target: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + pred = np.asarray(pred, dtype=np.float32) + target = np.asarray(target, dtype=np.float32) + if pred.ndim == 1: + pred = pred[None, :] + if target.ndim == 1: + target = target[None, :] + horizon = min(pred.shape[0], target.shape[0]) + dims = min(pred.shape[1], target.shape[1]) + if horizon <= 0 or dims <= 0: + raise ValueError(f"Cannot align action arrays with shapes {pred.shape} and {target.shape}") + return pred[:horizon, :dims], target[:horizon, :dims] + + +def aggregate_dataset_rows(rows: list[dict[str, Any]]) -> dict[str, Any]: + if not rows: + return { + "examples_evaluated": 0, + } + return { + "examples_evaluated": len(rows), + "avg_first_action_l2": float(np.mean([row["first_action_l2"] for row in rows])), + "avg_chunk_l2": float(np.mean([row["chunk_l2"] for row in rows])), + "avg_chunk_l1": float(np.mean([row["chunk_l1"] for row in rows])), + "first_action_match_rate": float(np.mean([row["first_action_match"] for row in rows])), + "chunk_match_rate": float(np.mean([row["chunk_match"] for row in rows])), + "avg_infer_ms": float(np.mean([row["infer_ms"] for row in rows])), + "avg_compare_horizon": float(np.mean([row["compare_horizon"] for row in rows])), + "avg_compare_dims": float(np.mean([row["compare_dims"] for row in rows])), + } + + +def benchmark_origin_dataset(config: BenchmarkConfig, output_dir: pathlib.Path) -> dict[str, Any] | None: + if config.origin_checkpoint_dir is None: + LOGGER.info("Origin pi05 checkpoint not provided. Skipping offline dataset benchmark.") + return None + + resources = create_dataset_resources(config) + if len(resources.dataset) == 0: + raise ValueError("Dataset is empty.") + + sample = resources.dataset[0] + norm_stats = choose_policy_norm_stats( + resources.data_config, + sample, + disable_policy_norm_stats=config.disable_policy_norm_stats, + ) + try: + policy = policy_config.create_trained_policy( + resources.train_config, + resolve_policy_checkpoint_dir(config.origin_checkpoint_dir), + repack_transforms=resources.data_config.repack_transforms, + norm_stats=norm_stats, + sample_kwargs={"num_steps": config.sample_actions_num_steps}, + ) + except Exception as e: + if "RESOURCE_EXHAUSTED" in str(e) or "Out of memory" in str(e): + raise _augment_oom_error(e, "Origin pi05 offline dataset benchmark") + raise + + num_examples = min(len(resources.dataset), config.max_eval_examples) + rows: list[dict[str, Any]] = [] + LOGGER.info("Running offline origin pi05 dataset benchmark on %d examples.", num_examples) + for idx in range(num_examples): + item = resources.dataset[idx] + result = policy.infer(item) + pred_actions, target_actions = align_actions(result["actions"], item["actions"]) + diff = pred_actions - target_actions + first_action_l2 = float(np.linalg.norm(diff[0])) + chunk_l2 = float(np.sqrt(np.mean(np.square(diff)))) + chunk_l1 = float(np.mean(np.abs(diff))) + rows.append( + { + "index": idx, + "task": _safe_string(item.get("task")), + "prompt": _safe_string(item.get("prompt")), + "compare_horizon": int(pred_actions.shape[0]), + "compare_dims": int(pred_actions.shape[1]), + "first_action_l2": first_action_l2, + "chunk_l2": chunk_l2, + "chunk_l1": chunk_l1, + "first_action_match": float(first_action_l2 <= config.action_tolerance), + "chunk_match": float(chunk_l2 <= config.action_tolerance), + "infer_ms": float(result["policy_timing"]["infer_ms"]), + } + ) + + summary = aggregate_dataset_rows(rows) + write_csv(output_dir / "origin_pi05_dataset_metrics.csv", rows) + return summary + + +def maybe_load_pruned_student(model: distill.PrefixDistillModel, checkpoint: pathlib.Path) -> None: + student_params = _model.restore_params(checkpoint, restore_type=np.ndarray) + student_state = nnx.state(model.student) + student_state.replace_by_pure_dict(student_params) + nnx.update(model.student, student_state) + + +def benchmark_pruned_dataset(config: BenchmarkConfig, output_dir: pathlib.Path) -> dict[str, Any] | None: + if config.pruned_student_checkpoint is None: + LOGGER.info("Pruned student checkpoint not provided. Skipping pruned offline benchmark.") + return None + + distill_config = to_distill_config(config) + train_config = distill.build_data_config(distill_config) + loader = distill.create_distill_data_loader(train_config, distill_config) + + try: + state = distill.init_state(distill_config) + except Exception as e: + if "RESOURCE_EXHAUSTED" in str(e) or "Out of memory" in str(e): + raise _augment_oom_error(e, "Pruned prefix offline benchmark") + raise + model = nnx.merge(state.model_def, state.params) + maybe_load_pruned_student(model, resolve_student_checkpoint(config.pruned_student_checkpoint)) + + rows: list[dict[str, Any]] = [] + LOGGER.info("Running pruned prefix offline benchmark on up to %d batches.", config.batch_size) + for batch_idx, (observation, _) in enumerate(loader): + if batch_idx >= max(1, math.ceil(config.max_eval_examples / max(config.batch_size, 1))): + break + loss, metrics = model.compute_loss( + jax.random.fold_in(jax.random.key(config.seed + 101), batch_idx), + observation, + hidden_loss_weight=config.hidden_loss_weight, + cosine_loss_weight=config.cosine_loss_weight, + train=False, + ) + rows.append( + { + "batch_idx": batch_idx, + "loss": float(loss), + "hidden_mse": float(metrics["hidden_mse"]), + "cosine_loss": float(metrics["cosine_loss"]), + "valid_tokens": float(metrics["valid_tokens"]), + } + ) + + if not rows: + return {"batches_evaluated": 0} + + summary = { + "batches_evaluated": len(rows), + "avg_loss": float(np.mean([row["loss"] for row in rows])), + "avg_hidden_mse": float(np.mean([row["hidden_mse"] for row in rows])), + "avg_cosine_loss": float(np.mean([row["cosine_loss"] for row in rows])), + "avg_valid_tokens": float(np.mean([row["valid_tokens"] for row in rows])), + } + write_csv(output_dir / "pruned_prefix_dataset_metrics.csv", rows) + return summary + + +def _quat2axisangle(quat: np.ndarray) -> np.ndarray: + if quat[3] > 1.0: + quat[3] = 1.0 + elif quat[3] < -1.0: + quat[3] = -1.0 + + den = np.sqrt(1.0 - quat[3] * quat[3]) + if math.isclose(den, 0.0): + return np.zeros(3) + return (quat[:3] * 2.0 * math.acos(quat[3])) / den + + +def _libero_max_steps(task_suite_name: str) -> int: + suite_to_steps = { + "libero_spatial": 220, + "libero_object": 280, + "libero_goal": 300, + "libero_10": 520, + "libero_90": 400, + } + if task_suite_name not in suite_to_steps: + raise ValueError(f"Unknown LIBERO task suite: {task_suite_name}") + return suite_to_steps[task_suite_name] + + +def benchmark_origin_libero(config: BenchmarkConfig, output_dir: pathlib.Path) -> dict[str, Any] | None: + if not config.run_libero_rollout: + return None + if config.origin_checkpoint_dir is None: + LOGGER.info("Origin pi05 checkpoint not provided. Skipping LIBERO rollout benchmark.") + return None + + try: + import imageio + from libero.libero import benchmark + from libero.libero import get_libero_path + from libero.libero.envs import OffScreenRenderEnv + except ImportError as e: + raise ImportError("LIBERO rollout benchmark requires the LIBERO env dependencies to be installed.") from e + + train_config = training_config.get_config(config.train_config_name) + try: + policy = policy_config.create_trained_policy( + train_config, + resolve_policy_checkpoint_dir(config.origin_checkpoint_dir), + sample_kwargs={"num_steps": config.sample_actions_num_steps}, + ) + except Exception as e: + if "RESOURCE_EXHAUSTED" in str(e) or "Out of memory" in str(e): + raise _augment_oom_error(e, "Origin pi05 LIBERO rollout benchmark") + raise + + suite_summaries: list[dict[str, Any]] = [] + episode_rows: list[dict[str, Any]] = [] + video_dir = pathlib.Path(config.libero_video_out_dir).resolve() if config.libero_video_out_dir else None + if video_dir is not None: + video_dir.mkdir(parents=True, exist_ok=True) + + benchmark_dict = benchmark.get_benchmark_dict() + np.random.seed(config.seed) + for suite_name in config.libero_task_suite_names: + task_suite = benchmark_dict[suite_name]() + num_tasks = task_suite.n_tasks + if config.libero_max_tasks is not None: + num_tasks = min(num_tasks, config.libero_max_tasks) + max_steps = _libero_max_steps(suite_name) + + total_episodes = 0 + total_successes = 0 + for task_id in range(num_tasks): + task = task_suite.get_task(task_id) + initial_states = task_suite.get_task_init_states(task_id) + task_description = task.language + task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file + env = OffScreenRenderEnv( + bddl_file_name=task_bddl_file, + camera_heights=LIBERO_ENV_RESOLUTION, + camera_widths=LIBERO_ENV_RESOLUTION, + ) + env.seed(config.seed) + + task_episodes = 0 + task_successes = 0 + for episode_idx in range(config.libero_num_trials_per_task): + env.reset() + obs = env.set_init_state(initial_states[episode_idx]) + action_plan = collections.deque() + replay_images = [] + done = False + t = 0 + + while t < max_steps + config.libero_num_steps_wait: + if t < config.libero_num_steps_wait: + obs, _, done, _ = env.step(LIBERO_DUMMY_ACTION) + t += 1 + continue + + img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1]) + wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1]) + img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, config.libero_resize_size, config.libero_resize_size)) + wrist_img = image_tools.convert_to_uint8( + image_tools.resize_with_pad(wrist_img, config.libero_resize_size, config.libero_resize_size) + ) + if video_dir is not None: + replay_images.append(img) + + if not action_plan: + element = { + "observation/image": img, + "observation/wrist_image": wrist_img, + "observation/state": np.concatenate( + ( + obs["robot0_eef_pos"], + _quat2axisangle(obs["robot0_eef_quat"].copy()), + obs["robot0_gripper_qpos"], + ) + ), + "prompt": str(task_description), + } + action_chunk = policy.infer(element)["actions"] + if len(action_chunk) < config.libero_replan_steps: + raise ValueError( + f"Policy only predicted {len(action_chunk)} actions, expected at least {config.libero_replan_steps}." + ) + action_plan.extend(action_chunk[: config.libero_replan_steps]) + + action = action_plan.popleft() + obs, _, done, _ = env.step(action.tolist()) + t += 1 + if done: + task_successes += 1 + total_successes += 1 + break + + if video_dir is not None and replay_images: + suffix = "success" if done else "failure" + suite_segment = suite_name.replace(" ", "_") + task_segment = task_description.replace(" ", "_") + imageio.mimwrite( + video_dir / f"{suite_segment}_task{task_id:02d}_ep{episode_idx:03d}_{task_segment}_{suffix}.mp4", + [np.asarray(x) for x in replay_images], + fps=10, + ) + + task_episodes += 1 + total_episodes += 1 + episode_rows.append( + { + "suite": suite_name, + "task_id": task_id, + "task": task_description, + "episode_idx": episode_idx, + "success": float(done), + "steps": t, + } + ) + + suite_summaries.append( + { + "suite": suite_name, + "episodes": total_episodes, + "successes": total_successes, + "success_rate": float(total_successes / total_episodes) if total_episodes else 0.0, + } + ) + + LOGGER.info("LIBERO suite %s success rate: %.4f", suite_name, suite_summaries[-1]["success_rate"]) + + write_csv(output_dir / "origin_pi05_libero_rollout.csv", episode_rows) + return { + "suites": suite_summaries, + } + + +def main(config: BenchmarkConfig) -> None: + init_logging() + output_dir = prepare_output_dir(config) + (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(config), indent=2)) + + summary: dict[str, Any] = { + "origin_pi05_dataset": None, + "pruned_prefix_dataset": None, + "origin_pi05_libero": None, + "notes": [], + } + + if config.run_dataset_benchmark: + LOGGER.info("Starting offline dataset benchmark.") + summary["origin_pi05_dataset"] = benchmark_origin_dataset(config, output_dir) + summary["pruned_prefix_dataset"] = benchmark_pruned_dataset(config, output_dir) + if config.pruned_student_checkpoint is None: + summary["notes"].append("Pruned prefix checkpoint not provided, so only origin dataset metrics were run.") + else: + summary["notes"].append("Offline dataset benchmark disabled.") + + if config.run_libero_rollout: + LOGGER.info("Starting LIBERO rollout benchmark.") + summary["origin_pi05_libero"] = benchmark_origin_libero(config, output_dir) + if config.pruned_student_checkpoint is not None: + summary["notes"].append( + "Pruned student rollout success is not supported yet because the saved student checkpoint is only a distilled prefix model, not a full pi05 action policy." + ) + else: + summary["notes"].append("LIBERO rollout benchmark disabled.") + + (output_dir / "summary.json").write_text(json.dumps(summary, indent=2)) + LOGGER.info("Benchmark results written to %s", output_dir) + + +if __name__ == "__main__": + main(tyro.cli(BenchmarkConfig)) diff --git a/prune_distill/train_prefix_distill.py b/prune_distill/train_prefix_distill.py index 78d9641d65..1b7e7f7c90 100644 --- a/prune_distill/train_prefix_distill.py +++ b/prune_distill/train_prefix_distill.py @@ -5,7 +5,6 @@ import json import logging import pathlib -import random import shutil from typing import Any @@ -14,6 +13,7 @@ import flax.nnx.bridge as nnx_bridge import jax import jax.numpy as jnp +import lerobot.common.datasets.lerobot_dataset as lerobot_dataset import numpy as np import optax import orbax.checkpoint as ocp @@ -29,6 +29,7 @@ from openpi.shared import nnx_utils import openpi.training.config as training_config import openpi.training.data_loader as data_loader +import openpi.transforms as _transforms LOGGER = logging.getLogger("prune_distill") @@ -62,6 +63,11 @@ class DistillConfig: teacher_checkpoint: str = "/root/pi_train/pi05_libero/params" output_dir: str = "/root/openpi-wr/checkpoints/prune_distill" train_config_name: str = "pi05_libero" + dataset_path: str = "/root/flatten_fold_v2" + norm_stats_assets_dir: str = "/root/pi_train/pi05_libero/assets" + norm_stats_asset_id: str = "physical-intelligence/libero" + max_examples: int | None = 50_000 + max_episodes: int | None = None batch_size: int = 8 num_workers: int = 2 num_train_steps: int = 10_000 @@ -71,7 +77,6 @@ class DistillConfig: learning_rate: float = 1e-4 warmup_steps: int = 200 weight_decay: float = 1e-2 - real_batch_prob: float = 0.8 hidden_loss_weight: float = 1.0 cosine_loss_weight: float = 0.1 dtype: str = "bfloat16" @@ -190,25 +195,6 @@ def compute_loss( } -class MixedBatchIterator: - def __init__(self, real_iter, fake_iter, *, real_batch_prob: float, seed: int, skip_steps: int = 0): - self._real_iter = real_iter - self._fake_iter = fake_iter - self._real_batch_prob = real_batch_prob - self._rng = random.Random(seed) - for _ in range(skip_steps): - self._rng.random() - - def __iter__(self): - return self - - def __next__(self): - use_real = self._rng.random() < self._real_batch_prob - if use_real: - return next(self._real_iter), "libero" - return next(self._fake_iter), "random" - - def init_logging() -> None: logging.basicConfig( level=logging.INFO, @@ -224,8 +210,17 @@ def _repo_root() -> pathlib.Path: def build_data_config(config: DistillConfig) -> training_config.TrainConfig: repo_root = _repo_root() base = training_config.get_config(config.train_config_name) + custom_data = dataclasses.replace( + base.data, + repo_id=config.dataset_path, + assets=training_config.AssetsConfig( + assets_dir=config.norm_stats_assets_dir, + asset_id=config.norm_stats_asset_id, + ), + ) return dataclasses.replace( base, + data=custom_data, batch_size=config.batch_size, num_workers=config.num_workers, assets_base_dir=str(repo_root / "assets"), @@ -233,6 +228,180 @@ def build_data_config(config: DistillConfig) -> training_config.TrainConfig: ) +def _dataset_root(config: DistillConfig) -> pathlib.Path: + return pathlib.Path(config.dataset_path).expanduser().resolve() + + +def _dataset_repo_id(dataset_root: pathlib.Path) -> str: + return dataset_root.name + + +def select_episode_subset( + config: DistillConfig, + dataset_meta: lerobot_dataset.LeRobotDatasetMetadata, +) -> list[int] | None: + if config.max_examples is None and config.max_episodes is None: + return None + + selected: list[int] = [] + total_examples = 0 + episodes = sorted(dataset_meta.episodes.items()) + for episode_idx, episode in episodes: + if config.max_episodes is not None and len(selected) >= config.max_episodes: + break + + length = int(episode["length"]) + if config.max_examples is not None and selected and total_examples + length > config.max_examples: + break + + selected.append(int(episode_idx)) + total_examples += length + + if config.max_examples is not None and total_examples >= config.max_examples: + break + + if not selected and episodes: + first_idx, first_episode = episodes[0] + selected = [int(first_idx)] + total_examples = int(first_episode["length"]) + + LOGGER.info( + "Using %d episodes with about %d examples from %s", + len(selected), + total_examples, + _dataset_root(config), + ) + return selected + + +def _make_repack_transform(mapping: dict[str, str]) -> _transforms.Group: + return _transforms.Group(inputs=[_transforms.RepackTransform(mapping)]) + + +def _first_present(keys: set[str], *candidates: str) -> str | None: + for key in candidates: + if key in keys: + return key + return None + + +def adapt_data_config_to_dataset( + data_config: training_config.DataConfig, + dataset: lerobot_dataset.LeRobotDataset, +) -> training_config.DataConfig: + raw_keys = set(dataset.features) + preview_keys = sorted( + key + for key in raw_keys + if key in {"actions", "prompt", "state", "task_index"} or any(token in key for token in ("image", "state", "hand", "head", "camera")) + ) + LOGGER.info("Detected raw dataset keys: %s", preview_keys[:16]) + + base_image_key = _first_present( + raw_keys, + "observation/image", + "observation.image", + "observation/images/base_0_rgb", + "observation.images.base_0_rgb", + "front_head", + "image", + ) + wrist_image_key = _first_present( + raw_keys, + "observation/wrist_image", + "observation.wrist_image", + "observation/images/left_wrist_0_rgb", + "observation.images.left_wrist_0_rgb", + "left_hand", + "right_hand", + "wrist_image", + ) + state_key = _first_present( + raw_keys, + "observation/state", + "observation.state", + "state", + ) + + repack_transforms = data_config.repack_transforms + if ( + base_image_key == "observation/image" + and wrist_image_key == "observation/wrist_image" + and state_key == "observation/state" + ): + LOGGER.info("Dataset already uses observation/* LIBERO keys. Disabling repack transform.") + repack_transforms = _transforms.Group() + elif base_image_key is not None and wrist_image_key is not None and state_key is not None: + LOGGER.info( + "Adapting dataset schema to LIBERO keys: base=%s wrist=%s state=%s", + base_image_key, + wrist_image_key, + state_key, + ) + mapping = { + "observation/image": base_image_key, + "observation/wrist_image": wrist_image_key, + "observation/state": state_key, + "actions": "actions", + } + if "prompt" in raw_keys or data_config.prompt_from_task: + mapping["prompt"] = "prompt" + repack_transforms = _make_repack_transform(mapping) + else: + LOGGER.warning( + "Could not infer a custom dataset image/state schema from keys %s. Keeping the configured repack transform.", + preview_keys[:16], + ) + + prompt_from_task = data_config.prompt_from_task + if "prompt" in raw_keys and prompt_from_task: + LOGGER.info("Dataset already contains a prompt column. Disabling prompt_from_task.") + prompt_from_task = False + + return dataclasses.replace( + data_config, + repack_transforms=repack_transforms, + prompt_from_task=prompt_from_task, + ) + + +def create_distill_data_loader( + train_config: training_config.TrainConfig, + distill_config: DistillConfig, +): + data_config = train_config.data.create(train_config.assets_dirs, train_config.model) + dataset_root = _dataset_root(distill_config) + dataset_meta = lerobot_dataset.LeRobotDatasetMetadata(_dataset_repo_id(dataset_root), root=dataset_root) + selected_episodes = select_episode_subset(distill_config, dataset_meta) + + dataset = lerobot_dataset.LeRobotDataset( + _dataset_repo_id(dataset_root), + root=dataset_root, + episodes=selected_episodes, + delta_timestamps={ + key: [t / dataset_meta.fps for t in range(train_config.model.action_horizon)] + for key in data_config.action_sequence_keys + }, + ) + data_config = adapt_data_config_to_dataset(data_config, dataset) + if data_config.prompt_from_task: + dataset = data_loader.TransformedDataset(dataset, [_transforms.PromptFromLeRobotTask(dataset_meta.tasks)]) + + # Prefix distillation only uses images and tokenized prompts, so we can skip dataset-specific + # state/action normalization and avoid mismatches with custom robot state dimensions. + dataset = data_loader.transform_dataset(dataset, data_config, skip_norm_stats=True) + local_batch_size = train_config.batch_size // jax.process_count() + torch_loader = data_loader.TorchDataLoader( + dataset, + local_batch_size=local_batch_size, + shuffle=True, + num_workers=train_config.num_workers, + seed=train_config.seed, + framework="jax", + ) + return data_loader.DataLoaderImpl(data_config, torch_loader) + + def load_checkpoint_params(params_path: str) -> dict[str, Any]: LOGGER.info("Loading teacher checkpoint from %s", params_path) return _model.restore_params(params_path, restore_type=np.ndarray) @@ -404,31 +573,16 @@ def save_student_checkpoint(output_dir: pathlib.Path, student_params: dict[str, def save_resume_checkpoint( output_dir: pathlib.Path, *, - student_params: dict[str, Any], - opt_state: optax.OptState, step: int, - train_rng: jax.Array, ) -> None: - ckpt = output_dir / "resume_state" - if ckpt.exists(): - shutil.rmtree(ckpt) - with ocp.PyTreeCheckpointer() as checkpointer: - checkpointer.save( - ckpt, - { - "student_params": student_params, - "opt_state": opt_state, - "step": np.asarray(step, dtype=np.int32), - "train_rng": np.asarray(train_rng), - }, - ) + (output_dir / "resume_state.json").write_text(json.dumps({"step": step}, indent=2)) -def save_step_checkpoint(output_dir: pathlib.Path, state: DistillState, train_rng: jax.Array) -> None: +def save_step_checkpoint(output_dir: pathlib.Path, state: DistillState) -> None: student_params = extract_student_params(state) step = int(state.step) save_student_checkpoint(output_dir, student_params) - save_resume_checkpoint(output_dir, student_params=student_params, opt_state=state.opt_state, step=step, train_rng=train_rng) + save_resume_checkpoint(output_dir, step=step) def _parse_step_dir(step_dir: pathlib.Path) -> int: @@ -465,44 +619,33 @@ def maybe_resume_state( config: DistillConfig, output_dir: pathlib.Path, state: DistillState, -) -> tuple[DistillState, jax.Array]: +) -> DistillState: if not config.resume: - return state, jax.random.key(config.seed + 1) + return state step_dir = _latest_step_dir(output_dir) if step_dir is None: raise FileNotFoundError(f"No step_* checkpoints found under {output_dir} to resume from.") - resume_dir = step_dir / "resume_state" - if resume_dir.exists(): - with ocp.PyTreeCheckpointer() as checkpointer: - restored = checkpointer.restore(resume_dir) - state.params.replace_by_pure_dict({"student": restored["student_params"]}) - resumed_state = dataclasses.replace( - state, - step=jnp.asarray(restored["step"], dtype=jnp.int32), - opt_state=restored["opt_state"], - ) - train_rng = jnp.asarray(restored["train_rng"]) - LOGGER.info("Resumed exact distill state from %s at step=%d", step_dir, int(resumed_state.step)) - return resumed_state, train_rng - student_params = _model.restore_params(step_dir / "student", restore_type=np.ndarray) - resumed_step = _parse_step_dir(step_dir) + resume_json = step_dir / "resume_state.json" + if resume_json.exists(): + resumed_step = int(json.loads(resume_json.read_text())["step"]) + else: + resumed_step = _parse_step_dir(step_dir) state.params.replace_by_pure_dict({"student": student_params}) resumed_state = dataclasses.replace( state, step=jnp.asarray(resumed_step, dtype=jnp.int32), opt_state=_set_opt_state_step(state.opt_state, resumed_step), ) - train_rng = jax.random.fold_in(jax.random.key(config.seed + 1), resumed_step) LOGGER.warning( - "Resumed from legacy student-only checkpoint %s at step=%d. " + "Resumed from %s at step=%d. " "Student weights were restored, but optimizer moments were reinitialized.", step_dir, resumed_step, ) - return resumed_state, train_rng + return resumed_state def prepare_output_dir(config: DistillConfig) -> pathlib.Path: @@ -530,7 +673,6 @@ def log_tensorboard_scalars( step: int, info: dict[str, float], *, - source: str, key_params: dict[str, float] | None = None, ) -> None: writer.add_scalar("train/loss", info["loss"], step) @@ -539,7 +681,6 @@ def log_tensorboard_scalars( writer.add_scalar("train/grad_norm", info["grad_norm"], step) writer.add_scalar("train/param_norm", info["param_norm"], step) writer.add_scalar("train/valid_tokens", info["valid_tokens"], step) - writer.add_scalar("train/source_is_libero", 1.0 if source == "libero" else 0.0, step) if key_params is not None: for name, value in key_params.items(): writer.add_scalar(f"key_params/{name}", value, step) @@ -552,23 +693,14 @@ def main(config: DistillConfig) -> None: output_dir = prepare_output_dir(config) tensorboard_dir = output_dir / "tensorboard" base_data_config = build_data_config(config) - fake_data_config = dataclasses.replace(base_data_config, data=training_config.FakeDataConfig()) - - real_loader = data_loader.create_data_loader(base_data_config, shuffle=True) - fake_loader = data_loader.create_data_loader(fake_data_config, shuffle=True, skip_norm_stats=True) + train_loader = create_distill_data_loader(base_data_config, config) writer = SummaryWriter(log_dir=str(tensorboard_dir)) try: state = init_state(config) - state, rng = maybe_resume_state(config, output_dir, state) + state = maybe_resume_state(config, output_dir, state) start_step = int(state.step) - batch_iter = MixedBatchIterator( - iter(real_loader), - iter(fake_loader), - real_batch_prob=config.real_batch_prob, - seed=config.seed, - skip_steps=start_step, - ) + batch_iter = iter(train_loader) total_params = count_params(state.params) frozen_params = count_params(state.params, FROZEN_FILTER) trainable_params = count_params(state.params, TRAINABLE_FILTER) @@ -606,23 +738,23 @@ def main(config: DistillConfig) -> None: donate_argnums=(0,), ) + base_rng = jax.random.key(config.seed + 1) for _ in range(start_step, config.num_train_steps): - batch, source = next(batch_iter) - rng, step_rng = jax.random.split(rng) + batch = next(batch_iter) + step_rng = jax.random.fold_in(base_rng, int(state.step)) state, info = ptrain_step(state, step_rng, batch) step_num = int(state.step) host_info = jax.tree.map(lambda x: float(x), info) - log_tensorboard_scalars(writer, step_num, host_info, source=source) + log_tensorboard_scalars(writer, step_num, host_info) if step_num % config.log_interval == 0 or step_num == 1: key_params = key_param_stats(state.params) for name, value in key_params.items(): writer.add_scalar(f"key_params/{name}", value, step_num) LOGGER.info( - "step=%d source=%s loss=%.6f hidden_mse=%.6f cosine=%.6f grad_norm=%.6f param_norm=%.6f key_params=%s", + "step=%d loss=%.6f hidden_mse=%.6f cosine=%.6f grad_norm=%.6f param_norm=%.6f key_params=%s", step_num, - source, host_info["loss"], host_info["hidden_mse"], host_info["cosine_loss"], @@ -634,7 +766,7 @@ def main(config: DistillConfig) -> None: if step_num % config.save_interval == 0 or step_num == config.num_train_steps: step_dir = output_dir / f"step_{step_num:07d}" step_dir.mkdir(parents=True, exist_ok=True) - save_step_checkpoint(step_dir, state, rng) + save_step_checkpoint(step_dir, state) writer.flush() final_key_params = key_param_stats(state.params)