Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
a7e1690
fix: support non-monotonic CP block masks
bradhilton Jun 18, 2026
a054d9d
fix: remove uncommitted packing dependency from CP tests
bradhilton Jun 18, 2026
f7aba7c
feat: add TrainerRank and generic tree GDN
bradhilton Jun 20, 2026
5ffdca7
fix: satisfy quality checks
bradhilton Jun 20, 2026
3fc0f09
fix: balance tree GDN CP plans
bradhilton Jun 20, 2026
41c7fbb
feat: add dynamic LoRA slots for TrainerRank
bradhilton Jun 22, 2026
c70f4b6
fix: resolve vllm runtime python for codec tests
bradhilton Jun 22, 2026
94afb0f
style: apply TrainerRank formatting
bradhilton Jun 22, 2026
327ced4
fix: guard dynamic slot optimizer reductions
bradhilton Jun 22, 2026
02e7103
Merge remote-tracking branch 'origin/main' into feat/trainer-rank-gdn…
bradhilton Jun 22, 2026
9dcc9f9
test: add megatron review readiness harness
bradhilton Jun 22, 2026
d78eab7
test: add Austin-focused TrainerRank validation
bradhilton Jun 23, 2026
991aaff
test: harden Austin validation harness
bradhilton Jun 23, 2026
896d344
fix: prune false exact CP mask blocks
bradhilton Jun 23, 2026
fb8b188
fix: refine full exact CP mask blocks
bradhilton Jun 23, 2026
0c8a16b
perf: avoid unnecessary exact CP mask refinement
bradhilton Jun 23, 2026
1ac8edf
fix: keep exact CP refinement for deep trees
bradhilton Jun 23, 2026
19cb41a
test: make CP mask parity slice aware
bradhilton Jun 23, 2026
f464939
fix: broadcast slice-aware mask parity
bradhilton Jun 23, 2026
f640167
perf: vectorize depth-one CP mask refinement
bradhilton Jun 23, 2026
3cf1fe5
perf: keep depth-one CP mask refinement vectorized
bradhilton Jun 23, 2026
63ca045
fix: handle empty CP correctness shards
bradhilton Jun 23, 2026
5fb5924
fix: collect perf metadata on all ranks
bradhilton Jun 23, 2026
f79f520
bench: use CP stage masks for flex timing
bradhilton Jun 23, 2026
2526641
fix: benchmark production sparse flex path
bradhilton Jun 23, 2026
1e95437
fix: mirror production flex stage padding in review bench
bradhilton Jun 23, 2026
3578fa4
fix: default flex review bench to qwen head dim
bradhilton Jun 23, 2026
76ded79
perf: skip exact refinement for homogeneous mask blocks
bradhilton Jun 23, 2026
e40b063
perf: use generic interval block mask
bradhilton Jun 23, 2026
b8e3649
feat: add adaptive trainer rank microbatches
bradhilton Jun 23, 2026
d2942d1
fix: handle uninitialized trainer rank topology
bradhilton Jun 23, 2026
984e5c2
fix: report trainer rank oom call context
bradhilton Jun 23, 2026
e47cae3
bench: add TrainerRank adaptive microbatch perf cases
bradhilton Jun 24, 2026
c2a4c30
fix: keep masked TrainerRank outputs graph-connected
bradhilton Jun 24, 2026
d57a3ba
bench: expose TrainerRank adaptive memory knobs
bradhilton Jun 24, 2026
c6891b8
bench: profile TrainerRank adaptive microbatches
bradhilton Jun 24, 2026
259ade9
bench: split adaptive selector profile
bradhilton Jun 24, 2026
936f040
perf: cache adaptive TrainerRank plans
bradhilton Jun 24, 2026
7ff317a
perf: preflight adaptive TrainerRank candidate plans
bradhilton Jun 24, 2026
41abffc
perf: speed up adaptive packed-token estimates
bradhilton Jun 24, 2026
e594d8a
perf: defer adaptive TrainerRank plan materialization
bradhilton Jun 24, 2026
2dade94
fix: restore quality checks
bradhilton Jun 24, 2026
333bc00
test: add TrainerRank weird-shape fast gate
bradhilton Jun 24, 2026
6581afd
fix: format TrainerRank fast gate
bradhilton Jun 24, 2026
d48e576
refactor: extract TrainerRank adaptive planner
bradhilton Jun 24, 2026
57aa2a4
refactor: trim TrainerRank planning surface
bradhilton Jun 24, 2026
6ef09ac
fix: preserve masked target logprobs for shared rows
bradhilton Jun 24, 2026
500b451
refactor: collapse TrainerRank adaptive cache key
bradhilton Jun 24, 2026
2d1fd56
refactor: collapse TrainerRank memory profile checks
bradhilton Jun 24, 2026
9c55ce2
fix: update TrainerRank adaptive perf profiler
bradhilton Jun 24, 2026
5456424
refactor: collapse TrainerRank memory checks
bradhilton Jun 24, 2026
b54fb7e
refactor: trim TrainerRank forwarding checks
bradhilton Jun 24, 2026
7f3b29e
refactor: simplify TrainerRank memory signature
bradhilton Jun 24, 2026
6cb6385
refactor: keep target oracle helper in dev
bradhilton Jun 24, 2026
714a125
refactor: localize topology oracle logits helper
bradhilton Jun 24, 2026
ee621c9
refactor: use hidden head path for target logprobs
bradhilton Jun 24, 2026
1be2840
refactor: trim TrainerRank head and planner paths
bradhilton Jun 24, 2026
5d032e2
refactor: return full TrainerRank head outputs
bradhilton Jun 24, 2026
5a61de9
refactor: simplify TrainerRank slot grad plumbing
bradhilton Jun 24, 2026
fc802ff
refactor: cache TrainerRank checkpoint slot params
bradhilton Jun 24, 2026
22ec72d
refactor: collapse LoRA parallel layout builders
bradhilton Jun 24, 2026
9e40647
refactor: unify shared-prefix planning
bradhilton Jun 24, 2026
9e2c396
refactor: collapse GDN bucket builders
bradhilton Jun 24, 2026
9ce5b56
refactor: remove stale context parallel helpers
bradhilton Jun 24, 2026
1b02e79
refactor: unify context parallel peer exchange
bradhilton Jun 24, 2026
0d419f4
refactor: centralize adapter export traversal
bradhilton Jun 24, 2026
edb0009
refactor: use generic context parallel pair planning
bradhilton Jun 24, 2026
404e2d3
refactor: remove unused context planner metadata
bradhilton Jun 24, 2026
4399cb9
chore: fix review perf harness
bradhilton Jun 25, 2026
49f40cd
refactor: inline trainer rank microbatch planner
bradhilton Jun 25, 2026
97ee6aa
perf: prune shared-prefix block-mask refinement
bradhilton Jun 25, 2026
33f7dff
refactor: simplify trainer rank planning records
bradhilton Jun 25, 2026
58d980e
perf: skip full-slice minmax in block masks
bradhilton Jun 25, 2026
d5217d7
refactor: simplify context parallel state records
bradhilton Jun 25, 2026
a46fdc8
refactor: simplify gdn tree state lookup
bradhilton Jun 25, 2026
04c43ef
perf: stabilize adaptive trainer windows
bradhilton Jun 25, 2026
b728092
perf: keep adaptive window size stable across tails
bradhilton Jun 25, 2026
f8119f5
perf: balance adaptive trainer windows
bradhilton Jun 25, 2026
5e712b8
chore: trace adaptive trainer perf windows
bradhilton Jun 25, 2026
63b8a19
chore: trace adaptive train-step windows
bradhilton Jun 25, 2026
2bfd1c8
chore: cap review block-mask validation
bradhilton Jun 25, 2026
2b1220a
refactor: simplify trainer rank head helpers
bradhilton Jun 25, 2026
a4ab209
perf: bound adaptive memory profile growth
bradhilton Jun 25, 2026
c4f1f6e
perf: speed shared-prefix common-prefix scan
bradhilton Jun 25, 2026
6958ca1
chore: add review perf thresholds
bradhilton Jun 25, 2026
d896cd0
refactor: simplify block-mask interval refinement
bradhilton Jun 25, 2026
c297a3b
refactor: unify trainer rank local stats kernels
bradhilton Jun 25, 2026
5b1af1a
refactor: simplify lora slot support
bradhilton Jun 25, 2026
a0d4d15
refactor: simplify shared-expert lora wrapper
bradhilton Jun 25, 2026
6eed4cb
refactor: collapse lora moe wrappers
bradhilton Jun 25, 2026
02a9732
refactor: simplify trainer rank lora plumbing
bradhilton Jun 25, 2026
0a924d4
refactor: collapse context parallel dispatch glue
bradhilton Jun 25, 2026
e2fccac
refactor: trim context parallel runtime plans
bradhilton Jun 25, 2026
9b5c5da
refactor: unify trainer rank head projection
bradhilton Jun 25, 2026
e282ddd
refactor: simplify trainer rank planning
bradhilton Jun 25, 2026
58b5e5d
refactor: simplify block mask validation
bradhilton Jun 25, 2026
a67fd86
refactor: remove alternate cp planner strategies
bradhilton Jun 25, 2026
15c01b6
refactor: trim trainer rank gdn surfaces
bradhilton Jun 25, 2026
ca369df
refactor: shrink shared prefix parser metadata
bradhilton Jun 25, 2026
3429e47
refactor: use shared prefix pack in trainer rank
bradhilton Jun 25, 2026
5ccb6f9
refactor: trim trainer rank packing helpers
bradhilton Jun 25, 2026
3893e9b
refactor: share trainer rank grad sync helpers
bradhilton Jun 25, 2026
b42b76f
refactor: unify gdn tree bucket planning
bradhilton Jun 25, 2026
3519caf
refactor: trim adaptive planner cache keys
bradhilton Jun 25, 2026
3b14f35
refactor: inline stale trainer rank helpers
bradhilton Jun 25, 2026
bbdd9d5
test: allow tensor fields in gdn layout fixture
bradhilton Jun 25, 2026
d1a5283
fix: type triton topk launch constants
bradhilton Jun 25, 2026
865e945
test: handle fused expert fc1 lora in oracle
bradhilton Jun 25, 2026
8a827d1
test: canonicalize padded forward trace rows
bradhilton Jun 25, 2026
0a4bb28
fix: run cp tree child buckets recurrently
bradhilton Jun 25, 2026
16e51b8
Revert "fix: run cp tree child buckets recurrently"
bradhilton Jun 25, 2026
fa09126
fix: keep unchained gdn subtrees colocated
bradhilton Jun 25, 2026
db4372c
fix: type triton topk launches
bradhilton Jun 25, 2026
47a74b4
fix: type grouped lora calls
bradhilton Jun 25, 2026
57229cf
fix: scope trainer rank memory checks to model group
bradhilton Jun 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/workflows/build-gpu-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ on:
required: true
default: true
type: boolean
prewarm_modal:
description: "Prebuild the pushed image in Modal when auth is configured"
required: true
default: true
type: boolean
prewarm_timeout:
description: "Timeout for GPU node prewarm rollout"
required: true
Expand Down Expand Up @@ -155,11 +160,16 @@ jobs:
PULL_IMAGE_REPO: ${{ inputs.pull_image_repo || 'docker.io/bradhiltonnw/art-gpu' }}
IMAGE_TAG: ${{ inputs.tag }}
NO_CACHE: ${{ inputs.no_cache }}
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
PREWARM_MODAL_INPUT: ${{ inputs.prewarm_modal }}
PREWARM_NODES: ${{ inputs.prewarm_nodes }}
PREWARM_TIMEOUT: ${{ inputs.prewarm_timeout }}
run: |
IMAGE_TAG="${IMAGE_TAG:-latest}"
NO_CACHE="${NO_CACHE:-false}"
export PREWARM_MODAL="${PREWARM_MODAL:-auto}"
PREWARM_MODAL_INPUT="${PREWARM_MODAL_INPUT:-true}"
PREWARM_NODES="${PREWARM_NODES:-true}"
PREWARM_TIMEOUT="${PREWARM_TIMEOUT:-30m}"

Expand All @@ -175,6 +185,10 @@ jobs:
args+=(--no-cache)
fi

if [ "${PREWARM_MODAL_INPUT}" = "false" ]; then
args+=(--no-prewarm-modal)
fi

if [ "${PREWARM_NODES}" != "true" ]; then
args+=(--no-prewarm-nodes)
fi
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ data/cache.db
streaming-chat-completions/
unsloth_compiled_cache/
wandb/
!/typings/wandb/
!/typings/wandb/**
docs/node_modules/
dist/
replays/
Expand Down
114 changes: 114 additions & 0 deletions dev/trainer_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import os

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
import typer

from art.megatron.trainer_rank import AdamParams, ForwardInput, TrainerRank


def main(
model: str = "Qwen/Qwen3-0.6B",
dataset: str = "roneneldan/TinyStories",
split: str = "train",
text_column: str = "text",
samples: int = 16,
steps: int = 1,
lr: float = 5e-5,
layers: int = 2,
max_seq_length: int = 256,
) -> None:
os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1")
os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1")
os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1")

if not torch.cuda.is_available():
raise RuntimeError("dev/trainer_rank.py requires CUDA")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group(backend="nccl")

try:
from datasets import load_dataset

from art.megatron import train as megatron_train

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
inputs: list[ForwardInput[torch.Tensor, None, None, None]] = []
for row in load_dataset(dataset, split=split, streaming=True):
text = str(row.get(text_column, "")).strip() # type: ignore[union-attr]
if not text:
continue
token_ids = tokenizer(
text,
add_special_tokens=True,
truncation=True,
max_length=max_seq_length + 1,
return_tensors="pt",
)["input_ids"].reshape(-1)
if int(token_ids.numel()) <= 1:
continue
inputs.append(
ForwardInput(
input_tokens=token_ids[:-1],
target_tokens=token_ids[1:],
)
)
if len(inputs) >= samples:
break
if not inputs:
raise RuntimeError("dataset produced no tokenized training examples")

runtime = megatron_train.build_training_runtime(
model_identifier=model,
provider_configure=lambda provider: setattr(
provider,
"num_layers",
layers,
),
print_env=dist.get_rank() == 0,
)
rank = TrainerRank(runtime)
if dist.get_rank() == 0:
print(
"TrainerRank ready: "
f"dp={megatron_train.ps.get_data_parallel_world_size()} "
f"device={rank.device}",
flush=True,
)

for step in range(steps):
loss_sum = torch.tensor(0.0, device=rank.device)
token_count = torch.tensor(0.0, device=rank.device)
for micro in rank.forward_micro_batches(inputs):
loss = torch.tensor(0.0, device=rank.device)
for output in micro.outputs:
assert output.target_logprobs is not None
loss = loss - output.target_logprobs.sum()
token_count += output.target_logprobs.numel()
if loss.requires_grad:
loss.backward()
loss_sum += loss.detach()

rank.dp_reduce(loss_sum)
rank.dp_reduce(token_count)
scale = 1.0 / max(float(token_count.item()), 1.0)
metrics = rank.optim_step(
params=AdamParams(learning_rate=lr),
scale_grads=scale,
)
metrics["loss"] = float(loss_sum.item() * scale)
metrics["tokens"] = float(token_count.item())
if dist.get_rank() == 0:
print(f"step={step} {metrics}", flush=True)

dist.barrier()
finally:
if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
typer.run(main)
25 changes: 25 additions & 0 deletions dev/trainer_rank_fast_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import subprocess
import sys

FAST_TESTS = (
"tests/unit/test_trainer_rank_validation.py",
"tests/unit/test_trainer_rank_weird_shapes.py",
"tests/unit/test_shared_prefix_packing.py",
"tests/unit/test_shared_prefix_tree.py",
"tests/unit/test_shared_prefix_attention_builder.py",
"tests/unit/test_shared_prefix_grad_parity.py",
)


def main() -> None:
raise SystemExit(
subprocess.call(
[sys.executable, "-m", "pytest", "--tb=short", *FAST_TESTS, *sys.argv[1:]]
)
)


if __name__ == "__main__":
main()
Loading
Loading