Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 177 additions & 28 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,42 +1,191 @@
# =============================================================================
# dllm-jax environment configuration
# Copy this file to .env and fill in your values.
# dllm-jax environment configuration — full reference
#
# Copy to `.env` (training) or `.env.deploy` (deploy/run scripts) and fill in
# the values you need. EVERY variable below has a sensible default in the
# code; uncomment only what you want to override.
#
# Sections:
# 1. Model & tokenizer
# 2. Dataset
# 3. Training schedule
# 4. Learning-rate schedule
# 5. Optimizer
# 6. Sharding (TP / FSDP)
# 7. DMax (block-diffusion masked LM)
# 8. SFT (chat-template fine-tuning)
# 9. Performance knobs (splash, remat, mask-embed init)
# 10. GCS checkpointing
# 11. Resume from checkpoint
# 12. HuggingFace Hub (auth + checkpoint mirror)
# 13. Weights & Biases
# 14. Profiling (xprof / jax profile)
# 15. Inference (scripts/tpu_dmax_infer_checkpoint.py)
# =============================================================================

# ---- GCS Checkpoints --------------------------------------------------------
# Option A: Regional bucket (auto-detected from TPU zone)
# The script builds the bucket name as: gs://${CHECKPOINT_BUCKET_PREFIX}-${region}
# e.g. gs://dllm-jax-us-east1 or gs://dllm-jax-europe-west4
CHECKPOINT_BUCKET_PREFIX=dllm-jax

# Option B: Fixed (non-regional) bucket name — overrides auto-detection
# CHECKPOINT_BUCKET=my-bucket-name
# ── 1. Model & tokenizer ────────────────────────────────────────────────────
MODEL_NAME=Qwen/Qwen3-8B
# DTYPE=bfloat16 # compute dtype: bfloat16 | float16 | float32
# MASK_TOKEN_ID= # default: vocab_size-1 (e.g. 151935 for Qwen3)
# try 151662 (<|fim_pad|>) for a pretrained warm start
# EOS_TOKEN_ID= # default: tokenizer's eos_token_id

# Option C: Full checkpoint directory — overrides both A and B
# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run

# Checkpoint frequency and retention
CHECKPOINT_STEPS=500
CHECKPOINT_KEEP=2
# ── 2. Dataset ──────────────────────────────────────────────────────────────
DATASET=tinystories # tinystories | wikipedia | openthoughts | synthetic | parquet
# DATASET_PATH= # required when DATASET=parquet (file or dir of *.parquet)

# ---- Weights & Biases -------------------------------------------------------
# Enable W&B logging (set to 1 to enable)
WANDB_LOG=0
WANDB_PROJECT=dllm-jax
# WANDB_API_KEY=your-wandb-api-key-here
# Get your API key: https://wandb.ai/authorize
# Or run `wandb login` on each TPU worker instead.

# ---- Model & Training -------------------------------------------------------
MODEL_NAME=Qwen/Qwen3-8B
DATASET=tinystories
# ── 3. Training schedule ────────────────────────────────────────────────────
MAX_LEN=16384
GLOBAL_BATCH=8
NUM_STEPS=20
NUM_STEPS=20 # set to 0 to use NUM_EPOCHS instead
NUM_EPOCHS=0
WARMUP_STEPS=5
PEAK_LR=1e-4
# RUN_NAME= # default: ${MODEL_SLUG}-${DATASET}-${unix_ts}
# LOGGING_STEPS=1
# LOAD_PRETRAINED=1 # 0 = random init (debugging only)


# ── 4. Learning-rate schedule (post-warmup) ─────────────────────────────────
# LR_SCHEDULE=constant # constant | cosine
# LR_DECAY_STEPS=0 # cosine: total decay steps after warmup
# LR_DECAY_ALPHA=0.1 # cosine: final LR = PEAK_LR * alpha


# ── 5. Optimizer ────────────────────────────────────────────────────────────
OPTIMIZER=adamw # adamw (recommended) | adafactor
# NOTE: Adafactor has been observed to diverge
# on Qwen3-8B MDLM/DMax around step ~60.
# Default to AdamW for real training.


# ── 6. Sharding ─────────────────────────────────────────────────────────────
# TP determines the tensor-parallel axis size; FSDP is derived as
# fsdp = jax.device_count() // TP
# v4-32 / v6e-256: TP=8 is fine.
# v5e-64: prefer TP=2 — TP=8 forces FSDP=8 and bloats optimizer state ~4x.
TP=8


# ── 7. DMax (block-diffusion masked LM) ─────────────────────────────────────
# DMAX_ENABLE=0 # 1 = train block-diffusion / DMax objective
# DMAX_BLOCK_SIZE=32 # tokens per noising block
# DMAX_ON_POLICY_RATIO=0.5 # fraction of steps using on-policy noise
# DMAX_NOISE_LOW=0.75
# DMAX_NOISE_HIGH=0.75


# ── 8. SFT (chat-template fine-tuning) ──────────────────────────────────────
# Used when DATASET=openthoughts (or other chat-templated SFT dataset).
# SFT_TRAIN_ON_ANSWERS_ONLY=0 # 0 = train on full chat-templated text (default,
# packs like pretraining)
# 1 = mask prompt tokens with -100 (per-example
# padded batches)


# ── 9. Performance knobs ────────────────────────────────────────────────────
# SPLASH_BLOCK=512 # splash_attention tile size for training
# SPLASH_FUSED_BWD=1 # use fused backward kernel
# DISABLE_SPLASH_ATTN=0 # 1 = fall back to dense attention (debugging)
# REMAT_POLICY=nothing_saveable # nothing_saveable | dots_saveable | minimal_checkpoint | None
# # see scripts/tpu_v6e_smoke.py for the full list

# Mask-token embedding warm start (DMax). On v5e-64 the .at[].set() leaks
# ~9.5GB of HBM when run pre-shard on replicated weights → use 'none' there.
# MASK_EMBED_INIT=mean # mean | none

# PEAK_TFLOPS_PER_CHIP=918 # v6e default; v4=275, v5e=197, v5p=459


# ── 10. GCS checkpointing ───────────────────────────────────────────────────
# Pick ONE of A / B / C below.
# A. Regional bucket (auto-detected from TPU zone): bucket name becomes
# gs://${CHECKPOINT_BUCKET_PREFIX}-${region}
CHECKPOINT_BUCKET_PREFIX=dllm-jax
# B. Fixed bucket — overrides A.
# CHECKPOINT_BUCKET=my-bucket-name
# C. Full directory — overrides A and B.
# CHECKPOINT_DIR=gs://my-bucket/checkpoints/my-run

# CHECKPOINT_SUBDIR=checkpoints # path under bucket
CHECKPOINT_STEPS=500 # save every N optimizer steps
# CHECKPOINT_SECONDS=0 # also save every N seconds (0 = disabled)
CHECKPOINT_KEEP=2 # how many recent checkpoints to retain
# CHECKPOINT_ON_FINISH=0 # 1 = always write a final checkpoint
# CHECKPOINT_USE_GCS=1 # 0 = force local even if CHECKPOINT_BUCKET set
# LOCAL_CHECKPOINT_DIR=/tmp/dllm-jax-checkpoints # used when GCS not configured

# Orbax internals (rarely changed):
# CHECKPOINT_ORBAX_SYNC_DIRS=1
# CHECKPOINT_ORBAX_SIGNAL_FALLBACK=1


# ── 11. Resume from checkpoint ──────────────────────────────────────────────
# RESUME_DIR=gs://your-bucket/checkpoints/old-run
# RESUME_STEP=0 # 0 = latest step in RESUME_DIR
# RESUME_RESTORE_OPTIMIZER=1 # 0 = restore weights only (e.g. pretrain → SFT)
# RESUME_RESET_STEP=0 # 1 = zero global_step/epoch counters after load
# (use when continuing onto a new dataset
# so NUM_STEPS/NUM_EPOCHS budget restarts)


# ── 12. HuggingFace Hub ─────────────────────────────────────────────────────
# HF_TOKEN=hf_... # for private model downloads / checkpoint upload
# HF_CHECKPOINT_REPO=username/repo
# HF_CHECKPOINT_REPO_TYPE=model # model | dataset
# HF_CHECKPOINT_PATH=checkpoints # path within the HF repo
# HF_CHECKPOINT_PRIVATE=1


# ── 13. Weights & Biases ────────────────────────────────────────────────────
WANDB_LOG=1 # always enable on training runs (incl. smokes)
WANDB_PROJECT=dllm-jax
# WANDB_RUN_NAME= # default: same as RUN_NAME
# WANDB_MODE=online # online | offline | disabled
# WANDB_API_KEY= # or run `wandb login` once on each TPU worker.
# Get a key: https://wandb.ai/authorize
# NEVER commit a real key. Keep it in .env.deploy
# (which is gitignored) or in ~/.netrc.


# ── 14. Profiling ───────────────────────────────────────────────────────────
# XPROF_ENABLE=0 # capture xprof traces
# XPROF_DIR=/tmp/xprof/run
# XPROF_START_STEP=4
# XPROF_STOP_STEP=7

# JAX programmatic profile (separate from xprof):
# JAX_PROFILE_DIR=
# JAX_PROFILE_START_STEP=0
# JAX_PROFILE_STEPS=0


# ── 15. Inference (scripts/tpu_dmax_infer_checkpoint.py) ────────────────────
# Required for inference: RESUME_DIR + RESUME_STEP (see section 11).
# PROMPT="Once upon a time"
# PROMPTS_FILE= # path to file with one prompt per line
# (overrides PROMPT; '#' lines = comments)
# GEN_LENGTH=32
# BLOCK_LENGTH=32
# STEPS=8 # max denoising steps per block
# THRESHOLD=0.95 # per-token confidence to accept early
# CONFIDENCE_STOP=0.9 # block-level early-exit threshold
# TOP_K=1 # soft-mix top-K for SPD (3 is good for Qwen3)
# TEMPERATURE=0.0 # 0.0 = greedy
# TEMPS= # comma list runs a sweep in one process,
# e.g. "0.0,0.3,0.5,0.7,1.0"
# SEED= # int; defaults to time-based

# INFER_IMPL=fast # fast | kv_fast | legacy
# INFER_SPLASH=0 # 1 = splash kernel for fast-path block-causal mask
# INFER_SPLASH_BLOCK=512
# INFER_KV_DTYPE=bf16 # bf16 (default, ~2x cache HBM savings) | fp32
# FAST_BUCKET_LENGTH=4096

# ---- HuggingFace Hub (optional checkpoint upload) ---------------------------
# HF_TOKEN=your-hf-token-here
# HF_CHECKPOINT_REPO=your-username/your-repo
# WARMUP_RUNS=0 # throwaway generates before measured runs
# MEASURED_RUNS=1 # report median over N timed runs
# RESTORE_OPTIMIZER=0 # 1 = restore optimizer state too (rarely useful for inference)
# SUPPRESS_MASK_TOKEN=0 # 1 = force-disable mask-token logits at argmax
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.env.*
!.env.example
*.lock
__pycache__/
*.egg-info/
Expand Down
Loading