diff --git a/.env.example b/.env.example
index 5e6f4c0..8f8cd0b 100644
--- a/.env.example
+++ b/.env.example
@@ -1,42 +1,191 @@
# =============================================================================
-# dllm-jax environment configuration
-# Copy this file to .env and fill in your values.
+# dllm-jax environment configuration — full reference
+#
+# Copy to `.env` (training) or `.env.deploy` (deploy/run scripts) and fill in
+# the values you need. EVERY variable below has a sensible default in the
+# code; uncomment only what you want to override.
+#
+# Sections:
+# 1. Model & tokenizer
+# 2. Dataset
+# 3. Training schedule
+# 4. Learning-rate schedule
+# 5. Optimizer
+# 6. Sharding (TP / FSDP)
+# 7. DMax (block-diffusion masked LM)
+# 8. SFT (chat-template fine-tuning)
+# 9. Performance knobs (splash, remat, mask-embed init)
+# 10. GCS checkpointing
+# 11. Resume from checkpoint
+# 12. HuggingFace Hub (auth + checkpoint mirror)
+# 13. Weights & Biases
+# 14. Profiling (xprof / jax profile)
+# 15. Inference (scripts/tpu_dmax_infer_checkpoint.py)
# =============================================================================
-# ---- GCS Checkpoints --------------------------------------------------------
-# Option A: Regional bucket (auto-detected from TPU zone)
-# The script builds the bucket name as: gs://${CHECKPOINT_BUCKET_PREFIX}-${region}
-# e.g. gs://dllm-jax-us-east1 or gs://dllm-jax-europe-west4
-CHECKPOINT_BUCKET_PREFIX=dllm-jax
-# Option B: Fixed (non-regional) bucket name — overrides auto-detection
-# CHECKPOINT_BUCKET=my-bucket-name
+# ── 1. Model & tokenizer ────────────────────────────────────────────────────
+MODEL_NAME=Qwen/Qwen3-8B
+# DTYPE=bfloat16 # compute dtype: bfloat16 | float16 | float32
+# MASK_TOKEN_ID= # default: vocab_size-1 (e.g. 151935 for Qwen3)
+ # try 151662 (<|fim_pad|>) for a pretrained warm start
+# EOS_TOKEN_ID= # default: tokenizer's eos_token_id
-# Option C: Full checkpoint directory — overrides both A and B
-# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run
-# Checkpoint frequency and retention
-CHECKPOINT_STEPS=500
-CHECKPOINT_KEEP=2
+# ── 2. Dataset ──────────────────────────────────────────────────────────────
+DATASET=tinystories # tinystories | wikipedia | openthoughts | synthetic | parquet
+# DATASET_PATH= # required when DATASET=parquet (file or dir of *.parquet)
-# ---- Weights & Biases -------------------------------------------------------
-# Enable W&B logging (set to 1 to enable)
-WANDB_LOG=0
-WANDB_PROJECT=dllm-jax
-# WANDB_API_KEY=your-wandb-api-key-here
-# Get your API key: https://wandb.ai/authorize
-# Or run `wandb login` on each TPU worker instead.
-# ---- Model & Training -------------------------------------------------------
-MODEL_NAME=Qwen/Qwen3-8B
-DATASET=tinystories
+# ── 3. Training schedule ────────────────────────────────────────────────────
MAX_LEN=16384
GLOBAL_BATCH=8
-NUM_STEPS=20
+NUM_STEPS=20 # set to 0 to use NUM_EPOCHS instead
NUM_EPOCHS=0
WARMUP_STEPS=5
PEAK_LR=1e-4
+# RUN_NAME= # default: ${MODEL_SLUG}-${DATASET}-${unix_ts}
+# LOGGING_STEPS=1
+# LOAD_PRETRAINED=1 # 0 = random init (debugging only)
+
+
+# ── 4. Learning-rate schedule (post-warmup) ─────────────────────────────────
+# LR_SCHEDULE=constant # constant | cosine
+# LR_DECAY_STEPS=0 # cosine: total decay steps after warmup
+# LR_DECAY_ALPHA=0.1 # cosine: final LR = PEAK_LR * alpha
+
+
+# ── 5. Optimizer ────────────────────────────────────────────────────────────
+OPTIMIZER=adamw # adamw (recommended) | adafactor
+ # NOTE: Adafactor has been observed to diverge
+ # on Qwen3-8B MDLM/DMax around step ~60.
+ # Default to AdamW for real training.
+
+
+# ── 6. Sharding ─────────────────────────────────────────────────────────────
+# TP determines the tensor-parallel axis size; FSDP is derived as
+# fsdp = jax.device_count() // TP
+# v4-32 / v6e-256: TP=8 is fine.
+# v5e-64: prefer TP=2 — TP=8 forces FSDP=8 and bloats optimizer state ~4x.
+TP=8
+
+
+# ── 7. DMax (block-diffusion masked LM) ─────────────────────────────────────
+# DMAX_ENABLE=0 # 1 = train block-diffusion / DMax objective
+# DMAX_BLOCK_SIZE=32 # tokens per noising block
+# DMAX_ON_POLICY_RATIO=0.5 # fraction of steps using on-policy noise
+# DMAX_NOISE_LOW=0.75
+# DMAX_NOISE_HIGH=0.75
+
+
+# ── 8. SFT (chat-template fine-tuning) ──────────────────────────────────────
+# Used when DATASET=openthoughts (or other chat-templated SFT dataset).
+# SFT_TRAIN_ON_ANSWERS_ONLY=0 # 0 = train on full chat-templated text (default,
+ # packs like pretraining)
+ # 1 = mask prompt tokens with -100 (per-example
+ # padded batches)
+
+
+# ── 9. Performance knobs ────────────────────────────────────────────────────
+# SPLASH_BLOCK=512 # splash_attention tile size for training
+# SPLASH_FUSED_BWD=1 # use fused backward kernel
+# DISABLE_SPLASH_ATTN=0 # 1 = fall back to dense attention (debugging)
+# REMAT_POLICY=nothing_saveable # nothing_saveable | dots_saveable | minimal_checkpoint | None
+# # see scripts/tpu_v6e_smoke.py for the full list
+
+# Mask-token embedding warm start (DMax). On v5e-64 the .at[].set() leaks
+# ~9.5GB of HBM when run pre-shard on replicated weights → use 'none' there.
+# MASK_EMBED_INIT=mean # mean | none
+
+# PEAK_TFLOPS_PER_CHIP=918 # v6e default; v4=275, v5e=197, v5p=459
+
+
+# ── 10. GCS checkpointing ───────────────────────────────────────────────────
+# Pick ONE of A / B / C below.
+# A. Regional bucket (auto-detected from TPU zone): bucket name becomes
+# gs://${CHECKPOINT_BUCKET_PREFIX}-${region}
+CHECKPOINT_BUCKET_PREFIX=dllm-jax
+# B. Fixed bucket — overrides A.
+# CHECKPOINT_BUCKET=my-bucket-name
+# C. Full directory — overrides A and B.
+# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run
+
+# CHECKPOINT_SUBDIR=checkpoints # path under bucket
+CHECKPOINT_STEPS=500 # save every N optimizer steps
+# CHECKPOINT_SECONDS=0 # also save every N seconds (0 = disabled)
+CHECKPOINT_KEEP=2 # how many recent checkpoints to retain
+# CHECKPOINT_ON_FINISH=0 # 1 = always write a final checkpoint
+# CHECKPOINT_USE_GCS=1 # 0 = force local even if CHECKPOINT_BUCKET set
+# LOCAL_CHECKPOINT_DIR=/tmp/dllm-jax-checkpoints # used when GCS not configured
+
+# Orbax internals (rarely changed):
+# CHECKPOINT_ORBAX_SYNC_DIRS=1
+# CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1
+
+
+# ── 11. Resume from checkpoint ──────────────────────────────────────────────
+# RESUME_DIR=gs://your-bucket/checkpoints/old-run
+# RESUME_STEP=0 # 0 = latest step in RESUME_DIR
+# RESUME_RESTORE_OPTIMIZER=1 # 0 = restore weights only (e.g. pretrain → SFT)
+# RESUME_RESET_STEP=0 # 1 = zero global_step/epoch counters after load
+ # (use when continuing onto a new dataset
+ # so NUM_STEPS/NUM_EPOCHS budget restarts)
+
+
+# ── 12. HuggingFace Hub ─────────────────────────────────────────────────────
+# HF_TOKEN=hf_... # for private model downloads / checkpoint upload
+# HF_CHECKPOINT_REPO=username/repo
+# HF_CHECKPOINT_REPO_TYPE=model # model | dataset
+# HF_CHECKPOINT_PATH=checkpoints # path within the HF repo
+# HF_CHECKPOINT_PRIVATE=1
+
+
+# ── 13. Weights & Biases ────────────────────────────────────────────────────
+WANDB_LOG=1 # always enable on training runs (incl. smokes)
+WANDB_PROJECT=dllm-jax
+# WANDB_RUN_NAME= # default: same as RUN_NAME
+# WANDB_MODE=online # online | offline | disabled
+# WANDB_API_KEY= # or run `wandb login` once on each TPU worker.
+ # Get a key: https://wandb.ai/authorize
+ # NEVER commit a real key. Keep it in .env.deploy
+ # (which is gitignored) or in ~/.netrc.
+
+
+# ── 14. Profiling ───────────────────────────────────────────────────────────
+# XPROF_ENABLE=0 # capture xprof traces
+# XPROF_DIR=/tmp/xprof/run
+# XPROF_START_STEP=4
+# XPROF_STOP_STEP=7
+
+# JAX programmatic profile (separate from xprof):
+# JAX_PROFILE_DIR=
+# JAX_PROFILE_START_STEP=0
+# JAX_PROFILE_STEPS=0
+
+
+# ── 15. Inference (scripts/tpu_dmax_infer_checkpoint.py) ────────────────────
+# Required for inference: RESUME_DIR + RESUME_STEP (see section 11).
+# PROMPT="Once upon a time"
+# PROMPTS_FILE= # path to file with one prompt per line
+ # (overrides PROMPT; '#' lines = comments)
+# GEN_LENGTH=32
+# BLOCK_LENGTH=32
+# STEPS=8 # max denoising steps per block
+# THRESHOLD=0.95 # per-token confidence to accept early
+# CONFIDENCE_STOP=0.9 # block-level early-exit threshold
+# TOP_K=1 # soft-mix top-K for SPD (3 is good for Qwen3)
+# TEMPERATURE=0.0 # 0.0 = greedy
+# TEMPS= # comma list runs a sweep in one process,
+ # e.g. "0.0,0.3,0.5,0.7,1.0"
+# SEED= # int; defaults to time-based
+
+# INFER_IMPL=fast # fast | kv_fast | legacy
+# INFER_SPLASH=0 # 1 = splash kernel for fast-path block-causal mask
+# INFER_SPLASH_BLOCK=512
+# INFER_KV_DTYPE=bf16 # bf16 (default, ~2x cache HBM savings) | fp32
+# FAST_BUCKET_LENGTH=4096
-# ---- HuggingFace Hub (optional checkpoint upload) ---------------------------
-# HF_TOKEN=your-hf-token-here
-# HF_CHECKPOINT_REPO=your-username/your-repo
+# WARMUP_RUNS=0 # throwaway generates before measured runs
+# MEASURED_RUNS=1 # report median over N timed runs
+# RESTORE_OPTIMIZER=0 # 1 = restore optimizer state too (rarely useful for inference)
+# SUPPRESS_MASK_TOKEN=0 # 1 = force-disable mask-token logits at argmax
diff --git a/.gitignore b/.gitignore
index 8986761..cf3ec24 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+.env.*
+!.env.example
*.lock
__pycache__/
*.egg-info/
diff --git a/README.md b/README.md
index 0cbd3a9..8b16483 100644
--- a/README.md
+++ b/README.md
@@ -2,195 +2,173 @@
+A pure-JAX stack for **diffusion language modeling** on TPU.
-A standalone **JAX backend for Diffusion Language Modeling (dLLM)** with **zero
-PyTorch or CUDA dependency**. Designed for TPU training of MDLM, BD3LM, Dream,
-and EditFlow objectives on pretrained HuggingFace checkpoints.
+No PyTorch, no CUDA. Loads pretrained HuggingFace checkpoints (Qwen3, Llama,
+…) directly into Flax NNX, then trains and decodes them under five
+diffusion objectives — **MDLM**, **BD3LM**, **Dream**, **DMax/OPUT**,
+**EditFlow** — with multi-host TPU sharding, GCS-resumable Orbax
+checkpoints, and a KV-cached fast inference path.
-## Features
+> 📝 **Background reading:**
+> [dLLM into TPU: An End-to-End Diffusion LM Stack in Pure JAX](https://medium.com/@JunbumLee/dllm-into-tpu-an-end-to-end-diffusion-lm-stack-in-pure-jax-5fc33c840ebb)
-- **Torch-free weight loading** — `safetensors` + `huggingface_hub` + numpy, no `torch.load`
-- **TPU-first** — Pallas flash attention via `shard_map`, 1D FSDP and 2D FSDP+TP
-- **Pretrained init** — load Qwen3, Llama, and other HF causal LMs directly into Flax NNX
-- **Five training objectives** — MDLM, BD3LM, Dream, DMax/OPUT, EditFlow
-- **Clean API** — public exports, no stub boilerplate
-- **DMax / OPUT end-to-end** — port of [`czg1225/DMax`](https://github.com/czg1225/DMax)
- training + Soft Parallel Decoding inference, with a KV-cached fast path
- matching reference's `cache='prefix'` setting
-`transformers` is used only for `AutoConfig` / `AutoTokenizer` (works without torch).
+## Concepts
-## Blog Post
+A **diffusion language model (dLLM)** trains a transformer to *denoise*
+masked sequences in parallel, instead of predicting tokens one-at-a-time.
+At inference it can refine many positions per step, generating text in a
+fraction of the autoregressive forward passes.
-[dLLM into TPU: An End-to-End Diffusion LM Stack in Pure JAX](https://medium.com/@JunbumLee/dllm-into-tpu-an-end-to-end-diffusion-lm-stack-in-pure-jax-5fc33c840ebb)
+Five training objectives ship in this repo:
-## Prerequisites
+| Objective | What it does | Trainer / Generator |
+|---|---|---|
+| **MDLM** (Sahoo et al.) | Mask-and-denoise with a learnable α schedule | `MDLMTrainer` |
+| **BD3LM** | Block-diffusion variant of MDLM | `BD3LMTrainer` |
+| **Dream** | SFT-friendly diffusion with prefix conditioning | `DreamTrainer` |
+| **DMax / OPUT** | High-noise, on-policy block-diffusion + Soft Parallel Decoding | `DMaxTrainer`, `dmax_generate_spd_*` |
+| **EditFlow** | Edit-style flow matching | `EditFlowTrainer` |
-### Google Cloud CLI (`gcloud`)
+The polish in this repo concentrates on **DMax** — training, inference,
+and checkpointing are all wired up for multi-host TPU. The other four
+objectives share the same model loading and sharding infrastructure.
-```bash
-# Install gcloud CLI (see https://cloud.google.com/sdk/docs/install)
-# After installation:
-gcloud auth login
-gcloud config set project YOUR_PROJECT_ID
-# Create a GCS bucket for checkpoints (pick your TPU region):
-gcloud storage buckets create gs://YOUR_BUCKET_NAME --location=us-east1
+## Installation
+
+```bash
+pip install -e .
+pip install -e '.[tpu]' # TPU runtime
+pip install -e '.[dev]' # pytest, etc.
```
-### Environment setup
+Requires Python ≥ 3.10. End-to-end validation on TPU v4-32 / v5e-64 / v6e-256.
-Copy the example env file and fill in your values:
+For multi-host TPU runs you also need `gcloud` CLI and a regional GCS
+bucket (default `gs://dllm-jax-${region}`). For TPU VM packaging quirks
+and the verified pin set, see [`docs/install.md`](docs/install.md).
+
+
+## Configuration
+
+Every script reads its config from environment variables. The single source
+of truth is [`.env.example`](.env.example) — 15 sections covering model,
+data, training, optimizer, sharding, DMax, SFT, perf, checkpointing, resume,
+HuggingFace, W&B, profiling, and inference.
```bash
cp .env.example .env
+$EDITOR .env
```
-Key variables in `.env`:
+Five variables to know on day one:
-| Variable | Description | Default |
-|----------|-------------|---------|
-| `CHECKPOINT_BUCKET_PREFIX` | Prefix for regional buckets (`gs://{prefix}-{region}`) | `dllm-jax` |
-| `CHECKPOINT_BUCKET` | Fixed bucket name (overrides regional auto-detection) | — |
-| `WANDB_API_KEY` | Weights & Biases API key ([get one here](https://wandb.ai/authorize)) | — |
-| `WANDB_LOG` | Enable W&B logging (`1` to enable) | `0` |
-| `MODEL_NAME` | HuggingFace model ID | `Qwen/Qwen3-8B` |
+| Variable | Meaning | Default |
+|---|---|---|
+| `MODEL_NAME` | HF model ID for tokenizer + config + initial weights | `Qwen/Qwen3-8B` |
+| `DATASET` | `tinystories`, `wikipedia`, `openthoughts`, `synthetic`, `parquet` | `tinystories` |
+| `RUN_NAME` | name written under `${CHECKPOINT_DIR}/${RUN_NAME}` and to W&B | auto |
+| `WANDB_LOG` | `1` to stream loss/MFU to W&B | `0` |
+| `TP` | tensor-parallel axis size | `8` |
-See [`.env.example`](.env.example) for the full list.
+> 🔐 Never commit `WANDB_API_KEY` or `HF_TOKEN`. `.gitignore` excludes
+> `.env`, `.env.local`, and `.env.deploy`. You can also `wandb login` and
+> `huggingface-cli login` once per worker and skip the env vars.
-> **Tip:** Instead of setting `WANDB_API_KEY` in `.env`, you can run `wandb login`
-> on each TPU worker. Never commit credentials to the repository.
-## Installation
+## Train
+
+The production training entry point is
+[`scripts/tpu_train.py`](scripts/tpu_train.py). It runs on **v4-32, v5e-64,
+v6e-256** and handles distributed init, 2D sharding, GCS DCP checkpoints,
+W&B, MFU logging, and DMax masking.
+
+A first run, end-to-end, from any TPU worker:
```bash
-pip install -e .
+gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=all \
+ --command="cd ~/dllm-jax && \
+ PYTHONPATH=~/dllm-jax \
+ RUN_NAME=hello-dmax \
+ MODEL_NAME=Qwen/Qwen3-0.6B \
+ DATASET=tinystories \
+ DMAX_ENABLE=1 \
+ MAX_LEN=4096 GLOBAL_BATCH=8 \
+ NUM_STEPS=20 WARMUP_STEPS=5 \
+ PEAK_LR=1e-4 OPTIMIZER=adamw \
+ WANDB_LOG=1 \
+ python3 scripts/tpu_train.py"
+```
-# TPU runtime
-pip install -e '.[tpu]'
+In the first ~3 minutes you'll see HF download → weight sharding →
+optimizer init → W&B URL → step-by-step loss lines. Successful exit:
-# Dev (pytest)
-pip install -e '.[dev]'
```
+[Worker 0] step=20 loss=… mfu=…
+[Worker 0] training complete in s
+```
+
+For the full reference — datasets, DMax/OPUT knobs, optimizer choice (and
+why **AdamW**, not Adafactor), LR schedule, mask-token warm start, memory
+tuning, checkpointing, three resume modes, sharding tables per hardware,
+and gotchas — see [`docs/training.md`](docs/training.md). Throughput
+methodology is in [`docs/mfu-optimization.md`](docs/mfu-optimization.md);
+runs land at ~38–48% MFU after splash + remat tuning.
-Requires Python >= 3.10, `jax >= 0.4.20`, `flax >= 0.10.0`,
-`orbax-checkpoint`, `gcsfs <= 2026.2.0`, `optax >= 0.2.0`.
-### TPU VM packaging caveat
+## Infer
-Some Python 3.10 TPU VM images have an older packaging stack. On those hosts,
-`pip install -e '.[tpu]'` can fail with a missing `build_editable` hook, and
-`pip install '.[tpu]'` can misread the project metadata as `UNKNOWN-0.0.0`
-without installing dependencies. In that case, run from the synced checkout
-with `PYTHONPATH` and install TPU dependencies explicitly:
+The multi-host inference entry point is
+[`scripts/tpu_infer.py`](scripts/tpu_infer.py). It restores a sharded DCP
+checkpoint, re-shards onto the inference mesh (which can differ from the
+training mesh), and runs Soft Parallel Decoding (SPD).
```bash
-python3 -m pip install --user -U 'jax[tpu]' \
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
- 'flax>=0.10.0,<0.11' orbax-checkpoint 'gcsfs<=2026.2.0' 'fsspec<=2026.2.0' \
- 'optax>=0.2.0' numpy 'transformers>=4.40.0' safetensors \
- huggingface_hub datasets wandb
+gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=all \
+ --command="cd ~/dllm-jax && \
+ PYTHONPATH=~/dllm-jax \
+ RESUME_DIR=gs://dllm-jax-us-east1/checkpoints/$RUN_NAME \
+ MODEL_NAME=Qwen/Qwen3-8B \
+ PROMPT='Once upon a time' \
+ GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \
+ INFER_IMPL=fast INFER_SPLASH=1 \
+ TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 \
+ TP=8 \
+ python3 scripts/tpu_infer.py"
```
-#### Verified TPU versions
-
-The training and inference paths are validated on TPU v4-32 (`us-central2-b`)
-with this stack — pin to these if `pip install '.[tpu]'` surfaces version
-drift:
-
-| Package | Version |
-|---------|---------|
-| Python | 3.10.12 |
-| jax / jaxlib | 0.6.2 |
-| libtpu | 0.0.17 |
-| flax | 0.10.7 |
-| optax | 0.2.8 |
-| orbax-checkpoint | 0.11.34 |
-| transformers | 5.5.3 |
-| safetensors | 0.7.0 |
-| datasets | 4.8.4 |
-| gcsfs / fsspec | 2025.3.2 |
-| huggingface_hub | 1.10.1 |
-| numpy | 2.2.6 |
-
-The 0.10.7 flax version forces the `_nnx_list = getattr(nnx, "List", list)`
-compat shim noted under **Gotchas**; newer flax (0.12+) on JAX 0.7+ should
-also work but hasn't been re-verified end-to-end on this repo.
-
-### Regional GCS checkpoints for TPU runs
-
-`scripts/tpu_v6e_smoke.py` saves sharded Orbax DCP checkpoints every
-`CHECKPOINT_STEPS` steps (default: 500). By default it detects the TPU
-zone, derives the region, and writes to a matching bucket named
-`gs://${CHECKPOINT_BUCKET_PREFIX}-${region}`. For example, a TPU in
-`us-east1-d` writes under `gs://dllm-jax-us-east1/checkpoints/${RUN_NAME}`.
+Three implementations are available — `fast` (default; fixed-shape,
+splash-friendly), `kv_fast` (KV-cached, best for ≥1024-token gens),
+`legacy` (debug-only Python loop). All three produce byte-identical output
+at matching settings.
-```bash
-PYTHONPATH=/path/to/dllm-jax \
- RUN_NAME=my-run-name \
- CHECKPOINT_STEPS=500 CHECKPOINT_KEEP=2 \
- MODEL_NAME=Qwen/Qwen3-8B DATASET=tinystories \
- MAX_LEN=16384 GLOBAL_BATCH=8 \
- NUM_STEPS=0 NUM_EPOCHS=3 WANDB_LOG=1 \
- python3 scripts/tpu_v6e_smoke.py
-```
+`INFER_SPLASH=1` swaps in a Pallas splash kernel matched to the
+block-causal mask. On Qwen3-8B at `GEN_LENGTH=1024` it's **3.5× faster**
+*and* fixes a latent dense-kernel quality bug at non-128-aligned sequence
+lengths.
-Override detection with `TPU_REGION=us-east1`, `CHECKPOINT_BUCKET=gs://...`,
-or a full `CHECKPOINT_DIR=gs://...`. Hub uploads are only used for local
-checkpoint directories; `gs://` is the durable distributed checkpoint target.
-On TPU VM images with JAX 0.6.x and Orbax 0.11.x, the script enables
-`CHECKPOINT_ORBAX_SYNC_DIRS=1` and `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1` by
-default to use JAX multihost barriers for distributed GCS checkpoint writes.
+For the full reference — implementation comparison, splash setup, prompt
+sweeps, env var glossary, throughput tables, and the single-host CLI —
+see [`docs/inference.md`](docs/inference.md). The deep-dive writeup with
+end-to-end timing is [`docs/inference-optimization.md`](docs/inference-optimization.md).
-## Quick Start
+
+## Python API (single host)
+
+For sanity checks on a laptop or single TPU host:
```python
from transformers import AutoTokenizer
-from dllm_jax import build_model_from_pretrained, MDLMConfig, MDLMTrainer, LinearAlphaScheduler
+from dllm_jax import (
+ build_model_from_pretrained,
+ DMaxConfig, DMaxTrainer, DMaxDataCollator,
+)
model, config = build_model_from_pretrained("Qwen/Qwen3-0.6B", task="llada")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
-trainer = MDLMTrainer(
- model=model,
- tokenizer=tokenizer,
- args=MDLMConfig(output_dir="./out", max_steps=1000, learning_rate=1e-4),
- train_dataset=dataset,
- data_collator=collator,
- scheduler=LinearAlphaScheduler(),
-)
-trainer.train()
-```
-
-## DMax / OPUT
-
-`dllm_jax` includes a JAX/Flax port of the DMax training and inference
-stack from [`czg1225/DMax`](https://github.com/czg1225/DMax):
-
-- **OPUT training** — fixed high-noise masking, two-stream `[noised | clean]`
- layout with block-diffusion attention, per-step on-policy rollout that
- replaces masked tokens with the model's own greedy predictions, gradient
- accumulation support.
-- **Soft Parallel Decoding (SPD) inference** with three implementations
- that produce byte-exact identical outputs at matching settings:
- - `dmax_generate_spd` — Python-loop reference path with host-side early
- breaks. Slow on TPU, useful for debugging.
- - `dmax_generate_spd_fast` — fixed-shape fori_loop compiled path; step-level
- and block-level early breaks via `jax.lax.while_loop`. Default for
- short/medium generations.
- - `dmax_generate_spd_kv_fast` — KV-cached variant matching reference's
- `cache='prefix'` path. Each step only projects K/V for the active block
- and attention runs over the cached prefix. ~1.6× speedup at 1024-token
- generation; overhead dominates at 128-token generation.
-- **Reference knobs** — `top_k` (soft-mix top-k, reference default 1),
- `temperature` + gumbel sampling, `seed`, `threshold`, `confidence_stop`,
- `suppress_mask_token`, and post-EOS fill matching reference's `early_stop`.
-
-### Training
-
-```python
-from dllm_jax import DMaxConfig, DMaxTrainer, DMaxDataCollator
-
trainer = DMaxTrainer(
model=model,
tokenizer=tokenizer,
@@ -198,241 +176,53 @@ trainer = DMaxTrainer(
output_dir="./out-dmax",
max_steps=1000,
learning_rate=2e-6,
- noise_range_low=0.75,
- noise_range_high=0.75,
- on_policy_ratio=0.5,
block_size=32,
- gradient_accumulation_steps=4, # optional
+ on_policy_ratio=0.5,
+ gradient_accumulation_steps=4,
),
- train_dataset=dataset,
+ train_dataset=dataset, # any HF datasets-style iterable
data_collator=DMaxDataCollator(tokenizer=tokenizer, label_pad_token_id=-100),
)
trainer.train()
```
-### Inference
-
-```python
-from dllm_jax import dmax_generate_spd_fast, dmax_generate_spd_kv_fast
-
-# Fast path (default for short/medium gen)
-output = dmax_generate_spd_fast(
- model,
- input_ids,
- tokenizer=tokenizer,
- gen_length=512,
- block_length=32,
- steps=32,
- threshold=0.5, # reference math eval default
- confidence_stop=0.9, # reference block-level break
- top_k=3, # soft mix aggregates top 3 candidates
- temperature=0.0, # 0.0 = greedy; >0.0 = gumbel sampling
-)
-print(tokenizer.decode(output.generated_tokens[0], skip_special_tokens=True))
-
-# KV-cache path (wins on long generations)
-output = dmax_generate_spd_kv_fast(
- model, input_ids, tokenizer=tokenizer,
- gen_length=2048, block_length=32, steps=32,
- threshold=0.5, top_k=3,
-)
-```
-
-`output.nfe` counts actual forward passes. For `fast` it matches
-reference's `num_forwards`; for `kv_fast` it is `fast_nfe + num_active_blocks`
-(the extra is the post-block hard-write pass that replaces soft K/V with
-hard K/V in the cache — reference's cross-block update).
-
-### CLI entry points
-
-```bash
-# Train
-python scripts/dmax_train.py \
- --model Qwen/Qwen3-0.6B \
- --dataset Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories \
- --max-steps 1000
-
-# Generate (pretrained base model)
-python scripts/dmax_generate.py \
- --model Qwen/Qwen3-0.6B \
- --prompt "Solve 37 * 48." \
- --gen-length 256 --block-length 32 --steps 32 \
- --threshold 0.5 --top-k 3 --impl fast
-
-# Generate (from a saved trainer checkpoint)
-python scripts/dmax_generate_checkpoint.py \
- --checkpoint-dir ./out-dmax/checkpoint-1000 \
- --prompt "Solve 37 * 48." \
- --gen-length 256 --impl kv_fast
-```
-
-### TPU multi-host inference
-
-[`scripts/tpu_dmax_infer_checkpoint.py`](scripts/tpu_dmax_infer_checkpoint.py)
-restores a distributed Orbax DCP checkpoint (GCS or local) on every worker,
-reshards across the inference mesh (which may differ from the training
-mesh), and runs `dmax_generate_spd_{fast,kv_fast,legacy}` end-to-end. All
-configuration is via environment variables:
-
-```bash
-gcloud compute tpus tpu-vm ssh $TPU_NAME \
- --zone=$ZONE --worker=all \
- --command="cd ~/dllm-jax && \
- RESUME_DIR=gs://$BUCKET/checkpoints/$RUN_NAME \
- MODEL_NAME=Qwen/Qwen3-8B \
- PROMPT='Once upon a time' \
- GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \
- INFER_IMPL=kv_fast TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 \
- TP=8 \
- python3 scripts/tpu_dmax_infer_checkpoint.py"
-```
-
-Omit `RESUME_STEP` and the script scans `commit_success.txt` files under
-`RESUME_DIR` and picks the latest committed step. Set `RESUME_STEP=` to
-pin a specific step.
+This is the same code path that
+[`scripts/dmax_train.py`](scripts/dmax_train.py) drives. Replace
+`DMaxConfig` / `DMaxTrainer` with `MDLMConfig` / `MDLMTrainer` (or the
+other objectives) for the equivalent single-host MDLM run.
-#### Environment variables
-| Variable | Meaning | Default |
-|-|-|-|
-| `RESUME_DIR` / `RESUME_FROM` | checkpoint parent directory (contains `checkpoint_/`) | — (required) |
-| `RESUME_STEP` | specific step to restore | latest committed |
-| `MODEL_NAME` | tokenizer + HF config source | `Qwen/Qwen3-8B` |
-| `PROMPT` | input prompt | `Once upon a time` |
-| `GEN_LENGTH` | tokens to generate | `32` |
-| `BLOCK_LENGTH` | DMax block size | `32` |
-| `STEPS` | max denoising steps per block | `8` |
-| `INFER_IMPL` | `fast`, `kv_fast`, or `legacy` | `fast` |
-| `FAST_BUCKET_LENGTH` | compile window for `fast`; ignored by `kv_fast` | `4096` |
-| `THRESHOLD` | left-to-right confidence cutoff — reference math eval `0.5`, code `0.65`, other benchmarks `0.9`–`0.95` | `0.95` |
-| `CONFIDENCE_STOP` | block early-exit confidence | `0.9` |
-| `TOP_K` | soft-mix top-k (reference default `1`; `3` is more coherent on undertrained ckpts) | `1` |
-| `TEMPERATURE` | gumbel-max sampling temperature; `0.0` = greedy | `0.0` |
-| `SEED` | RNG seed (only needed when `TEMPERATURE > 0`) | none |
-| `SUPPRESS_MASK_TOKEN` | set `1` to force-disable mask-token logits during argmax | `0` |
-| `MASK_TOKEN_ID` / `EOS_TOKEN_ID` | overrides for the model's mask / EOS id | tokenizer default |
-| `TP` | tensor-parallel axis size; `fsdp` is derived as `jax.device_count() // TP` | `8` |
-| `RESTORE_OPTIMIZER` | restore optimizer state (useful for resumed training, not inference) | `0` |
-
-#### Measured throughput (TPU v4-32, Qwen3-8B, `TOP_K=3`)
-
-| Config | nfe | generate_seconds |
-|-|-|-|
-| `fast`, GEN=128 | 114 | 32.1 |
-| `kv_fast`, GEN=128 | 119 | 65.9 |
-| `fast`, GEN=1024 | 1010 | 124.7 |
-| `kv_fast`, GEN=1024 | 1043 | 76.6 |
-
-At short generations `kv_fast` pays more XLA compile cost than it saves; at
-`GEN=1024` the block-local forward wins (~1.6× faster than `fast`). `nfe`
-for `kv_fast` is `fast_nfe + num_active_blocks`; the extra forwards are the
-post-block hard-write that overwrites soft K/V with hard K/V in the cache
-(the JAX analogue of reference's cross-block update).
-
-#### Troubleshooting
-
-- **Stale TPU checkout** — `ImportError: cannot import name 'dmax_generate_spd_kv_fast'` or similar means the TPU copy is out of date; re-`scp` `dllm_jax/` and the target script.
-- **Orbax signal-contract hang on restore** — handled by default; the script flips `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1` when the JAX distributed client isn't initialized. Set the env var to `0` to disable the shim explicitly.
-- **`gs://` / TensorFlow warnings** — harmless; restore uses Orbax's GCS client, not `tf.io.gfile`.
-- **First-run generation is "slow"** — includes model restore + XLA compile. Subsequent runs with identical shapes hit the cached graph.
-- **Output collapses into punctuation with `TOP_K=1`** — SPD feeds the previous step's distribution back as a soft mix; at `TOP_K=1` a low-confidence single token gives a bad signal. Try `TOP_K=3`.
-
-The KV-cache design is documented in [`docs/kv_cache_design.md`](docs/kv_cache_design.md).
-A narrative write-up of the whole end-to-end port — HF checkpoint →
-OPUT training → KV-cached SPD on TPU — is in
-[`docs/porting-dmax-to-tpu.md`](docs/porting-dmax-to-tpu.md).
-
-## Checkpoints
-
-Two save/load paths are supported depending on scale:
-
-- **Single-host trainer checkpoints** (`DMaxTrainer.save_model`, etc.) write
- pickle files (`save_only_model=True`) or Flax training `Checkpoints` to a
- local directory. These pair with `restore_model_checkpoint` and the local
- `scripts/dmax_generate_checkpoint.py` script.
-- **Distributed TPU DCP checkpoints** (Orbax `PyTreeCheckpointHandler` /
- `StandardCheckpointer`) are the durable save path for multi-host v4/v5/v6
- training via `scripts/tpu_v6e_smoke.py`. Checkpoints are sharded across
- workers, written directly to GCS (`gs://${CHECKPOINT_BUCKET_PREFIX}-${region}`),
- and committed with `commit_success.txt` markers. Resume and inference read
- the latest committed step via `latest_committed_gcs_step`.
-
- ```bash
- # Train with DCP checkpoints every 500 steps, keep last 2
- PYTHONPATH=$(pwd) \
- RUN_NAME=my-run CHECKPOINT_STEPS=500 CHECKPOINT_KEEP=2 \
- python3 scripts/tpu_v6e_smoke.py
- ```
-
- Inference from the resulting GCS checkpoints is covered above under
- [TPU multi-host inference](#tpu-multi-host-inference).
-
-## Package Layout
+## Package layout
```
dllm_jax/
-├── models.py # GenericDecoderLM (+ call_cached for KV cache), GenericEncoderLM, EditFlowModel
-├── trainers.py # MDLMTrainer, BD3LMTrainer, DreamTrainer, DMaxTrainer, EditFlowTrainer
-│ # (gradient accumulation, cached LR schedule)
-├── configs.py # ModelArguments, DataArguments, TrainingArguments, MDLMConfig, DMaxConfig, ...
-├── dmax.py # DMax SPD: dmax_generate_spd, dmax_generate_spd_fast, dmax_generate_spd_kv_fast
+├── models.py # GenericDecoderLM (+ call_cached for KV cache),
+│ # GenericEncoderLM, EditFlowModel
+├── trainers.py # MDLM / BD3LM / Dream / DMax / EditFlow trainers
+├── configs.py # ModelArguments, DataArguments, TrainingArguments,
+│ # MDLMConfig, DMaxConfig, …
+├── dmax.py # DMax SPD: dmax_generate_spd[_fast|_kv_fast]
├── schedulers.py # LinearAlpha, CosineAlpha, LinearKappa, CubicKappa, CosineKappa
-├── data.py # DMaxDataCollator, DreamSFTCollator, EditFlowCollator, NoAttentionMaskWrapper, ...
-├── checkpoints.py # restore_model_checkpoint (single-host pickle / Flax Checkpoints)
-├── weights.py # Torch-free safetensors -> NNX weight loader
-└── utils.py # resolve_with_base_env, parse_spec, get_default_logger, ...
+├── data.py # DMaxDataCollator, DreamSFTCollator, EditFlowCollator, …
+├── checkpoints.py # restore_model_checkpoint (single-host)
+├── weights.py # Torch-free safetensors → NNX weight loader
+└── utils.py
scripts/
-├── dmax_train.py # single-host DMax OPUT CLI
-├── dmax_tinystories_train.py # small-scale DMax sanity training
-├── dmax_generate.py # single-host SPD generation CLI (base model)
-├── dmax_generate_checkpoint.py # single-host SPD generation from a saved checkpoint
-├── tpu_dmax_infer_checkpoint.py # multi-host Orbax DCP restore + SPD generation
-├── tpu_v4_32_train_3epoch.py # multi-host DMax training on TPU v4-32
-├── run_tpu_v4_32_3epoch.sh # wrapper for the v4-32 training launch
-└── tpu_v6e_smoke.py # MDLM smoke trainer with DCP checkpointing
+├── tpu_train.py # multi-host training (v4 / v5e / v6e)
+├── tpu_infer.py # multi-host inference from GCS DCP ckpt
+├── dmax_train.py # single-host DMax CLI
+├── dmax_generate.py # single-host SPD generation (base model)
+├── dmax_generate_checkpoint.py # single-host SPD from saved ckpt
+├── inspect_sft_data.py # SFT pipeline data dump
+└── examples/ # tinystories sanity, legacy v4-32 wrappers
docs/
-├── tpu_v4_32_ondemand_inference.md # runbook for TPU inference from GCS checkpoints
-├── kv_cache_design.md # design notes for the KV-cached SPD path
-└── porting-dmax-to-tpu.md # narrative write-up of the whole end-to-end port
-```
-
-## Sharding
-
-| Model Size | Strategy | Mesh Shape | Notes |
-|------------|-------------|----------------------|-------------------------------------------|
-| <= 3B | 1D FSDP | `(ndev,)` fsdp | Direct TPU init, `P("fsdp", None)` |
-| 3B - 8B+ | 2D FSDP+TP | `(ndev/tp, tp)` fsdp,tp | CPU init → shard, `P("fsdp", "tp")` |
-
-For example, TPU v4-32 with `TP=8` uses `(2, 8)`; TPU v5e-64 uses `(8, 8)`.
-Large models additionally need CPU-first init
-(`jax.default_device(jax.devices("cpu")[0])`), gradient checkpointing via
-`jax.remat` on each transformer layer, and Pallas flash attention.
-
-## Gotchas
-
-**AdamW > Adafactor for pretrained-init MDLM.** Adafactor's
-`scale_by_param_block_rms` misbehaves on bidirectional objectives over causal
-LM weights — loss descends then climbs back after ~60 steps. Use
-`optax.adamw(b1=0.9, b2=0.95, wd=0.01)` (the library default).
-
-**Optimizer rebind after split/merge.** If you manually `nnx.split` and
-`nnx.merge` the optimizer in a hand-written TPU script, rebind `model` to
-avoid silent zero-progress:
-
-```python
-opt_gdef, opt_state = nnx.split(optimizer)
-opt_state = jax.tree.map(fsdp_sharding, opt_state)
-optimizer = nnx.merge(opt_gdef, opt_state)
-model = optimizer.model # CRITICAL
+├── install.md, training.md, inference.md # user-facing references
+├── inference-optimization.md, mfu-optimization.md, kv-cache-design.md
+ # deep-dive writeups
```
-The built-in `trainers.py` never splits the optimizer, so it is safe.
-
-**flax 0.10.7 vs 0.12+.** Python 3.10 TPU VMs pin flax to 0.10.7. `models.py`
-uses a `_nnx_list = getattr(nnx, "List", list)` compat shim so the same code
-runs on both. `optimizer.update(grads)` is called positionally (0.10.7 API).
## License
diff --git a/dllm_jax/dmax.py b/dllm_jax/dmax.py
index 5347490..77bdc4b 100644
--- a/dllm_jax/dmax.py
+++ b/dllm_jax/dmax.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+import os
from dataclasses import dataclass
from typing import Any
@@ -12,6 +13,174 @@
from dllm_jax.trainers import resolve_mask_token_id
+def _resolve_kv_cache_dtype(model: Any):
+ """Pick the KV-cache dtype for kv_fast inference.
+
+ Default: the model's compute dtype (bf16) — cuts cache HBM in half and
+ lets ``jax.nn.dot_product_attention`` run at bf16 MXU throughput.
+ ``INFER_KV_DTYPE=fp32`` restores the previous float32 cache (matches the
+ qk_norm-upcast precision path; slightly more accurate, ~2× slower attn).
+ """
+ choice = os.environ.get("INFER_KV_DTYPE", "bf16").strip().lower()
+ if choice in {"fp32", "float32", "f32"}:
+ return jnp.float32
+ compute_name = getattr(model, "dtype_name", "bfloat16")
+ if compute_name == "float32":
+ return jnp.float32
+ if compute_name == "float16":
+ return jnp.float16
+ return jnp.bfloat16
+
+
+def _infer_mesh_from_model(model: Any):
+ """Return the NamedSharding mesh used by ``model`` params, or None.
+
+ Walks ``nnx.state(model)`` leaves, unwrapping nnx.Variable objects to their
+ underlying jax.Array so ``.sharding.mesh`` is accessible.
+ """
+ try:
+ state = nnx.state(model)
+ except Exception:
+ state = model
+
+ for leaf in jax.tree_util.tree_leaves(state):
+ value = getattr(leaf, "value", leaf)
+ if isinstance(value, jax.Array):
+ sharding = getattr(value, "sharding", None)
+ mesh = getattr(sharding, "mesh", None) if sharding is not None else None
+ if mesh is not None and getattr(mesh, "shape", None):
+ return mesh
+ return None
+
+
+def install_block_causal_splash_for_inference(
+ mesh: Any,
+ num_heads: int,
+ total_length: int,
+ block_length: int,
+ *,
+ splash_block: int | None = None,
+):
+ """Register a splash_attention kernel as ``dllm_jax.models._MASKED_FLASH_ATTN_FN``.
+
+ The kernel targets the block-causal mask shape produced by
+ ``create_block_causal_attention_mask(total_length, block_length)`` — i.e.
+ block N attends fully to all blocks [0, N]. Intended for the fixed-shape
+ ``dmax_generate_spd_fast`` path; no backward pass is compiled.
+
+ Mesh is expected to have axis names ``fsdp`` and ``tp`` (matches the
+ training and inference scripts in this repo); heads shard along ``tp``.
+ """
+ import numpy as np
+ from jax.experimental.pallas.ops.tpu.splash_attention import (
+ splash_attention_mask as _sm,
+ splash_attention_kernel as _sk,
+ )
+ from jax.experimental.shard_map import shard_map
+ from jax.sharding import PartitionSpec as P
+
+ from dllm_jax import models as _models
+
+ tp = int(mesh.shape.get("tp", 1)) if hasattr(mesh, "shape") else 1
+ heads_per_tp = max(1, num_heads // tp)
+
+ idx = np.arange(total_length)
+ block_q = idx[:, None] // block_length
+ block_kv = idx[None, :] // block_length
+ mask_np = (block_q >= block_kv).astype(np.bool_)
+
+ splash_mask = _sm.MultiHeadMask(masks=[_sm.NumpyMask(mask_np)] * heads_per_tp)
+ requested_bs = int(splash_block) if splash_block is not None else int(
+ os.environ.get("INFER_SPLASH_BLOCK", "512")
+ )
+ # Splash on TPU requires block_q/block_kv both divide total_length AND be
+ # multiples of 128 (MXU lane width). Find the largest such value ≤ request.
+ NUM_LANES = 128
+ bs = requested_bs - (requested_bs % NUM_LANES) if requested_bs >= NUM_LANES else NUM_LANES
+ while bs > NUM_LANES and (total_length % bs != 0 or bs % NUM_LANES != 0):
+ bs -= NUM_LANES
+ if bs < NUM_LANES or total_length % bs != 0:
+ raise ValueError(
+ f"splash install: no block size ≤ {requested_bs} that is a multiple "
+ f"of {NUM_LANES} and divides total_length={total_length}"
+ )
+ if bs != requested_bs:
+ print(
+ f"[dmax] INFER_SPLASH_BLOCK={requested_bs} → using block={bs} "
+ f"(largest multiple of {NUM_LANES} dividing total_length={total_length})",
+ flush=True,
+ )
+ block_sizes = _sk.BlockSizes(
+ block_q=bs, block_kv=bs, block_kv_compute=bs,
+ )
+ splash_fn = _sk.make_splash_mha_single_device(
+ mask=splash_mask, block_sizes=block_sizes,
+ )
+
+ def _per_shard(q, k, v, sm_scale):
+ q_scaled = (q * sm_scale).astype(q.dtype)
+ return jax.vmap(splash_fn)(q_scaled, k, v)
+
+ # Inference typically runs batch=1; sharding the batch axis across ``fsdp``
+ # would require batch >= fsdp. Replicate on batch, shard heads on ``tp``.
+ _batch_spec = None
+ in_specs = (P(_batch_spec, "tp", None, None),) * 3
+ out_specs = P(_batch_spec, "tp", None, None)
+
+ def _sharded_masked_flash(q, k, v, sm_scale):
+ return shard_map(
+ lambda q_, k_, v_: _per_shard(q_, k_, v_, sm_scale),
+ mesh=mesh,
+ in_specs=in_specs,
+ out_specs=out_specs,
+ check_rep=False,
+ )(q, k, v)
+
+ _models._MASKED_FLASH_ATTN_FN = _sharded_masked_flash
+ print(
+ f"[dmax] splash installed for block-causal inference: "
+ f"mask {mask_np.shape} heads_per_tp={heads_per_tp} tp={tp} block={bs}",
+ flush=True,
+ )
+ return _sharded_masked_flash
+
+
+def _env_true(name: str) -> bool:
+ return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"}
+
+
+def _try_autoinstall_block_causal_splash(model: Any, total_length: int, block_length: int):
+ """Gated by ``INFER_SPLASH=1``. Installs splash if a mesh can be inferred.
+
+ Returns the installed kernel on success, None on skip/failure. A failure
+ leaves the existing ``_MASKED_FLASH_ATTN_FN`` untouched so the caller
+ silently falls back to dense attention.
+ """
+ if not _env_true("INFER_SPLASH"):
+ return None
+ try:
+ mesh = _infer_mesh_from_model(model)
+ if mesh is None:
+ print(
+ "[dmax] INFER_SPLASH=1 but no mesh could be inferred from model; "
+ "falling back to dense attention",
+ flush=True,
+ )
+ return None
+ num_heads = int(model.spec.num_attention_heads)
+ return install_block_causal_splash_for_inference(
+ mesh, num_heads, total_length, block_length,
+ )
+ except Exception as exc: # pragma: no cover — best-effort runtime install
+ import traceback
+ print(
+ f"[dmax] INFER_SPLASH=1 requested but splash install failed, "
+ f"falling back to dense attention: {exc}\n{traceback.format_exc()}",
+ flush=True,
+ )
+ return None
+
+
@dataclass
class DMaxGenerationConfig:
gen_length: int = 2048
@@ -282,6 +451,13 @@ def dmax_generate_spd_fast(
gen_length = int(gen_length)
batch_size, prompt_length = input_ids.shape
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
+ # Splash attention wants total_length % 128 == 0 (TPU MXU lane alignment).
+ # With block_length=32 that means num_blocks must be a multiple of 4. Pad
+ # up when INFER_SPLASH=1 — the extra blocks add compute but are correct.
+ if _env_true("INFER_SPLASH"):
+ align_blocks = max(1, 128 // max(1, block_length))
+ if align_blocks > 1 and num_blocks % align_blocks != 0:
+ num_blocks += align_blocks - (num_blocks % align_blocks)
total_length = num_blocks * block_length
new_gen_length = total_length - prompt_length
prefill_blocks = prompt_length // block_length
@@ -538,6 +714,11 @@ def step_cond(step_carry):
attention_mask = create_block_causal_attention_mask(total_length, block_length)
position_ids = jnp.broadcast_to(jnp.arange(total_length)[None, :], (batch_size, total_length))
+ # Opt-in splash kernel for the dense block-causal attention that this path
+ # normally runs through ``jax.nn.dot_product_attention``. See
+ # ``install_block_causal_splash_for_inference`` and the INFER_SPLASH env.
+ _try_autoinstall_block_causal_splash(model, total_length, block_length)
+
@nnx.jit
def generate_fixed_shape(current_model, prompt_ids):
x = jnp.full(
@@ -1098,13 +1279,11 @@ def generate_kv(current_model, prompt_ids):
mask_embedding.astype(jnp.float32), axis=-1, keepdims=True
)
- # K/V cache stored in float32 to match the QK-norm precision path that
- # the non-cached ``fast`` attention uses: for models with ``qk_norm``
- # (e.g. Qwen3) ``nnx.RMSNorm`` upcasts to float32, and
- # ``jax.nn.dot_product_attention`` requires Q, K, V to all share dtype.
- # V is upcast from the model's compute dtype (bf16) to float32 when
- # written, which is exact.
- cache_dtype = jnp.float32
+ # KV cache dtype: default = model compute dtype (bf16) for ~2× HBM
+ # savings and bf16 attention throughput. Set INFER_KV_DTYPE=fp32 to
+ # fall back to float32 — matches the pre-bf16 precision path (qk_norm
+ # upcast to f32 preserved through attention).
+ cache_dtype = _resolve_kv_cache_dtype(current_model)
past_kv = [
(
diff --git a/dllm_jax/models.py b/dllm_jax/models.py
index 7d69cb2..84fbb83 100644
--- a/dllm_jax/models.py
+++ b/dllm_jax/models.py
@@ -11,6 +11,7 @@
import jax
import jax.numpy as jnp
from flax import nnx
+from jax.ad_checkpoint import checkpoint_name
import transformers
# Compatibility: flax >= 0.11 has nnx.List; older versions use plain list()
@@ -20,6 +21,11 @@
# Signature: _FLASH_ATTN_FN(q, k, v, sm_scale) with layout (batch, heads, seq, dim).
_FLASH_ATTN_FN = None
+# Masked flash hook: set by training script for block-diffusion / other sparse masks.
+# Signature: _MASKED_FLASH_ATTN_FN(q, k, v, sm_scale) with same layout — mask baked in.
+# Used when attention_mask is passed (so dense attention would otherwise run).
+_MASKED_FLASH_ATTN_FN = None
+
def get_dtype(name: str):
if name == "float32":
@@ -290,7 +296,9 @@ def __init__(self, spec: ModelSpec, intermediate_size: int, *, gated: bool, rngs
def __call__(self, hidden_states):
if self.gated:
- return self.down_proj(self.act(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
+ gate_up = self.act(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
+ gate_up = checkpoint_name(gate_up, "gate_up")
+ return self.down_proj(gate_up)
return self.fc2(self.act(self.fc1(hidden_states)))
@@ -355,6 +363,9 @@ def _project_qkv(self, hidden_states, position_ids):
k = self.k_norm(k)
cos, sin = build_rope(position_ids, self.rotary_dim, self.rope_theta, q.dtype)
q, k = apply_rope(q, k, cos, sin, self.rotary_dim)
+ q = checkpoint_name(q, "q")
+ k = checkpoint_name(k, "k")
+ v = checkpoint_name(v, "v")
return q, k, v
def _attention(self, q, k, v, attention_mask):
@@ -369,6 +380,11 @@ def _attention(self, q, k, v, attention_mask):
q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3),
1.0 / math.sqrt(self.head_dim),
).transpose(0, 2, 1, 3).reshape(batch_size, query_len, self.num_heads * self.head_dim)
+ elif _MASKED_FLASH_ATTN_FN is not None and attention_mask is not None:
+ output = _MASKED_FLASH_ATTN_FN(
+ q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3),
+ 1.0 / math.sqrt(self.head_dim),
+ ).transpose(0, 2, 1, 3).reshape(batch_size, query_len, self.num_heads * self.head_dim)
else:
mask = expand_attention_mask(attention_mask, batch_size, query_len, key_len)
output = jax.nn.dot_product_attention(q, k, v, mask=mask).reshape(
diff --git a/dllm_jax/weights.py b/dllm_jax/weights.py
index 69e7c96..0cde3a8 100644
--- a/dllm_jax/weights.py
+++ b/dllm_jax/weights.py
@@ -36,16 +36,17 @@ def _hf_download(model_name: str, verbose: bool = True) -> str:
print(f"[Worker {proc}] Running: {' '.join(cmd)}", flush=True)
result = subprocess.run(cmd, capture_output=True, text=True, env=env)
- if result.returncode != 0:
- if verbose and proc == 0:
- print(f"[Worker {proc}] hf CLI failed (rc={result.returncode}), falling back to snapshot_download", flush=True)
- print(f"[Worker {proc}] stderr: {result.stderr[:200]}", flush=True)
- from huggingface_hub import snapshot_download
- return snapshot_download(
- model_name,
- allow_patterns=["*.safetensors", "*.safetensors.index.json", "config.json"],
- )
- return result.stdout.strip().split("\n")[-1]
+ # Modern ``hf download`` (huggingface_hub>=0.25) decorates the final path
+ # line with a ``path:`` label and can emit progress lines; prefer the
+ # library call to get a clean local_dir path back.
+ from huggingface_hub import snapshot_download
+ if result.returncode != 0 and verbose and proc == 0:
+ print(f"[Worker {proc}] hf CLI failed (rc={result.returncode})", flush=True)
+ print(f"[Worker {proc}] stderr: {result.stderr[:200]}", flush=True)
+ return snapshot_download(
+ model_name,
+ allow_patterns=["*.safetensors", "*.safetensors.index.json", "config.json"],
+ )
def load_pretrained_weights(
diff --git a/docs/inference-optimization.md b/docs/inference-optimization.md
new file mode 100644
index 0000000..e9a8e82
--- /dev/null
+++ b/docs/inference-optimization.md
@@ -0,0 +1,207 @@
+# Inference Optimization: DMax + Qwen3-8B on TPU v5e-64
+
+Summary of an inference tuning pass on `tpu-v5e-64-us` (us-central1-a)
+for `scripts/tpu_infer.py` generating from a
+tinystories-trained DMax checkpoint
+(`dmax-8b-tinystories-forward8k-3epoch-.../checkpoint_41400`).
+
+## Headline
+
+| Config | impl | attn | KV cache | median s | tok/s | nfe | gen tokens | quality |
+|--|--|--|--|--|--|--|--|--|
+| A (pre-patch) | fast | dense | — | 142.33 | 7.2 | 1028 | 1024 | **GARBAGE** |
+| **B (ship)** | **fast** | **splash** | — | **40.86** | **25.1** | 128 | 175 | ✅ coherent, EOS |
+| C | kv_fast | dense (fp32) | fp32 | 52.90 | 19.4 | 95 | 143 | ✅ coherent, EOS |
+| D | kv_fast | dense (fp32) | bf16 | 52.12 | 19.6 | 95 | 143 | ✅ bit-identical to C |
+
+All runs at `Qwen3-8B`, `PROMPT='Once upon a time'`, `GEN_LENGTH=1024`,
+`BLOCK_LENGTH=32`, `STEPS=32`, `TOP_K=3`, `THRESHOLD=0.5`,
+`CONFIDENCE_STOP=0.9`, `TP=8`, greedy (`TEMPERATURE=0`), 1 warmup + 2
+measured runs, median.
+
+**fast+splash (B) is 3.49× faster than the dense-fast baseline AND
+fixes a latent quality bug.** It's the recommended default for this
+model family at these shapes.
+
+**bf16 KV cache (C→D) is a wash: 1.5% faster, byte-identical output.**
+At `B=1 × q=32 × k=1152` the attention isn't bandwidth-bound so the
+cache dtype doesn't matter. The flag is still exposed (`INFER_KV_DTYPE`)
+so you can fall back to fp32 if you need the extra precision.
+
+## The latent bug in fast+dense (Config A)
+
+The `fast` path under `jax.nn.dot_product_attention` with a
+block-causal mask at `total_length=1056` (prompt=4 + gen=1024 rounded
+up to 33×32 blocks) produces degenerate output:
+
+> …Lily was happy to help and she went to the laundry room.
+> When she got there, she saw a big pile of clothes. She started to
+> fold them and put them away. She was very careful and did a
+> her....,,,,,, and,,,,,,,,,,,,,,,,,,,,,,...... *(900+ more tokens of
+> commas and periods, runs out max_nfe)*
+
+First ~80 tokens look normal; then the generation collapses into
+punctuation filler and never reaches EOS, running to `nfe=1024`. The
+same model, same prompt, same hyperparams under splash attention
+produces a clean TinyStory. `kv_fast` (C, D) also produces clean
+output. So the bug is specific to the
+`fast + dot_product_attention + total_length=1056` combination.
+
+Likely cause: numerical reduction order in the dense kernel at that
+non-128-aligned sequence length pushes the intermediate denoising
+trajectory into a repetition attractor. Splash's block-sparse kernel
+uses different reduction order AND my integration pads `num_blocks`
+so `total_length=1152` (128-aligned), so splash dodges the trap both
+numerically and by shape.
+
+## Code changes (all on branch `splash_attn`)
+
+### `dllm_jax/dmax.py`
+
+- `install_block_causal_splash_for_inference(mesh, num_heads, total_length, block_length, splash_block=None)`
+ — builds `NumpyMask(block_q >= block_kv)`, wraps as `MultiHeadMask`,
+ constructs a forward-only `make_splash_mha_single_device`, and
+ registers it as `dllm_jax.models._MASKED_FLASH_ATTN_FN` under a
+ `shard_map(mesh, in_specs=(P(None, 'tp', None, None),)*3)`. Auto-
+ picks the largest splash tile that is a multiple of 128 and divides
+ `total_length`.
+- `_infer_mesh_from_model(model)` — walks `nnx.state(model)` leaves and
+ unwraps `nnx.Variable` objects to read `.sharding.mesh`. The naive
+ `jax.tree_util.tree_leaves(model)` returns Variable wrappers, not
+ raw arrays, so the mesh lookup needs this extra hop.
+- `_try_autoinstall_block_causal_splash(model, total_length, block_length)`
+ — gated by `INFER_SPLASH=1`. Called from `dmax_generate_spd_fast`
+ right before the `@nnx.jit` trace.
+- `dmax_generate_spd_fast`: when `INFER_SPLASH=1`, pads `num_blocks`
+ up to `128 // block_length` multiple so `total_length % 128 == 0`
+ (TPU MXU lane alignment). A few extra blocks of compute, correct
+ output.
+- `_resolve_kv_cache_dtype(model)` for `dmax_generate_spd_kv_fast`.
+ Default changed from `jnp.float32` to the model's compute dtype
+ (bf16). `INFER_KV_DTYPE=fp32` reverts.
+
+### `dllm_jax/models.py`
+
+- `checkpoint_name` markers on `q`/`k`/`v` (post-RoPE) in
+ `SelfAttention._project_qkv` and on the `gate*up` product in
+ `DenseMLP.__call__`. Inert unless a named remat policy selects them
+ (see `REMAT_POLICY` in `tpu_train.py`) — added for the training
+ path but harmless for inference.
+
+### `scripts/tpu_infer.py`
+
+- `WARMUP_RUNS` + `MEASURED_RUNS` env vars, median reporting.
+- `TEMPS="0.0,0.3,0.5,..."` env var runs a temperature sweep in a
+ single process (shares the restore cost).
+
+## Env knobs (on `scripts/tpu_infer.py`)
+
+| Var | Default | Effect |
+|--|--|--|
+| `INFER_IMPL` | `fast` | `fast` / `kv_fast` / `legacy`. Use `fast` with splash for best throughput. |
+| `INFER_SPLASH` | `0` | `1` enables splash kernel on the `fast` block-causal mask. Requires sharded model. |
+| `INFER_SPLASH_BLOCK` | `512` | Splash tile size. Auto-adjusts to the largest multiple of 128 that divides `total_length`. |
+| `INFER_KV_DTYPE` | `bf16` | KV-cache dtype for `kv_fast`. `fp32` restores the pre-patch precision path (no measurable speed gain). |
+| `WARMUP_RUNS` | `0` | Throwaway generates before the measured ones. Each pays compile; helps stabilize measurements across seeds. |
+| `MEASURED_RUNS` | `1` | Number of timed generates; reports median. |
+| `TEMPS` | — | Comma-separated temperatures for a single-process sweep (e.g. `0.0,0.5,1.0`). |
+
+## Temperature sweep (fast+splash, seed=42)
+
+| temp | median s | tok/s | nfe | gen | quality |
+|--|--|--|--|--|--|
+| **0.0 (greedy)** | 41.28 | **24.8** | 128 | 175 | ✅ full coherent story, EOS |
+| 0.3 | 163.19 | 6.3 | 1130 | 1024 | ~80 good tokens, then repeating `, the the to..` |
+| 0.5 | 165.22 | 6.2 | 1146 | 1024 | ~80 good tokens, then repeating `the with. They...` |
+| 0.7 | 164.83 | 6.2 | 1144 | 1024 | ~110 good tokens, then repeating `lesson heron with and..` |
+| 1.0 | 76.70 | 13.4 | 416 | 382 | ~80 good tokens, then gibberish; random EOS mid-way |
+| 1.5 | 166.04 | 6.2 | 1152 | 1024 | ~20 good tokens, then multilingual soup (`官方微信荤`, `훙`, `эту`, `ซึ่งเป็น`) |
+
+### Findings
+
+1. **Only greedy (temp=0) is usable** for this checkpoint on this
+ model. Any stochastic temperature collapses the block-diffusion
+ denoising into repetition after 20–110 tokens.
+2. The `nfe` column reflects whether the block-level convergence
+ checks (`all_confident ≥ 0.9`, `same_as_previous`) fire. Greedy
+ converges fast (≈4 steps/block avg); stochastic sampling defeats
+ both checks and runs out max `nfe`.
+3. At `temp=1.5` the sampler pulls rare tokens from Qwen3's
+ multilingual tail (Chinese / Korean / Thai / Russian / Spanish).
+4. **Top-p / top-k sampling is not implemented.** Without them, any
+ non-greedy temp is unusable on this checkpoint. The existing
+ `top_k` parameter is a different thing — a soft-embedding mix for
+ the intermediate denoising states, not a sampling-time filter.
+
+### Recommended fix for usable sampling
+
+Add top-p (nucleus) sampling in `dllm_jax/dmax.py:_sample_x0`,
+thread `TOP_P` env var through the inference script. Rough
+implementation (~10 lines):
+
+```python
+if top_p is not None and 0.0 < top_p < 1.0:
+ sorted_logits = jnp.sort(logits_f32, axis=-1)[..., ::-1]
+ sorted_probs = jax.nn.softmax(sorted_logits, axis=-1)
+ cum = jnp.cumsum(sorted_probs, axis=-1)
+ cutoff = jnp.sum(cum < top_p, axis=-1, keepdims=True)
+ threshold = jnp.take_along_axis(sorted_logits, cutoff, axis=-1)
+ logits_f32 = jnp.where(logits_f32 < threshold, -jnp.inf, logits_f32)
+```
+
+Use the pre-truncation `active_probs` for the `THRESHOLD` /
+`CONFIDENCE_STOP` early-exit machinery so nucleus truncation doesn't
+artificially inflate confidence.
+
+## What's next
+
+In decreasing ROI:
+
+1. **Top-p sampling** — unblocks any non-greedy generation. 20–30 min
+ (patch `_sample_x0`, thread `TOP_P` through the generate variants
+ + scripts + README). Required before shipping a user-facing
+ interface with temperature.
+2. **Splash attention in the `kv_fast` prefill** — easy half of the
+ kv_fast splash integration. Prefill shape is static; only the
+ per-block decode has a dynamic mask. Prefill is a sizeable
+ fraction of long-prompt latency. ~30 min.
+3. **Splash attention in the `kv_fast` decode** — requires either
+ pre-compiling N splash kernels (one per block position) and
+ dispatching via `jax.lax.switch`, or expressing the
+ `k_pos < block_end` gate as a runtime segment mask that splash
+ accepts. Benefit uncertain — at `q=32 × k=1152`, splash tiles are
+ underfilled and per-tile launch overhead may dominate. Current
+ benchmark already shows `fast+splash` beats `kv_fast+dense`, so
+ this is low priority unless we scale to very long contexts.
+4. **Share jit compile cache across temperature values** — pass
+ `temperature` as a traced `jax.Array` rather than a closure
+ constant so one compile serves all temps. Saves ~30 s per extra
+ temp in a sweep. Cosmetic for benchmarks, not a user-facing
+ improvement.
+
+## Reproduce
+
+### Single run
+
+```bash
+gcloud compute tpus tpu-vm ssh tpu-v5e-64-us --zone=us-central1-a --worker=all \
+ --command="cd ~/dllm-jax && \
+ PYTHONPATH=~/dllm-jax:\${PYTHONPATH:-} \
+ RESUME_DIR=gs://dllm-jax-us-central1/checkpoints/ \
+ RESUME_STEP=41400 \
+ MODEL_NAME=Qwen/Qwen3-8B PROMPT='Once upon a time' \
+ GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \
+ TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 TP=8 \
+ INFER_IMPL=fast INFER_SPLASH=1 INFER_SPLASH_BLOCK=512 \
+ WARMUP_RUNS=1 MEASURED_RUNS=2 \
+ python3 scripts/tpu_infer.py"
+```
+
+### Temperature sweep
+
+Same as above but replace the last two lines with:
+
+```
+TEMPS=0.0,0.3,0.5,0.7,1.0,1.5 SEED=42 \
+WARMUP_RUNS=0 MEASURED_RUNS=1 \
+```
diff --git a/docs/inference.md b/docs/inference.md
new file mode 100644
index 0000000..d0b2265
--- /dev/null
+++ b/docs/inference.md
@@ -0,0 +1,123 @@
+# Inference
+
+Full reference for [`scripts/tpu_infer.py`](../scripts/tpu_infer.py), the
+multi-host inference entry point. See the [README](../README.md#infer) for
+the quick-start example. For a deep dive on the splash-attention vs dense
+trade-off and end-to-end timing, see
+[`inference-optimization.md`](inference-optimization.md). For the design of
+the KV-cache path, see [`kv-cache-design.md`](kv-cache-design.md).
+
+## Three implementations
+
+All three produce **byte-identical output at matching settings.** Pick one
+based on shape and length:
+
+| `INFER_IMPL=` | What it does | When to use |
+|---|---|---|
+| `fast` (default) | fixed-shape `fori_loop` with step- and block-level early breaks | short/medium gen, especially with `INFER_SPLASH=1` |
+| `kv_fast` | KV-cached, decode only the active block per step | long gen (≥1024 tokens) |
+| `legacy` | Python-loop reference path with host-side breaks | debugging only — slow on TPU |
+
+## Splash attention
+
+`INFER_SPLASH=1` swaps `jax.nn.dot_product_attention` for a Pallas splash
+kernel matched to the block-causal mask. On Qwen3-8B at `GEN_LENGTH=1024`
+it's **3.5× faster** *and* fixes a latent dense-kernel quality bug at
+non-128-aligned sequence lengths.
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `INFER_SPLASH` | enable splash for the `fast` block-causal mask | `0` |
+| `INFER_SPLASH_BLOCK` | splash tile size (auto-rounds to 128-multiple ≤ this) | `512` |
+| `INFER_KV_DTYPE` | KV-cache dtype for `kv_fast`: `bf16` (2× HBM savings) or `fp32` | `bf16` |
+
+## Multi-prompt and temperature sweeps
+
+You can amortize the restore cost across many generations:
+
+```bash
+# Multiple prompts, one per line:
+PROMPTS_FILE=./prompts.txt … python3 scripts/tpu_infer.py
+
+# Temperature sweep with 1 warmup + 2 measured runs each:
+PROMPT='Once upon a time' \
+TEMPS='0.0,0.3,0.5,0.7,1.0' \
+WARMUP_RUNS=1 MEASURED_RUNS=2 \
+… python3 scripts/tpu_infer.py
+```
+
+## Inference env vars
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `RESUME_DIR` / `RESUME_STEP` | which checkpoint to restore | — / latest |
+| `MODEL_NAME` | tokenizer + HF config source | `Qwen/Qwen3-8B` |
+| `PROMPT` | input prompt | `Once upon a time` |
+| `PROMPTS_FILE` | one prompt per line (overrides `PROMPT`; `#` lines = comments) | — |
+| `GEN_LENGTH` | tokens to generate | `32` |
+| `BLOCK_LENGTH` | DMax block size | `32` |
+| `STEPS` | max denoising steps per block | `8` |
+| `INFER_IMPL` | `fast` / `kv_fast` / `legacy` | `fast` |
+| `FAST_BUCKET_LENGTH` | compile window for `fast`; ignored by `kv_fast` | `4096` |
+| `THRESHOLD` | left-to-right confidence cutoff (math `0.5`, code `0.65`, others `0.9`–`0.95`) | `0.95` |
+| `CONFIDENCE_STOP` | block-level early-exit confidence | `0.9` |
+| `TOP_K` | soft-mix top-k (`3` is more coherent on undertrained ckpts) | `1` |
+| `TEMPERATURE` | gumbel-max temperature (`0.0` = greedy) | `0.0` |
+| `TEMPS` | comma-separated sweep, e.g. `0.0,0.3,0.5` | — |
+| `SEED` | RNG seed (only used when `TEMPERATURE > 0`) | none |
+| `WARMUP_RUNS` | throwaway generates before timing | `0` |
+| `MEASURED_RUNS` | timed generates; reports median | `1` |
+| `SUPPRESS_MASK_TOKEN` | force-disable mask logits at argmax | `0` |
+| `MASK_TOKEN_ID` / `EOS_TOKEN_ID` | overrides | tokenizer default |
+| `TP` | inference mesh's TP axis | `8` |
+| `RESTORE_OPTIMIZER` | also restore optimizer state | `0` |
+| `INFER_SPLASH` / `INFER_SPLASH_BLOCK` / `INFER_KV_DTYPE` | see above | `0` / `512` / `bf16` |
+
+## Throughput
+
+TPU v4-32, Qwen3-8B, `TOP_K=3`:
+
+| Config | NFE | seconds |
+|---|---|---|
+| `fast`, GEN=128 | 114 | 32.1 |
+| `kv_fast`, GEN=128 | 119 | 65.9 |
+| `fast`, GEN=1024 | 1010 | 124.7 |
+| `kv_fast`, GEN=1024 | 1043 | 76.6 |
+
+TPU v5e-64, Qwen3-8B DMax, `TOP_K=3`, `GEN=1024` (full writeup in
+[`inference-optimization.md`](inference-optimization.md)):
+
+| Config | seconds | tok/s | quality |
+|---|---|---|---|
+| `fast` (dense attention) | 142.3 | 7.2 | ⚠️ collapses into punctuation |
+| **`fast` + splash** | **40.9** | **25.1** | ✅ coherent |
+| `kv_fast` (fp32 KV) | 52.9 | 19.4 | ✅ coherent |
+| `kv_fast` (bf16 KV) | 52.1 | 19.6 | ✅ bit-identical |
+
+## Single-host CLI
+
+For laptop / single-TPU-host work without GCS:
+
+```bash
+# Generate from a base model (no DMax training needed)
+python scripts/dmax_generate.py \
+ --model Qwen/Qwen3-0.6B \
+ --prompt "Solve 37 * 48." \
+ --gen-length 256 --block-length 32 --steps 32 \
+ --threshold 0.5 --top-k 3 --impl fast
+
+# Generate from a saved single-host trainer checkpoint
+python scripts/dmax_generate_checkpoint.py \
+ --checkpoint-dir ./out-dmax/checkpoint-1000 \
+ --prompt "Solve 37 * 48." \
+ --gen-length 256 --impl kv_fast
+```
+
+## Gotchas
+
+**Output collapses into punctuation with `TOP_K=1`.** SPD feeds the previous
+step's distribution back as a soft mix; at `TOP_K=1` a low-confidence
+single token gives a bad signal. Try `TOP_K=3`.
+
+**First-run generation is slow.** It includes model restore + XLA compile.
+Subsequent runs with identical shapes hit the cached graph and are ~5× faster.
diff --git a/docs/install.md b/docs/install.md
new file mode 100644
index 0000000..8fc00c5
--- /dev/null
+++ b/docs/install.md
@@ -0,0 +1,73 @@
+# Installation
+
+## Prerequisites for multi-host TPU runs
+
+You need three things before any of the multi-host scripts run. Single-host
+CPU/GPU development needs none of them.
+
+**1. A Google Cloud project with TPU access.** Install the [`gcloud` CLI](https://cloud.google.com/sdk/docs/install), then:
+
+```bash
+gcloud auth login
+gcloud config set project YOUR_PROJECT_ID
+```
+
+**2. A regional GCS bucket for checkpoints.** TPU and bucket should share a
+region (cross-region writes are slow):
+
+```bash
+gcloud storage buckets create gs://YOUR_BUCKET_NAME --location=us-east1
+```
+
+By default this repo writes to `gs://${CHECKPOINT_BUCKET_PREFIX}-${region}`
+(prefix `dllm-jax`), so the matching bucket would be `gs://dllm-jax-us-east1`.
+Override via `CHECKPOINT_BUCKET=...` or `CHECKPOINT_DIR=gs://...` if you prefer
+a different layout.
+
+**3. Optional but recommended:** a [Weights & Biases](https://wandb.ai)
+account for loss curves, and a [HuggingFace](https://huggingface.co) account
+if you want gated models or to upload checkpoints. You can run `wandb login`
+and `huggingface-cli login` once per TPU worker and skip the env vars
+entirely.
+
+## TPU VM packaging caveat
+
+Some Python 3.10 TPU VM images ship an older packaging stack where
+`pip install -e '.[tpu]'` can fail with a missing `build_editable` hook,
+or `pip install '.[tpu]'` misreads metadata as `UNKNOWN-0.0.0` without
+installing dependencies. If that happens, skip editable mode and install
+deps explicitly from the synced checkout:
+
+```bash
+python3 -m pip install --user -U 'jax[tpu]' \
+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \
+ 'flax>=0.10.0,<0.11' orbax-checkpoint 'gcsfs<=2026.2.0' 'fsspec<=2026.2.0' \
+ 'optax>=0.2.0' numpy 'transformers>=4.40.0' safetensors \
+ huggingface_hub datasets wandb
+```
+
+Then run scripts with `PYTHONPATH=/path/to/dllm-jax python3 …`.
+
+## Verified TPU versions
+
+End-to-end validation on TPU v4-32 (`us-central2-b`) with this exact stack.
+Pin to these if `pip install '.[tpu]'` shows version drift:
+
+| Package | Version |
+|---------|---------|
+| Python | 3.10.12 |
+| jax / jaxlib | 0.6.2 |
+| libtpu | 0.0.17 |
+| flax | 0.10.7 |
+| optax | 0.2.8 |
+| orbax-checkpoint | 0.11.34 |
+| transformers | 5.5.3 |
+| safetensors | 0.7.0 |
+| datasets | 4.8.4 |
+| gcsfs / fsspec | 2025.3.2 |
+| huggingface_hub | 1.10.1 |
+| numpy | 2.2.6 |
+
+Newer flax (0.12+) on JAX 0.7+ should also work; the
+`_nnx_list = getattr(nnx, "List", list)` shim in `models.py` handles the
+cross-version difference.
diff --git a/docs/kv_cache_design.md b/docs/kv-cache-design.md
similarity index 99%
rename from docs/kv_cache_design.md
rename to docs/kv-cache-design.md
index fdfb3f5..58c4fb2 100644
--- a/docs/kv_cache_design.md
+++ b/docs/kv-cache-design.md
@@ -222,7 +222,7 @@ and identical `updated_block` at every step (modulo float32 rounding in attentio
1. Add `past_k/past_v/cache_position` plumbing to model with no-op default.
2. Implement `dmax_generate_spd_kv_fast` + post-block hard-write pass.
-3. Add `INFER_IMPL=kv_fast` to `tpu_dmax_infer_checkpoint.py`.
+3. Add `INFER_IMPL=kv_fast` to `tpu_infer.py`.
4. A/B test: `kv_fast` vs `fast` on the same prompt/settings. Compare text and `nfe`.
5. Benchmark: measure `generate_seconds` at `gen_length=128, 1024, 4096`. Expect
3–5× speedup on longer contexts.
diff --git a/docs/mfu-optimization.md b/docs/mfu-optimization.md
new file mode 100644
index 0000000..3881c0d
--- /dev/null
+++ b/docs/mfu-optimization.md
@@ -0,0 +1,201 @@
+# MFU Optimization: Qwen3-8B + DMax OPUT on TPU v5e-64
+
+Summary of a MFU tuning pass on `tpu-v5e-64-eu` (europe-west4-b) for
+`scripts/tpu_train.py` training **Qwen3-8B with full DMax OPUT**
+(on_policy_ratio=0.5 rollout kept, block-diffusion attention mask,
+synthetic data, random init, AdamW, aggressive `jax.remat`).
+
+## Headline
+
+| Config | MFU (attn-adj) | tok/s | step s | vs start |
+|--|--|--|--|--|
+| starting point | 2.9% | 6,570 | 10.0 | 1.0× |
+| **final (EE)** | **14.7%** | **32,834** | **8.0** | **5.1×** |
+
+All comparisons are at `MODEL_NAME=Qwen/Qwen3-8B`, `MAX_LEN=4096`,
+`PEAK_TFLOPS_PER_CHIP=197` (12.6 PFLOPS bf16 peak on v5e-64),
+preserved training semantics (full OPUT rollout — changing semantics
+is separated below).
+
+The reported MFU above is under the **original** formula that credited
+one fwd+bwd at `seq=MAX_LEN`. In reality DMax's compiled graph runs
+fwd+bwd at `seq=2·MAX_LEN` and — whenever `DMAX_ON_POLICY_RATIO > 0`
+— also pays a stop_gradient rollout forward at `seq=2·MAX_LEN`. The
+attention term is quadratic, so the true multiplier is **~3.01×** (not
+2×). Under correct accounting the EE measurement is
+**≈ 44% of v5e bf16 peak**, i.e., the compiled graph reaches just under
+half of MXU peak on these shapes. Fixed in `tpu_train.py:319–337`
+(see item 1 below).
+
+## Optimization ladder (full OPUT semantics preserved)
+
+| # | Change | MFU | tok/s | step s | Notes |
+|--|--|--|--|--|--|
+| 0 | Starting — TP=8 hardcoded, dense attn, B=16 L=4096 | 2.9% | 6,570 | 10.0 | baseline |
+| 1 | `TP=2 × FSDP=32` mesh (P) | 7.2% | 16,019 | 8.17 | mesh re-shaping freed HBM; larger matmuls; B=32 fit |
+| 2 | `splash_attention` for block-diffusion mask (Z) | 8.9% | 19,809 | 6.61 | replaced dense `jax.nn.dot_product_attention` on 2L×2L mask |
+| 3 | Scale to B=64 using splash's O(seq) attn memory (BB) | 9.2% | 20,469 | 12.8 | doubled batch — was OOM at B=64 with dense attn |
+| 4 | **Tune splash `BlockSizes` to 512 + `use_fused_bwd_kernel=True` (EE)** | **14.7%** | **32,834** | **8.0** | kernel-level tile + bwd fusion — biggest single win |
+
+### Why each step helped
+
+1. **TP=2 × FSDP=32 mesh** — `tpu_train.py` hard-coded `TP = 8`.
+ On v5e-64's 8×8 torus that forces FSDP=8, so optimizer state shards
+ only 8-way and B must be a multiple of 8. Dropping TP to 2 (with
+ `mesh_utils.create_device_mesh(..., allow_split_physical_axes=True)`
+ because `(32, 2)` doesn't factor onto 8×8) bumps FSDP to 32, which
+ (a) makes the per-chip optimizer/grad/param footprint ~4× smaller
+ and (b) enlarges the matmul size per chip (hidden/TP=2048 vs 512),
+ better amortizing MXU tile overhead.
+
+2. **splash_attention for the block-diffusion mask** — the DMax path
+ was silently **not** using flash attention: `dllm_jax/models.py:367`
+ gated flash on `attention_mask is None`, and DMax always passes a
+ non-None block-diffusion mask. The fallback was a dense
+ `jax.nn.dot_product_attention` over the full 2L×2L mask. Switched
+ to `jax.experimental.pallas.ops.tpu.splash_attention`, with the
+ block-diffusion pattern wrapped as `NumpyMask` inside a
+ `MultiHeadMask([mask] * heads_per_tp)`. Splash runs a block-sparse
+ flash that skips masked regions and never materializes the score
+ matrix.
+
+3. **Scale to B=64 with splash** — dense attention's 2L×2L score
+ materialization was the reason B=64 L=4096 OOMed previously
+ (R: 28.84 GB > 15.75 GB). splash is O(seq), so B=64 fit and the
+ mesh could amortize its fixed costs over 2× the tokens.
+
+4. **Tune splash tile sizes + fused bwd** — the single biggest kernel
+ change. `splash_attention_kernel.BlockSizes` default is
+ `block_q=block_kv=block_kv_compute=128` with a separate `dkv` and
+ `dq` backward pass (`use_fused_bwd_kernel=False`). 128×128 is much
+ smaller than what v5e wants — at the per-chip shapes we run
+ (8192 × 128), each tile barely saturates the MXU and the
+ per-kernel launch overhead dominates. Bumping to 512 (8× more
+ work per launch) and enabling the fused bwd kernel (dkv + dq in a
+ single pass) drops step time 12.8 s → 8.0 s (−37%).
+
+## Knobs added
+
+All exposed as env vars on `scripts/tpu_train.py`:
+
+| Env | Default | Effect |
+|--|--|--|
+| `TP` | 8 | tensor-parallel axis size; 2 or 4 usually better for v5e-64 |
+| `SPLASH_BLOCK` | 512 | splash tile size (was 128 default) |
+| `SPLASH_FUSED_BWD` | 1 | enable fused `dkv`/`dq` backward kernel |
+| `DMAX_ON_POLICY_RATIO` | 0.5 | skips rollout fwd when 0 (see ablation below) |
+| `REMAT_POLICY` | nothing_saveable | `gate_up` / `qkv_gate_up` / `dots_saveable` / `everything_saveable` save more between fwd and bwd; saves recompute at HBM cost. |
+
+## Peak config (EE) — reproduce
+
+```bash
+RUN_NAME=qwen3-8b-dmax-v5e-ee \
+MODEL_NAME=Qwen/Qwen3-8B DATASET=tinystories \
+MAX_LEN=4096 GLOBAL_BATCH=64 \
+TP=2 \
+SPLASH_BLOCK=512 SPLASH_FUSED_BWD=1 \
+DMAX_ENABLE=1 DMAX_ON_POLICY_RATIO=0.5 \
+DMAX_NOISE_LOW=0.75 DMAX_NOISE_HIGH=0.75 DMAX_BLOCK_SIZE=32 \
+PEAK_TFLOPS_PER_CHIP=197 \
+NUM_STEPS=0 NUM_EPOCHS=3 WANDB_LOG=1 \
+python3 scripts/tpu_train.py
+```
+
+## What did NOT help
+
+| Tried | Result |
+|--|--|
+| XLA async-collective flags (`xla_tpu_enable_async_collective_fusion`, etc.) | 0.0 pp — XProf showed comm is only ~14% of step time, not the bottleneck. |
+| Longer seq at same batch (`L=7168`, `L=8192`) at TP=2 | L=7168 was slower (tok/s dropped); L=8192 OOMed (program 13.24 GB > 12.63 GB free — splash mask 16384² too big). |
+| Bigger batch past B=64 at TP=2 (`B=96`, `B=128`) | OOM — program size exceeds HBM. |
+| `OPTIMIZER=adafactor` at B=64 (bf16 factored opt state) | OOMed anyway — activations, not optimizer state, are the blocker at B=64 with splash defaults. |
+| Skipping rollout forward unconditionally in DMax path | Changes OPUT semantics to off-policy-only, not a valid DMax result (see ablation table). |
+
+## XProf breakdown (captured on a pre-tuned splash config)
+
+Sampled steady-state steps 4–6, 4 chips on worker 0:
+
+| Bucket | % of device time |
+|--|--|
+| fusion (matmul + elementwise epilogues) | 77% |
+| while loops (remat recompute) | 7.5% |
+| collective-permute | 6.0% |
+| all-gather (FSDP) | 5.9% |
+| all-reduce | 2.5% |
+| copy / reshape / reduce | <1% |
+
+So comm is ~14% combined (not bound) and MXU utilization inside the
+fusions is the real ceiling. The splash tile tuning in step 4 above
+attacks exactly that — bigger tiles raise per-fusion MXU efficiency.
+
+## Ablation — non-OPUT configs (throughput-only)
+
+These change training semantics and are **not** DMax OPUT results;
+kept for attribution.
+
+| Tag | Config | MFU | tok/s | Why reported |
+|--|--|--|--|--|
+| I | plain MDLM (DMax off), B=8 L=4096 TP=8 | 6.7% | 14,960 | Confirms the 2.9% starting ceiling is really ~6.7% HW × DMax accounting factor. |
+| W | DMax TP=2 B=32 L=4096, `on_policy_ratio=0` + rollout fwd guard | 8.4% | 18,874 | Demonstrates the rollout fwd was ~⅓ of compute — but removing it drops OPUT semantics. |
+| AA | W + splash attention | 10.1% | 22,605 | Same, also not OPUT. |
+
+The W/AA rollout-skip patch sits behind a `DMAX_ON_POLICY_RATIO > 0`
+guard in `loss_fn` — at the default 0.5 it's a no-op and OPUT runs
+normally.
+
+## Code changes (all uncommitted, applied on the TPU)
+
+**`scripts/tpu_train.py`**:
+- Line 87: `TP = int(os.environ.get("TP", "8"))` + split-axes mesh fallback.
+- After flash install: build `_block_diffusion_mask_numpy` once at init,
+ wrap in `MultiHeadMask`, construct splash with tuned `BlockSizes`,
+ register as `dllm_models._MASKED_FLASH_ATTN_FN` (shard_map'd over the
+ fsdp×tp mesh with vmap over batch).
+- `loss_fn`: rollout forward now guarded by `DMAX_ON_POLICY_RATIO > 0`.
+- Profiler hook env: `JAX_PROFILE_DIR`, `JAX_PROFILE_START_STEP`,
+ `JAX_PROFILE_STEPS`.
+
+**`dllm_jax/models.py`**:
+- New `_MASKED_FLASH_ATTN_FN` global. `_attention` now routes to it
+ when `attention_mask is not None`, keeping dense `dot_product_attention`
+ as the final fallback.
+
+## What's next if you want to push past 14.7%
+
+In decreasing ROI:
+
+1. ✅ **Fix the DMax MFU formula** — **implemented** in
+ `tpu_train.py:319–337`. The correct multiplier is **~3.01×**,
+ not 2×: fwd+bwd at 2L gives dense ×2 and attention ×4 (quadratic),
+ and the rollout fwd at 2L (always compiled when
+ `DMAX_ON_POLICY_RATIO > 0`) adds another ~0.67× dense and ~1.33×
+ attention. The new formula credits both passes at their actual
+ sequence lengths and reduces to the pre-DMax formula when
+ `DMAX_ENABLE=0`. Reported MFU on the EE config goes 14.7% → **~44%**
+ (attention-adjusted) without any change in wall-clock.
+2. **Fused SwiGLU Pallas kernel** — merges
+ `down_proj(silu(gate_proj) * up_proj)` into one kernel. Avoids
+ materializing the `(B, 2L, intermediate)` activation in HBM.
+ Mostly reduces the 7.5% remat `while` cost. Estimated +1-3 pp MFU.
+ Multi-day build.
+3. ⚙️ **Looser remat policy** — `REMAT_POLICY` env var now exposes
+ `gate_up`, `qkv_gate_up`, `dots_saveable`, `everything_saveable`
+ alongside the default `nothing_saveable`. `dllm_jax/models.py`
+ marks `q`/`k`/`v` (post-RoPE) and the gate·up product with
+ `checkpoint_name`, so the named policies are free to save them.
+ Not yet measured on v5e-64 — safest starting point is
+ `REMAT_POLICY=gate_up` at **B=32** (B=64 is likely to OOM given
+ saving `gate_up` at seq=2L costs ~24 KB × 2L × layers per sample,
+ ~14 GB per chip at B=64/FSDP=32 bf16).
+4. **Architectural: fold `[noised; clean]` into batch dim** instead of
+ seq dim. Halves attention quadratic cost and the activation
+ footprint. Requires rewriting the block-diffusion mask to run as
+ two separate per-stream attentions. Biggest potential lever but
+ a real rewrite.
+
+## Artefacts
+
+- `/tmp/sweep_results.md` — full sweep table (all configs, incl. OOMs).
+- `/tmp/xprof_summary.md` — XProf trace analysis + top ops.
+- `/tmp/xprof-P/plugins/profile/*/` — raw xplane.pb + trace.json.gz
+ for TensorBoard.
diff --git a/docs/training.md b/docs/training.md
new file mode 100644
index 0000000..ce191f0
--- /dev/null
+++ b/docs/training.md
@@ -0,0 +1,218 @@
+# Training
+
+Full reference for [`scripts/tpu_train.py`](../scripts/tpu_train.py), the
+multi-host training entry point. See the [README](../README.md#train) for the
+quick-start example.
+
+## Datasets
+
+| `DATASET=` | Source | Use |
+|---|---|---|
+| `tinystories` | `roneneldan/TinyStories` (streamed) | small-scale sanity / smoke |
+| `wikipedia` | `wikimedia/wikipedia 20231101.en` (streamed) | pretrain-style packing |
+| `openthoughts` | `open-thoughts/OpenThoughts-114k` (streamed) | SFT (chat-templated) |
+| `parquet` | local `*.parquet` files via `DATASET_PATH=` | bring-your-own |
+| `synthetic` | random ints | regression / shape testing |
+
+`tinystories`, `wikipedia`, and `parquet` use **token-stream packing**:
+documents are tokenized, joined by EOS, and chunked to `MAX_LEN`.
+`openthoughts` is a chat-templated SFT set; by default it also packs the
+full chat-templated text. Set `SFT_TRAIN_ON_ANSWERS_ONLY=1` to mask prompt
+tokens with `-100` and supervise only assistant turns (per-example padded
+batches; some rows >`MAX_LEN` are dropped).
+
+To inspect what your SFT pipeline actually feeds the model:
+
+```bash
+python3 scripts/inspect_sft_data.py --model Qwen/Qwen3-8B --max-len 4096 --rows 2
+```
+
+## DMax / OPUT training
+
+`DMAX_ENABLE=1` switches the script from MDLM-style noising to DMax's
+high-noise, on-policy, block-diffusion objective.
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `DMAX_ENABLE` | turn on DMax (otherwise MDLM-style) | `0` |
+| `DMAX_BLOCK_SIZE` | tokens per noising block | `32` |
+| `DMAX_ON_POLICY_RATIO` | fraction of steps using model's own greedy preds | `0.5` |
+| `DMAX_NOISE_LOW` / `DMAX_NOISE_HIGH` | uniform noise range | `0.75` / `0.75` |
+
+When DMax is on, the script also installs a Pallas **splash attention**
+kernel matched to the block-diffusion mask:
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `SPLASH_BLOCK` | tile size (must be a multiple of 128 dividing `MAX_LEN`) | `512` |
+| `SPLASH_FUSED_BWD` | fused backward kernel | `1` |
+| `DISABLE_SPLASH_ATTN` | fall back to dense attention (debugging) | `0` |
+
+## Optimizer
+
+Both AdamW and Adafactor are wired up:
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `OPTIMIZER` | `adamw` or `adafactor` | `adamw` |
+| `PEAK_LR` | peak learning rate (post-warmup) | `1e-4` |
+| `WARMUP_STEPS` | linear ramp from 0 → `PEAK_LR` | `5` |
+
+> ⚠️ **AdamW is the recommended default** for diffusion-LM training on
+> pretrained init. Adafactor's `scale_by_param_block_rms` misbehaves on
+> bidirectional objectives over causal-LM weights — loss descends and then
+> climbs back to ~12 around step 60 on Qwen3-8B for both MDLM and DMax/OPUT.
+> Use Adafactor only when HBM forces it (e.g. 128k context on a single
+> chip), and budget the divergence risk accordingly.
+
+## Learning-rate schedule
+
+By default the schedule is *constant after warmup*. To decay:
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `LR_SCHEDULE` | `constant` or `cosine` | `constant` |
+| `LR_DECAY_STEPS` | cosine: total decay steps after warmup | `0` |
+| `LR_DECAY_ALPHA` | cosine: final LR = `PEAK_LR * alpha` | `0.1` |
+
+## Mask-token warm start
+
+DMax reuses a token id as the MASK token (default `vocab_size - 1`, an
+*untrained* reserved slot on Qwen3). Without intervention, its row in the
+embedding matrix is essentially random — observed to drift the noised
+forward toward uniform predictions around step 150.
+
+To seed the mask row with the **mean** of all other input/output embedding
+rows:
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `MASK_EMBED_INIT` | `mean` (warm start) or `none` (skip) | `mean` |
+| `MASK_TOKEN_ID` | which row to seed | `vocab_size - 1` |
+
+> ⚠️ On **v5e-64**, set `MASK_EMBED_INIT=none` for now. The pre-shard
+> `.at[].set()` on replicated embed + lm_head leaks ~9.5GB HBM and blocks
+> the training program from loading.
+
+You can also point at a pretrained Qwen3 special token like
+`MASK_TOKEN_ID=151662` (`<|fim_pad|>`) to inherit "fill in the missing
+piece" semantics for free.
+
+## Memory & throughput knobs
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `REMAT_POLICY` | gradient checkpointing policy | `nothing_saveable` |
+| `MAX_LEN` | context length | `16384` |
+| `GLOBAL_BATCH` | global batch (across all chips) | `8` |
+| `PEAK_TFLOPS_PER_CHIP` | for MFU calc; v4=275, v5e=197, v5p=459, v6e=918 | `918` |
+
+The throughput methodology is in [`mfu-optimization.md`](mfu-optimization.md).
+Reference numbers land between ~38–48% MFU after splash + remat tuning.
+
+## Checkpointing
+
+Sharded Orbax DCP checkpoints write to GCS by default:
+
+```
+gs://${CHECKPOINT_BUCKET_PREFIX}-${region}/checkpoints/${RUN_NAME}/
+├── checkpoint_500/
+├── checkpoint_500/commit_success.txt
+├── checkpoint_1000/
+└── …
+```
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `CHECKPOINT_STEPS` | save every N steps | `500` |
+| `CHECKPOINT_KEEP` | retain the last N | `2` |
+| `CHECKPOINT_BUCKET_PREFIX` | regional auto-detect prefix | `dllm-jax` |
+| `CHECKPOINT_BUCKET` | fixed bucket (overrides prefix) | — |
+| `CHECKPOINT_DIR` | full directory (overrides both) | — |
+| `LOCAL_CHECKPOINT_DIR` | fallback if GCS not configured | `/tmp/dllm-jax-checkpoints` |
+| `CHECKPOINT_ON_FINISH` | always write a final checkpoint | `0` |
+
+The script also enables Orbax barrier shims by default
+(`CHECKPOINT_ORBAX_SYNC_DIRS=1`, `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1`)
+so distributed GCS writes use JAX multi-host barriers — needed on the
+JAX-0.6.x / Orbax-0.11.x TPU VM stack.
+
+## Resume
+
+Resuming is one variable: point at the parent directory and optionally pin
+a step.
+
+```bash
+RESUME_DIR=gs://dllm-jax-us-east1/checkpoints/old-run \
+RESUME_STEP=0 \
+… python3 scripts/tpu_train.py
+```
+
+`RESUME_STEP=0` (default) scans `commit_success.txt` markers and picks the
+latest committed step. Set `RESUME_STEP=` to pin a specific one.
+
+### Three resume modes
+
+| Scenario | Settings |
+|---|---|
+| Continue training, same data, same hardware | `RESUME_DIR=…` (defaults are correct) |
+| **Continue onto a new dataset** (e.g. switch to `openthoughts`) | `RESUME_DIR=… RESUME_RESET_STEP=1` |
+| **Pretrain → SFT** (different optimizer family or lr) | `RESUME_DIR=… RESUME_RESTORE_OPTIMIZER=0 RESUME_RESET_STEP=1` |
+
+| Variable | Meaning | Default |
+|---|---|---|
+| `RESUME_DIR` / `RESUME_FROM` | parent dir containing `checkpoint_/` | — |
+| `RESUME_STEP` | specific step (`0` = latest) | `0` |
+| `RESUME_RESTORE_OPTIMIZER` | `0` = restore weights only, fresh optimizer | `1` |
+| `RESUME_RESET_STEP` | `1` = zero global_step / epoch counters | `0` |
+
+Switching hardware (v4 → v5e, v6e → v4, …) is supported transparently by
+orbax: the script re-shards on restore using the *current* mesh. The only
+thing you need to update is `TP`.
+
+## Sharding and hardware sizing
+
+The 2D mesh is `(fsdp, tp)` with `fsdp = jax.device_count() // TP`.
+
+| Hardware | Chips | Recommended `TP` | Mesh shape | Notes |
+|---|---|---|---|---|
+| TPU v4-32 | 16 | `8` | `(2, 8)` | validated for Qwen3-8B training and inference |
+| TPU v5e-64 | 64 | `2` | `(32, 2)` | `TP=8` here forces FSDP=8 and ~4× optimizer-state HBM; prefer `TP=2` |
+| TPU v6e-256 | 256 | `8` | `(32, 8)` | high MFU with splash + remat |
+| Single-host (≤ 3B params) | any | `1` | `(N,)` 1D FSDP | `P("fsdp", None)` direct TPU init |
+
+For models ≥ 3B you generally want CPU-first init
+(`jax.default_device(jax.devices("cpu")[0])`), gradient checkpointing on
+every transformer layer, and Pallas flash attention. The smoke script does
+all three automatically.
+
+## Gotchas
+
+**Adafactor diverges around step 60.** See the optimizer note above —
+default to AdamW unless HBM forces otherwise.
+
+**Optimizer rebind after split/merge.** If you manually `nnx.split` and
+`nnx.merge` the optimizer in a hand-written script, rebind `model` or
+you'll get silent zero-progress:
+
+```python
+opt_gdef, opt_state = nnx.split(optimizer)
+opt_state = jax.tree.map(fsdp_sharding, opt_state)
+optimizer = nnx.merge(opt_gdef, opt_state)
+model = optimizer.model # CRITICAL — without this, grads land on a stale model
+```
+
+The built-in `trainers.py` and `tpu_train.py` already do this; only
+relevant if you're rolling your own loop.
+
+**flax 0.10.7 vs 0.12+.** Python 3.10 TPU VMs pin flax to 0.10.7. The
+`_nnx_list = getattr(nnx, "List", list)` shim in `models.py` and the
+positional `optimizer.update(grads)` call are deliberate compat choices.
+Don't "modernize" them without testing on the older stack.
+
+**Stale TPU checkout.** `ImportError: cannot import name '...'` usually
+means the TPU copy is out of date. Re-`scp` `dllm_jax/` and the target
+script.
+
+**`gs://` / TensorFlow warnings on restore.** Harmless. Restore uses
+Orbax's GCS client, not `tf.io.gfile`.
diff --git a/scripts/dmax_tinystories_train.py b/scripts/examples/dmax_train_tinystories.py
similarity index 100%
rename from scripts/dmax_tinystories_train.py
rename to scripts/examples/dmax_train_tinystories.py
diff --git a/scripts/run_tpu_v4_32_3epoch.sh b/scripts/examples/run_tpu_train_v4_32_3epoch.sh
similarity index 87%
rename from scripts/run_tpu_v4_32_3epoch.sh
rename to scripts/examples/run_tpu_train_v4_32_3epoch.sh
index d4305e3..63901c5 100755
--- a/scripts/run_tpu_v4_32_3epoch.sh
+++ b/scripts/examples/run_tpu_train_v4_32_3epoch.sh
@@ -11,5 +11,5 @@ export LOG_EVERY="${LOG_EVERY:-50}"
export TOKENIZERS_PARALLELISM=false
export LIBTPU_INIT_ARGS="${LIBTPU_INIT_ARGS:---xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_data_parallel_opt_different_sized_ops=true}"
-cd "$(dirname "$0")/.."
-exec python3 scripts/tpu_v4_32_train_3epoch.py 2>&1
+cd "$(dirname "$0")/../.."
+exec python3 scripts/examples/tpu_train_v4_32_3epoch.py 2>&1
diff --git a/scripts/tpu_v4_32_train_3epoch.py b/scripts/examples/tpu_train_v4_32_3epoch.py
similarity index 100%
rename from scripts/tpu_v4_32_train_3epoch.py
rename to scripts/examples/tpu_train_v4_32_3epoch.py
diff --git a/scripts/inspect_sft_data.py b/scripts/inspect_sft_data.py
new file mode 100644
index 0000000..31a4064
--- /dev/null
+++ b/scripts/inspect_sft_data.py
@@ -0,0 +1,162 @@
+"""Dump SFT data pipeline batches to stdout for inspection.
+
+Mirrors the packing path in ``scripts/tpu_train.py`` (``refill_buffer`` +
+``get_batch``) without any JAX / training code, so we can cheaply check:
+
+- does ``apply_chat_template`` round-trip correctly (decode == human-readable)?
+- does concat-packing produce the expected ``<|im_start|>...<|im_end|>`` structure?
+- does labels == input_ids (full-text SFT) and is the MASK token absent from data?
+- does the EOS token sit between docs, not inside them?
+- are there suspicious id==-1 / id>vocab_size / repeated pathological tokens?
+
+Run locally or on any TPU worker:
+ python3 scripts/inspect_sft_data.py [--max-len 4096] [--rows 2] [--batches 1]
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import sys
+from collections import Counter
+from pathlib import Path
+
+import numpy as np
+
+
+ROOT = Path(__file__).resolve().parent.parent
+if str(ROOT) not in sys.path:
+ sys.path.insert(0, str(ROOT))
+
+
+_ROLE_MAP = {
+ "human": "user",
+ "gpt": "assistant",
+ "system": "system",
+ "user": "user",
+ "assistant": "assistant",
+}
+
+
+def _sft_row_to_messages(row):
+ convs = row.get("conversations")
+ if not convs:
+ return None
+ messages = []
+ system = row.get("system")
+ if system:
+ messages.append({"role": "system", "content": system})
+ for c in convs:
+ if isinstance(c, dict) and "from" in c:
+ messages.append({"role": _ROLE_MAP.get(c["from"], c["from"]), "content": c["value"]})
+ else:
+ messages.append(c)
+ return messages
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--model", default=os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B"))
+ ap.add_argument("--max-len", type=int, default=int(os.environ.get("MAX_LEN", "4096")))
+ ap.add_argument("--batch", type=int, default=int(os.environ.get("GLOBAL_BATCH", "4")))
+ ap.add_argument("--rows", type=int, default=2, help="how many packed rows to decode in full")
+ ap.add_argument("--batches", type=int, default=1, help="how many batches to draw")
+ ap.add_argument("--mask-id", type=int, default=int(os.environ.get("MASK_TOKEN_ID", "151662")))
+ args = ap.parse_args()
+
+ print(f"[inspect] model={args.model} max_len={args.max_len} batch={args.batch} mask_id={args.mask_id}")
+
+ import transformers
+ from datasets import load_dataset
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
+ eos_id = tokenizer.eos_token_id
+ print(f"[inspect] eos_id={eos_id} pad_id={tokenizer.pad_token_id} vocab={tokenizer.vocab_size}")
+ # Decode the special MASK id to see what Qwen3 calls 151662
+ try:
+ mask_tok = tokenizer.decode([args.mask_id])
+ print(f"[inspect] mask_id {args.mask_id} -> {mask_tok!r}")
+ except Exception as exc:
+ print(f"[inspect] decoding mask_id failed: {exc}")
+
+ ds = load_dataset("open-thoughts/OpenThoughts-114k", "default", split="train", streaming=True)
+ ds_iter = iter(ds)
+
+ # Replicate refill_buffer + get_batch packing
+ token_buffer: list[int] = []
+ total_rows_consumed = 0
+ total_rows_dropped = 0
+
+ def refill(needed: int):
+ nonlocal total_rows_consumed, total_rows_dropped
+ while len(token_buffer) < needed:
+ try:
+ row = next(ds_iter)
+ except StopIteration:
+ break
+ messages = _sft_row_to_messages(row)
+ total_rows_consumed += 1
+ if messages is None:
+ total_rows_dropped += 1
+ continue
+ out = tokenizer.apply_chat_template(
+ messages, tokenize=True, add_generation_prompt=False, return_dict=True,
+ )
+ token_buffer.extend(out["input_ids"])
+ token_buffer.append(eos_id)
+
+ for b in range(args.batches):
+ print("\n" + "=" * 80)
+ print(f"BATCH {b + 1}/{args.batches}")
+ print("=" * 80)
+ needed = args.batch * args.max_len
+ refill(needed)
+ if not token_buffer:
+ print("[inspect] dataset exhausted before first batch")
+ return
+
+ ids = np.full((args.batch, args.max_len), tokenizer.pad_token_id, dtype=np.int64)
+ for i in range(args.batch):
+ length = min(args.max_len, len(token_buffer))
+ if length > 0:
+ ids[i, :length] = token_buffer[:length]
+ token_buffer = token_buffer[length:]
+ labels = ids.copy() # full-text SFT
+
+ # Batch-level diagnostics
+ counts = Counter(ids.ravel().tolist())
+ top = counts.most_common(10)
+ print(f"[diag] rows consumed so far: {total_rows_consumed} (dropped={total_rows_dropped})")
+ print(f"[diag] ids shape={ids.shape} dtype={ids.dtype}")
+ print(f"[diag] labels == ids: {np.array_equal(ids, labels)}")
+ print(f"[diag] min/max id: {ids.min()} / {ids.max()}")
+ print(f"[diag] mask_id {args.mask_id} occurrences in batch: {(ids == args.mask_id).sum()}")
+ print(f"[diag] eos_id {eos_id} occurrences in batch: {(ids == eos_id).sum()}")
+ print(f"[diag] pad_id {tokenizer.pad_token_id} occurrences in batch: {(ids == tokenizer.pad_token_id).sum()}")
+ print(f"[diag] top-10 tokens in batch:")
+ for tok_id, cnt in top:
+ try:
+ s = tokenizer.decode([tok_id])
+ except Exception:
+ s = "?"
+ print(f" {tok_id:6d} x {cnt:7d} -> {s!r}")
+
+ # Decode head and tail of a few rows
+ for r in range(min(args.rows, args.batch)):
+ print("\n" + "-" * 60)
+ print(f"row {r} — first 400 chars")
+ print("-" * 60)
+ head_text = tokenizer.decode(ids[r, :400], skip_special_tokens=False)
+ print(head_text)
+ print("-" * 60)
+ print(f"row {r} — last 400 chars")
+ print("-" * 60)
+ tail_text = tokenizer.decode(ids[r, -400:], skip_special_tokens=False)
+ print(tail_text)
+ # EOS locations for doc boundaries
+ eos_positions = np.where(ids[r] == eos_id)[0]
+ print(f"[row {r}] eos at positions: {eos_positions[:20].tolist()}{'...' if len(eos_positions) > 20 else ''} (count={len(eos_positions)})")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/tpu_dmax_infer_checkpoint.py b/scripts/tpu_infer.py
similarity index 79%
rename from scripts/tpu_dmax_infer_checkpoint.py
rename to scripts/tpu_infer.py
index db162d1..7cd2263 100644
--- a/scripts/tpu_dmax_infer_checkpoint.py
+++ b/scripts/tpu_infer.py
@@ -364,6 +364,19 @@ def main() -> None:
model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B")
prompt = os.environ.get("PROMPT", "Once upon a time")
+ # PROMPTS_FILE (one prompt per line) overrides PROMPT and runs generation
+ # for each prompt against the same restored model. Blank lines skipped;
+ # lines starting with # are treated as comments and also skipped.
+ prompts_file = os.environ.get("PROMPTS_FILE", "").strip()
+ if prompts_file:
+ with open(prompts_file) as f:
+ prompts_list = [
+ ln.rstrip("\n") for ln in f if ln.strip() and not ln.lstrip().startswith("#")
+ ]
+ if not prompts_list:
+ raise ValueError(f"PROMPTS_FILE={prompts_file} produced no usable prompts")
+ else:
+ prompts_list = [prompt]
gen_length = int(os.environ.get("GEN_LENGTH", "32"))
block_length = int(os.environ.get("BLOCK_LENGTH", "32"))
steps = int(os.environ.get("STEPS", "8"))
@@ -374,6 +387,8 @@ def main() -> None:
fast_bucket_length = int(os.environ.get("FAST_BUCKET_LENGTH", "4096"))
temperature = float(os.environ.get("TEMPERATURE", "0.0"))
top_k = int(os.environ.get("TOP_K", "1"))
+ warmup_runs = max(0, int(os.environ.get("WARMUP_RUNS", "0")))
+ measured_runs = max(1, int(os.environ.get("MEASURED_RUNS", "1")))
seed_env = os.environ.get("SEED")
seed = int(seed_env) if seed_env else None
dtype_name = os.environ.get("DTYPE", "bfloat16")
@@ -470,9 +485,19 @@ def main() -> None:
restored_epoch_step = int(np.asarray(multihost_utils.process_allgather(restored["epoch_step"], tiled=False))[0])
sync_all("infer-restore-complete")
- input_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
- input_arr = np.asarray([input_ids], dtype=np.int32)
- input_arr = jax.device_put(input_arr, NamedSharding(mesh, P()))
+ # For multi-prompt runs, tokenize all upfront. Pad to the max so every
+ # call has the same input shape and only one compile happens.
+ tokenized = [tokenizer(p, add_special_tokens=False)["input_ids"] for p in prompts_list]
+ max_prompt_len = max(len(ids) for ids in tokenized)
+ pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
+ padded = [
+ (ids + [pad_id] * (max_prompt_len - len(ids))) for ids in tokenized
+ ]
+ prompt_lengths = [len(ids) for ids in tokenized] # original (unpadded) length per prompt
+ input_arr_full = np.asarray(padded, dtype=np.int32) # [num_prompts, max_len]
+ input_ids = tokenized[0] # for the first-prompt log line
+ # Device-place the first prompt's input array for the compile-path unchanged below.
+ input_arr = jax.device_put(input_arr_full[:1], NamedSharding(mesh, P()))
if proc == 0:
print(
@@ -513,22 +538,82 @@ def main() -> None:
generate_kwargs["bucket_length"] = fast_bucket_length
# kv_fast does not use bucket_length (it runs a single compiled while_loop
# over blocks with a pre-allocated KV cache of size ``total_length``).
- output = generate_fn(model, input_arr, **generate_kwargs)
- output.generated_tokens.block_until_ready()
- nfe = int(np.asarray(output.nfe))
- generated_all = np.asarray(multihost_utils.process_allgather(output.generated_tokens, tiled=False))
- sync_all("infer-generation-complete")
+
+ # TEMPS env var lets one process run multiple temperatures sharing the
+ # same restored model. Example: TEMPS="0.0,0.3,0.5,0.7,1.0"
+ temps_env = os.environ.get("TEMPS", "").strip()
+ if temps_env:
+ temps_list = [float(t) for t in temps_env.split(",") if t.strip()]
+ else:
+ temps_list = [temperature]
+
+ def _run_once(tag: str, temp_val: float, seed_val):
+ local_kwargs = dict(generate_kwargs)
+ local_kwargs["temperature"] = temp_val
+ local_kwargs["seed"] = seed_val
+ rt0 = time.time()
+ output = generate_fn(model, input_arr, **local_kwargs)
+ output.generated_tokens.block_until_ready()
+ dt = time.time() - rt0
+ nfe_val = int(np.asarray(output.nfe))
+ if proc == 0:
+ print(f"[bench] {tag} temp={temp_val} generate_seconds={dt:.2f} nfe={nfe_val}", flush=True)
+ sync_all(f"infer-{tag}")
+ return output, dt, nfe_val
+
+ all_results = [] # list of (prompt_idx, prompt_str, temp_result_dict)
+ for p_idx, p_text in enumerate(prompts_list):
+ # Swap input_arr to this prompt's padded tokens (same shape → no recompile).
+ input_arr = jax.device_put(input_arr_full[p_idx:p_idx + 1], NamedSharding(mesh, P()))
+ for t_idx, temp_val in enumerate(temps_list):
+ for i in range(warmup_runs):
+ _run_once(f"warmup[p{p_idx}_temp={temp_val},{i + 1}/{warmup_runs}]", temp_val, seed)
+
+ measured_times = []
+ measured_nfe = []
+ output = None
+ for i in range(measured_runs):
+ output, dt, nfe_val = _run_once(
+ f"measured[p{p_idx}_temp={temp_val},{i + 1}/{measured_runs}]", temp_val, seed,
+ )
+ measured_times.append(dt)
+ measured_nfe.append(nfe_val)
+
+ generated_all = np.asarray(multihost_utils.process_allgather(output.generated_tokens, tiled=False))
+ sync_all(f"infer-generation-complete-p{p_idx}-temp{t_idx}")
+ if proc == 0:
+ generated = generated_all[0, 0].tolist() if generated_all.ndim == 3 else generated_all[0].tolist()
+ median_dt = sorted(measured_times)[len(measured_times) // 2]
+ tok_per_s = gen_length / median_dt if median_dt > 0 else float("nan")
+ all_results.append((
+ p_idx, p_text,
+ {
+ "temperature": temp_val,
+ "seed": seed,
+ "nfe": measured_nfe[-1],
+ "gen_tokens": len(generated),
+ "median_s": median_dt,
+ "tok_per_s": tok_per_s,
+ "text": tokenizer.decode(generated, skip_special_tokens=True),
+ },
+ ))
if proc == 0:
- generated = generated_all[0, 0].tolist() if generated_all.ndim == 3 else generated_all[0].tolist()
- print("=" * 70)
+ print("\n" + "=" * 70)
print(f"checkpoint_step={restored_step} epoch={restored_epoch} epoch_step={restored_epoch_step}")
- print(f"prompt={prompt!r}")
- print("generated:")
- print(tokenizer.decode(generated, skip_special_tokens=True))
- print(f"nfe={nfe} generated_tokens={len(generated)}")
- print(f"restore_plus_generate_seconds={time.time() - t0:.1f}")
- print(f"generate_seconds={time.time() - tg:.1f}")
+ for p_idx, p_text, r in all_results:
+ print("=" * 70)
+ print(f"[prompt {p_idx + 1}/{len(prompts_list)}] {p_text!r}")
+ print("-" * 70)
+ print(
+ f"temp={r['temperature']} seed={r['seed']} "
+ f"nfe={r['nfe']} gen_tokens={r['gen_tokens']} "
+ f"median={r['median_s']:.2f}s tok_per_s={r['tok_per_s']:.1f}"
+ )
+ print("generated:")
+ print(r["text"])
+ print("=" * 70)
+ print(f"restore_plus_all_generate_seconds={time.time() - t0:.1f}")
print("=" * 70, flush=True)
diff --git a/scripts/tpu_v6e_smoke.py b/scripts/tpu_train.py
similarity index 71%
rename from scripts/tpu_v6e_smoke.py
rename to scripts/tpu_train.py
index ed139ed..19502ad 100644
--- a/scripts/tpu_v6e_smoke.py
+++ b/scripts/tpu_train.py
@@ -73,9 +73,13 @@ def load_dotenv(path: str = ".env") -> None:
try:
import orbax.checkpoint as ocp
from orbax.checkpoint import options as ocp_options
+ from orbax.checkpoint import args as ocp_args
+ from orbax.checkpoint import type_handlers as ocp_type_handlers
except ImportError:
ocp = None
ocp_options = None
+ ocp_args = None
+ ocp_type_handlers = None
proc = jax.process_index()
nproc = jax.process_count()
@@ -83,10 +87,13 @@ def load_dotenv(path: str = ".env") -> None:
nlocal = jax.local_device_count()
print(f"[Worker {proc}/{nproc}] devices={ndev} local={nlocal} backend={jax.default_backend()}", flush=True)
-# ── 2D mesh: FSDP (8) × TP (8) ──────────────────────────────
-TP = 8
-DP = ndev // TP # 8 on v6e-64
-devices = mesh_utils.create_device_mesh((DP, TP))
+# ── 2D mesh: FSDP × TP (TP overridable via env, default 8) ──
+TP = int(os.environ.get("TP", "8"))
+DP = ndev // TP
+try:
+ devices = mesh_utils.create_device_mesh((DP, TP))
+except NotImplementedError:
+ devices = mesh_utils.create_device_mesh((DP, TP), allow_split_physical_axes=True)
mesh = Mesh(devices, axis_names=("fsdp", "tp"))
if proc == 0:
print(f"[Worker {proc}] 2D Mesh: fsdp={DP} × tp={TP}", flush=True)
@@ -141,6 +148,12 @@ def gs_join(base: str, *parts: str) -> str:
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B")
DATASET = os.environ.get("DATASET", "tinystories").lower()
SYNTHETIC_DATA = DATASET in {"synthetic", "random", "dummy"}
+# SFT datasets use per-example padded batches. Plain pretraining datasets
+# concatenate tokens into a single stream.
+SFT_DATASET = DATASET in {"openthoughts", "open-thoughts", "open-thoughts-114k"}
+# Default: train on the whole chat-templated string (treat it like tinystories).
+# Set SFT_TRAIN_ON_ANSWERS_ONLY=1 to revert to prompt-masked supervision.
+SFT_TRAIN_ON_ANSWERS_ONLY = os.environ.get("SFT_TRAIN_ON_ANSWERS_ONLY", "0").strip().lower() in {"1", "true", "yes", "on"}
DATASET_PATH = os.environ.get("DATASET_PATH")
MODEL_SLUG = path_safe_name(MODEL_NAME.rstrip("/").rsplit("/", 1)[-1]) or "model"
RUN_NAME = os.environ.get("RUN_NAME") or os.environ.get("WANDB_RUN_NAME") or f"{MODEL_SLUG}-{DATASET}-{int(time.time())}"
@@ -201,6 +214,12 @@ def gs_join(base: str, *parts: str) -> str:
DMAX_NOISE_LOW = float(os.environ.get("DMAX_NOISE_LOW", "0.75"))
DMAX_NOISE_HIGH = float(os.environ.get("DMAX_NOISE_HIGH", "0.75"))
DMAX_BLOCK_SIZE = int(os.environ.get("DMAX_BLOCK_SIZE", "32"))
+# ── XProf / JAX profiler ─────────────────────────────────────────────────────
+XPROF_ENABLE = env_flag("XPROF_ENABLE", False)
+XPROF_DIR = os.environ.get("XPROF_DIR", "/tmp/xprof/run")
+XPROF_START_STEP = int(os.environ.get("XPROF_START_STEP", "4"))
+XPROF_STOP_STEP = int(os.environ.get("XPROF_STOP_STEP", "7"))
+_xprof_active = False
if WANDB_LOG and WANDB_MODE != "offline" and proc == 0:
has_wandb_auth = bool(os.environ.get("WANDB_API_KEY")) or Path.home().joinpath(".netrc").exists()
@@ -308,8 +327,25 @@ def estimate_model_params(cfg) -> int:
return int(embedding_params + lm_head_params + layers * (attn_params + mlp_params + norm_params))
EST_PARAMS = estimate_model_params(config)
-TRAIN_FLOPS_PER_TOKEN_DENSE = 6 * EST_PARAMS
-TRAIN_FLOPS_PER_TOKEN_ATTN = 12 * int(config.num_hidden_layers) * MAX_LEN * int(config.hidden_size)
+# DMax stacks [noised; clean] so each sample's compiled fwd+bwd runs at seq=2·L,
+# and when on_policy_ratio>0 the compiled graph also pays a stop_gradient rollout
+# fwd at seq=2·L (the per-sample on_policy_flag only gates the argmax, not execution).
+_seq_clean = MAX_LEN
+_seq_train_fwd = 2 * MAX_LEN if DMAX_ENABLE else MAX_LEN
+_seq_rollout_fwd = 2 * MAX_LEN if (DMAX_ENABLE and DMAX_ON_POLICY_RATIO > 0) else 0
+_layers = int(config.num_hidden_layers)
+_hidden = int(config.hidden_size)
+# 6N/token fwd+bwd + 2N/token rollout fwd (dense equivalent).
+_per_sample_dense = 6 * EST_PARAMS * _seq_train_fwd + 2 * EST_PARAMS * _seq_rollout_fwd
+# 12·L·seq²·H fwd+bwd + 4·L·seq²·H rollout fwd (dense attention reference, even
+# though splash may skip masked blocks — keeps the denominator comparable to the
+# non-DMax baseline).
+_per_sample_attn = (
+ 12 * _layers * _seq_train_fwd * _seq_train_fwd * _hidden
+ + 4 * _layers * _seq_rollout_fwd * _seq_rollout_fwd * _hidden
+)
+TRAIN_FLOPS_PER_TOKEN_DENSE = _per_sample_dense // _seq_clean
+TRAIN_FLOPS_PER_TOKEN_ATTN = _per_sample_attn // _seq_clean
TRAIN_FLOPS_PER_TOKEN_TOTAL = TRAIN_FLOPS_PER_TOKEN_DENSE + TRAIN_FLOPS_PER_TOKEN_ATTN
PEAK_FLOPS = PEAK_TFLOPS_PER_CHIP * 1e12 * ndev
@@ -321,6 +357,8 @@ def estimate_model_params(cfg) -> int:
"train_flops_per_token_attention": TRAIN_FLOPS_PER_TOKEN_ATTN,
"train_flops_per_token_total": TRAIN_FLOPS_PER_TOKEN_TOTAL,
"peak_flops": PEAK_FLOPS,
+ "train_seq_train_fwd": _seq_train_fwd,
+ "train_seq_rollout_fwd": _seq_rollout_fwd,
},
allow_val_change=True,
)
@@ -328,11 +366,14 @@ def estimate_model_params(cfg) -> int:
if proc == 0:
print(f"[Worker {proc}] {MODEL_NAME}: {config.num_hidden_layers}L h={config.hidden_size} "
f"V={config.vocab_size}", flush=True)
+ _dmax_tag = (
+ f" dmax(seq_fwd={_seq_train_fwd},rollout_fwd={_seq_rollout_fwd})" if DMAX_ENABLE else ""
+ )
print(
f"[Worker {proc}] estimated_params={EST_PARAMS/1e9:.2f}B "
f"flops/token dense={TRAIN_FLOPS_PER_TOKEN_DENSE/1e9:.1f}G "
f"attn={TRAIN_FLOPS_PER_TOKEN_ATTN/1e9:.1f}G "
- f"peak={PEAK_FLOPS/1e15:.2f} PFLOP/s",
+ f"peak={PEAK_FLOPS/1e15:.2f} PFLOP/s{_dmax_tag}",
flush=True,
)
if RUN_FULL_EPOCHS:
@@ -360,8 +401,102 @@ def _sharded_flash_attn(q, k, v, sm_scale):
if proc == 0:
print(f"[Worker {proc}] Pallas flash installed (blocks=512)", flush=True)
+# ── Splash attention for DMax block-diffusion mask ─────────
+# The dense-mask fallback in _attention is ~1.5-2s/step for Qwen3-8B at seq=8192;
+# splash_attention_kernel runs block-sparse flash over the block-diffusion pattern.
+# DISABLE_SPLASH_ATTN=1 skips splash install and falls back to dense dot_product_attention
+# with the full block-diffusion mask — slower but a useful isolation for
+# numerics debugging.
+if DMAX_ENABLE and not env_flag("DISABLE_SPLASH_ATTN", False):
+ try:
+ from jax.experimental.pallas.ops.tpu.splash_attention import (
+ splash_attention_mask as _sm,
+ splash_attention_kernel as _sk,
+ )
+
+ def _block_diffusion_mask_numpy(seq_len_l: int, block_size: int) -> np.ndarray:
+ two_l = seq_len_l * 2
+ q_idx = np.arange(two_l)[:, None]
+ kv_idx = np.arange(two_l)[None, :]
+ x0_q = q_idx >= seq_len_l
+ x0_kv = kv_idx >= seq_len_l
+ block_q = np.where(x0_q, (q_idx - seq_len_l) // block_size, q_idx // block_size)
+ block_kv = np.where(x0_kv, (kv_idx - seq_len_l) // block_size, kv_idx // block_size)
+ bd = (block_q == block_kv) & (x0_q == x0_kv)
+ off = (block_q > block_kv) & x0_kv & (~x0_q)
+ bc = (block_q >= block_kv) & x0_kv & x0_q
+ return bd | off | bc
+
+ _num_heads_total = int(config.num_attention_heads)
+ _heads_per_tp = _num_heads_total // TP
+ _mask_np = _block_diffusion_mask_numpy(MAX_LEN, DMAX_BLOCK_SIZE).astype(np.bool_)
+ _splash_mask = _sm.MultiHeadMask(masks=[_sm.NumpyMask(_mask_np)] * _heads_per_tp)
+ _splash_bs = int(os.environ.get("SPLASH_BLOCK", "512"))
+ _splash_fused_bwd = env_flag("SPLASH_FUSED_BWD", True)
+ _bs_kwargs = dict(
+ block_q=_splash_bs, block_kv=_splash_bs, block_kv_compute=_splash_bs,
+ block_q_dkv=_splash_bs, block_kv_dkv=_splash_bs, block_kv_dkv_compute=_splash_bs,
+ use_fused_bwd_kernel=_splash_fused_bwd,
+ )
+ if not _splash_fused_bwd:
+ _bs_kwargs.update(block_q_dq=_splash_bs, block_kv_dq=_splash_bs)
+ _splash_block_sizes = _sk.BlockSizes(**_bs_kwargs)
+ _splash_fn = _sk.make_splash_mha_single_device(mask=_splash_mask, block_sizes=_splash_block_sizes)
+
+ def _splash_per_shard(q, k, v, sm_scale):
+ # q/k/v per shard: [B, H_local, T, D]. Splash expects [H, T, D] per call.
+ q_scaled = (q * sm_scale).astype(q.dtype)
+ return jax.vmap(_splash_fn)(q_scaled, k, v)
+
+ def _sharded_masked_flash(q, k, v, sm_scale):
+ return shard_map(
+ lambda q, k, v: _splash_per_shard(q, k, v, sm_scale),
+ mesh=mesh,
+ in_specs=(P("fsdp", "tp", None, None),) * 3,
+ out_specs=P("fsdp", "tp", None, None),
+ check_rep=False,
+ )(q, k, v)
+
+ dllm_models._MASKED_FLASH_ATTN_FN = _sharded_masked_flash
+ if proc == 0:
+ print(
+ f"[Worker {proc}] Splash attention installed for DMax block-diffusion "
+ f"(mask {_mask_np.shape}, heads_per_tp={_heads_per_tp}, "
+ f"block={_splash_bs}, fused_bwd={_splash_fused_bwd})",
+ flush=True,
+ )
+ except Exception as _exc:
+ if proc == 0:
+ print(f"[Worker {proc}] Splash attention setup failed: {_exc}", flush=True)
+
# ── Per-layer remat ────────────────────────────────────────
-remat_policy = jax.checkpoint_policies.nothing_saveable
+# REMAT_POLICY selects what to save between fwd and bwd. Saving more cuts the
+# ~7.5% recompute cost but increases HBM — risky near the B=64 ceiling.
+# nothing_saveable (default) recompute everything
+# gate_up save the MLP gate*up product (biggest matmuls)
+# qkv_gate_up also save q/k/v post-RoPE
+# dots_saveable jax preset: save all dot outputs
+# everything_saveable jax preset: save everything (HBM-hungry)
+_REMAT_POLICY_NAME = os.environ.get("REMAT_POLICY", "nothing_saveable")
+if _REMAT_POLICY_NAME == "nothing_saveable":
+ remat_policy = jax.checkpoint_policies.nothing_saveable
+elif _REMAT_POLICY_NAME == "gate_up":
+ remat_policy = jax.checkpoint_policies.save_only_these_names("gate_up")
+elif _REMAT_POLICY_NAME == "qkv_gate_up":
+ remat_policy = jax.checkpoint_policies.save_only_these_names(
+ "q", "k", "v", "gate_up",
+ )
+elif _REMAT_POLICY_NAME == "dots_saveable":
+ remat_policy = jax.checkpoint_policies.dots_saveable
+elif _REMAT_POLICY_NAME == "everything_saveable":
+ remat_policy = jax.checkpoint_policies.everything_saveable
+else:
+ raise ValueError(
+ f"Unknown REMAT_POLICY={_REMAT_POLICY_NAME!r}; expected one of "
+ "nothing_saveable, gate_up, qkv_gate_up, dots_saveable, everything_saveable."
+ )
+if proc == 0:
+ print(f"[Worker {proc}] remat_policy={_REMAT_POLICY_NAME}", flush=True)
def _remat_hidden_for_heads(self, input_ids=None, *, inputs_embeds=None, attention_mask=None, position_ids=None):
if inputs_embeds is not None:
@@ -421,6 +556,13 @@ def create_block_diffusion_attention_mask(seq_len: int, block_size: int) -> jnp.
n_loaded = 0
print(f"[Worker {proc}] LOAD_PRETRAINED=0; using random initialization", flush=True)
+# MASK_EMBED_INIT runs POST-shard (see below, right after sharding) so that
+# the .at[].set() peak holds sharded-per-chip copies (~75MB), not replicated
+# copies (~2.4GB × 2 = 4.8GB) which OOM v5e-64's 16GB/chip HBM.
+_mask_init_mode = os.environ.get("MASK_EMBED_INIT", "mean").strip().lower()
+if _mask_init_mode not in {"mean", "mean-embed", "average", "none", "skip", "off", ""}:
+ raise ValueError(f"Unknown MASK_EMBED_INIT={_mask_init_mode!r}")
+
# ── Shard to 2D mesh ───────────────────────────────────────
print(f"[Worker {proc}] Sharding to 2D TPU mesh...", flush=True)
t1 = time.time()
@@ -471,17 +613,59 @@ def shard_tree(tree, label: str):
gc.collect()
print(f"[Worker {proc}] Sharding done in {time.time()-t1:.1f}s (total: {time.time()-t0:.1f}s)", flush=True)
+# ── Mask-token embedding warm start (POST-shard) ───────────
+# DMax reuses ``mask_id`` (default vocab_size-1, optionally <|fim_pad|>=151662,
+# etc.) as the MASK token. Untrained vocab rows collapse the noised forward
+# (observed ~step 150 drift to uniform predictions). Seed the mask row with
+# the MEAN of other input/output embedding rows for a sensible "average
+# semantic" warm start. Runs POST-shard so the ``.at[].set()`` doesn't hold
+# two full ~2.4GB replicas in HBM simultaneously.
+if _mask_init_mode in {"mean", "mean-embed", "average"}:
+ _mask_id_local = int(os.environ.get("MASK_TOKEN_ID", str(int(config.vocab_size) - 1)))
+ # On-device compute to respect multi-host sharding — ``jax.device_get`` on
+ # a cross-host sharded jax.Array raises. XLA handles the sharded sum/slice.
+ emb_var = model.embed_tokens.embedding
+ _emb = emb_var.value
+ _sum_rows = jnp.sum(_emb.astype(jnp.float32), axis=0)
+ _mean_row = (_sum_rows - _emb[_mask_id_local].astype(jnp.float32)) / (_emb.shape[0] - 1)
+ emb_var.value = _emb.at[_mask_id_local].set(_mean_row.astype(_emb.dtype))
+ del _emb, _sum_rows, _mean_row
+ gc.collect()
+ lm_head = getattr(model, "lm_head", None)
+ if lm_head is not None and hasattr(lm_head, "kernel"):
+ kvar = lm_head.kernel
+ _k = kvar.value
+ _sum_cols = jnp.sum(_k.astype(jnp.float32), axis=1)
+ _mean_col = (_sum_cols - _k[:, _mask_id_local].astype(jnp.float32)) / (_k.shape[1] - 1)
+ kvar.value = _k.at[:, _mask_id_local].set(_mean_col.astype(_k.dtype))
+ del _k, _sum_cols, _mean_col
+ gc.collect()
+ if proc == 0:
+ print(
+ f"[Worker {proc}] MASK_EMBED_INIT=mean (post-shard, on-device): seeded row/col {_mask_id_local}",
+ flush=True,
+ )
+
# ── Optimizer: AdamW (safer than Adafactor for pretrained MDLM) ──
+# LR_SCHEDULE=cosine with LR_DECAY_STEPS=N decays from PEAK_LR to alpha*PEAK_LR
+# (alpha=LR_DECAY_ALPHA, default 0.1) over N steps after warmup. Default schedule
+# is constant-after-warmup (back-compat with prior runs).
+_lr_schedule_kind = os.environ.get("LR_SCHEDULE", "constant").strip().lower()
+_lr_decay_steps = int(os.environ.get("LR_DECAY_STEPS", "0"))
+_lr_decay_alpha = float(os.environ.get("LR_DECAY_ALPHA", "0.1"))
+if _lr_schedule_kind == "cosine" and _lr_decay_steps > 0:
+ _post_warmup = optax.cosine_decay_schedule(
+ init_value=PEAK_LR, decay_steps=_lr_decay_steps, alpha=_lr_decay_alpha,
+ )
+else:
+ _post_warmup = optax.constant_schedule(PEAK_LR)
if WARMUP_STEPS > 0:
lr_schedule = optax.join_schedules(
- schedules=[
- optax.linear_schedule(0.0, PEAK_LR, WARMUP_STEPS),
- optax.constant_schedule(PEAK_LR),
- ],
+ schedules=[optax.linear_schedule(0.0, PEAK_LR, WARMUP_STEPS), _post_warmup],
boundaries=[WARMUP_STEPS],
)
else:
- lr_schedule = optax.constant_schedule(PEAK_LR)
+ lr_schedule = _post_warmup
if OPTIMIZER == "adafactor":
# Factored optimizer state (~4x less HBM than AdamW) — required to fit
# 128k context on Qwen3-8B per chip. Loss may climb back after ~step 60
@@ -509,7 +693,11 @@ def shard_tree(tree, label: str):
# ── MDLM config + tokenizer + data ─────────────────────────
mdlm_cfg = MDLMConfig(output_dir="/tmp/q3-8b-smoke", max_steps=NUM_STEPS, learning_rate=PEAK_LR)
alpha_sched = LinearAlphaScheduler()
-mask_id = config.vocab_size - 1
+# MASK_TOKEN_ID overrides the default (vocab_size-1, which is a truly-empty
+# reserved slot on Qwen3 with an untrained embedding). Set to a pretrained
+# Qwen3 special token like 151662 (<|fim_pad|>) to inherit useful "fill in
+# the missing piece" semantics for the denoising task.
+mask_id = int(os.environ.get("MASK_TOKEN_ID", str(config.vocab_size - 1)))
print(f"[Worker {proc}] Loading tokenizer + dataset stream...", flush=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -527,6 +715,8 @@ def make_dataset_iter():
ds = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True)
elif DATASET == "wikipedia":
ds = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)
+ elif SFT_DATASET:
+ ds = load_dataset("open-thoughts/OpenThoughts-114k", "default", split="train", streaming=True)
else:
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
return iter(ds)
@@ -542,20 +732,130 @@ def refill_buffer(needed):
exhausted = False
while len(token_buffer) < needed:
try:
- text = next(ds_iter)["text"]
+ row = next(ds_iter)
except StopIteration:
exhausted = True
break
+ if SFT_DATASET:
+ ids = _sft_row_to_tokens(row)
+ if ids is None:
+ continue
+ token_buffer.extend(ids)
+ else:
+ token_buffer.extend(tokenizer.encode(row["text"], add_special_tokens=False))
# Append EOS between docs so the model sees document boundaries.
- token_buffer.extend(tokenizer.encode(text, add_special_tokens=False))
token_buffer.append(eos_id)
return not exhausted
+
+# SFT pipeline:
+# SFT_TRAIN_ON_ANSWERS_ONLY=0 (default, full-text): tokens go through
+# ``_sft_row_to_tokens`` and are concat-packed by ``refill_buffer`` just
+# like tinystories (no row dropping, no truncation — every chat-templated
+# token is trained on).
+# SFT_TRAIN_ON_ANSWERS_ONLY=1 (answer-only): per-example padded batch via
+# ``_sft_batch`` + ``_tokenize_sft_row`` because packing can't preserve
+# the prompt/answer boundary needed for the -100 label mask.
+_ROLE_MAP = {"human": "user", "gpt": "assistant", "system": "system", "user": "user", "assistant": "assistant"}
+_sft_dropped = 0
+_sft_kept = 0
+
+
+def _sft_row_to_messages(row):
+ convs = row.get("conversations")
+ if not convs:
+ return None
+ messages = []
+ system = row.get("system")
+ if system:
+ messages.append({"role": "system", "content": system})
+ for c in convs:
+ if isinstance(c, dict) and "from" in c:
+ messages.append({"role": _ROLE_MAP.get(c["from"], c["from"]), "content": c["value"]})
+ else:
+ messages.append(c)
+ return messages
+
+
+def _sft_row_to_tokens(row):
+ """Tokenize an SFT row via chat template. Used by the concat-pack pipeline.
+ Returns token ids (no length filter — packing handles arbitrary lengths)."""
+ messages = _sft_row_to_messages(row)
+ if messages is None:
+ return None
+ out = tokenizer.apply_chat_template(
+ messages, tokenize=True, add_generation_prompt=False, return_dict=True,
+ )
+ return out["input_ids"]
+
+
+def _tokenize_sft_row(row):
+ """Answer-only path: returns (input_ids, labels) of length ≤ MAX_LEN or
+ None to skip. Only called when SFT_TRAIN_ON_ANSWERS_ONLY=1."""
+ messages = _sft_row_to_messages(row)
+ if messages is None:
+ return None
+ full_out = tokenizer.apply_chat_template(
+ messages, tokenize=True, add_generation_prompt=False, return_dict=True,
+ )
+ full_ids = full_out["input_ids"]
+ if len(full_ids) > MAX_LEN:
+ return None
+ # Prompt = everything except the final assistant turn; assistant supervision
+ # starts at prompt_len. Works for 1-user-1-assistant (OpenThoughts shape).
+ if len(messages) >= 2 and messages[-1]["role"] == "assistant":
+ prompt_msgs = messages[:-1]
+ else:
+ prompt_msgs = messages
+ prompt_out = tokenizer.apply_chat_template(
+ prompt_msgs, tokenize=True, add_generation_prompt=True, return_dict=True,
+ )
+ prompt_len = min(len(prompt_out["input_ids"]), len(full_ids))
+ labels = list(full_ids)
+ for i in range(prompt_len):
+ labels[i] = -100
+ return full_ids, labels
+
+
+def _sft_batch():
+ global _sft_dropped, _sft_kept
+ ids = np.full((GLOBAL_BATCH, MAX_LEN), tokenizer.pad_token_id, dtype=np.int64)
+ labels = np.full((GLOBAL_BATCH, MAX_LEN), -100, dtype=np.int64)
+ filled = 0
+ while filled < GLOBAL_BATCH:
+ try:
+ row = next(ds_iter)
+ except StopIteration:
+ if filled == 0:
+ return None
+ break
+ tok = _tokenize_sft_row(row)
+ if tok is None:
+ _sft_dropped += 1
+ continue
+ row_ids, row_labels = tok
+ L = len(row_ids)
+ ids[filled, :L] = row_ids
+ labels[filled, :L] = row_labels
+ filled += 1
+ _sft_kept += 1
+ # Pad unfilled rows (end of stream) with a duplicate of row 0 to keep shape.
+ for i in range(filled, GLOBAL_BATCH):
+ ids[i] = ids[0]
+ labels[i] = labels[0]
+ return {"input_ids": ids, "labels": labels}
+
+
def get_batch():
global token_buffer
if SYNTHETIC_DATA:
ids = synthetic_rng.integers(0, max(2, mask_id), size=(GLOBAL_BATCH, MAX_LEN), dtype=np.int64)
return {"input_ids": ids, "labels": ids.copy()}
+ # Answer-only needs per-example padded batches (to preserve the -100 mask).
+ # Full-text SFT (default) concat-packs through ``refill_buffer`` below,
+ # reaching full MAX_LEN utilization with no row dropping.
+ if SFT_DATASET and SFT_TRAIN_ON_ANSWERS_ONLY:
+ return _sft_batch()
refill_buffer(GLOBAL_BATCH * MAX_LEN)
if not token_buffer:
return None
@@ -576,7 +876,8 @@ def loss_fn(mdl, batch):
# DMax on-policy rollout: for examples flagged by on_policy_flag, replace
# MASK tokens in the noisy half with the model's own greedy predictions
# (stop_gradient) before the supervised forward.
- if "on_policy_flag" in batch and batch["on_policy_flag"] is not None:
+ # Skip entirely when DMAX_ON_POLICY_RATIO=0 — saves a full 2L forward per step.
+ if DMAX_ON_POLICY_RATIO > 0 and "on_policy_flag" in batch and batch["on_policy_flag"] is not None:
seq_len = batch["input_ids"].shape[1]
flag = batch["on_policy_flag"][:, None] # [B, 1]
rollout_logits = jax.lax.stop_gradient(
@@ -931,7 +1232,7 @@ def save_training_checkpoint(global_step: int, epoch: int, epoch_step: int, *, f
wandb.log({"checkpoint/global_step": global_step}, step=global_step)
def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps: int | None):
- global rng, total_tokens, last_epoch, last_epoch_step
+ global rng, total_tokens, last_epoch, last_epoch_step, _xprof_active
ts = time.time()
data_t0 = time.time()
debug_step = global_step == 1
@@ -994,6 +1295,14 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps
if debug_step:
print(f"[Worker {proc}] step {global_step}: train_step start", flush=True)
+ if XPROF_ENABLE and global_step == XPROF_START_STEP and not _xprof_active:
+ import jax.profiler as _jprof
+ Path(XPROF_DIR).mkdir(parents=True, exist_ok=True)
+ _jprof.start_trace(XPROF_DIR)
+ _xprof_active = True
+ if proc == 0:
+ print(f"[Worker {proc}] xprof start (step {global_step}) -> {XPROF_DIR}", flush=True)
+
train_t0 = time.time()
metrics = train_step(model, optimizer, batch)
loss_val = float(metrics["loss"])
@@ -1046,6 +1355,12 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps
step=global_step,
)
save_training_checkpoint(global_step, epoch, epoch_step)
+ if XPROF_ENABLE and _xprof_active and global_step >= XPROF_STOP_STEP:
+ import jax.profiler as _jprof
+ _jprof.stop_trace()
+ _xprof_active = False
+ if proc == 0:
+ print(f"[Worker {proc}] xprof stop (step {global_step}) -> {XPROF_DIR}", flush=True)
return True
# ── Resume from checkpoint ────────────────────────────────
@@ -1095,22 +1410,64 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps
print(f"[Worker {proc}] Restoring checkpoint: {ckpt_path}", flush=True)
_resume_rng_placeholder = np.asarray(jax.random.key_data(jax.random.key(0)))
+ _resume_restore_optimizer = env_flag("RESUME_RESTORE_OPTIMIZER", True)
restore_target = {
"model": nnx.state(model),
- "optimizer": nnx.state(optimizer),
"global_step": np.asarray(0, dtype=np.int64),
"epoch": np.asarray(0, dtype=np.int32),
"epoch_step": np.asarray(0, dtype=np.int64),
"total_tokens": np.asarray(0, dtype=np.int64),
"rng": _resume_rng_placeholder,
}
+ if _resume_restore_optimizer:
+ restore_target["optimizer"] = nnx.state(optimizer)
# Use orbax directly for GCS restore (flax's restore_checkpoint can't list GCS dirs)
resume_ckpt = make_orbax_checkpointer()
if proc == 0:
print(f"[Worker {proc}] Starting orbax restore...", flush=True)
with orbax_set_mesh_context_patch():
- restored = resume_ckpt.restore(ckpt_path, target=restore_target)
+ if not _resume_restore_optimizer and ocp_args is not None and ocp_type_handlers is not None:
+ # Model-only restore: use PyTreeRestore(partial_restore=True) so the
+ # checkpoint's optimizer state (different shape / optimizer family)
+ # is simply ignored.
+ def _to_restore_args(x):
+ if isinstance(x, jax.Array):
+ return ocp_type_handlers.ArrayRestoreArgs(
+ restore_type=jax.Array,
+ dtype=x.dtype,
+ sharding=x.sharding,
+ global_shape=x.shape,
+ )
+ if isinstance(x, (jnp.ndarray, np.ndarray)):
+ arr = np.asarray(x)
+ return ocp_type_handlers.RestoreArgs(restore_type=np.ndarray, dtype=arr.dtype)
+ return ocp_type_handlers.RestoreArgs()
+
+ restore_args_tree = jax.tree_util.tree_map(
+ _to_restore_args,
+ restore_target,
+ is_leaf=lambda x: isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)),
+ )
+ pytree_ckpt = (
+ ocp.Checkpointer(ocp.PyTreeCheckpointHandler())
+ if ocp_options is None
+ else ocp.Checkpointer(
+ ocp.PyTreeCheckpointHandler(
+ multiprocessing_options=ocp_options.MultiprocessingOptions(primary_host=0)
+ )
+ )
+ )
+ restored = pytree_ckpt.restore(
+ ckpt_path,
+ args=ocp_args.PyTreeRestore(
+ item=restore_target,
+ restore_args=restore_args_tree,
+ partial_restore=True,
+ ),
+ )
+ else:
+ restored = resume_ckpt.restore(ckpt_path, target=restore_target)
if proc == 0:
print(f"[Worker {proc}] Orbax restore done, re-sharding...", flush=True)
@@ -1184,16 +1541,45 @@ def reshard_tree(tree, label: str):
restored_model_state = reshard_tree(restored["model"], "resume-model")
model = nnx.merge(gdef, restored_model_state)
- opt_gdef, _ = nnx.split(optimizer)
- restored_opt_state = reshard_tree(restored["optimizer"], "resume-optimizer")
- optimizer = nnx.merge(opt_gdef, restored_opt_state)
- model = optimizer.model # rebind after merge
-
- resumed_step = int(restored["global_step"])
- resumed_epoch = int(restored["epoch"])
- resumed_epoch_step = int(restored["epoch_step"])
- total_tokens = int(restored["total_tokens"])
- rng = jax.random.wrap_key_data(restored["rng"])
+ if _resume_restore_optimizer:
+ opt_gdef, _ = nnx.split(optimizer)
+ restored_opt_state = reshard_tree(restored["optimizer"], "resume-optimizer")
+ optimizer = nnx.merge(opt_gdef, restored_opt_state)
+ model = optimizer.model # rebind after merge
+ else:
+ if proc == 0:
+ print(
+ f"[Worker {proc}] RESUME_RESTORE_OPTIMIZER=0: keeping fresh "
+ f"{OPTIMIZER} optimizer state (ignoring checkpoint's optimizer)",
+ flush=True,
+ )
+ opt_gdef, opt_state = nnx.split(optimizer)
+ optimizer = nnx.merge(opt_gdef, opt_state)
+ model = optimizer.model # rebind
+
+ # RESUME_RESET_STEP=1 zeroes the step/epoch counters after loading weights.
+ # Use when continuing onto a different dataset (e.g. pretrained checkpoint →
+ # SFT) so the training budget (NUM_STEPS / NUM_EPOCHS) starts from scratch.
+ _resume_reset_step = env_flag("RESUME_RESET_STEP", False)
+ if _resume_reset_step:
+ if proc == 0:
+ print(
+ f"[Worker {proc}] RESUME_RESET_STEP=1: ignoring checkpoint step "
+ f"{int(restored['global_step'])}/epoch {int(restored['epoch'])} "
+ "and starting training from step 0",
+ flush=True,
+ )
+ resumed_step = 0
+ resumed_epoch = 0
+ resumed_epoch_step = 0
+ total_tokens = 0
+ rng = jax.random.key(42)
+ else:
+ resumed_step = int(restored["global_step"])
+ resumed_epoch = int(restored["epoch"])
+ resumed_epoch_step = int(restored["epoch_step"])
+ total_tokens = int(restored["total_tokens"])
+ rng = jax.random.wrap_key_data(restored["rng"])
del restored, restore_target, _resume_rng_placeholder
gc.collect()
@@ -1239,10 +1625,18 @@ def reshard_tree(tree, label: str):
if proc == 0:
print(f"[Worker {proc}] Finished epoch {epoch}/{NUM_EPOCHS} after {epoch_step - 1} steps", flush=True)
else:
+ _PROFILE_DIR = os.environ.get("JAX_PROFILE_DIR")
+ _PROFILE_START = int(os.environ.get("JAX_PROFILE_START_STEP", "0"))
+ _PROFILE_STEPS = int(os.environ.get("JAX_PROFILE_STEPS", "0"))
for step in range(resumed_step + 1, NUM_STEPS + 1):
global_step = step
+ if _PROFILE_DIR and step == _PROFILE_START:
+ jax.profiler.start_trace(_PROFILE_DIR)
if not run_training_step(global_step, epoch=0, epoch_step=step, total_steps=NUM_STEPS):
break
+ if _PROFILE_DIR and _PROFILE_STEPS and step == _PROFILE_START + _PROFILE_STEPS - 1:
+ jax.block_until_ready(step)
+ jax.profiler.stop_trace()
tt = time.time() - t_start
if CHECKPOINT_ON_FINISH: