diff --git a/.env.example b/.env.example index 5e6f4c0..8f8cd0b 100644 --- a/.env.example +++ b/.env.example @@ -1,42 +1,191 @@ # ============================================================================= -# dllm-jax environment configuration -# Copy this file to .env and fill in your values. +# dllm-jax environment configuration — full reference +# +# Copy to `.env` (training) or `.env.deploy` (deploy/run scripts) and fill in +# the values you need. EVERY variable below has a sensible default in the +# code; uncomment only what you want to override. +# +# Sections: +# 1. Model & tokenizer +# 2. Dataset +# 3. Training schedule +# 4. Learning-rate schedule +# 5. Optimizer +# 6. Sharding (TP / FSDP) +# 7. DMax (block-diffusion masked LM) +# 8. SFT (chat-template fine-tuning) +# 9. Performance knobs (splash, remat, mask-embed init) +# 10. GCS checkpointing +# 11. Resume from checkpoint +# 12. HuggingFace Hub (auth + checkpoint mirror) +# 13. Weights & Biases +# 14. Profiling (xprof / jax profile) +# 15. Inference (scripts/tpu_dmax_infer_checkpoint.py) # ============================================================================= -# ---- GCS Checkpoints -------------------------------------------------------- -# Option A: Regional bucket (auto-detected from TPU zone) -# The script builds the bucket name as: gs://${CHECKPOINT_BUCKET_PREFIX}-${region} -# e.g. gs://dllm-jax-us-east1 or gs://dllm-jax-europe-west4 -CHECKPOINT_BUCKET_PREFIX=dllm-jax -# Option B: Fixed (non-regional) bucket name — overrides auto-detection -# CHECKPOINT_BUCKET=my-bucket-name +# ── 1. Model & tokenizer ──────────────────────────────────────────────────── +MODEL_NAME=Qwen/Qwen3-8B +# DTYPE=bfloat16 # compute dtype: bfloat16 | float16 | float32 +# MASK_TOKEN_ID= # default: vocab_size-1 (e.g. 151935 for Qwen3) + # try 151662 (<|fim_pad|>) for a pretrained warm start +# EOS_TOKEN_ID= # default: tokenizer's eos_token_id -# Option C: Full checkpoint directory — overrides both A and B -# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run -# Checkpoint frequency and retention -CHECKPOINT_STEPS=500 -CHECKPOINT_KEEP=2 +# ── 2. Dataset ────────────────────────────────────────────────────────────── +DATASET=tinystories # tinystories | wikipedia | openthoughts | synthetic | parquet +# DATASET_PATH= # required when DATASET=parquet (file or dir of *.parquet) -# ---- Weights & Biases ------------------------------------------------------- -# Enable W&B logging (set to 1 to enable) -WANDB_LOG=0 -WANDB_PROJECT=dllm-jax -# WANDB_API_KEY=your-wandb-api-key-here -# Get your API key: https://wandb.ai/authorize -# Or run `wandb login` on each TPU worker instead. -# ---- Model & Training ------------------------------------------------------- -MODEL_NAME=Qwen/Qwen3-8B -DATASET=tinystories +# ── 3. Training schedule ──────────────────────────────────────────────────── MAX_LEN=16384 GLOBAL_BATCH=8 -NUM_STEPS=20 +NUM_STEPS=20 # set to 0 to use NUM_EPOCHS instead NUM_EPOCHS=0 WARMUP_STEPS=5 PEAK_LR=1e-4 +# RUN_NAME= # default: ${MODEL_SLUG}-${DATASET}-${unix_ts} +# LOGGING_STEPS=1 +# LOAD_PRETRAINED=1 # 0 = random init (debugging only) + + +# ── 4. Learning-rate schedule (post-warmup) ───────────────────────────────── +# LR_SCHEDULE=constant # constant | cosine +# LR_DECAY_STEPS=0 # cosine: total decay steps after warmup +# LR_DECAY_ALPHA=0.1 # cosine: final LR = PEAK_LR * alpha + + +# ── 5. Optimizer ──────────────────────────────────────────────────────────── +OPTIMIZER=adamw # adamw (recommended) | adafactor + # NOTE: Adafactor has been observed to diverge + # on Qwen3-8B MDLM/DMax around step ~60. + # Default to AdamW for real training. + + +# ── 6. Sharding ───────────────────────────────────────────────────────────── +# TP determines the tensor-parallel axis size; FSDP is derived as +# fsdp = jax.device_count() // TP +# v4-32 / v6e-256: TP=8 is fine. +# v5e-64: prefer TP=2 — TP=8 forces FSDP=8 and bloats optimizer state ~4x. +TP=8 + + +# ── 7. DMax (block-diffusion masked LM) ───────────────────────────────────── +# DMAX_ENABLE=0 # 1 = train block-diffusion / DMax objective +# DMAX_BLOCK_SIZE=32 # tokens per noising block +# DMAX_ON_POLICY_RATIO=0.5 # fraction of steps using on-policy noise +# DMAX_NOISE_LOW=0.75 +# DMAX_NOISE_HIGH=0.75 + + +# ── 8. SFT (chat-template fine-tuning) ────────────────────────────────────── +# Used when DATASET=openthoughts (or other chat-templated SFT dataset). +# SFT_TRAIN_ON_ANSWERS_ONLY=0 # 0 = train on full chat-templated text (default, + # packs like pretraining) + # 1 = mask prompt tokens with -100 (per-example + # padded batches) + + +# ── 9. Performance knobs ──────────────────────────────────────────────────── +# SPLASH_BLOCK=512 # splash_attention tile size for training +# SPLASH_FUSED_BWD=1 # use fused backward kernel +# DISABLE_SPLASH_ATTN=0 # 1 = fall back to dense attention (debugging) +# REMAT_POLICY=nothing_saveable # nothing_saveable | dots_saveable | minimal_checkpoint | None +# # see scripts/tpu_v6e_smoke.py for the full list + +# Mask-token embedding warm start (DMax). On v5e-64 the .at[].set() leaks +# ~9.5GB of HBM when run pre-shard on replicated weights → use 'none' there. +# MASK_EMBED_INIT=mean # mean | none + +# PEAK_TFLOPS_PER_CHIP=918 # v6e default; v4=275, v5e=197, v5p=459 + + +# ── 10. GCS checkpointing ─────────────────────────────────────────────────── +# Pick ONE of A / B / C below. +# A. Regional bucket (auto-detected from TPU zone): bucket name becomes +# gs://${CHECKPOINT_BUCKET_PREFIX}-${region} +CHECKPOINT_BUCKET_PREFIX=dllm-jax +# B. Fixed bucket — overrides A. +# CHECKPOINT_BUCKET=my-bucket-name +# C. Full directory — overrides A and B. +# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run + +# CHECKPOINT_SUBDIR=checkpoints # path under bucket +CHECKPOINT_STEPS=500 # save every N optimizer steps +# CHECKPOINT_SECONDS=0 # also save every N seconds (0 = disabled) +CHECKPOINT_KEEP=2 # how many recent checkpoints to retain +# CHECKPOINT_ON_FINISH=0 # 1 = always write a final checkpoint +# CHECKPOINT_USE_GCS=1 # 0 = force local even if CHECKPOINT_BUCKET set +# LOCAL_CHECKPOINT_DIR=/tmp/dllm-jax-checkpoints # used when GCS not configured + +# Orbax internals (rarely changed): +# CHECKPOINT_ORBAX_SYNC_DIRS=1 +# CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1 + + +# ── 11. Resume from checkpoint ────────────────────────────────────────────── +# RESUME_DIR=gs://your-bucket/checkpoints/old-run +# RESUME_STEP=0 # 0 = latest step in RESUME_DIR +# RESUME_RESTORE_OPTIMIZER=1 # 0 = restore weights only (e.g. pretrain → SFT) +# RESUME_RESET_STEP=0 # 1 = zero global_step/epoch counters after load + # (use when continuing onto a new dataset + # so NUM_STEPS/NUM_EPOCHS budget restarts) + + +# ── 12. HuggingFace Hub ───────────────────────────────────────────────────── +# HF_TOKEN=hf_... # for private model downloads / checkpoint upload +# HF_CHECKPOINT_REPO=username/repo +# HF_CHECKPOINT_REPO_TYPE=model # model | dataset +# HF_CHECKPOINT_PATH=checkpoints # path within the HF repo +# HF_CHECKPOINT_PRIVATE=1 + + +# ── 13. Weights & Biases ──────────────────────────────────────────────────── +WANDB_LOG=1 # always enable on training runs (incl. smokes) +WANDB_PROJECT=dllm-jax +# WANDB_RUN_NAME= # default: same as RUN_NAME +# WANDB_MODE=online # online | offline | disabled +# WANDB_API_KEY= # or run `wandb login` once on each TPU worker. + # Get a key: https://wandb.ai/authorize + # NEVER commit a real key. Keep it in .env.deploy + # (which is gitignored) or in ~/.netrc. + + +# ── 14. Profiling ─────────────────────────────────────────────────────────── +# XPROF_ENABLE=0 # capture xprof traces +# XPROF_DIR=/tmp/xprof/run +# XPROF_START_STEP=4 +# XPROF_STOP_STEP=7 + +# JAX programmatic profile (separate from xprof): +# JAX_PROFILE_DIR= +# JAX_PROFILE_START_STEP=0 +# JAX_PROFILE_STEPS=0 + + +# ── 15. Inference (scripts/tpu_dmax_infer_checkpoint.py) ──────────────────── +# Required for inference: RESUME_DIR + RESUME_STEP (see section 11). +# PROMPT="Once upon a time" +# PROMPTS_FILE= # path to file with one prompt per line + # (overrides PROMPT; '#' lines = comments) +# GEN_LENGTH=32 +# BLOCK_LENGTH=32 +# STEPS=8 # max denoising steps per block +# THRESHOLD=0.95 # per-token confidence to accept early +# CONFIDENCE_STOP=0.9 # block-level early-exit threshold +# TOP_K=1 # soft-mix top-K for SPD (3 is good for Qwen3) +# TEMPERATURE=0.0 # 0.0 = greedy +# TEMPS= # comma list runs a sweep in one process, + # e.g. "0.0,0.3,0.5,0.7,1.0" +# SEED= # int; defaults to time-based + +# INFER_IMPL=fast # fast | kv_fast | legacy +# INFER_SPLASH=0 # 1 = splash kernel for fast-path block-causal mask +# INFER_SPLASH_BLOCK=512 +# INFER_KV_DTYPE=bf16 # bf16 (default, ~2x cache HBM savings) | fp32 +# FAST_BUCKET_LENGTH=4096 -# ---- HuggingFace Hub (optional checkpoint upload) --------------------------- -# HF_TOKEN=your-hf-token-here -# HF_CHECKPOINT_REPO=your-username/your-repo +# WARMUP_RUNS=0 # throwaway generates before measured runs +# MEASURED_RUNS=1 # report median over N timed runs +# RESTORE_OPTIMIZER=0 # 1 = restore optimizer state too (rarely useful for inference) +# SUPPRESS_MASK_TOKEN=0 # 1 = force-disable mask-token logits at argmax diff --git a/.gitignore b/.gitignore index 8986761..cf3ec24 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.env.* +!.env.example *.lock __pycache__/ *.egg-info/ diff --git a/README.md b/README.md index 0cbd3a9..8b16483 100644 --- a/README.md +++ b/README.md @@ -2,195 +2,173 @@ Generated Image April 17, 2026 - 11_54PM +A pure-JAX stack for **diffusion language modeling** on TPU. -A standalone **JAX backend for Diffusion Language Modeling (dLLM)** with **zero -PyTorch or CUDA dependency**. Designed for TPU training of MDLM, BD3LM, Dream, -and EditFlow objectives on pretrained HuggingFace checkpoints. +No PyTorch, no CUDA. Loads pretrained HuggingFace checkpoints (Qwen3, Llama, +…) directly into Flax NNX, then trains and decodes them under five +diffusion objectives — **MDLM**, **BD3LM**, **Dream**, **DMax/OPUT**, +**EditFlow** — with multi-host TPU sharding, GCS-resumable Orbax +checkpoints, and a KV-cached fast inference path. -## Features +> 📝 **Background reading:** +> [dLLM into TPU: An End-to-End Diffusion LM Stack in Pure JAX](https://medium.com/@JunbumLee/dllm-into-tpu-an-end-to-end-diffusion-lm-stack-in-pure-jax-5fc33c840ebb) -- **Torch-free weight loading** — `safetensors` + `huggingface_hub` + numpy, no `torch.load` -- **TPU-first** — Pallas flash attention via `shard_map`, 1D FSDP and 2D FSDP+TP -- **Pretrained init** — load Qwen3, Llama, and other HF causal LMs directly into Flax NNX -- **Five training objectives** — MDLM, BD3LM, Dream, DMax/OPUT, EditFlow -- **Clean API** — public exports, no stub boilerplate -- **DMax / OPUT end-to-end** — port of [`czg1225/DMax`](https://github.com/czg1225/DMax) - training + Soft Parallel Decoding inference, with a KV-cached fast path - matching reference's `cache='prefix'` setting -`transformers` is used only for `AutoConfig` / `AutoTokenizer` (works without torch). +## Concepts -## Blog Post +A **diffusion language model (dLLM)** trains a transformer to *denoise* +masked sequences in parallel, instead of predicting tokens one-at-a-time. +At inference it can refine many positions per step, generating text in a +fraction of the autoregressive forward passes. -[dLLM into TPU: An End-to-End Diffusion LM Stack in Pure JAX](https://medium.com/@JunbumLee/dllm-into-tpu-an-end-to-end-diffusion-lm-stack-in-pure-jax-5fc33c840ebb) +Five training objectives ship in this repo: -## Prerequisites +| Objective | What it does | Trainer / Generator | +|---|---|---| +| **MDLM** (Sahoo et al.) | Mask-and-denoise with a learnable α schedule | `MDLMTrainer` | +| **BD3LM** | Block-diffusion variant of MDLM | `BD3LMTrainer` | +| **Dream** | SFT-friendly diffusion with prefix conditioning | `DreamTrainer` | +| **DMax / OPUT** | High-noise, on-policy block-diffusion + Soft Parallel Decoding | `DMaxTrainer`, `dmax_generate_spd_*` | +| **EditFlow** | Edit-style flow matching | `EditFlowTrainer` | -### Google Cloud CLI (`gcloud`) +The polish in this repo concentrates on **DMax** — training, inference, +and checkpointing are all wired up for multi-host TPU. The other four +objectives share the same model loading and sharding infrastructure. -```bash -# Install gcloud CLI (see https://cloud.google.com/sdk/docs/install) -# After installation: -gcloud auth login -gcloud config set project YOUR_PROJECT_ID -# Create a GCS bucket for checkpoints (pick your TPU region): -gcloud storage buckets create gs://YOUR_BUCKET_NAME --location=us-east1 +## Installation + +```bash +pip install -e . +pip install -e '.[tpu]' # TPU runtime +pip install -e '.[dev]' # pytest, etc. ``` -### Environment setup +Requires Python ≥ 3.10. End-to-end validation on TPU v4-32 / v5e-64 / v6e-256. -Copy the example env file and fill in your values: +For multi-host TPU runs you also need `gcloud` CLI and a regional GCS +bucket (default `gs://dllm-jax-${region}`). For TPU VM packaging quirks +and the verified pin set, see [`docs/install.md`](docs/install.md). + + +## Configuration + +Every script reads its config from environment variables. The single source +of truth is [`.env.example`](.env.example) — 15 sections covering model, +data, training, optimizer, sharding, DMax, SFT, perf, checkpointing, resume, +HuggingFace, W&B, profiling, and inference. ```bash cp .env.example .env +$EDITOR .env ``` -Key variables in `.env`: +Five variables to know on day one: -| Variable | Description | Default | -|----------|-------------|---------| -| `CHECKPOINT_BUCKET_PREFIX` | Prefix for regional buckets (`gs://{prefix}-{region}`) | `dllm-jax` | -| `CHECKPOINT_BUCKET` | Fixed bucket name (overrides regional auto-detection) | — | -| `WANDB_API_KEY` | Weights & Biases API key ([get one here](https://wandb.ai/authorize)) | — | -| `WANDB_LOG` | Enable W&B logging (`1` to enable) | `0` | -| `MODEL_NAME` | HuggingFace model ID | `Qwen/Qwen3-8B` | +| Variable | Meaning | Default | +|---|---|---| +| `MODEL_NAME` | HF model ID for tokenizer + config + initial weights | `Qwen/Qwen3-8B` | +| `DATASET` | `tinystories`, `wikipedia`, `openthoughts`, `synthetic`, `parquet` | `tinystories` | +| `RUN_NAME` | name written under `${CHECKPOINT_DIR}/${RUN_NAME}` and to W&B | auto | +| `WANDB_LOG` | `1` to stream loss/MFU to W&B | `0` | +| `TP` | tensor-parallel axis size | `8` | -See [`.env.example`](.env.example) for the full list. +> 🔐 Never commit `WANDB_API_KEY` or `HF_TOKEN`. `.gitignore` excludes +> `.env`, `.env.local`, and `.env.deploy`. You can also `wandb login` and +> `huggingface-cli login` once per worker and skip the env vars. -> **Tip:** Instead of setting `WANDB_API_KEY` in `.env`, you can run `wandb login` -> on each TPU worker. Never commit credentials to the repository. -## Installation +## Train + +The production training entry point is +[`scripts/tpu_train.py`](scripts/tpu_train.py). It runs on **v4-32, v5e-64, +v6e-256** and handles distributed init, 2D sharding, GCS DCP checkpoints, +W&B, MFU logging, and DMax masking. + +A first run, end-to-end, from any TPU worker: ```bash -pip install -e . +gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=all \ + --command="cd ~/dllm-jax && \ + PYTHONPATH=~/dllm-jax \ + RUN_NAME=hello-dmax \ + MODEL_NAME=Qwen/Qwen3-0.6B \ + DATASET=tinystories \ + DMAX_ENABLE=1 \ + MAX_LEN=4096 GLOBAL_BATCH=8 \ + NUM_STEPS=20 WARMUP_STEPS=5 \ + PEAK_LR=1e-4 OPTIMIZER=adamw \ + WANDB_LOG=1 \ + python3 scripts/tpu_train.py" +``` -# TPU runtime -pip install -e '.[tpu]' +In the first ~3 minutes you'll see HF download → weight sharding → +optimizer init → W&B URL → step-by-step loss lines. Successful exit: -# Dev (pytest) -pip install -e '.[dev]' ``` +[Worker 0] step=20 loss=… mfu=… +[Worker 0] training complete in s +``` + +For the full reference — datasets, DMax/OPUT knobs, optimizer choice (and +why **AdamW**, not Adafactor), LR schedule, mask-token warm start, memory +tuning, checkpointing, three resume modes, sharding tables per hardware, +and gotchas — see [`docs/training.md`](docs/training.md). Throughput +methodology is in [`docs/mfu-optimization.md`](docs/mfu-optimization.md); +runs land at ~38–48% MFU after splash + remat tuning. -Requires Python >= 3.10, `jax >= 0.4.20`, `flax >= 0.10.0`, -`orbax-checkpoint`, `gcsfs <= 2026.2.0`, `optax >= 0.2.0`. -### TPU VM packaging caveat +## Infer -Some Python 3.10 TPU VM images have an older packaging stack. On those hosts, -`pip install -e '.[tpu]'` can fail with a missing `build_editable` hook, and -`pip install '.[tpu]'` can misread the project metadata as `UNKNOWN-0.0.0` -without installing dependencies. In that case, run from the synced checkout -with `PYTHONPATH` and install TPU dependencies explicitly: +The multi-host inference entry point is +[`scripts/tpu_infer.py`](scripts/tpu_infer.py). It restores a sharded DCP +checkpoint, re-shards onto the inference mesh (which can differ from the +training mesh), and runs Soft Parallel Decoding (SPD). ```bash -python3 -m pip install --user -U 'jax[tpu]' \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ - 'flax>=0.10.0,<0.11' orbax-checkpoint 'gcsfs<=2026.2.0' 'fsspec<=2026.2.0' \ - 'optax>=0.2.0' numpy 'transformers>=4.40.0' safetensors \ - huggingface_hub datasets wandb +gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=all \ + --command="cd ~/dllm-jax && \ + PYTHONPATH=~/dllm-jax \ + RESUME_DIR=gs://dllm-jax-us-east1/checkpoints/$RUN_NAME \ + MODEL_NAME=Qwen/Qwen3-8B \ + PROMPT='Once upon a time' \ + GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \ + INFER_IMPL=fast INFER_SPLASH=1 \ + TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 \ + TP=8 \ + python3 scripts/tpu_infer.py" ``` -#### Verified TPU versions - -The training and inference paths are validated on TPU v4-32 (`us-central2-b`) -with this stack — pin to these if `pip install '.[tpu]'` surfaces version -drift: - -| Package | Version | -|---------|---------| -| Python | 3.10.12 | -| jax / jaxlib | 0.6.2 | -| libtpu | 0.0.17 | -| flax | 0.10.7 | -| optax | 0.2.8 | -| orbax-checkpoint | 0.11.34 | -| transformers | 5.5.3 | -| safetensors | 0.7.0 | -| datasets | 4.8.4 | -| gcsfs / fsspec | 2025.3.2 | -| huggingface_hub | 1.10.1 | -| numpy | 2.2.6 | - -The 0.10.7 flax version forces the `_nnx_list = getattr(nnx, "List", list)` -compat shim noted under **Gotchas**; newer flax (0.12+) on JAX 0.7+ should -also work but hasn't been re-verified end-to-end on this repo. - -### Regional GCS checkpoints for TPU runs - -`scripts/tpu_v6e_smoke.py` saves sharded Orbax DCP checkpoints every -`CHECKPOINT_STEPS` steps (default: 500). By default it detects the TPU -zone, derives the region, and writes to a matching bucket named -`gs://${CHECKPOINT_BUCKET_PREFIX}-${region}`. For example, a TPU in -`us-east1-d` writes under `gs://dllm-jax-us-east1/checkpoints/${RUN_NAME}`. +Three implementations are available — `fast` (default; fixed-shape, +splash-friendly), `kv_fast` (KV-cached, best for ≥1024-token gens), +`legacy` (debug-only Python loop). All three produce byte-identical output +at matching settings. -```bash -PYTHONPATH=/path/to/dllm-jax \ - RUN_NAME=my-run-name \ - CHECKPOINT_STEPS=500 CHECKPOINT_KEEP=2 \ - MODEL_NAME=Qwen/Qwen3-8B DATASET=tinystories \ - MAX_LEN=16384 GLOBAL_BATCH=8 \ - NUM_STEPS=0 NUM_EPOCHS=3 WANDB_LOG=1 \ - python3 scripts/tpu_v6e_smoke.py -``` +`INFER_SPLASH=1` swaps in a Pallas splash kernel matched to the +block-causal mask. On Qwen3-8B at `GEN_LENGTH=1024` it's **3.5× faster** +*and* fixes a latent dense-kernel quality bug at non-128-aligned sequence +lengths. -Override detection with `TPU_REGION=us-east1`, `CHECKPOINT_BUCKET=gs://...`, -or a full `CHECKPOINT_DIR=gs://...`. Hub uploads are only used for local -checkpoint directories; `gs://` is the durable distributed checkpoint target. -On TPU VM images with JAX 0.6.x and Orbax 0.11.x, the script enables -`CHECKPOINT_ORBAX_SYNC_DIRS=1` and `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1` by -default to use JAX multihost barriers for distributed GCS checkpoint writes. +For the full reference — implementation comparison, splash setup, prompt +sweeps, env var glossary, throughput tables, and the single-host CLI — +see [`docs/inference.md`](docs/inference.md). The deep-dive writeup with +end-to-end timing is [`docs/inference-optimization.md`](docs/inference-optimization.md). -## Quick Start + +## Python API (single host) + +For sanity checks on a laptop or single TPU host: ```python from transformers import AutoTokenizer -from dllm_jax import build_model_from_pretrained, MDLMConfig, MDLMTrainer, LinearAlphaScheduler +from dllm_jax import ( + build_model_from_pretrained, + DMaxConfig, DMaxTrainer, DMaxDataCollator, +) model, config = build_model_from_pretrained("Qwen/Qwen3-0.6B", task="llada") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") -trainer = MDLMTrainer( - model=model, - tokenizer=tokenizer, - args=MDLMConfig(output_dir="./out", max_steps=1000, learning_rate=1e-4), - train_dataset=dataset, - data_collator=collator, - scheduler=LinearAlphaScheduler(), -) -trainer.train() -``` - -## DMax / OPUT - -`dllm_jax` includes a JAX/Flax port of the DMax training and inference -stack from [`czg1225/DMax`](https://github.com/czg1225/DMax): - -- **OPUT training** — fixed high-noise masking, two-stream `[noised | clean]` - layout with block-diffusion attention, per-step on-policy rollout that - replaces masked tokens with the model's own greedy predictions, gradient - accumulation support. -- **Soft Parallel Decoding (SPD) inference** with three implementations - that produce byte-exact identical outputs at matching settings: - - `dmax_generate_spd` — Python-loop reference path with host-side early - breaks. Slow on TPU, useful for debugging. - - `dmax_generate_spd_fast` — fixed-shape fori_loop compiled path; step-level - and block-level early breaks via `jax.lax.while_loop`. Default for - short/medium generations. - - `dmax_generate_spd_kv_fast` — KV-cached variant matching reference's - `cache='prefix'` path. Each step only projects K/V for the active block - and attention runs over the cached prefix. ~1.6× speedup at 1024-token - generation; overhead dominates at 128-token generation. -- **Reference knobs** — `top_k` (soft-mix top-k, reference default 1), - `temperature` + gumbel sampling, `seed`, `threshold`, `confidence_stop`, - `suppress_mask_token`, and post-EOS fill matching reference's `early_stop`. - -### Training - -```python -from dllm_jax import DMaxConfig, DMaxTrainer, DMaxDataCollator - trainer = DMaxTrainer( model=model, tokenizer=tokenizer, @@ -198,241 +176,53 @@ trainer = DMaxTrainer( output_dir="./out-dmax", max_steps=1000, learning_rate=2e-6, - noise_range_low=0.75, - noise_range_high=0.75, - on_policy_ratio=0.5, block_size=32, - gradient_accumulation_steps=4, # optional + on_policy_ratio=0.5, + gradient_accumulation_steps=4, ), - train_dataset=dataset, + train_dataset=dataset, # any HF datasets-style iterable data_collator=DMaxDataCollator(tokenizer=tokenizer, label_pad_token_id=-100), ) trainer.train() ``` -### Inference - -```python -from dllm_jax import dmax_generate_spd_fast, dmax_generate_spd_kv_fast - -# Fast path (default for short/medium gen) -output = dmax_generate_spd_fast( - model, - input_ids, - tokenizer=tokenizer, - gen_length=512, - block_length=32, - steps=32, - threshold=0.5, # reference math eval default - confidence_stop=0.9, # reference block-level break - top_k=3, # soft mix aggregates top 3 candidates - temperature=0.0, # 0.0 = greedy; >0.0 = gumbel sampling -) -print(tokenizer.decode(output.generated_tokens[0], skip_special_tokens=True)) - -# KV-cache path (wins on long generations) -output = dmax_generate_spd_kv_fast( - model, input_ids, tokenizer=tokenizer, - gen_length=2048, block_length=32, steps=32, - threshold=0.5, top_k=3, -) -``` - -`output.nfe` counts actual forward passes. For `fast` it matches -reference's `num_forwards`; for `kv_fast` it is `fast_nfe + num_active_blocks` -(the extra is the post-block hard-write pass that replaces soft K/V with -hard K/V in the cache — reference's cross-block update). - -### CLI entry points - -```bash -# Train -python scripts/dmax_train.py \ - --model Qwen/Qwen3-0.6B \ - --dataset Zigeng/DMax-LLaDA-2.0-Mini-Math-Trajectories \ - --max-steps 1000 - -# Generate (pretrained base model) -python scripts/dmax_generate.py \ - --model Qwen/Qwen3-0.6B \ - --prompt "Solve 37 * 48." \ - --gen-length 256 --block-length 32 --steps 32 \ - --threshold 0.5 --top-k 3 --impl fast - -# Generate (from a saved trainer checkpoint) -python scripts/dmax_generate_checkpoint.py \ - --checkpoint-dir ./out-dmax/checkpoint-1000 \ - --prompt "Solve 37 * 48." \ - --gen-length 256 --impl kv_fast -``` - -### TPU multi-host inference - -[`scripts/tpu_dmax_infer_checkpoint.py`](scripts/tpu_dmax_infer_checkpoint.py) -restores a distributed Orbax DCP checkpoint (GCS or local) on every worker, -reshards across the inference mesh (which may differ from the training -mesh), and runs `dmax_generate_spd_{fast,kv_fast,legacy}` end-to-end. All -configuration is via environment variables: - -```bash -gcloud compute tpus tpu-vm ssh $TPU_NAME \ - --zone=$ZONE --worker=all \ - --command="cd ~/dllm-jax && \ - RESUME_DIR=gs://$BUCKET/checkpoints/$RUN_NAME \ - MODEL_NAME=Qwen/Qwen3-8B \ - PROMPT='Once upon a time' \ - GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \ - INFER_IMPL=kv_fast TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 \ - TP=8 \ - python3 scripts/tpu_dmax_infer_checkpoint.py" -``` - -Omit `RESUME_STEP` and the script scans `commit_success.txt` files under -`RESUME_DIR` and picks the latest committed step. Set `RESUME_STEP=` to -pin a specific step. +This is the same code path that +[`scripts/dmax_train.py`](scripts/dmax_train.py) drives. Replace +`DMaxConfig` / `DMaxTrainer` with `MDLMConfig` / `MDLMTrainer` (or the +other objectives) for the equivalent single-host MDLM run. -#### Environment variables -| Variable | Meaning | Default | -|-|-|-| -| `RESUME_DIR` / `RESUME_FROM` | checkpoint parent directory (contains `checkpoint_/`) | — (required) | -| `RESUME_STEP` | specific step to restore | latest committed | -| `MODEL_NAME` | tokenizer + HF config source | `Qwen/Qwen3-8B` | -| `PROMPT` | input prompt | `Once upon a time` | -| `GEN_LENGTH` | tokens to generate | `32` | -| `BLOCK_LENGTH` | DMax block size | `32` | -| `STEPS` | max denoising steps per block | `8` | -| `INFER_IMPL` | `fast`, `kv_fast`, or `legacy` | `fast` | -| `FAST_BUCKET_LENGTH` | compile window for `fast`; ignored by `kv_fast` | `4096` | -| `THRESHOLD` | left-to-right confidence cutoff — reference math eval `0.5`, code `0.65`, other benchmarks `0.9`–`0.95` | `0.95` | -| `CONFIDENCE_STOP` | block early-exit confidence | `0.9` | -| `TOP_K` | soft-mix top-k (reference default `1`; `3` is more coherent on undertrained ckpts) | `1` | -| `TEMPERATURE` | gumbel-max sampling temperature; `0.0` = greedy | `0.0` | -| `SEED` | RNG seed (only needed when `TEMPERATURE > 0`) | none | -| `SUPPRESS_MASK_TOKEN` | set `1` to force-disable mask-token logits during argmax | `0` | -| `MASK_TOKEN_ID` / `EOS_TOKEN_ID` | overrides for the model's mask / EOS id | tokenizer default | -| `TP` | tensor-parallel axis size; `fsdp` is derived as `jax.device_count() // TP` | `8` | -| `RESTORE_OPTIMIZER` | restore optimizer state (useful for resumed training, not inference) | `0` | - -#### Measured throughput (TPU v4-32, Qwen3-8B, `TOP_K=3`) - -| Config | nfe | generate_seconds | -|-|-|-| -| `fast`, GEN=128 | 114 | 32.1 | -| `kv_fast`, GEN=128 | 119 | 65.9 | -| `fast`, GEN=1024 | 1010 | 124.7 | -| `kv_fast`, GEN=1024 | 1043 | 76.6 | - -At short generations `kv_fast` pays more XLA compile cost than it saves; at -`GEN=1024` the block-local forward wins (~1.6× faster than `fast`). `nfe` -for `kv_fast` is `fast_nfe + num_active_blocks`; the extra forwards are the -post-block hard-write that overwrites soft K/V with hard K/V in the cache -(the JAX analogue of reference's cross-block update). - -#### Troubleshooting - -- **Stale TPU checkout** — `ImportError: cannot import name 'dmax_generate_spd_kv_fast'` or similar means the TPU copy is out of date; re-`scp` `dllm_jax/` and the target script. -- **Orbax signal-contract hang on restore** — handled by default; the script flips `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1` when the JAX distributed client isn't initialized. Set the env var to `0` to disable the shim explicitly. -- **`gs://` / TensorFlow warnings** — harmless; restore uses Orbax's GCS client, not `tf.io.gfile`. -- **First-run generation is "slow"** — includes model restore + XLA compile. Subsequent runs with identical shapes hit the cached graph. -- **Output collapses into punctuation with `TOP_K=1`** — SPD feeds the previous step's distribution back as a soft mix; at `TOP_K=1` a low-confidence single token gives a bad signal. Try `TOP_K=3`. - -The KV-cache design is documented in [`docs/kv_cache_design.md`](docs/kv_cache_design.md). -A narrative write-up of the whole end-to-end port — HF checkpoint → -OPUT training → KV-cached SPD on TPU — is in -[`docs/porting-dmax-to-tpu.md`](docs/porting-dmax-to-tpu.md). - -## Checkpoints - -Two save/load paths are supported depending on scale: - -- **Single-host trainer checkpoints** (`DMaxTrainer.save_model`, etc.) write - pickle files (`save_only_model=True`) or Flax training `Checkpoints` to a - local directory. These pair with `restore_model_checkpoint` and the local - `scripts/dmax_generate_checkpoint.py` script. -- **Distributed TPU DCP checkpoints** (Orbax `PyTreeCheckpointHandler` / - `StandardCheckpointer`) are the durable save path for multi-host v4/v5/v6 - training via `scripts/tpu_v6e_smoke.py`. Checkpoints are sharded across - workers, written directly to GCS (`gs://${CHECKPOINT_BUCKET_PREFIX}-${region}`), - and committed with `commit_success.txt` markers. Resume and inference read - the latest committed step via `latest_committed_gcs_step`. - - ```bash - # Train with DCP checkpoints every 500 steps, keep last 2 - PYTHONPATH=$(pwd) \ - RUN_NAME=my-run CHECKPOINT_STEPS=500 CHECKPOINT_KEEP=2 \ - python3 scripts/tpu_v6e_smoke.py - ``` - - Inference from the resulting GCS checkpoints is covered above under - [TPU multi-host inference](#tpu-multi-host-inference). - -## Package Layout +## Package layout ``` dllm_jax/ -├── models.py # GenericDecoderLM (+ call_cached for KV cache), GenericEncoderLM, EditFlowModel -├── trainers.py # MDLMTrainer, BD3LMTrainer, DreamTrainer, DMaxTrainer, EditFlowTrainer -│ # (gradient accumulation, cached LR schedule) -├── configs.py # ModelArguments, DataArguments, TrainingArguments, MDLMConfig, DMaxConfig, ... -├── dmax.py # DMax SPD: dmax_generate_spd, dmax_generate_spd_fast, dmax_generate_spd_kv_fast +├── models.py # GenericDecoderLM (+ call_cached for KV cache), +│ # GenericEncoderLM, EditFlowModel +├── trainers.py # MDLM / BD3LM / Dream / DMax / EditFlow trainers +├── configs.py # ModelArguments, DataArguments, TrainingArguments, +│ # MDLMConfig, DMaxConfig, … +├── dmax.py # DMax SPD: dmax_generate_spd[_fast|_kv_fast] ├── schedulers.py # LinearAlpha, CosineAlpha, LinearKappa, CubicKappa, CosineKappa -├── data.py # DMaxDataCollator, DreamSFTCollator, EditFlowCollator, NoAttentionMaskWrapper, ... -├── checkpoints.py # restore_model_checkpoint (single-host pickle / Flax Checkpoints) -├── weights.py # Torch-free safetensors -> NNX weight loader -└── utils.py # resolve_with_base_env, parse_spec, get_default_logger, ... +├── data.py # DMaxDataCollator, DreamSFTCollator, EditFlowCollator, … +├── checkpoints.py # restore_model_checkpoint (single-host) +├── weights.py # Torch-free safetensors → NNX weight loader +└── utils.py scripts/ -├── dmax_train.py # single-host DMax OPUT CLI -├── dmax_tinystories_train.py # small-scale DMax sanity training -├── dmax_generate.py # single-host SPD generation CLI (base model) -├── dmax_generate_checkpoint.py # single-host SPD generation from a saved checkpoint -├── tpu_dmax_infer_checkpoint.py # multi-host Orbax DCP restore + SPD generation -├── tpu_v4_32_train_3epoch.py # multi-host DMax training on TPU v4-32 -├── run_tpu_v4_32_3epoch.sh # wrapper for the v4-32 training launch -└── tpu_v6e_smoke.py # MDLM smoke trainer with DCP checkpointing +├── tpu_train.py # multi-host training (v4 / v5e / v6e) +├── tpu_infer.py # multi-host inference from GCS DCP ckpt +├── dmax_train.py # single-host DMax CLI +├── dmax_generate.py # single-host SPD generation (base model) +├── dmax_generate_checkpoint.py # single-host SPD from saved ckpt +├── inspect_sft_data.py # SFT pipeline data dump +└── examples/ # tinystories sanity, legacy v4-32 wrappers docs/ -├── tpu_v4_32_ondemand_inference.md # runbook for TPU inference from GCS checkpoints -├── kv_cache_design.md # design notes for the KV-cached SPD path -└── porting-dmax-to-tpu.md # narrative write-up of the whole end-to-end port -``` - -## Sharding - -| Model Size | Strategy | Mesh Shape | Notes | -|------------|-------------|----------------------|-------------------------------------------| -| <= 3B | 1D FSDP | `(ndev,)` fsdp | Direct TPU init, `P("fsdp", None)` | -| 3B - 8B+ | 2D FSDP+TP | `(ndev/tp, tp)` fsdp,tp | CPU init → shard, `P("fsdp", "tp")` | - -For example, TPU v4-32 with `TP=8` uses `(2, 8)`; TPU v5e-64 uses `(8, 8)`. -Large models additionally need CPU-first init -(`jax.default_device(jax.devices("cpu")[0])`), gradient checkpointing via -`jax.remat` on each transformer layer, and Pallas flash attention. - -## Gotchas - -**AdamW > Adafactor for pretrained-init MDLM.** Adafactor's -`scale_by_param_block_rms` misbehaves on bidirectional objectives over causal -LM weights — loss descends then climbs back after ~60 steps. Use -`optax.adamw(b1=0.9, b2=0.95, wd=0.01)` (the library default). - -**Optimizer rebind after split/merge.** If you manually `nnx.split` and -`nnx.merge` the optimizer in a hand-written TPU script, rebind `model` to -avoid silent zero-progress: - -```python -opt_gdef, opt_state = nnx.split(optimizer) -opt_state = jax.tree.map(fsdp_sharding, opt_state) -optimizer = nnx.merge(opt_gdef, opt_state) -model = optimizer.model # CRITICAL +├── install.md, training.md, inference.md # user-facing references +├── inference-optimization.md, mfu-optimization.md, kv-cache-design.md + # deep-dive writeups ``` -The built-in `trainers.py` never splits the optimizer, so it is safe. - -**flax 0.10.7 vs 0.12+.** Python 3.10 TPU VMs pin flax to 0.10.7. `models.py` -uses a `_nnx_list = getattr(nnx, "List", list)` compat shim so the same code -runs on both. `optimizer.update(grads)` is called positionally (0.10.7 API). ## License diff --git a/dllm_jax/dmax.py b/dllm_jax/dmax.py index 5347490..77bdc4b 100644 --- a/dllm_jax/dmax.py +++ b/dllm_jax/dmax.py @@ -2,6 +2,7 @@ from __future__ import annotations +import os from dataclasses import dataclass from typing import Any @@ -12,6 +13,174 @@ from dllm_jax.trainers import resolve_mask_token_id +def _resolve_kv_cache_dtype(model: Any): + """Pick the KV-cache dtype for kv_fast inference. + + Default: the model's compute dtype (bf16) — cuts cache HBM in half and + lets ``jax.nn.dot_product_attention`` run at bf16 MXU throughput. + ``INFER_KV_DTYPE=fp32`` restores the previous float32 cache (matches the + qk_norm-upcast precision path; slightly more accurate, ~2× slower attn). + """ + choice = os.environ.get("INFER_KV_DTYPE", "bf16").strip().lower() + if choice in {"fp32", "float32", "f32"}: + return jnp.float32 + compute_name = getattr(model, "dtype_name", "bfloat16") + if compute_name == "float32": + return jnp.float32 + if compute_name == "float16": + return jnp.float16 + return jnp.bfloat16 + + +def _infer_mesh_from_model(model: Any): + """Return the NamedSharding mesh used by ``model`` params, or None. + + Walks ``nnx.state(model)`` leaves, unwrapping nnx.Variable objects to their + underlying jax.Array so ``.sharding.mesh`` is accessible. + """ + try: + state = nnx.state(model) + except Exception: + state = model + + for leaf in jax.tree_util.tree_leaves(state): + value = getattr(leaf, "value", leaf) + if isinstance(value, jax.Array): + sharding = getattr(value, "sharding", None) + mesh = getattr(sharding, "mesh", None) if sharding is not None else None + if mesh is not None and getattr(mesh, "shape", None): + return mesh + return None + + +def install_block_causal_splash_for_inference( + mesh: Any, + num_heads: int, + total_length: int, + block_length: int, + *, + splash_block: int | None = None, +): + """Register a splash_attention kernel as ``dllm_jax.models._MASKED_FLASH_ATTN_FN``. + + The kernel targets the block-causal mask shape produced by + ``create_block_causal_attention_mask(total_length, block_length)`` — i.e. + block N attends fully to all blocks [0, N]. Intended for the fixed-shape + ``dmax_generate_spd_fast`` path; no backward pass is compiled. + + Mesh is expected to have axis names ``fsdp`` and ``tp`` (matches the + training and inference scripts in this repo); heads shard along ``tp``. + """ + import numpy as np + from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_mask as _sm, + splash_attention_kernel as _sk, + ) + from jax.experimental.shard_map import shard_map + from jax.sharding import PartitionSpec as P + + from dllm_jax import models as _models + + tp = int(mesh.shape.get("tp", 1)) if hasattr(mesh, "shape") else 1 + heads_per_tp = max(1, num_heads // tp) + + idx = np.arange(total_length) + block_q = idx[:, None] // block_length + block_kv = idx[None, :] // block_length + mask_np = (block_q >= block_kv).astype(np.bool_) + + splash_mask = _sm.MultiHeadMask(masks=[_sm.NumpyMask(mask_np)] * heads_per_tp) + requested_bs = int(splash_block) if splash_block is not None else int( + os.environ.get("INFER_SPLASH_BLOCK", "512") + ) + # Splash on TPU requires block_q/block_kv both divide total_length AND be + # multiples of 128 (MXU lane width). Find the largest such value ≤ request. + NUM_LANES = 128 + bs = requested_bs - (requested_bs % NUM_LANES) if requested_bs >= NUM_LANES else NUM_LANES + while bs > NUM_LANES and (total_length % bs != 0 or bs % NUM_LANES != 0): + bs -= NUM_LANES + if bs < NUM_LANES or total_length % bs != 0: + raise ValueError( + f"splash install: no block size ≤ {requested_bs} that is a multiple " + f"of {NUM_LANES} and divides total_length={total_length}" + ) + if bs != requested_bs: + print( + f"[dmax] INFER_SPLASH_BLOCK={requested_bs} → using block={bs} " + f"(largest multiple of {NUM_LANES} dividing total_length={total_length})", + flush=True, + ) + block_sizes = _sk.BlockSizes( + block_q=bs, block_kv=bs, block_kv_compute=bs, + ) + splash_fn = _sk.make_splash_mha_single_device( + mask=splash_mask, block_sizes=block_sizes, + ) + + def _per_shard(q, k, v, sm_scale): + q_scaled = (q * sm_scale).astype(q.dtype) + return jax.vmap(splash_fn)(q_scaled, k, v) + + # Inference typically runs batch=1; sharding the batch axis across ``fsdp`` + # would require batch >= fsdp. Replicate on batch, shard heads on ``tp``. + _batch_spec = None + in_specs = (P(_batch_spec, "tp", None, None),) * 3 + out_specs = P(_batch_spec, "tp", None, None) + + def _sharded_masked_flash(q, k, v, sm_scale): + return shard_map( + lambda q_, k_, v_: _per_shard(q_, k_, v_, sm_scale), + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + )(q, k, v) + + _models._MASKED_FLASH_ATTN_FN = _sharded_masked_flash + print( + f"[dmax] splash installed for block-causal inference: " + f"mask {mask_np.shape} heads_per_tp={heads_per_tp} tp={tp} block={bs}", + flush=True, + ) + return _sharded_masked_flash + + +def _env_true(name: str) -> bool: + return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} + + +def _try_autoinstall_block_causal_splash(model: Any, total_length: int, block_length: int): + """Gated by ``INFER_SPLASH=1``. Installs splash if a mesh can be inferred. + + Returns the installed kernel on success, None on skip/failure. A failure + leaves the existing ``_MASKED_FLASH_ATTN_FN`` untouched so the caller + silently falls back to dense attention. + """ + if not _env_true("INFER_SPLASH"): + return None + try: + mesh = _infer_mesh_from_model(model) + if mesh is None: + print( + "[dmax] INFER_SPLASH=1 but no mesh could be inferred from model; " + "falling back to dense attention", + flush=True, + ) + return None + num_heads = int(model.spec.num_attention_heads) + return install_block_causal_splash_for_inference( + mesh, num_heads, total_length, block_length, + ) + except Exception as exc: # pragma: no cover — best-effort runtime install + import traceback + print( + f"[dmax] INFER_SPLASH=1 requested but splash install failed, " + f"falling back to dense attention: {exc}\n{traceback.format_exc()}", + flush=True, + ) + return None + + @dataclass class DMaxGenerationConfig: gen_length: int = 2048 @@ -282,6 +451,13 @@ def dmax_generate_spd_fast( gen_length = int(gen_length) batch_size, prompt_length = input_ids.shape num_blocks = (prompt_length + gen_length + block_length - 1) // block_length + # Splash attention wants total_length % 128 == 0 (TPU MXU lane alignment). + # With block_length=32 that means num_blocks must be a multiple of 4. Pad + # up when INFER_SPLASH=1 — the extra blocks add compute but are correct. + if _env_true("INFER_SPLASH"): + align_blocks = max(1, 128 // max(1, block_length)) + if align_blocks > 1 and num_blocks % align_blocks != 0: + num_blocks += align_blocks - (num_blocks % align_blocks) total_length = num_blocks * block_length new_gen_length = total_length - prompt_length prefill_blocks = prompt_length // block_length @@ -538,6 +714,11 @@ def step_cond(step_carry): attention_mask = create_block_causal_attention_mask(total_length, block_length) position_ids = jnp.broadcast_to(jnp.arange(total_length)[None, :], (batch_size, total_length)) + # Opt-in splash kernel for the dense block-causal attention that this path + # normally runs through ``jax.nn.dot_product_attention``. See + # ``install_block_causal_splash_for_inference`` and the INFER_SPLASH env. + _try_autoinstall_block_causal_splash(model, total_length, block_length) + @nnx.jit def generate_fixed_shape(current_model, prompt_ids): x = jnp.full( @@ -1098,13 +1279,11 @@ def generate_kv(current_model, prompt_ids): mask_embedding.astype(jnp.float32), axis=-1, keepdims=True ) - # K/V cache stored in float32 to match the QK-norm precision path that - # the non-cached ``fast`` attention uses: for models with ``qk_norm`` - # (e.g. Qwen3) ``nnx.RMSNorm`` upcasts to float32, and - # ``jax.nn.dot_product_attention`` requires Q, K, V to all share dtype. - # V is upcast from the model's compute dtype (bf16) to float32 when - # written, which is exact. - cache_dtype = jnp.float32 + # KV cache dtype: default = model compute dtype (bf16) for ~2× HBM + # savings and bf16 attention throughput. Set INFER_KV_DTYPE=fp32 to + # fall back to float32 — matches the pre-bf16 precision path (qk_norm + # upcast to f32 preserved through attention). + cache_dtype = _resolve_kv_cache_dtype(current_model) past_kv = [ ( diff --git a/dllm_jax/models.py b/dllm_jax/models.py index 7d69cb2..84fbb83 100644 --- a/dllm_jax/models.py +++ b/dllm_jax/models.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp from flax import nnx +from jax.ad_checkpoint import checkpoint_name import transformers # Compatibility: flax >= 0.11 has nnx.List; older versions use plain list() @@ -20,6 +21,11 @@ # Signature: _FLASH_ATTN_FN(q, k, v, sm_scale) with layout (batch, heads, seq, dim). _FLASH_ATTN_FN = None +# Masked flash hook: set by training script for block-diffusion / other sparse masks. +# Signature: _MASKED_FLASH_ATTN_FN(q, k, v, sm_scale) with same layout — mask baked in. +# Used when attention_mask is passed (so dense attention would otherwise run). +_MASKED_FLASH_ATTN_FN = None + def get_dtype(name: str): if name == "float32": @@ -290,7 +296,9 @@ def __init__(self, spec: ModelSpec, intermediate_size: int, *, gated: bool, rngs def __call__(self, hidden_states): if self.gated: - return self.down_proj(self.act(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) + gate_up = self.act(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + gate_up = checkpoint_name(gate_up, "gate_up") + return self.down_proj(gate_up) return self.fc2(self.act(self.fc1(hidden_states))) @@ -355,6 +363,9 @@ def _project_qkv(self, hidden_states, position_ids): k = self.k_norm(k) cos, sin = build_rope(position_ids, self.rotary_dim, self.rope_theta, q.dtype) q, k = apply_rope(q, k, cos, sin, self.rotary_dim) + q = checkpoint_name(q, "q") + k = checkpoint_name(k, "k") + v = checkpoint_name(v, "v") return q, k, v def _attention(self, q, k, v, attention_mask): @@ -369,6 +380,11 @@ def _attention(self, q, k, v, attention_mask): q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3), 1.0 / math.sqrt(self.head_dim), ).transpose(0, 2, 1, 3).reshape(batch_size, query_len, self.num_heads * self.head_dim) + elif _MASKED_FLASH_ATTN_FN is not None and attention_mask is not None: + output = _MASKED_FLASH_ATTN_FN( + q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3), + 1.0 / math.sqrt(self.head_dim), + ).transpose(0, 2, 1, 3).reshape(batch_size, query_len, self.num_heads * self.head_dim) else: mask = expand_attention_mask(attention_mask, batch_size, query_len, key_len) output = jax.nn.dot_product_attention(q, k, v, mask=mask).reshape( diff --git a/dllm_jax/weights.py b/dllm_jax/weights.py index 69e7c96..0cde3a8 100644 --- a/dllm_jax/weights.py +++ b/dllm_jax/weights.py @@ -36,16 +36,17 @@ def _hf_download(model_name: str, verbose: bool = True) -> str: print(f"[Worker {proc}] Running: {' '.join(cmd)}", flush=True) result = subprocess.run(cmd, capture_output=True, text=True, env=env) - if result.returncode != 0: - if verbose and proc == 0: - print(f"[Worker {proc}] hf CLI failed (rc={result.returncode}), falling back to snapshot_download", flush=True) - print(f"[Worker {proc}] stderr: {result.stderr[:200]}", flush=True) - from huggingface_hub import snapshot_download - return snapshot_download( - model_name, - allow_patterns=["*.safetensors", "*.safetensors.index.json", "config.json"], - ) - return result.stdout.strip().split("\n")[-1] + # Modern ``hf download`` (huggingface_hub>=0.25) decorates the final path + # line with a ``path:`` label and can emit progress lines; prefer the + # library call to get a clean local_dir path back. + from huggingface_hub import snapshot_download + if result.returncode != 0 and verbose and proc == 0: + print(f"[Worker {proc}] hf CLI failed (rc={result.returncode})", flush=True) + print(f"[Worker {proc}] stderr: {result.stderr[:200]}", flush=True) + return snapshot_download( + model_name, + allow_patterns=["*.safetensors", "*.safetensors.index.json", "config.json"], + ) def load_pretrained_weights( diff --git a/docs/inference-optimization.md b/docs/inference-optimization.md new file mode 100644 index 0000000..e9a8e82 --- /dev/null +++ b/docs/inference-optimization.md @@ -0,0 +1,207 @@ +# Inference Optimization: DMax + Qwen3-8B on TPU v5e-64 + +Summary of an inference tuning pass on `tpu-v5e-64-us` (us-central1-a) +for `scripts/tpu_infer.py` generating from a +tinystories-trained DMax checkpoint +(`dmax-8b-tinystories-forward8k-3epoch-.../checkpoint_41400`). + +## Headline + +| Config | impl | attn | KV cache | median s | tok/s | nfe | gen tokens | quality | +|--|--|--|--|--|--|--|--|--| +| A (pre-patch) | fast | dense | — | 142.33 | 7.2 | 1028 | 1024 | **GARBAGE** | +| **B (ship)** | **fast** | **splash** | — | **40.86** | **25.1** | 128 | 175 | ✅ coherent, EOS | +| C | kv_fast | dense (fp32) | fp32 | 52.90 | 19.4 | 95 | 143 | ✅ coherent, EOS | +| D | kv_fast | dense (fp32) | bf16 | 52.12 | 19.6 | 95 | 143 | ✅ bit-identical to C | + +All runs at `Qwen3-8B`, `PROMPT='Once upon a time'`, `GEN_LENGTH=1024`, +`BLOCK_LENGTH=32`, `STEPS=32`, `TOP_K=3`, `THRESHOLD=0.5`, +`CONFIDENCE_STOP=0.9`, `TP=8`, greedy (`TEMPERATURE=0`), 1 warmup + 2 +measured runs, median. + +**fast+splash (B) is 3.49× faster than the dense-fast baseline AND +fixes a latent quality bug.** It's the recommended default for this +model family at these shapes. + +**bf16 KV cache (C→D) is a wash: 1.5% faster, byte-identical output.** +At `B=1 × q=32 × k=1152` the attention isn't bandwidth-bound so the +cache dtype doesn't matter. The flag is still exposed (`INFER_KV_DTYPE`) +so you can fall back to fp32 if you need the extra precision. + +## The latent bug in fast+dense (Config A) + +The `fast` path under `jax.nn.dot_product_attention` with a +block-causal mask at `total_length=1056` (prompt=4 + gen=1024 rounded +up to 33×32 blocks) produces degenerate output: + +> …Lily was happy to help and she went to the laundry room. +> When she got there, she saw a big pile of clothes. She started to +> fold them and put them away. She was very careful and did a +> her....,,,,,, and,,,,,,,,,,,,,,,,,,,,,,...... *(900+ more tokens of +> commas and periods, runs out max_nfe)* + +First ~80 tokens look normal; then the generation collapses into +punctuation filler and never reaches EOS, running to `nfe=1024`. The +same model, same prompt, same hyperparams under splash attention +produces a clean TinyStory. `kv_fast` (C, D) also produces clean +output. So the bug is specific to the +`fast + dot_product_attention + total_length=1056` combination. + +Likely cause: numerical reduction order in the dense kernel at that +non-128-aligned sequence length pushes the intermediate denoising +trajectory into a repetition attractor. Splash's block-sparse kernel +uses different reduction order AND my integration pads `num_blocks` +so `total_length=1152` (128-aligned), so splash dodges the trap both +numerically and by shape. + +## Code changes (all on branch `splash_attn`) + +### `dllm_jax/dmax.py` + +- `install_block_causal_splash_for_inference(mesh, num_heads, total_length, block_length, splash_block=None)` + — builds `NumpyMask(block_q >= block_kv)`, wraps as `MultiHeadMask`, + constructs a forward-only `make_splash_mha_single_device`, and + registers it as `dllm_jax.models._MASKED_FLASH_ATTN_FN` under a + `shard_map(mesh, in_specs=(P(None, 'tp', None, None),)*3)`. Auto- + picks the largest splash tile that is a multiple of 128 and divides + `total_length`. +- `_infer_mesh_from_model(model)` — walks `nnx.state(model)` leaves and + unwraps `nnx.Variable` objects to read `.sharding.mesh`. The naive + `jax.tree_util.tree_leaves(model)` returns Variable wrappers, not + raw arrays, so the mesh lookup needs this extra hop. +- `_try_autoinstall_block_causal_splash(model, total_length, block_length)` + — gated by `INFER_SPLASH=1`. Called from `dmax_generate_spd_fast` + right before the `@nnx.jit` trace. +- `dmax_generate_spd_fast`: when `INFER_SPLASH=1`, pads `num_blocks` + up to `128 // block_length` multiple so `total_length % 128 == 0` + (TPU MXU lane alignment). A few extra blocks of compute, correct + output. +- `_resolve_kv_cache_dtype(model)` for `dmax_generate_spd_kv_fast`. + Default changed from `jnp.float32` to the model's compute dtype + (bf16). `INFER_KV_DTYPE=fp32` reverts. + +### `dllm_jax/models.py` + +- `checkpoint_name` markers on `q`/`k`/`v` (post-RoPE) in + `SelfAttention._project_qkv` and on the `gate*up` product in + `DenseMLP.__call__`. Inert unless a named remat policy selects them + (see `REMAT_POLICY` in `tpu_train.py`) — added for the training + path but harmless for inference. + +### `scripts/tpu_infer.py` + +- `WARMUP_RUNS` + `MEASURED_RUNS` env vars, median reporting. +- `TEMPS="0.0,0.3,0.5,..."` env var runs a temperature sweep in a + single process (shares the restore cost). + +## Env knobs (on `scripts/tpu_infer.py`) + +| Var | Default | Effect | +|--|--|--| +| `INFER_IMPL` | `fast` | `fast` / `kv_fast` / `legacy`. Use `fast` with splash for best throughput. | +| `INFER_SPLASH` | `0` | `1` enables splash kernel on the `fast` block-causal mask. Requires sharded model. | +| `INFER_SPLASH_BLOCK` | `512` | Splash tile size. Auto-adjusts to the largest multiple of 128 that divides `total_length`. | +| `INFER_KV_DTYPE` | `bf16` | KV-cache dtype for `kv_fast`. `fp32` restores the pre-patch precision path (no measurable speed gain). | +| `WARMUP_RUNS` | `0` | Throwaway generates before the measured ones. Each pays compile; helps stabilize measurements across seeds. | +| `MEASURED_RUNS` | `1` | Number of timed generates; reports median. | +| `TEMPS` | — | Comma-separated temperatures for a single-process sweep (e.g. `0.0,0.5,1.0`). | + +## Temperature sweep (fast+splash, seed=42) + +| temp | median s | tok/s | nfe | gen | quality | +|--|--|--|--|--|--| +| **0.0 (greedy)** | 41.28 | **24.8** | 128 | 175 | ✅ full coherent story, EOS | +| 0.3 | 163.19 | 6.3 | 1130 | 1024 | ~80 good tokens, then repeating `, the the to..` | +| 0.5 | 165.22 | 6.2 | 1146 | 1024 | ~80 good tokens, then repeating `the with. They...` | +| 0.7 | 164.83 | 6.2 | 1144 | 1024 | ~110 good tokens, then repeating `lesson heron with and..` | +| 1.0 | 76.70 | 13.4 | 416 | 382 | ~80 good tokens, then gibberish; random EOS mid-way | +| 1.5 | 166.04 | 6.2 | 1152 | 1024 | ~20 good tokens, then multilingual soup (`官方微信荤`, `훙`, `эту`, `ซึ่งเป็น`) | + +### Findings + +1. **Only greedy (temp=0) is usable** for this checkpoint on this + model. Any stochastic temperature collapses the block-diffusion + denoising into repetition after 20–110 tokens. +2. The `nfe` column reflects whether the block-level convergence + checks (`all_confident ≥ 0.9`, `same_as_previous`) fire. Greedy + converges fast (≈4 steps/block avg); stochastic sampling defeats + both checks and runs out max `nfe`. +3. At `temp=1.5` the sampler pulls rare tokens from Qwen3's + multilingual tail (Chinese / Korean / Thai / Russian / Spanish). +4. **Top-p / top-k sampling is not implemented.** Without them, any + non-greedy temp is unusable on this checkpoint. The existing + `top_k` parameter is a different thing — a soft-embedding mix for + the intermediate denoising states, not a sampling-time filter. + +### Recommended fix for usable sampling + +Add top-p (nucleus) sampling in `dllm_jax/dmax.py:_sample_x0`, +thread `TOP_P` env var through the inference script. Rough +implementation (~10 lines): + +```python +if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits = jnp.sort(logits_f32, axis=-1)[..., ::-1] + sorted_probs = jax.nn.softmax(sorted_logits, axis=-1) + cum = jnp.cumsum(sorted_probs, axis=-1) + cutoff = jnp.sum(cum < top_p, axis=-1, keepdims=True) + threshold = jnp.take_along_axis(sorted_logits, cutoff, axis=-1) + logits_f32 = jnp.where(logits_f32 < threshold, -jnp.inf, logits_f32) +``` + +Use the pre-truncation `active_probs` for the `THRESHOLD` / +`CONFIDENCE_STOP` early-exit machinery so nucleus truncation doesn't +artificially inflate confidence. + +## What's next + +In decreasing ROI: + +1. **Top-p sampling** — unblocks any non-greedy generation. 20–30 min + (patch `_sample_x0`, thread `TOP_P` through the generate variants + + scripts + README). Required before shipping a user-facing + interface with temperature. +2. **Splash attention in the `kv_fast` prefill** — easy half of the + kv_fast splash integration. Prefill shape is static; only the + per-block decode has a dynamic mask. Prefill is a sizeable + fraction of long-prompt latency. ~30 min. +3. **Splash attention in the `kv_fast` decode** — requires either + pre-compiling N splash kernels (one per block position) and + dispatching via `jax.lax.switch`, or expressing the + `k_pos < block_end` gate as a runtime segment mask that splash + accepts. Benefit uncertain — at `q=32 × k=1152`, splash tiles are + underfilled and per-tile launch overhead may dominate. Current + benchmark already shows `fast+splash` beats `kv_fast+dense`, so + this is low priority unless we scale to very long contexts. +4. **Share jit compile cache across temperature values** — pass + `temperature` as a traced `jax.Array` rather than a closure + constant so one compile serves all temps. Saves ~30 s per extra + temp in a sweep. Cosmetic for benchmarks, not a user-facing + improvement. + +## Reproduce + +### Single run + +```bash +gcloud compute tpus tpu-vm ssh tpu-v5e-64-us --zone=us-central1-a --worker=all \ + --command="cd ~/dllm-jax && \ + PYTHONPATH=~/dllm-jax:\${PYTHONPATH:-} \ + RESUME_DIR=gs://dllm-jax-us-central1/checkpoints/ \ + RESUME_STEP=41400 \ + MODEL_NAME=Qwen/Qwen3-8B PROMPT='Once upon a time' \ + GEN_LENGTH=1024 BLOCK_LENGTH=32 STEPS=32 \ + TOP_K=3 THRESHOLD=0.5 CONFIDENCE_STOP=0.9 TP=8 \ + INFER_IMPL=fast INFER_SPLASH=1 INFER_SPLASH_BLOCK=512 \ + WARMUP_RUNS=1 MEASURED_RUNS=2 \ + python3 scripts/tpu_infer.py" +``` + +### Temperature sweep + +Same as above but replace the last two lines with: + +``` +TEMPS=0.0,0.3,0.5,0.7,1.0,1.5 SEED=42 \ +WARMUP_RUNS=0 MEASURED_RUNS=1 \ +``` diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 0000000..d0b2265 --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,123 @@ +# Inference + +Full reference for [`scripts/tpu_infer.py`](../scripts/tpu_infer.py), the +multi-host inference entry point. See the [README](../README.md#infer) for +the quick-start example. For a deep dive on the splash-attention vs dense +trade-off and end-to-end timing, see +[`inference-optimization.md`](inference-optimization.md). For the design of +the KV-cache path, see [`kv-cache-design.md`](kv-cache-design.md). + +## Three implementations + +All three produce **byte-identical output at matching settings.** Pick one +based on shape and length: + +| `INFER_IMPL=` | What it does | When to use | +|---|---|---| +| `fast` (default) | fixed-shape `fori_loop` with step- and block-level early breaks | short/medium gen, especially with `INFER_SPLASH=1` | +| `kv_fast` | KV-cached, decode only the active block per step | long gen (≥1024 tokens) | +| `legacy` | Python-loop reference path with host-side breaks | debugging only — slow on TPU | + +## Splash attention + +`INFER_SPLASH=1` swaps `jax.nn.dot_product_attention` for a Pallas splash +kernel matched to the block-causal mask. On Qwen3-8B at `GEN_LENGTH=1024` +it's **3.5× faster** *and* fixes a latent dense-kernel quality bug at +non-128-aligned sequence lengths. + +| Variable | Meaning | Default | +|---|---|---| +| `INFER_SPLASH` | enable splash for the `fast` block-causal mask | `0` | +| `INFER_SPLASH_BLOCK` | splash tile size (auto-rounds to 128-multiple ≤ this) | `512` | +| `INFER_KV_DTYPE` | KV-cache dtype for `kv_fast`: `bf16` (2× HBM savings) or `fp32` | `bf16` | + +## Multi-prompt and temperature sweeps + +You can amortize the restore cost across many generations: + +```bash +# Multiple prompts, one per line: +PROMPTS_FILE=./prompts.txt … python3 scripts/tpu_infer.py + +# Temperature sweep with 1 warmup + 2 measured runs each: +PROMPT='Once upon a time' \ +TEMPS='0.0,0.3,0.5,0.7,1.0' \ +WARMUP_RUNS=1 MEASURED_RUNS=2 \ +… python3 scripts/tpu_infer.py +``` + +## Inference env vars + +| Variable | Meaning | Default | +|---|---|---| +| `RESUME_DIR` / `RESUME_STEP` | which checkpoint to restore | — / latest | +| `MODEL_NAME` | tokenizer + HF config source | `Qwen/Qwen3-8B` | +| `PROMPT` | input prompt | `Once upon a time` | +| `PROMPTS_FILE` | one prompt per line (overrides `PROMPT`; `#` lines = comments) | — | +| `GEN_LENGTH` | tokens to generate | `32` | +| `BLOCK_LENGTH` | DMax block size | `32` | +| `STEPS` | max denoising steps per block | `8` | +| `INFER_IMPL` | `fast` / `kv_fast` / `legacy` | `fast` | +| `FAST_BUCKET_LENGTH` | compile window for `fast`; ignored by `kv_fast` | `4096` | +| `THRESHOLD` | left-to-right confidence cutoff (math `0.5`, code `0.65`, others `0.9`–`0.95`) | `0.95` | +| `CONFIDENCE_STOP` | block-level early-exit confidence | `0.9` | +| `TOP_K` | soft-mix top-k (`3` is more coherent on undertrained ckpts) | `1` | +| `TEMPERATURE` | gumbel-max temperature (`0.0` = greedy) | `0.0` | +| `TEMPS` | comma-separated sweep, e.g. `0.0,0.3,0.5` | — | +| `SEED` | RNG seed (only used when `TEMPERATURE > 0`) | none | +| `WARMUP_RUNS` | throwaway generates before timing | `0` | +| `MEASURED_RUNS` | timed generates; reports median | `1` | +| `SUPPRESS_MASK_TOKEN` | force-disable mask logits at argmax | `0` | +| `MASK_TOKEN_ID` / `EOS_TOKEN_ID` | overrides | tokenizer default | +| `TP` | inference mesh's TP axis | `8` | +| `RESTORE_OPTIMIZER` | also restore optimizer state | `0` | +| `INFER_SPLASH` / `INFER_SPLASH_BLOCK` / `INFER_KV_DTYPE` | see above | `0` / `512` / `bf16` | + +## Throughput + +TPU v4-32, Qwen3-8B, `TOP_K=3`: + +| Config | NFE | seconds | +|---|---|---| +| `fast`, GEN=128 | 114 | 32.1 | +| `kv_fast`, GEN=128 | 119 | 65.9 | +| `fast`, GEN=1024 | 1010 | 124.7 | +| `kv_fast`, GEN=1024 | 1043 | 76.6 | + +TPU v5e-64, Qwen3-8B DMax, `TOP_K=3`, `GEN=1024` (full writeup in +[`inference-optimization.md`](inference-optimization.md)): + +| Config | seconds | tok/s | quality | +|---|---|---|---| +| `fast` (dense attention) | 142.3 | 7.2 | ⚠️ collapses into punctuation | +| **`fast` + splash** | **40.9** | **25.1** | ✅ coherent | +| `kv_fast` (fp32 KV) | 52.9 | 19.4 | ✅ coherent | +| `kv_fast` (bf16 KV) | 52.1 | 19.6 | ✅ bit-identical | + +## Single-host CLI + +For laptop / single-TPU-host work without GCS: + +```bash +# Generate from a base model (no DMax training needed) +python scripts/dmax_generate.py \ + --model Qwen/Qwen3-0.6B \ + --prompt "Solve 37 * 48." \ + --gen-length 256 --block-length 32 --steps 32 \ + --threshold 0.5 --top-k 3 --impl fast + +# Generate from a saved single-host trainer checkpoint +python scripts/dmax_generate_checkpoint.py \ + --checkpoint-dir ./out-dmax/checkpoint-1000 \ + --prompt "Solve 37 * 48." \ + --gen-length 256 --impl kv_fast +``` + +## Gotchas + +**Output collapses into punctuation with `TOP_K=1`.** SPD feeds the previous +step's distribution back as a soft mix; at `TOP_K=1` a low-confidence +single token gives a bad signal. Try `TOP_K=3`. + +**First-run generation is slow.** It includes model restore + XLA compile. +Subsequent runs with identical shapes hit the cached graph and are ~5× faster. diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 0000000..8fc00c5 --- /dev/null +++ b/docs/install.md @@ -0,0 +1,73 @@ +# Installation + +## Prerequisites for multi-host TPU runs + +You need three things before any of the multi-host scripts run. Single-host +CPU/GPU development needs none of them. + +**1. A Google Cloud project with TPU access.** Install the [`gcloud` CLI](https://cloud.google.com/sdk/docs/install), then: + +```bash +gcloud auth login +gcloud config set project YOUR_PROJECT_ID +``` + +**2. A regional GCS bucket for checkpoints.** TPU and bucket should share a +region (cross-region writes are slow): + +```bash +gcloud storage buckets create gs://YOUR_BUCKET_NAME --location=us-east1 +``` + +By default this repo writes to `gs://${CHECKPOINT_BUCKET_PREFIX}-${region}` +(prefix `dllm-jax`), so the matching bucket would be `gs://dllm-jax-us-east1`. +Override via `CHECKPOINT_BUCKET=...` or `CHECKPOINT_DIR=gs://...` if you prefer +a different layout. + +**3. Optional but recommended:** a [Weights & Biases](https://wandb.ai) +account for loss curves, and a [HuggingFace](https://huggingface.co) account +if you want gated models or to upload checkpoints. You can run `wandb login` +and `huggingface-cli login` once per TPU worker and skip the env vars +entirely. + +## TPU VM packaging caveat + +Some Python 3.10 TPU VM images ship an older packaging stack where +`pip install -e '.[tpu]'` can fail with a missing `build_editable` hook, +or `pip install '.[tpu]'` misreads metadata as `UNKNOWN-0.0.0` without +installing dependencies. If that happens, skip editable mode and install +deps explicitly from the synced checkout: + +```bash +python3 -m pip install --user -U 'jax[tpu]' \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html \ + 'flax>=0.10.0,<0.11' orbax-checkpoint 'gcsfs<=2026.2.0' 'fsspec<=2026.2.0' \ + 'optax>=0.2.0' numpy 'transformers>=4.40.0' safetensors \ + huggingface_hub datasets wandb +``` + +Then run scripts with `PYTHONPATH=/path/to/dllm-jax python3 …`. + +## Verified TPU versions + +End-to-end validation on TPU v4-32 (`us-central2-b`) with this exact stack. +Pin to these if `pip install '.[tpu]'` shows version drift: + +| Package | Version | +|---------|---------| +| Python | 3.10.12 | +| jax / jaxlib | 0.6.2 | +| libtpu | 0.0.17 | +| flax | 0.10.7 | +| optax | 0.2.8 | +| orbax-checkpoint | 0.11.34 | +| transformers | 5.5.3 | +| safetensors | 0.7.0 | +| datasets | 4.8.4 | +| gcsfs / fsspec | 2025.3.2 | +| huggingface_hub | 1.10.1 | +| numpy | 2.2.6 | + +Newer flax (0.12+) on JAX 0.7+ should also work; the +`_nnx_list = getattr(nnx, "List", list)` shim in `models.py` handles the +cross-version difference. diff --git a/docs/kv_cache_design.md b/docs/kv-cache-design.md similarity index 99% rename from docs/kv_cache_design.md rename to docs/kv-cache-design.md index fdfb3f5..58c4fb2 100644 --- a/docs/kv_cache_design.md +++ b/docs/kv-cache-design.md @@ -222,7 +222,7 @@ and identical `updated_block` at every step (modulo float32 rounding in attentio 1. Add `past_k/past_v/cache_position` plumbing to model with no-op default. 2. Implement `dmax_generate_spd_kv_fast` + post-block hard-write pass. -3. Add `INFER_IMPL=kv_fast` to `tpu_dmax_infer_checkpoint.py`. +3. Add `INFER_IMPL=kv_fast` to `tpu_infer.py`. 4. A/B test: `kv_fast` vs `fast` on the same prompt/settings. Compare text and `nfe`. 5. Benchmark: measure `generate_seconds` at `gen_length=128, 1024, 4096`. Expect 3–5× speedup on longer contexts. diff --git a/docs/mfu-optimization.md b/docs/mfu-optimization.md new file mode 100644 index 0000000..3881c0d --- /dev/null +++ b/docs/mfu-optimization.md @@ -0,0 +1,201 @@ +# MFU Optimization: Qwen3-8B + DMax OPUT on TPU v5e-64 + +Summary of a MFU tuning pass on `tpu-v5e-64-eu` (europe-west4-b) for +`scripts/tpu_train.py` training **Qwen3-8B with full DMax OPUT** +(on_policy_ratio=0.5 rollout kept, block-diffusion attention mask, +synthetic data, random init, AdamW, aggressive `jax.remat`). + +## Headline + +| Config | MFU (attn-adj) | tok/s | step s | vs start | +|--|--|--|--|--| +| starting point | 2.9% | 6,570 | 10.0 | 1.0× | +| **final (EE)** | **14.7%** | **32,834** | **8.0** | **5.1×** | + +All comparisons are at `MODEL_NAME=Qwen/Qwen3-8B`, `MAX_LEN=4096`, +`PEAK_TFLOPS_PER_CHIP=197` (12.6 PFLOPS bf16 peak on v5e-64), +preserved training semantics (full OPUT rollout — changing semantics +is separated below). + +The reported MFU above is under the **original** formula that credited +one fwd+bwd at `seq=MAX_LEN`. In reality DMax's compiled graph runs +fwd+bwd at `seq=2·MAX_LEN` and — whenever `DMAX_ON_POLICY_RATIO > 0` +— also pays a stop_gradient rollout forward at `seq=2·MAX_LEN`. The +attention term is quadratic, so the true multiplier is **~3.01×** (not +2×). Under correct accounting the EE measurement is +**≈ 44% of v5e bf16 peak**, i.e., the compiled graph reaches just under +half of MXU peak on these shapes. Fixed in `tpu_train.py:319–337` +(see item 1 below). + +## Optimization ladder (full OPUT semantics preserved) + +| # | Change | MFU | tok/s | step s | Notes | +|--|--|--|--|--|--| +| 0 | Starting — TP=8 hardcoded, dense attn, B=16 L=4096 | 2.9% | 6,570 | 10.0 | baseline | +| 1 | `TP=2 × FSDP=32` mesh (P) | 7.2% | 16,019 | 8.17 | mesh re-shaping freed HBM; larger matmuls; B=32 fit | +| 2 | `splash_attention` for block-diffusion mask (Z) | 8.9% | 19,809 | 6.61 | replaced dense `jax.nn.dot_product_attention` on 2L×2L mask | +| 3 | Scale to B=64 using splash's O(seq) attn memory (BB) | 9.2% | 20,469 | 12.8 | doubled batch — was OOM at B=64 with dense attn | +| 4 | **Tune splash `BlockSizes` to 512 + `use_fused_bwd_kernel=True` (EE)** | **14.7%** | **32,834** | **8.0** | kernel-level tile + bwd fusion — biggest single win | + +### Why each step helped + +1. **TP=2 × FSDP=32 mesh** — `tpu_train.py` hard-coded `TP = 8`. + On v5e-64's 8×8 torus that forces FSDP=8, so optimizer state shards + only 8-way and B must be a multiple of 8. Dropping TP to 2 (with + `mesh_utils.create_device_mesh(..., allow_split_physical_axes=True)` + because `(32, 2)` doesn't factor onto 8×8) bumps FSDP to 32, which + (a) makes the per-chip optimizer/grad/param footprint ~4× smaller + and (b) enlarges the matmul size per chip (hidden/TP=2048 vs 512), + better amortizing MXU tile overhead. + +2. **splash_attention for the block-diffusion mask** — the DMax path + was silently **not** using flash attention: `dllm_jax/models.py:367` + gated flash on `attention_mask is None`, and DMax always passes a + non-None block-diffusion mask. The fallback was a dense + `jax.nn.dot_product_attention` over the full 2L×2L mask. Switched + to `jax.experimental.pallas.ops.tpu.splash_attention`, with the + block-diffusion pattern wrapped as `NumpyMask` inside a + `MultiHeadMask([mask] * heads_per_tp)`. Splash runs a block-sparse + flash that skips masked regions and never materializes the score + matrix. + +3. **Scale to B=64 with splash** — dense attention's 2L×2L score + materialization was the reason B=64 L=4096 OOMed previously + (R: 28.84 GB > 15.75 GB). splash is O(seq), so B=64 fit and the + mesh could amortize its fixed costs over 2× the tokens. + +4. **Tune splash tile sizes + fused bwd** — the single biggest kernel + change. `splash_attention_kernel.BlockSizes` default is + `block_q=block_kv=block_kv_compute=128` with a separate `dkv` and + `dq` backward pass (`use_fused_bwd_kernel=False`). 128×128 is much + smaller than what v5e wants — at the per-chip shapes we run + (8192 × 128), each tile barely saturates the MXU and the + per-kernel launch overhead dominates. Bumping to 512 (8× more + work per launch) and enabling the fused bwd kernel (dkv + dq in a + single pass) drops step time 12.8 s → 8.0 s (−37%). + +## Knobs added + +All exposed as env vars on `scripts/tpu_train.py`: + +| Env | Default | Effect | +|--|--|--| +| `TP` | 8 | tensor-parallel axis size; 2 or 4 usually better for v5e-64 | +| `SPLASH_BLOCK` | 512 | splash tile size (was 128 default) | +| `SPLASH_FUSED_BWD` | 1 | enable fused `dkv`/`dq` backward kernel | +| `DMAX_ON_POLICY_RATIO` | 0.5 | skips rollout fwd when 0 (see ablation below) | +| `REMAT_POLICY` | nothing_saveable | `gate_up` / `qkv_gate_up` / `dots_saveable` / `everything_saveable` save more between fwd and bwd; saves recompute at HBM cost. | + +## Peak config (EE) — reproduce + +```bash +RUN_NAME=qwen3-8b-dmax-v5e-ee \ +MODEL_NAME=Qwen/Qwen3-8B DATASET=tinystories \ +MAX_LEN=4096 GLOBAL_BATCH=64 \ +TP=2 \ +SPLASH_BLOCK=512 SPLASH_FUSED_BWD=1 \ +DMAX_ENABLE=1 DMAX_ON_POLICY_RATIO=0.5 \ +DMAX_NOISE_LOW=0.75 DMAX_NOISE_HIGH=0.75 DMAX_BLOCK_SIZE=32 \ +PEAK_TFLOPS_PER_CHIP=197 \ +NUM_STEPS=0 NUM_EPOCHS=3 WANDB_LOG=1 \ +python3 scripts/tpu_train.py +``` + +## What did NOT help + +| Tried | Result | +|--|--| +| XLA async-collective flags (`xla_tpu_enable_async_collective_fusion`, etc.) | 0.0 pp — XProf showed comm is only ~14% of step time, not the bottleneck. | +| Longer seq at same batch (`L=7168`, `L=8192`) at TP=2 | L=7168 was slower (tok/s dropped); L=8192 OOMed (program 13.24 GB > 12.63 GB free — splash mask 16384² too big). | +| Bigger batch past B=64 at TP=2 (`B=96`, `B=128`) | OOM — program size exceeds HBM. | +| `OPTIMIZER=adafactor` at B=64 (bf16 factored opt state) | OOMed anyway — activations, not optimizer state, are the blocker at B=64 with splash defaults. | +| Skipping rollout forward unconditionally in DMax path | Changes OPUT semantics to off-policy-only, not a valid DMax result (see ablation table). | + +## XProf breakdown (captured on a pre-tuned splash config) + +Sampled steady-state steps 4–6, 4 chips on worker 0: + +| Bucket | % of device time | +|--|--| +| fusion (matmul + elementwise epilogues) | 77% | +| while loops (remat recompute) | 7.5% | +| collective-permute | 6.0% | +| all-gather (FSDP) | 5.9% | +| all-reduce | 2.5% | +| copy / reshape / reduce | <1% | + +So comm is ~14% combined (not bound) and MXU utilization inside the +fusions is the real ceiling. The splash tile tuning in step 4 above +attacks exactly that — bigger tiles raise per-fusion MXU efficiency. + +## Ablation — non-OPUT configs (throughput-only) + +These change training semantics and are **not** DMax OPUT results; +kept for attribution. + +| Tag | Config | MFU | tok/s | Why reported | +|--|--|--|--|--| +| I | plain MDLM (DMax off), B=8 L=4096 TP=8 | 6.7% | 14,960 | Confirms the 2.9% starting ceiling is really ~6.7% HW × DMax accounting factor. | +| W | DMax TP=2 B=32 L=4096, `on_policy_ratio=0` + rollout fwd guard | 8.4% | 18,874 | Demonstrates the rollout fwd was ~⅓ of compute — but removing it drops OPUT semantics. | +| AA | W + splash attention | 10.1% | 22,605 | Same, also not OPUT. | + +The W/AA rollout-skip patch sits behind a `DMAX_ON_POLICY_RATIO > 0` +guard in `loss_fn` — at the default 0.5 it's a no-op and OPUT runs +normally. + +## Code changes (all uncommitted, applied on the TPU) + +**`scripts/tpu_train.py`**: +- Line 87: `TP = int(os.environ.get("TP", "8"))` + split-axes mesh fallback. +- After flash install: build `_block_diffusion_mask_numpy` once at init, + wrap in `MultiHeadMask`, construct splash with tuned `BlockSizes`, + register as `dllm_models._MASKED_FLASH_ATTN_FN` (shard_map'd over the + fsdp×tp mesh with vmap over batch). +- `loss_fn`: rollout forward now guarded by `DMAX_ON_POLICY_RATIO > 0`. +- Profiler hook env: `JAX_PROFILE_DIR`, `JAX_PROFILE_START_STEP`, + `JAX_PROFILE_STEPS`. + +**`dllm_jax/models.py`**: +- New `_MASKED_FLASH_ATTN_FN` global. `_attention` now routes to it + when `attention_mask is not None`, keeping dense `dot_product_attention` + as the final fallback. + +## What's next if you want to push past 14.7% + +In decreasing ROI: + +1. ✅ **Fix the DMax MFU formula** — **implemented** in + `tpu_train.py:319–337`. The correct multiplier is **~3.01×**, + not 2×: fwd+bwd at 2L gives dense ×2 and attention ×4 (quadratic), + and the rollout fwd at 2L (always compiled when + `DMAX_ON_POLICY_RATIO > 0`) adds another ~0.67× dense and ~1.33× + attention. The new formula credits both passes at their actual + sequence lengths and reduces to the pre-DMax formula when + `DMAX_ENABLE=0`. Reported MFU on the EE config goes 14.7% → **~44%** + (attention-adjusted) without any change in wall-clock. +2. **Fused SwiGLU Pallas kernel** — merges + `down_proj(silu(gate_proj) * up_proj)` into one kernel. Avoids + materializing the `(B, 2L, intermediate)` activation in HBM. + Mostly reduces the 7.5% remat `while` cost. Estimated +1-3 pp MFU. + Multi-day build. +3. ⚙️ **Looser remat policy** — `REMAT_POLICY` env var now exposes + `gate_up`, `qkv_gate_up`, `dots_saveable`, `everything_saveable` + alongside the default `nothing_saveable`. `dllm_jax/models.py` + marks `q`/`k`/`v` (post-RoPE) and the gate·up product with + `checkpoint_name`, so the named policies are free to save them. + Not yet measured on v5e-64 — safest starting point is + `REMAT_POLICY=gate_up` at **B=32** (B=64 is likely to OOM given + saving `gate_up` at seq=2L costs ~24 KB × 2L × layers per sample, + ~14 GB per chip at B=64/FSDP=32 bf16). +4. **Architectural: fold `[noised; clean]` into batch dim** instead of + seq dim. Halves attention quadratic cost and the activation + footprint. Requires rewriting the block-diffusion mask to run as + two separate per-stream attentions. Biggest potential lever but + a real rewrite. + +## Artefacts + +- `/tmp/sweep_results.md` — full sweep table (all configs, incl. OOMs). +- `/tmp/xprof_summary.md` — XProf trace analysis + top ops. +- `/tmp/xprof-P/plugins/profile/*/` — raw xplane.pb + trace.json.gz + for TensorBoard. diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000..ce191f0 --- /dev/null +++ b/docs/training.md @@ -0,0 +1,218 @@ +# Training + +Full reference for [`scripts/tpu_train.py`](../scripts/tpu_train.py), the +multi-host training entry point. See the [README](../README.md#train) for the +quick-start example. + +## Datasets + +| `DATASET=` | Source | Use | +|---|---|---| +| `tinystories` | `roneneldan/TinyStories` (streamed) | small-scale sanity / smoke | +| `wikipedia` | `wikimedia/wikipedia 20231101.en` (streamed) | pretrain-style packing | +| `openthoughts` | `open-thoughts/OpenThoughts-114k` (streamed) | SFT (chat-templated) | +| `parquet` | local `*.parquet` files via `DATASET_PATH=` | bring-your-own | +| `synthetic` | random ints | regression / shape testing | + +`tinystories`, `wikipedia`, and `parquet` use **token-stream packing**: +documents are tokenized, joined by EOS, and chunked to `MAX_LEN`. +`openthoughts` is a chat-templated SFT set; by default it also packs the +full chat-templated text. Set `SFT_TRAIN_ON_ANSWERS_ONLY=1` to mask prompt +tokens with `-100` and supervise only assistant turns (per-example padded +batches; some rows >`MAX_LEN` are dropped). + +To inspect what your SFT pipeline actually feeds the model: + +```bash +python3 scripts/inspect_sft_data.py --model Qwen/Qwen3-8B --max-len 4096 --rows 2 +``` + +## DMax / OPUT training + +`DMAX_ENABLE=1` switches the script from MDLM-style noising to DMax's +high-noise, on-policy, block-diffusion objective. + +| Variable | Meaning | Default | +|---|---|---| +| `DMAX_ENABLE` | turn on DMax (otherwise MDLM-style) | `0` | +| `DMAX_BLOCK_SIZE` | tokens per noising block | `32` | +| `DMAX_ON_POLICY_RATIO` | fraction of steps using model's own greedy preds | `0.5` | +| `DMAX_NOISE_LOW` / `DMAX_NOISE_HIGH` | uniform noise range | `0.75` / `0.75` | + +When DMax is on, the script also installs a Pallas **splash attention** +kernel matched to the block-diffusion mask: + +| Variable | Meaning | Default | +|---|---|---| +| `SPLASH_BLOCK` | tile size (must be a multiple of 128 dividing `MAX_LEN`) | `512` | +| `SPLASH_FUSED_BWD` | fused backward kernel | `1` | +| `DISABLE_SPLASH_ATTN` | fall back to dense attention (debugging) | `0` | + +## Optimizer + +Both AdamW and Adafactor are wired up: + +| Variable | Meaning | Default | +|---|---|---| +| `OPTIMIZER` | `adamw` or `adafactor` | `adamw` | +| `PEAK_LR` | peak learning rate (post-warmup) | `1e-4` | +| `WARMUP_STEPS` | linear ramp from 0 → `PEAK_LR` | `5` | + +> ⚠️ **AdamW is the recommended default** for diffusion-LM training on +> pretrained init. Adafactor's `scale_by_param_block_rms` misbehaves on +> bidirectional objectives over causal-LM weights — loss descends and then +> climbs back to ~12 around step 60 on Qwen3-8B for both MDLM and DMax/OPUT. +> Use Adafactor only when HBM forces it (e.g. 128k context on a single +> chip), and budget the divergence risk accordingly. + +## Learning-rate schedule + +By default the schedule is *constant after warmup*. To decay: + +| Variable | Meaning | Default | +|---|---|---| +| `LR_SCHEDULE` | `constant` or `cosine` | `constant` | +| `LR_DECAY_STEPS` | cosine: total decay steps after warmup | `0` | +| `LR_DECAY_ALPHA` | cosine: final LR = `PEAK_LR * alpha` | `0.1` | + +## Mask-token warm start + +DMax reuses a token id as the MASK token (default `vocab_size - 1`, an +*untrained* reserved slot on Qwen3). Without intervention, its row in the +embedding matrix is essentially random — observed to drift the noised +forward toward uniform predictions around step 150. + +To seed the mask row with the **mean** of all other input/output embedding +rows: + +| Variable | Meaning | Default | +|---|---|---| +| `MASK_EMBED_INIT` | `mean` (warm start) or `none` (skip) | `mean` | +| `MASK_TOKEN_ID` | which row to seed | `vocab_size - 1` | + +> ⚠️ On **v5e-64**, set `MASK_EMBED_INIT=none` for now. The pre-shard +> `.at[].set()` on replicated embed + lm_head leaks ~9.5GB HBM and blocks +> the training program from loading. + +You can also point at a pretrained Qwen3 special token like +`MASK_TOKEN_ID=151662` (`<|fim_pad|>`) to inherit "fill in the missing +piece" semantics for free. + +## Memory & throughput knobs + +| Variable | Meaning | Default | +|---|---|---| +| `REMAT_POLICY` | gradient checkpointing policy | `nothing_saveable` | +| `MAX_LEN` | context length | `16384` | +| `GLOBAL_BATCH` | global batch (across all chips) | `8` | +| `PEAK_TFLOPS_PER_CHIP` | for MFU calc; v4=275, v5e=197, v5p=459, v6e=918 | `918` | + +The throughput methodology is in [`mfu-optimization.md`](mfu-optimization.md). +Reference numbers land between ~38–48% MFU after splash + remat tuning. + +## Checkpointing + +Sharded Orbax DCP checkpoints write to GCS by default: + +``` +gs://${CHECKPOINT_BUCKET_PREFIX}-${region}/checkpoints/${RUN_NAME}/ +├── checkpoint_500/ +├── checkpoint_500/commit_success.txt +├── checkpoint_1000/ +└── … +``` + +| Variable | Meaning | Default | +|---|---|---| +| `CHECKPOINT_STEPS` | save every N steps | `500` | +| `CHECKPOINT_KEEP` | retain the last N | `2` | +| `CHECKPOINT_BUCKET_PREFIX` | regional auto-detect prefix | `dllm-jax` | +| `CHECKPOINT_BUCKET` | fixed bucket (overrides prefix) | — | +| `CHECKPOINT_DIR` | full directory (overrides both) | — | +| `LOCAL_CHECKPOINT_DIR` | fallback if GCS not configured | `/tmp/dllm-jax-checkpoints` | +| `CHECKPOINT_ON_FINISH` | always write a final checkpoint | `0` | + +The script also enables Orbax barrier shims by default +(`CHECKPOINT_ORBAX_SYNC_DIRS=1`, `CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1`) +so distributed GCS writes use JAX multi-host barriers — needed on the +JAX-0.6.x / Orbax-0.11.x TPU VM stack. + +## Resume + +Resuming is one variable: point at the parent directory and optionally pin +a step. + +```bash +RESUME_DIR=gs://dllm-jax-us-east1/checkpoints/old-run \ +RESUME_STEP=0 \ +… python3 scripts/tpu_train.py +``` + +`RESUME_STEP=0` (default) scans `commit_success.txt` markers and picks the +latest committed step. Set `RESUME_STEP=` to pin a specific one. + +### Three resume modes + +| Scenario | Settings | +|---|---| +| Continue training, same data, same hardware | `RESUME_DIR=…` (defaults are correct) | +| **Continue onto a new dataset** (e.g. switch to `openthoughts`) | `RESUME_DIR=… RESUME_RESET_STEP=1` | +| **Pretrain → SFT** (different optimizer family or lr) | `RESUME_DIR=… RESUME_RESTORE_OPTIMIZER=0 RESUME_RESET_STEP=1` | + +| Variable | Meaning | Default | +|---|---|---| +| `RESUME_DIR` / `RESUME_FROM` | parent dir containing `checkpoint_/` | — | +| `RESUME_STEP` | specific step (`0` = latest) | `0` | +| `RESUME_RESTORE_OPTIMIZER` | `0` = restore weights only, fresh optimizer | `1` | +| `RESUME_RESET_STEP` | `1` = zero global_step / epoch counters | `0` | + +Switching hardware (v4 → v5e, v6e → v4, …) is supported transparently by +orbax: the script re-shards on restore using the *current* mesh. The only +thing you need to update is `TP`. + +## Sharding and hardware sizing + +The 2D mesh is `(fsdp, tp)` with `fsdp = jax.device_count() // TP`. + +| Hardware | Chips | Recommended `TP` | Mesh shape | Notes | +|---|---|---|---|---| +| TPU v4-32 | 16 | `8` | `(2, 8)` | validated for Qwen3-8B training and inference | +| TPU v5e-64 | 64 | `2` | `(32, 2)` | `TP=8` here forces FSDP=8 and ~4× optimizer-state HBM; prefer `TP=2` | +| TPU v6e-256 | 256 | `8` | `(32, 8)` | high MFU with splash + remat | +| Single-host (≤ 3B params) | any | `1` | `(N,)` 1D FSDP | `P("fsdp", None)` direct TPU init | + +For models ≥ 3B you generally want CPU-first init +(`jax.default_device(jax.devices("cpu")[0])`), gradient checkpointing on +every transformer layer, and Pallas flash attention. The smoke script does +all three automatically. + +## Gotchas + +**Adafactor diverges around step 60.** See the optimizer note above — +default to AdamW unless HBM forces otherwise. + +**Optimizer rebind after split/merge.** If you manually `nnx.split` and +`nnx.merge` the optimizer in a hand-written script, rebind `model` or +you'll get silent zero-progress: + +```python +opt_gdef, opt_state = nnx.split(optimizer) +opt_state = jax.tree.map(fsdp_sharding, opt_state) +optimizer = nnx.merge(opt_gdef, opt_state) +model = optimizer.model # CRITICAL — without this, grads land on a stale model +``` + +The built-in `trainers.py` and `tpu_train.py` already do this; only +relevant if you're rolling your own loop. + +**flax 0.10.7 vs 0.12+.** Python 3.10 TPU VMs pin flax to 0.10.7. The +`_nnx_list = getattr(nnx, "List", list)` shim in `models.py` and the +positional `optimizer.update(grads)` call are deliberate compat choices. +Don't "modernize" them without testing on the older stack. + +**Stale TPU checkout.** `ImportError: cannot import name '...'` usually +means the TPU copy is out of date. Re-`scp` `dllm_jax/` and the target +script. + +**`gs://` / TensorFlow warnings on restore.** Harmless. Restore uses +Orbax's GCS client, not `tf.io.gfile`. diff --git a/scripts/dmax_tinystories_train.py b/scripts/examples/dmax_train_tinystories.py similarity index 100% rename from scripts/dmax_tinystories_train.py rename to scripts/examples/dmax_train_tinystories.py diff --git a/scripts/run_tpu_v4_32_3epoch.sh b/scripts/examples/run_tpu_train_v4_32_3epoch.sh similarity index 87% rename from scripts/run_tpu_v4_32_3epoch.sh rename to scripts/examples/run_tpu_train_v4_32_3epoch.sh index d4305e3..63901c5 100755 --- a/scripts/run_tpu_v4_32_3epoch.sh +++ b/scripts/examples/run_tpu_train_v4_32_3epoch.sh @@ -11,5 +11,5 @@ export LOG_EVERY="${LOG_EVERY:-50}" export TOKENIZERS_PARALLELISM=false export LIBTPU_INIT_ARGS="${LIBTPU_INIT_ARGS:---xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_data_parallel_opt_different_sized_ops=true}" -cd "$(dirname "$0")/.." -exec python3 scripts/tpu_v4_32_train_3epoch.py 2>&1 +cd "$(dirname "$0")/../.." +exec python3 scripts/examples/tpu_train_v4_32_3epoch.py 2>&1 diff --git a/scripts/tpu_v4_32_train_3epoch.py b/scripts/examples/tpu_train_v4_32_3epoch.py similarity index 100% rename from scripts/tpu_v4_32_train_3epoch.py rename to scripts/examples/tpu_train_v4_32_3epoch.py diff --git a/scripts/inspect_sft_data.py b/scripts/inspect_sft_data.py new file mode 100644 index 0000000..31a4064 --- /dev/null +++ b/scripts/inspect_sft_data.py @@ -0,0 +1,162 @@ +"""Dump SFT data pipeline batches to stdout for inspection. + +Mirrors the packing path in ``scripts/tpu_train.py`` (``refill_buffer`` + +``get_batch``) without any JAX / training code, so we can cheaply check: + +- does ``apply_chat_template`` round-trip correctly (decode == human-readable)? +- does concat-packing produce the expected ``<|im_start|>...<|im_end|>`` structure? +- does labels == input_ids (full-text SFT) and is the MASK token absent from data? +- does the EOS token sit between docs, not inside them? +- are there suspicious id==-1 / id>vocab_size / repeated pathological tokens? + +Run locally or on any TPU worker: + python3 scripts/inspect_sft_data.py [--max-len 4096] [--rows 2] [--batches 1] +""" + +from __future__ import annotations + +import argparse +import os +import sys +from collections import Counter +from pathlib import Path + +import numpy as np + + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +_ROLE_MAP = { + "human": "user", + "gpt": "assistant", + "system": "system", + "user": "user", + "assistant": "assistant", +} + + +def _sft_row_to_messages(row): + convs = row.get("conversations") + if not convs: + return None + messages = [] + system = row.get("system") + if system: + messages.append({"role": "system", "content": system}) + for c in convs: + if isinstance(c, dict) and "from" in c: + messages.append({"role": _ROLE_MAP.get(c["from"], c["from"]), "content": c["value"]}) + else: + messages.append(c) + return messages + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model", default=os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B")) + ap.add_argument("--max-len", type=int, default=int(os.environ.get("MAX_LEN", "4096"))) + ap.add_argument("--batch", type=int, default=int(os.environ.get("GLOBAL_BATCH", "4"))) + ap.add_argument("--rows", type=int, default=2, help="how many packed rows to decode in full") + ap.add_argument("--batches", type=int, default=1, help="how many batches to draw") + ap.add_argument("--mask-id", type=int, default=int(os.environ.get("MASK_TOKEN_ID", "151662"))) + args = ap.parse_args() + + print(f"[inspect] model={args.model} max_len={args.max_len} batch={args.batch} mask_id={args.mask_id}") + + import transformers + from datasets import load_dataset + + tokenizer = transformers.AutoTokenizer.from_pretrained(args.model) + eos_id = tokenizer.eos_token_id + print(f"[inspect] eos_id={eos_id} pad_id={tokenizer.pad_token_id} vocab={tokenizer.vocab_size}") + # Decode the special MASK id to see what Qwen3 calls 151662 + try: + mask_tok = tokenizer.decode([args.mask_id]) + print(f"[inspect] mask_id {args.mask_id} -> {mask_tok!r}") + except Exception as exc: + print(f"[inspect] decoding mask_id failed: {exc}") + + ds = load_dataset("open-thoughts/OpenThoughts-114k", "default", split="train", streaming=True) + ds_iter = iter(ds) + + # Replicate refill_buffer + get_batch packing + token_buffer: list[int] = [] + total_rows_consumed = 0 + total_rows_dropped = 0 + + def refill(needed: int): + nonlocal total_rows_consumed, total_rows_dropped + while len(token_buffer) < needed: + try: + row = next(ds_iter) + except StopIteration: + break + messages = _sft_row_to_messages(row) + total_rows_consumed += 1 + if messages is None: + total_rows_dropped += 1 + continue + out = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=False, return_dict=True, + ) + token_buffer.extend(out["input_ids"]) + token_buffer.append(eos_id) + + for b in range(args.batches): + print("\n" + "=" * 80) + print(f"BATCH {b + 1}/{args.batches}") + print("=" * 80) + needed = args.batch * args.max_len + refill(needed) + if not token_buffer: + print("[inspect] dataset exhausted before first batch") + return + + ids = np.full((args.batch, args.max_len), tokenizer.pad_token_id, dtype=np.int64) + for i in range(args.batch): + length = min(args.max_len, len(token_buffer)) + if length > 0: + ids[i, :length] = token_buffer[:length] + token_buffer = token_buffer[length:] + labels = ids.copy() # full-text SFT + + # Batch-level diagnostics + counts = Counter(ids.ravel().tolist()) + top = counts.most_common(10) + print(f"[diag] rows consumed so far: {total_rows_consumed} (dropped={total_rows_dropped})") + print(f"[diag] ids shape={ids.shape} dtype={ids.dtype}") + print(f"[diag] labels == ids: {np.array_equal(ids, labels)}") + print(f"[diag] min/max id: {ids.min()} / {ids.max()}") + print(f"[diag] mask_id {args.mask_id} occurrences in batch: {(ids == args.mask_id).sum()}") + print(f"[diag] eos_id {eos_id} occurrences in batch: {(ids == eos_id).sum()}") + print(f"[diag] pad_id {tokenizer.pad_token_id} occurrences in batch: {(ids == tokenizer.pad_token_id).sum()}") + print(f"[diag] top-10 tokens in batch:") + for tok_id, cnt in top: + try: + s = tokenizer.decode([tok_id]) + except Exception: + s = "?" + print(f" {tok_id:6d} x {cnt:7d} -> {s!r}") + + # Decode head and tail of a few rows + for r in range(min(args.rows, args.batch)): + print("\n" + "-" * 60) + print(f"row {r} — first 400 chars") + print("-" * 60) + head_text = tokenizer.decode(ids[r, :400], skip_special_tokens=False) + print(head_text) + print("-" * 60) + print(f"row {r} — last 400 chars") + print("-" * 60) + tail_text = tokenizer.decode(ids[r, -400:], skip_special_tokens=False) + print(tail_text) + # EOS locations for doc boundaries + eos_positions = np.where(ids[r] == eos_id)[0] + print(f"[row {r}] eos at positions: {eos_positions[:20].tolist()}{'...' if len(eos_positions) > 20 else ''} (count={len(eos_positions)})") + + +if __name__ == "__main__": + main() diff --git a/scripts/tpu_dmax_infer_checkpoint.py b/scripts/tpu_infer.py similarity index 79% rename from scripts/tpu_dmax_infer_checkpoint.py rename to scripts/tpu_infer.py index db162d1..7cd2263 100644 --- a/scripts/tpu_dmax_infer_checkpoint.py +++ b/scripts/tpu_infer.py @@ -364,6 +364,19 @@ def main() -> None: model_name = os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B") prompt = os.environ.get("PROMPT", "Once upon a time") + # PROMPTS_FILE (one prompt per line) overrides PROMPT and runs generation + # for each prompt against the same restored model. Blank lines skipped; + # lines starting with # are treated as comments and also skipped. + prompts_file = os.environ.get("PROMPTS_FILE", "").strip() + if prompts_file: + with open(prompts_file) as f: + prompts_list = [ + ln.rstrip("\n") for ln in f if ln.strip() and not ln.lstrip().startswith("#") + ] + if not prompts_list: + raise ValueError(f"PROMPTS_FILE={prompts_file} produced no usable prompts") + else: + prompts_list = [prompt] gen_length = int(os.environ.get("GEN_LENGTH", "32")) block_length = int(os.environ.get("BLOCK_LENGTH", "32")) steps = int(os.environ.get("STEPS", "8")) @@ -374,6 +387,8 @@ def main() -> None: fast_bucket_length = int(os.environ.get("FAST_BUCKET_LENGTH", "4096")) temperature = float(os.environ.get("TEMPERATURE", "0.0")) top_k = int(os.environ.get("TOP_K", "1")) + warmup_runs = max(0, int(os.environ.get("WARMUP_RUNS", "0"))) + measured_runs = max(1, int(os.environ.get("MEASURED_RUNS", "1"))) seed_env = os.environ.get("SEED") seed = int(seed_env) if seed_env else None dtype_name = os.environ.get("DTYPE", "bfloat16") @@ -470,9 +485,19 @@ def main() -> None: restored_epoch_step = int(np.asarray(multihost_utils.process_allgather(restored["epoch_step"], tiled=False))[0]) sync_all("infer-restore-complete") - input_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"] - input_arr = np.asarray([input_ids], dtype=np.int32) - input_arr = jax.device_put(input_arr, NamedSharding(mesh, P())) + # For multi-prompt runs, tokenize all upfront. Pad to the max so every + # call has the same input shape and only one compile happens. + tokenized = [tokenizer(p, add_special_tokens=False)["input_ids"] for p in prompts_list] + max_prompt_len = max(len(ids) for ids in tokenized) + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + padded = [ + (ids + [pad_id] * (max_prompt_len - len(ids))) for ids in tokenized + ] + prompt_lengths = [len(ids) for ids in tokenized] # original (unpadded) length per prompt + input_arr_full = np.asarray(padded, dtype=np.int32) # [num_prompts, max_len] + input_ids = tokenized[0] # for the first-prompt log line + # Device-place the first prompt's input array for the compile-path unchanged below. + input_arr = jax.device_put(input_arr_full[:1], NamedSharding(mesh, P())) if proc == 0: print( @@ -513,22 +538,82 @@ def main() -> None: generate_kwargs["bucket_length"] = fast_bucket_length # kv_fast does not use bucket_length (it runs a single compiled while_loop # over blocks with a pre-allocated KV cache of size ``total_length``). - output = generate_fn(model, input_arr, **generate_kwargs) - output.generated_tokens.block_until_ready() - nfe = int(np.asarray(output.nfe)) - generated_all = np.asarray(multihost_utils.process_allgather(output.generated_tokens, tiled=False)) - sync_all("infer-generation-complete") + + # TEMPS env var lets one process run multiple temperatures sharing the + # same restored model. Example: TEMPS="0.0,0.3,0.5,0.7,1.0" + temps_env = os.environ.get("TEMPS", "").strip() + if temps_env: + temps_list = [float(t) for t in temps_env.split(",") if t.strip()] + else: + temps_list = [temperature] + + def _run_once(tag: str, temp_val: float, seed_val): + local_kwargs = dict(generate_kwargs) + local_kwargs["temperature"] = temp_val + local_kwargs["seed"] = seed_val + rt0 = time.time() + output = generate_fn(model, input_arr, **local_kwargs) + output.generated_tokens.block_until_ready() + dt = time.time() - rt0 + nfe_val = int(np.asarray(output.nfe)) + if proc == 0: + print(f"[bench] {tag} temp={temp_val} generate_seconds={dt:.2f} nfe={nfe_val}", flush=True) + sync_all(f"infer-{tag}") + return output, dt, nfe_val + + all_results = [] # list of (prompt_idx, prompt_str, temp_result_dict) + for p_idx, p_text in enumerate(prompts_list): + # Swap input_arr to this prompt's padded tokens (same shape → no recompile). + input_arr = jax.device_put(input_arr_full[p_idx:p_idx + 1], NamedSharding(mesh, P())) + for t_idx, temp_val in enumerate(temps_list): + for i in range(warmup_runs): + _run_once(f"warmup[p{p_idx}_temp={temp_val},{i + 1}/{warmup_runs}]", temp_val, seed) + + measured_times = [] + measured_nfe = [] + output = None + for i in range(measured_runs): + output, dt, nfe_val = _run_once( + f"measured[p{p_idx}_temp={temp_val},{i + 1}/{measured_runs}]", temp_val, seed, + ) + measured_times.append(dt) + measured_nfe.append(nfe_val) + + generated_all = np.asarray(multihost_utils.process_allgather(output.generated_tokens, tiled=False)) + sync_all(f"infer-generation-complete-p{p_idx}-temp{t_idx}") + if proc == 0: + generated = generated_all[0, 0].tolist() if generated_all.ndim == 3 else generated_all[0].tolist() + median_dt = sorted(measured_times)[len(measured_times) // 2] + tok_per_s = gen_length / median_dt if median_dt > 0 else float("nan") + all_results.append(( + p_idx, p_text, + { + "temperature": temp_val, + "seed": seed, + "nfe": measured_nfe[-1], + "gen_tokens": len(generated), + "median_s": median_dt, + "tok_per_s": tok_per_s, + "text": tokenizer.decode(generated, skip_special_tokens=True), + }, + )) if proc == 0: - generated = generated_all[0, 0].tolist() if generated_all.ndim == 3 else generated_all[0].tolist() - print("=" * 70) + print("\n" + "=" * 70) print(f"checkpoint_step={restored_step} epoch={restored_epoch} epoch_step={restored_epoch_step}") - print(f"prompt={prompt!r}") - print("generated:") - print(tokenizer.decode(generated, skip_special_tokens=True)) - print(f"nfe={nfe} generated_tokens={len(generated)}") - print(f"restore_plus_generate_seconds={time.time() - t0:.1f}") - print(f"generate_seconds={time.time() - tg:.1f}") + for p_idx, p_text, r in all_results: + print("=" * 70) + print(f"[prompt {p_idx + 1}/{len(prompts_list)}] {p_text!r}") + print("-" * 70) + print( + f"temp={r['temperature']} seed={r['seed']} " + f"nfe={r['nfe']} gen_tokens={r['gen_tokens']} " + f"median={r['median_s']:.2f}s tok_per_s={r['tok_per_s']:.1f}" + ) + print("generated:") + print(r["text"]) + print("=" * 70) + print(f"restore_plus_all_generate_seconds={time.time() - t0:.1f}") print("=" * 70, flush=True) diff --git a/scripts/tpu_v6e_smoke.py b/scripts/tpu_train.py similarity index 71% rename from scripts/tpu_v6e_smoke.py rename to scripts/tpu_train.py index ed139ed..19502ad 100644 --- a/scripts/tpu_v6e_smoke.py +++ b/scripts/tpu_train.py @@ -73,9 +73,13 @@ def load_dotenv(path: str = ".env") -> None: try: import orbax.checkpoint as ocp from orbax.checkpoint import options as ocp_options + from orbax.checkpoint import args as ocp_args + from orbax.checkpoint import type_handlers as ocp_type_handlers except ImportError: ocp = None ocp_options = None + ocp_args = None + ocp_type_handlers = None proc = jax.process_index() nproc = jax.process_count() @@ -83,10 +87,13 @@ def load_dotenv(path: str = ".env") -> None: nlocal = jax.local_device_count() print(f"[Worker {proc}/{nproc}] devices={ndev} local={nlocal} backend={jax.default_backend()}", flush=True) -# ── 2D mesh: FSDP (8) × TP (8) ────────────────────────────── -TP = 8 -DP = ndev // TP # 8 on v6e-64 -devices = mesh_utils.create_device_mesh((DP, TP)) +# ── 2D mesh: FSDP × TP (TP overridable via env, default 8) ── +TP = int(os.environ.get("TP", "8")) +DP = ndev // TP +try: + devices = mesh_utils.create_device_mesh((DP, TP)) +except NotImplementedError: + devices = mesh_utils.create_device_mesh((DP, TP), allow_split_physical_axes=True) mesh = Mesh(devices, axis_names=("fsdp", "tp")) if proc == 0: print(f"[Worker {proc}] 2D Mesh: fsdp={DP} × tp={TP}", flush=True) @@ -141,6 +148,12 @@ def gs_join(base: str, *parts: str) -> str: MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen3-8B") DATASET = os.environ.get("DATASET", "tinystories").lower() SYNTHETIC_DATA = DATASET in {"synthetic", "random", "dummy"} +# SFT datasets use per-example padded batches. Plain pretraining datasets +# concatenate tokens into a single stream. +SFT_DATASET = DATASET in {"openthoughts", "open-thoughts", "open-thoughts-114k"} +# Default: train on the whole chat-templated string (treat it like tinystories). +# Set SFT_TRAIN_ON_ANSWERS_ONLY=1 to revert to prompt-masked supervision. +SFT_TRAIN_ON_ANSWERS_ONLY = os.environ.get("SFT_TRAIN_ON_ANSWERS_ONLY", "0").strip().lower() in {"1", "true", "yes", "on"} DATASET_PATH = os.environ.get("DATASET_PATH") MODEL_SLUG = path_safe_name(MODEL_NAME.rstrip("/").rsplit("/", 1)[-1]) or "model" RUN_NAME = os.environ.get("RUN_NAME") or os.environ.get("WANDB_RUN_NAME") or f"{MODEL_SLUG}-{DATASET}-{int(time.time())}" @@ -201,6 +214,12 @@ def gs_join(base: str, *parts: str) -> str: DMAX_NOISE_LOW = float(os.environ.get("DMAX_NOISE_LOW", "0.75")) DMAX_NOISE_HIGH = float(os.environ.get("DMAX_NOISE_HIGH", "0.75")) DMAX_BLOCK_SIZE = int(os.environ.get("DMAX_BLOCK_SIZE", "32")) +# ── XProf / JAX profiler ───────────────────────────────────────────────────── +XPROF_ENABLE = env_flag("XPROF_ENABLE", False) +XPROF_DIR = os.environ.get("XPROF_DIR", "/tmp/xprof/run") +XPROF_START_STEP = int(os.environ.get("XPROF_START_STEP", "4")) +XPROF_STOP_STEP = int(os.environ.get("XPROF_STOP_STEP", "7")) +_xprof_active = False if WANDB_LOG and WANDB_MODE != "offline" and proc == 0: has_wandb_auth = bool(os.environ.get("WANDB_API_KEY")) or Path.home().joinpath(".netrc").exists() @@ -308,8 +327,25 @@ def estimate_model_params(cfg) -> int: return int(embedding_params + lm_head_params + layers * (attn_params + mlp_params + norm_params)) EST_PARAMS = estimate_model_params(config) -TRAIN_FLOPS_PER_TOKEN_DENSE = 6 * EST_PARAMS -TRAIN_FLOPS_PER_TOKEN_ATTN = 12 * int(config.num_hidden_layers) * MAX_LEN * int(config.hidden_size) +# DMax stacks [noised; clean] so each sample's compiled fwd+bwd runs at seq=2·L, +# and when on_policy_ratio>0 the compiled graph also pays a stop_gradient rollout +# fwd at seq=2·L (the per-sample on_policy_flag only gates the argmax, not execution). +_seq_clean = MAX_LEN +_seq_train_fwd = 2 * MAX_LEN if DMAX_ENABLE else MAX_LEN +_seq_rollout_fwd = 2 * MAX_LEN if (DMAX_ENABLE and DMAX_ON_POLICY_RATIO > 0) else 0 +_layers = int(config.num_hidden_layers) +_hidden = int(config.hidden_size) +# 6N/token fwd+bwd + 2N/token rollout fwd (dense equivalent). +_per_sample_dense = 6 * EST_PARAMS * _seq_train_fwd + 2 * EST_PARAMS * _seq_rollout_fwd +# 12·L·seq²·H fwd+bwd + 4·L·seq²·H rollout fwd (dense attention reference, even +# though splash may skip masked blocks — keeps the denominator comparable to the +# non-DMax baseline). +_per_sample_attn = ( + 12 * _layers * _seq_train_fwd * _seq_train_fwd * _hidden + + 4 * _layers * _seq_rollout_fwd * _seq_rollout_fwd * _hidden +) +TRAIN_FLOPS_PER_TOKEN_DENSE = _per_sample_dense // _seq_clean +TRAIN_FLOPS_PER_TOKEN_ATTN = _per_sample_attn // _seq_clean TRAIN_FLOPS_PER_TOKEN_TOTAL = TRAIN_FLOPS_PER_TOKEN_DENSE + TRAIN_FLOPS_PER_TOKEN_ATTN PEAK_FLOPS = PEAK_TFLOPS_PER_CHIP * 1e12 * ndev @@ -321,6 +357,8 @@ def estimate_model_params(cfg) -> int: "train_flops_per_token_attention": TRAIN_FLOPS_PER_TOKEN_ATTN, "train_flops_per_token_total": TRAIN_FLOPS_PER_TOKEN_TOTAL, "peak_flops": PEAK_FLOPS, + "train_seq_train_fwd": _seq_train_fwd, + "train_seq_rollout_fwd": _seq_rollout_fwd, }, allow_val_change=True, ) @@ -328,11 +366,14 @@ def estimate_model_params(cfg) -> int: if proc == 0: print(f"[Worker {proc}] {MODEL_NAME}: {config.num_hidden_layers}L h={config.hidden_size} " f"V={config.vocab_size}", flush=True) + _dmax_tag = ( + f" dmax(seq_fwd={_seq_train_fwd},rollout_fwd={_seq_rollout_fwd})" if DMAX_ENABLE else "" + ) print( f"[Worker {proc}] estimated_params={EST_PARAMS/1e9:.2f}B " f"flops/token dense={TRAIN_FLOPS_PER_TOKEN_DENSE/1e9:.1f}G " f"attn={TRAIN_FLOPS_PER_TOKEN_ATTN/1e9:.1f}G " - f"peak={PEAK_FLOPS/1e15:.2f} PFLOP/s", + f"peak={PEAK_FLOPS/1e15:.2f} PFLOP/s{_dmax_tag}", flush=True, ) if RUN_FULL_EPOCHS: @@ -360,8 +401,102 @@ def _sharded_flash_attn(q, k, v, sm_scale): if proc == 0: print(f"[Worker {proc}] Pallas flash installed (blocks=512)", flush=True) +# ── Splash attention for DMax block-diffusion mask ───────── +# The dense-mask fallback in _attention is ~1.5-2s/step for Qwen3-8B at seq=8192; +# splash_attention_kernel runs block-sparse flash over the block-diffusion pattern. +# DISABLE_SPLASH_ATTN=1 skips splash install and falls back to dense dot_product_attention +# with the full block-diffusion mask — slower but a useful isolation for +# numerics debugging. +if DMAX_ENABLE and not env_flag("DISABLE_SPLASH_ATTN", False): + try: + from jax.experimental.pallas.ops.tpu.splash_attention import ( + splash_attention_mask as _sm, + splash_attention_kernel as _sk, + ) + + def _block_diffusion_mask_numpy(seq_len_l: int, block_size: int) -> np.ndarray: + two_l = seq_len_l * 2 + q_idx = np.arange(two_l)[:, None] + kv_idx = np.arange(two_l)[None, :] + x0_q = q_idx >= seq_len_l + x0_kv = kv_idx >= seq_len_l + block_q = np.where(x0_q, (q_idx - seq_len_l) // block_size, q_idx // block_size) + block_kv = np.where(x0_kv, (kv_idx - seq_len_l) // block_size, kv_idx // block_size) + bd = (block_q == block_kv) & (x0_q == x0_kv) + off = (block_q > block_kv) & x0_kv & (~x0_q) + bc = (block_q >= block_kv) & x0_kv & x0_q + return bd | off | bc + + _num_heads_total = int(config.num_attention_heads) + _heads_per_tp = _num_heads_total // TP + _mask_np = _block_diffusion_mask_numpy(MAX_LEN, DMAX_BLOCK_SIZE).astype(np.bool_) + _splash_mask = _sm.MultiHeadMask(masks=[_sm.NumpyMask(_mask_np)] * _heads_per_tp) + _splash_bs = int(os.environ.get("SPLASH_BLOCK", "512")) + _splash_fused_bwd = env_flag("SPLASH_FUSED_BWD", True) + _bs_kwargs = dict( + block_q=_splash_bs, block_kv=_splash_bs, block_kv_compute=_splash_bs, + block_q_dkv=_splash_bs, block_kv_dkv=_splash_bs, block_kv_dkv_compute=_splash_bs, + use_fused_bwd_kernel=_splash_fused_bwd, + ) + if not _splash_fused_bwd: + _bs_kwargs.update(block_q_dq=_splash_bs, block_kv_dq=_splash_bs) + _splash_block_sizes = _sk.BlockSizes(**_bs_kwargs) + _splash_fn = _sk.make_splash_mha_single_device(mask=_splash_mask, block_sizes=_splash_block_sizes) + + def _splash_per_shard(q, k, v, sm_scale): + # q/k/v per shard: [B, H_local, T, D]. Splash expects [H, T, D] per call. + q_scaled = (q * sm_scale).astype(q.dtype) + return jax.vmap(_splash_fn)(q_scaled, k, v) + + def _sharded_masked_flash(q, k, v, sm_scale): + return shard_map( + lambda q, k, v: _splash_per_shard(q, k, v, sm_scale), + mesh=mesh, + in_specs=(P("fsdp", "tp", None, None),) * 3, + out_specs=P("fsdp", "tp", None, None), + check_rep=False, + )(q, k, v) + + dllm_models._MASKED_FLASH_ATTN_FN = _sharded_masked_flash + if proc == 0: + print( + f"[Worker {proc}] Splash attention installed for DMax block-diffusion " + f"(mask {_mask_np.shape}, heads_per_tp={_heads_per_tp}, " + f"block={_splash_bs}, fused_bwd={_splash_fused_bwd})", + flush=True, + ) + except Exception as _exc: + if proc == 0: + print(f"[Worker {proc}] Splash attention setup failed: {_exc}", flush=True) + # ── Per-layer remat ──────────────────────────────────────── -remat_policy = jax.checkpoint_policies.nothing_saveable +# REMAT_POLICY selects what to save between fwd and bwd. Saving more cuts the +# ~7.5% recompute cost but increases HBM — risky near the B=64 ceiling. +# nothing_saveable (default) recompute everything +# gate_up save the MLP gate*up product (biggest matmuls) +# qkv_gate_up also save q/k/v post-RoPE +# dots_saveable jax preset: save all dot outputs +# everything_saveable jax preset: save everything (HBM-hungry) +_REMAT_POLICY_NAME = os.environ.get("REMAT_POLICY", "nothing_saveable") +if _REMAT_POLICY_NAME == "nothing_saveable": + remat_policy = jax.checkpoint_policies.nothing_saveable +elif _REMAT_POLICY_NAME == "gate_up": + remat_policy = jax.checkpoint_policies.save_only_these_names("gate_up") +elif _REMAT_POLICY_NAME == "qkv_gate_up": + remat_policy = jax.checkpoint_policies.save_only_these_names( + "q", "k", "v", "gate_up", + ) +elif _REMAT_POLICY_NAME == "dots_saveable": + remat_policy = jax.checkpoint_policies.dots_saveable +elif _REMAT_POLICY_NAME == "everything_saveable": + remat_policy = jax.checkpoint_policies.everything_saveable +else: + raise ValueError( + f"Unknown REMAT_POLICY={_REMAT_POLICY_NAME!r}; expected one of " + "nothing_saveable, gate_up, qkv_gate_up, dots_saveable, everything_saveable." + ) +if proc == 0: + print(f"[Worker {proc}] remat_policy={_REMAT_POLICY_NAME}", flush=True) def _remat_hidden_for_heads(self, input_ids=None, *, inputs_embeds=None, attention_mask=None, position_ids=None): if inputs_embeds is not None: @@ -421,6 +556,13 @@ def create_block_diffusion_attention_mask(seq_len: int, block_size: int) -> jnp. n_loaded = 0 print(f"[Worker {proc}] LOAD_PRETRAINED=0; using random initialization", flush=True) +# MASK_EMBED_INIT runs POST-shard (see below, right after sharding) so that +# the .at[].set() peak holds sharded-per-chip copies (~75MB), not replicated +# copies (~2.4GB × 2 = 4.8GB) which OOM v5e-64's 16GB/chip HBM. +_mask_init_mode = os.environ.get("MASK_EMBED_INIT", "mean").strip().lower() +if _mask_init_mode not in {"mean", "mean-embed", "average", "none", "skip", "off", ""}: + raise ValueError(f"Unknown MASK_EMBED_INIT={_mask_init_mode!r}") + # ── Shard to 2D mesh ─────────────────────────────────────── print(f"[Worker {proc}] Sharding to 2D TPU mesh...", flush=True) t1 = time.time() @@ -471,17 +613,59 @@ def shard_tree(tree, label: str): gc.collect() print(f"[Worker {proc}] Sharding done in {time.time()-t1:.1f}s (total: {time.time()-t0:.1f}s)", flush=True) +# ── Mask-token embedding warm start (POST-shard) ─────────── +# DMax reuses ``mask_id`` (default vocab_size-1, optionally <|fim_pad|>=151662, +# etc.) as the MASK token. Untrained vocab rows collapse the noised forward +# (observed ~step 150 drift to uniform predictions). Seed the mask row with +# the MEAN of other input/output embedding rows for a sensible "average +# semantic" warm start. Runs POST-shard so the ``.at[].set()`` doesn't hold +# two full ~2.4GB replicas in HBM simultaneously. +if _mask_init_mode in {"mean", "mean-embed", "average"}: + _mask_id_local = int(os.environ.get("MASK_TOKEN_ID", str(int(config.vocab_size) - 1))) + # On-device compute to respect multi-host sharding — ``jax.device_get`` on + # a cross-host sharded jax.Array raises. XLA handles the sharded sum/slice. + emb_var = model.embed_tokens.embedding + _emb = emb_var.value + _sum_rows = jnp.sum(_emb.astype(jnp.float32), axis=0) + _mean_row = (_sum_rows - _emb[_mask_id_local].astype(jnp.float32)) / (_emb.shape[0] - 1) + emb_var.value = _emb.at[_mask_id_local].set(_mean_row.astype(_emb.dtype)) + del _emb, _sum_rows, _mean_row + gc.collect() + lm_head = getattr(model, "lm_head", None) + if lm_head is not None and hasattr(lm_head, "kernel"): + kvar = lm_head.kernel + _k = kvar.value + _sum_cols = jnp.sum(_k.astype(jnp.float32), axis=1) + _mean_col = (_sum_cols - _k[:, _mask_id_local].astype(jnp.float32)) / (_k.shape[1] - 1) + kvar.value = _k.at[:, _mask_id_local].set(_mean_col.astype(_k.dtype)) + del _k, _sum_cols, _mean_col + gc.collect() + if proc == 0: + print( + f"[Worker {proc}] MASK_EMBED_INIT=mean (post-shard, on-device): seeded row/col {_mask_id_local}", + flush=True, + ) + # ── Optimizer: AdamW (safer than Adafactor for pretrained MDLM) ── +# LR_SCHEDULE=cosine with LR_DECAY_STEPS=N decays from PEAK_LR to alpha*PEAK_LR +# (alpha=LR_DECAY_ALPHA, default 0.1) over N steps after warmup. Default schedule +# is constant-after-warmup (back-compat with prior runs). +_lr_schedule_kind = os.environ.get("LR_SCHEDULE", "constant").strip().lower() +_lr_decay_steps = int(os.environ.get("LR_DECAY_STEPS", "0")) +_lr_decay_alpha = float(os.environ.get("LR_DECAY_ALPHA", "0.1")) +if _lr_schedule_kind == "cosine" and _lr_decay_steps > 0: + _post_warmup = optax.cosine_decay_schedule( + init_value=PEAK_LR, decay_steps=_lr_decay_steps, alpha=_lr_decay_alpha, + ) +else: + _post_warmup = optax.constant_schedule(PEAK_LR) if WARMUP_STEPS > 0: lr_schedule = optax.join_schedules( - schedules=[ - optax.linear_schedule(0.0, PEAK_LR, WARMUP_STEPS), - optax.constant_schedule(PEAK_LR), - ], + schedules=[optax.linear_schedule(0.0, PEAK_LR, WARMUP_STEPS), _post_warmup], boundaries=[WARMUP_STEPS], ) else: - lr_schedule = optax.constant_schedule(PEAK_LR) + lr_schedule = _post_warmup if OPTIMIZER == "adafactor": # Factored optimizer state (~4x less HBM than AdamW) — required to fit # 128k context on Qwen3-8B per chip. Loss may climb back after ~step 60 @@ -509,7 +693,11 @@ def shard_tree(tree, label: str): # ── MDLM config + tokenizer + data ───────────────────────── mdlm_cfg = MDLMConfig(output_dir="/tmp/q3-8b-smoke", max_steps=NUM_STEPS, learning_rate=PEAK_LR) alpha_sched = LinearAlphaScheduler() -mask_id = config.vocab_size - 1 +# MASK_TOKEN_ID overrides the default (vocab_size-1, which is a truly-empty +# reserved slot on Qwen3 with an untrained embedding). Set to a pretrained +# Qwen3 special token like 151662 (<|fim_pad|>) to inherit useful "fill in +# the missing piece" semantics for the denoising task. +mask_id = int(os.environ.get("MASK_TOKEN_ID", str(config.vocab_size - 1))) print(f"[Worker {proc}] Loading tokenizer + dataset stream...", flush=True) tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) @@ -527,6 +715,8 @@ def make_dataset_iter(): ds = load_dataset("parquet", data_files=parquet_files, split="train", streaming=True) elif DATASET == "wikipedia": ds = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True) + elif SFT_DATASET: + ds = load_dataset("open-thoughts/OpenThoughts-114k", "default", split="train", streaming=True) else: ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True) return iter(ds) @@ -542,20 +732,130 @@ def refill_buffer(needed): exhausted = False while len(token_buffer) < needed: try: - text = next(ds_iter)["text"] + row = next(ds_iter) except StopIteration: exhausted = True break + if SFT_DATASET: + ids = _sft_row_to_tokens(row) + if ids is None: + continue + token_buffer.extend(ids) + else: + token_buffer.extend(tokenizer.encode(row["text"], add_special_tokens=False)) # Append EOS between docs so the model sees document boundaries. - token_buffer.extend(tokenizer.encode(text, add_special_tokens=False)) token_buffer.append(eos_id) return not exhausted + +# SFT pipeline: +# SFT_TRAIN_ON_ANSWERS_ONLY=0 (default, full-text): tokens go through +# ``_sft_row_to_tokens`` and are concat-packed by ``refill_buffer`` just +# like tinystories (no row dropping, no truncation — every chat-templated +# token is trained on). +# SFT_TRAIN_ON_ANSWERS_ONLY=1 (answer-only): per-example padded batch via +# ``_sft_batch`` + ``_tokenize_sft_row`` because packing can't preserve +# the prompt/answer boundary needed for the -100 label mask. +_ROLE_MAP = {"human": "user", "gpt": "assistant", "system": "system", "user": "user", "assistant": "assistant"} +_sft_dropped = 0 +_sft_kept = 0 + + +def _sft_row_to_messages(row): + convs = row.get("conversations") + if not convs: + return None + messages = [] + system = row.get("system") + if system: + messages.append({"role": "system", "content": system}) + for c in convs: + if isinstance(c, dict) and "from" in c: + messages.append({"role": _ROLE_MAP.get(c["from"], c["from"]), "content": c["value"]}) + else: + messages.append(c) + return messages + + +def _sft_row_to_tokens(row): + """Tokenize an SFT row via chat template. Used by the concat-pack pipeline. + Returns token ids (no length filter — packing handles arbitrary lengths).""" + messages = _sft_row_to_messages(row) + if messages is None: + return None + out = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=False, return_dict=True, + ) + return out["input_ids"] + + +def _tokenize_sft_row(row): + """Answer-only path: returns (input_ids, labels) of length ≤ MAX_LEN or + None to skip. Only called when SFT_TRAIN_ON_ANSWERS_ONLY=1.""" + messages = _sft_row_to_messages(row) + if messages is None: + return None + full_out = tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=False, return_dict=True, + ) + full_ids = full_out["input_ids"] + if len(full_ids) > MAX_LEN: + return None + # Prompt = everything except the final assistant turn; assistant supervision + # starts at prompt_len. Works for 1-user-1-assistant (OpenThoughts shape). + if len(messages) >= 2 and messages[-1]["role"] == "assistant": + prompt_msgs = messages[:-1] + else: + prompt_msgs = messages + prompt_out = tokenizer.apply_chat_template( + prompt_msgs, tokenize=True, add_generation_prompt=True, return_dict=True, + ) + prompt_len = min(len(prompt_out["input_ids"]), len(full_ids)) + labels = list(full_ids) + for i in range(prompt_len): + labels[i] = -100 + return full_ids, labels + + +def _sft_batch(): + global _sft_dropped, _sft_kept + ids = np.full((GLOBAL_BATCH, MAX_LEN), tokenizer.pad_token_id, dtype=np.int64) + labels = np.full((GLOBAL_BATCH, MAX_LEN), -100, dtype=np.int64) + filled = 0 + while filled < GLOBAL_BATCH: + try: + row = next(ds_iter) + except StopIteration: + if filled == 0: + return None + break + tok = _tokenize_sft_row(row) + if tok is None: + _sft_dropped += 1 + continue + row_ids, row_labels = tok + L = len(row_ids) + ids[filled, :L] = row_ids + labels[filled, :L] = row_labels + filled += 1 + _sft_kept += 1 + # Pad unfilled rows (end of stream) with a duplicate of row 0 to keep shape. + for i in range(filled, GLOBAL_BATCH): + ids[i] = ids[0] + labels[i] = labels[0] + return {"input_ids": ids, "labels": labels} + + def get_batch(): global token_buffer if SYNTHETIC_DATA: ids = synthetic_rng.integers(0, max(2, mask_id), size=(GLOBAL_BATCH, MAX_LEN), dtype=np.int64) return {"input_ids": ids, "labels": ids.copy()} + # Answer-only needs per-example padded batches (to preserve the -100 mask). + # Full-text SFT (default) concat-packs through ``refill_buffer`` below, + # reaching full MAX_LEN utilization with no row dropping. + if SFT_DATASET and SFT_TRAIN_ON_ANSWERS_ONLY: + return _sft_batch() refill_buffer(GLOBAL_BATCH * MAX_LEN) if not token_buffer: return None @@ -576,7 +876,8 @@ def loss_fn(mdl, batch): # DMax on-policy rollout: for examples flagged by on_policy_flag, replace # MASK tokens in the noisy half with the model's own greedy predictions # (stop_gradient) before the supervised forward. - if "on_policy_flag" in batch and batch["on_policy_flag"] is not None: + # Skip entirely when DMAX_ON_POLICY_RATIO=0 — saves a full 2L forward per step. + if DMAX_ON_POLICY_RATIO > 0 and "on_policy_flag" in batch and batch["on_policy_flag"] is not None: seq_len = batch["input_ids"].shape[1] flag = batch["on_policy_flag"][:, None] # [B, 1] rollout_logits = jax.lax.stop_gradient( @@ -931,7 +1232,7 @@ def save_training_checkpoint(global_step: int, epoch: int, epoch_step: int, *, f wandb.log({"checkpoint/global_step": global_step}, step=global_step) def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps: int | None): - global rng, total_tokens, last_epoch, last_epoch_step + global rng, total_tokens, last_epoch, last_epoch_step, _xprof_active ts = time.time() data_t0 = time.time() debug_step = global_step == 1 @@ -994,6 +1295,14 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps if debug_step: print(f"[Worker {proc}] step {global_step}: train_step start", flush=True) + if XPROF_ENABLE and global_step == XPROF_START_STEP and not _xprof_active: + import jax.profiler as _jprof + Path(XPROF_DIR).mkdir(parents=True, exist_ok=True) + _jprof.start_trace(XPROF_DIR) + _xprof_active = True + if proc == 0: + print(f"[Worker {proc}] xprof start (step {global_step}) -> {XPROF_DIR}", flush=True) + train_t0 = time.time() metrics = train_step(model, optimizer, batch) loss_val = float(metrics["loss"]) @@ -1046,6 +1355,12 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps step=global_step, ) save_training_checkpoint(global_step, epoch, epoch_step) + if XPROF_ENABLE and _xprof_active and global_step >= XPROF_STOP_STEP: + import jax.profiler as _jprof + _jprof.stop_trace() + _xprof_active = False + if proc == 0: + print(f"[Worker {proc}] xprof stop (step {global_step}) -> {XPROF_DIR}", flush=True) return True # ── Resume from checkpoint ──────────────────────────────── @@ -1095,22 +1410,64 @@ def run_training_step(global_step: int, epoch: int, epoch_step: int, total_steps print(f"[Worker {proc}] Restoring checkpoint: {ckpt_path}", flush=True) _resume_rng_placeholder = np.asarray(jax.random.key_data(jax.random.key(0))) + _resume_restore_optimizer = env_flag("RESUME_RESTORE_OPTIMIZER", True) restore_target = { "model": nnx.state(model), - "optimizer": nnx.state(optimizer), "global_step": np.asarray(0, dtype=np.int64), "epoch": np.asarray(0, dtype=np.int32), "epoch_step": np.asarray(0, dtype=np.int64), "total_tokens": np.asarray(0, dtype=np.int64), "rng": _resume_rng_placeholder, } + if _resume_restore_optimizer: + restore_target["optimizer"] = nnx.state(optimizer) # Use orbax directly for GCS restore (flax's restore_checkpoint can't list GCS dirs) resume_ckpt = make_orbax_checkpointer() if proc == 0: print(f"[Worker {proc}] Starting orbax restore...", flush=True) with orbax_set_mesh_context_patch(): - restored = resume_ckpt.restore(ckpt_path, target=restore_target) + if not _resume_restore_optimizer and ocp_args is not None and ocp_type_handlers is not None: + # Model-only restore: use PyTreeRestore(partial_restore=True) so the + # checkpoint's optimizer state (different shape / optimizer family) + # is simply ignored. + def _to_restore_args(x): + if isinstance(x, jax.Array): + return ocp_type_handlers.ArrayRestoreArgs( + restore_type=jax.Array, + dtype=x.dtype, + sharding=x.sharding, + global_shape=x.shape, + ) + if isinstance(x, (jnp.ndarray, np.ndarray)): + arr = np.asarray(x) + return ocp_type_handlers.RestoreArgs(restore_type=np.ndarray, dtype=arr.dtype) + return ocp_type_handlers.RestoreArgs() + + restore_args_tree = jax.tree_util.tree_map( + _to_restore_args, + restore_target, + is_leaf=lambda x: isinstance(x, (jax.Array, jnp.ndarray, np.ndarray)), + ) + pytree_ckpt = ( + ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + if ocp_options is None + else ocp.Checkpointer( + ocp.PyTreeCheckpointHandler( + multiprocessing_options=ocp_options.MultiprocessingOptions(primary_host=0) + ) + ) + ) + restored = pytree_ckpt.restore( + ckpt_path, + args=ocp_args.PyTreeRestore( + item=restore_target, + restore_args=restore_args_tree, + partial_restore=True, + ), + ) + else: + restored = resume_ckpt.restore(ckpt_path, target=restore_target) if proc == 0: print(f"[Worker {proc}] Orbax restore done, re-sharding...", flush=True) @@ -1184,16 +1541,45 @@ def reshard_tree(tree, label: str): restored_model_state = reshard_tree(restored["model"], "resume-model") model = nnx.merge(gdef, restored_model_state) - opt_gdef, _ = nnx.split(optimizer) - restored_opt_state = reshard_tree(restored["optimizer"], "resume-optimizer") - optimizer = nnx.merge(opt_gdef, restored_opt_state) - model = optimizer.model # rebind after merge - - resumed_step = int(restored["global_step"]) - resumed_epoch = int(restored["epoch"]) - resumed_epoch_step = int(restored["epoch_step"]) - total_tokens = int(restored["total_tokens"]) - rng = jax.random.wrap_key_data(restored["rng"]) + if _resume_restore_optimizer: + opt_gdef, _ = nnx.split(optimizer) + restored_opt_state = reshard_tree(restored["optimizer"], "resume-optimizer") + optimizer = nnx.merge(opt_gdef, restored_opt_state) + model = optimizer.model # rebind after merge + else: + if proc == 0: + print( + f"[Worker {proc}] RESUME_RESTORE_OPTIMIZER=0: keeping fresh " + f"{OPTIMIZER} optimizer state (ignoring checkpoint's optimizer)", + flush=True, + ) + opt_gdef, opt_state = nnx.split(optimizer) + optimizer = nnx.merge(opt_gdef, opt_state) + model = optimizer.model # rebind + + # RESUME_RESET_STEP=1 zeroes the step/epoch counters after loading weights. + # Use when continuing onto a different dataset (e.g. pretrained checkpoint → + # SFT) so the training budget (NUM_STEPS / NUM_EPOCHS) starts from scratch. + _resume_reset_step = env_flag("RESUME_RESET_STEP", False) + if _resume_reset_step: + if proc == 0: + print( + f"[Worker {proc}] RESUME_RESET_STEP=1: ignoring checkpoint step " + f"{int(restored['global_step'])}/epoch {int(restored['epoch'])} " + "and starting training from step 0", + flush=True, + ) + resumed_step = 0 + resumed_epoch = 0 + resumed_epoch_step = 0 + total_tokens = 0 + rng = jax.random.key(42) + else: + resumed_step = int(restored["global_step"]) + resumed_epoch = int(restored["epoch"]) + resumed_epoch_step = int(restored["epoch_step"]) + total_tokens = int(restored["total_tokens"]) + rng = jax.random.wrap_key_data(restored["rng"]) del restored, restore_target, _resume_rng_placeholder gc.collect() @@ -1239,10 +1625,18 @@ def reshard_tree(tree, label: str): if proc == 0: print(f"[Worker {proc}] Finished epoch {epoch}/{NUM_EPOCHS} after {epoch_step - 1} steps", flush=True) else: + _PROFILE_DIR = os.environ.get("JAX_PROFILE_DIR") + _PROFILE_START = int(os.environ.get("JAX_PROFILE_START_STEP", "0")) + _PROFILE_STEPS = int(os.environ.get("JAX_PROFILE_STEPS", "0")) for step in range(resumed_step + 1, NUM_STEPS + 1): global_step = step + if _PROFILE_DIR and step == _PROFILE_START: + jax.profiler.start_trace(_PROFILE_DIR) if not run_training_step(global_step, epoch=0, epoch_step=step, total_steps=NUM_STEPS): break + if _PROFILE_DIR and _PROFILE_STEPS and step == _PROFILE_START + _PROFILE_STEPS - 1: + jax.block_until_ready(step) + jax.profiler.stop_trace() tt = time.time() - t_start if CHECKPOINT_ON_FINISH: