From a2414ab7b6db18919ea3033cc5fa95a22de3efa4 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 15:25:02 -0600 Subject: [PATCH 01/33] feat: unify shared-prefix megatron execution --- .github/workflows/build-gpu-image.yml | 14 + .gitignore | 2 + dev/megatron_review_perf.py | 868 ++++ pyproject.toml | 2 + scripts/build-gpu-image.sh | 59 +- src/art/megatron/context_parallel/__init__.py | 8 - .../megatron/context_parallel/block_mask.py | 555 ++- src/art/megatron/context_parallel/builder.py | 245 +- src/art/megatron/context_parallel/comm.py | 564 +-- src/art/megatron/context_parallel/executor.py | 41 +- .../megatron/context_parallel/layout_index.py | 7 +- src/art/megatron/context_parallel/runtime.py | 1102 +---- src/art/megatron/context_parallel/types.py | 197 +- src/art/megatron/gdn/__init__.py | 4 - src/art/megatron/gdn/gdn_shared_prefix.py | 4365 +++-------------- src/art/megatron/gdn/layout.py | 105 +- src/art/megatron/gdn/operator.py | 1280 ++--- src/art/megatron/lora.py | 1126 +++-- .../model_support/handlers/default_dense.py | 77 +- .../model_support/handlers/qwen3_5.py | 323 +- src/art/megatron/model_support/spec.py | 1 + src/art/megatron/service.py | 4 - src/art/megatron/setup.sh | 4 + src/art/megatron/shared_prefix_packing.py | 279 ++ src/art/megatron/shared_prefix_state.py | 185 +- src/art/megatron/shared_prefix_tree.py | 234 + src/art/megatron/training/finalize_grads.py | 97 +- src/art/megatron/weights/adapter_export.py | 365 +- src/art/megatron/weights/lora_publish.py | 273 +- .../test_attention_packed_vs_flattened.py | 85 +- .../gdn_shared_prefix/distributed_init.py | 7 + .../gdn_shared_prefix/layout_reference.py | 10 +- .../megatron/gdn_shared_prefix/oracles.py | 86 +- .../gdn_shared_prefix/packed_layout.py | 56 +- .../gdn_shared_prefix/parser_import.py | 1 - .../gdn_shared_prefix/real_gdn_oracle.py | 192 +- .../test_fla_cp_native_recurrent.py | 24 +- .../test_gdn_cp_packed_correctness.py | 332 +- .../test_gdn_cp_train_prepare.py | 16 +- ...en35_full_model_cp1_packed_vs_flattened.py | 294 +- .../test_real_gdn_cp1_packed_vs_flattened.py | 56 +- .../test_real_gdn_native_fla_cp.py | 38 +- .../test_real_gdn_tp_lora.py | 16 +- .../megatron/lora/test_lora_disk_codecs.py | 69 +- .../megatron/model_support/forward_trace.py | 33 + .../megatron/model_support/oracle_worker.py | 16 +- .../test_oracle_harness_invariants.py | 31 + .../test_shared_prefix_attention_builder.py | 586 +++ tests/unit/test_shared_prefix_grad_parity.py | 279 ++ tests/unit/test_shared_prefix_packing.py | 160 + tests/unit/test_shared_prefix_tree.py | 502 ++ typings/wandb/__init__.pyi | 38 + typings/wandb/sdk/__init__.pyi | 5 + typings/wandb/sdk/wandb_run.pyi | 3 + 54 files changed, 7223 insertions(+), 8098 deletions(-) create mode 100644 dev/megatron_review_perf.py create mode 100644 src/art/megatron/shared_prefix_packing.py create mode 100644 src/art/megatron/shared_prefix_tree.py create mode 100644 tests/integration/megatron/gdn_shared_prefix/distributed_init.py create mode 100644 tests/unit/test_shared_prefix_attention_builder.py create mode 100644 tests/unit/test_shared_prefix_grad_parity.py create mode 100644 tests/unit/test_shared_prefix_packing.py create mode 100644 tests/unit/test_shared_prefix_tree.py create mode 100644 typings/wandb/__init__.pyi create mode 100644 typings/wandb/sdk/__init__.pyi create mode 100644 typings/wandb/sdk/wandb_run.pyi diff --git a/.github/workflows/build-gpu-image.yml b/.github/workflows/build-gpu-image.yml index 12dbfad96..cdfc23634 100644 --- a/.github/workflows/build-gpu-image.yml +++ b/.github/workflows/build-gpu-image.yml @@ -30,6 +30,11 @@ on: required: true default: true type: boolean + prewarm_modal: + description: "Prebuild the pushed image in Modal when auth is configured" + required: true + default: true + type: boolean prewarm_timeout: description: "Timeout for GPU node prewarm rollout" required: true @@ -155,11 +160,16 @@ jobs: PULL_IMAGE_REPO: ${{ inputs.pull_image_repo || 'docker.io/bradhiltonnw/art-gpu' }} IMAGE_TAG: ${{ inputs.tag }} NO_CACHE: ${{ inputs.no_cache }} + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + PREWARM_MODAL_INPUT: ${{ inputs.prewarm_modal }} PREWARM_NODES: ${{ inputs.prewarm_nodes }} PREWARM_TIMEOUT: ${{ inputs.prewarm_timeout }} run: | IMAGE_TAG="${IMAGE_TAG:-latest}" NO_CACHE="${NO_CACHE:-false}" + export PREWARM_MODAL="${PREWARM_MODAL:-auto}" + PREWARM_MODAL_INPUT="${PREWARM_MODAL_INPUT:-true}" PREWARM_NODES="${PREWARM_NODES:-true}" PREWARM_TIMEOUT="${PREWARM_TIMEOUT:-30m}" @@ -175,6 +185,10 @@ jobs: args+=(--no-cache) fi + if [ "${PREWARM_MODAL_INPUT}" = "false" ]; then + args+=(--no-prewarm-modal) + fi + if [ "${PREWARM_NODES}" != "true" ]; then args+=(--no-prewarm-nodes) fi diff --git a/.gitignore b/.gitignore index d1f4ebd59..0dfae3afe 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ data/cache.db streaming-chat-completions/ unsloth_compiled_cache/ wandb/ +!/typings/wandb/ +!/typings/wandb/** docs/node_modules/ dist/ replays/ diff --git a/dev/megatron_review_perf.py b/dev/megatron_review_perf.py new file mode 100644 index 000000000..4da5a8ad0 --- /dev/null +++ b/dev/megatron_review_perf.py @@ -0,0 +1,868 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass +import json +from pathlib import Path +import time + +import numpy as np +import torch +from torch.nn.attention.flex_attention import AuxRequest, BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask +import typer + +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) +from art.megatron.context_parallel.builder import build_shared_prefix_attention_spec +from art.megatron.context_parallel.executor import _build_stage_execution_spec +from art.megatron.context_parallel.runtime import ( + _RUNTIME_PLAN_CACHE, + get_or_build_runtime_plan, +) +from art.megatron.context_parallel.types import ( + ContextParallelConfig, + FlexMaskSpec, + ParallelTopology, + StageExecutionSpec, + StagePlan, +) +from art.megatron.flex_attn.compiled import ( + normalize_sparse_block_size, + sparse_compiled_flex_attention, +) +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes + + +def main( + workload: str = "austin_198k", + max_depth: int = 1, + cp_size: int = 4, + block_size: int = 128, + prefix_families: int = 4, + prefix_len: int = 1024, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 8, + completion_len: int = 128, + warmup: int = 3, + repeat: int = 10, + shape_variants: int = 4, + validate_torch: bool = True, + validate_torch_token_cap: int = 32768, + run_flex: bool = True, + flex_token_cap: int = 8192, + flex_heads: int = 2, + flex_head_dim: int = 128, + flex_mask_variants: str = "current,causal_abs_only", + max_block_mask_build_ms: float | None = None, + max_cp_planning_cold_ms: float | None = None, + output_jsonl: Path = Path(".local/trainer_rank_review/block_mask_flex.jsonl"), +) -> None: + if warmup < 0 or repeat < 1: + raise ValueError("warmup must be >= 0 and repeat must be >= 1") + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + + pack = _pack_workload( + workload=workload, + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = ContextParallelConfig(block_size=block_size) + topology = ParallelTopology(cp=cp_size) + base = { + "workload": workload, + "max_depth": max_depth, + "cp_size": cp_size, + "block_size": block_size, + "packed_tokens": int(pack.tokens.numel()), + "logical_tokens": _logical_tokens(pack), + "warmup": warmup, + "repeat": repeat, + "validate_torch": validate_torch, + "validate_torch_token_cap": validate_torch_token_cap, + } + + plan, plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cold", + "ms": plan_ms, + **_plan_stats(plan), + }, + ) + + cached_plan, cached_plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cached", + "ms": cached_plan_ms, + **_plan_stats(cached_plan), + }, + ) + + stage_masks, mask_ms = _bench_cpu( + lambda: _build_stage_masks(pack, plan, config), + warmup=warmup, + repeat=repeat, + ) + masks = tuple(mask for mask, _ in stage_masks) + torch_validation_skipped = _torch_validation_skip_reason( + validate_torch=validate_torch, + packed_tokens=int(pack.tokens.numel()), + token_cap=validate_torch_token_cap, + ) + if torch_validation_skipped is None: + for mask, slices in stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "block_mask_build", + "ms": mask_ms, + "torch_validation_skipped": torch_validation_skipped, + **_mask_stats(masks), + }, + ) + _check_threshold("block_mask_build", mask_ms, max_block_mask_build_ms) + _check_threshold("cp_planning_cold", plan_ms, max_cp_planning_cold_ms) + + if run_flex: + for record in _flex_records( + pack, + plan, + config, + warmup=warmup, + repeat=repeat, + token_cap=flex_token_cap, + heads=flex_heads, + head_dim=flex_head_dim, + variants=_csv_values(flex_mask_variants), + ): + _write(output_jsonl, {**base, **record}) + + for variant in range(shape_variants): + variant_pack = _pack_workload( + workload="regular", + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len + variant * 17, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len + variant * 3, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len + variant * 11, + ) + variant_spec = build_shared_prefix_attention_spec( + group_ids=variant_pack.group_ids, + parent_ids=variant_pack.parent_ids, + ) + variant_plan, variant_plan_ms = _bench_cpu( + lambda pack=variant_pack, spec=variant_spec: _build_cp_plan( + pack, + spec, + topology, + config, + ), + warmup=0, + repeat=1, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + variant_stage_masks, variant_mask_ms = _bench_cpu( + lambda pack=variant_pack, plan=variant_plan: _build_stage_masks( + pack, + plan, + config, + ), + warmup=0, + repeat=1, + ) + variant_masks = tuple(mask for mask, _ in variant_stage_masks) + variant_torch_validation_skipped = _torch_validation_skip_reason( + validate_torch=validate_torch, + packed_tokens=int(variant_pack.tokens.numel()), + token_cap=validate_torch_token_cap, + ) + if variant_torch_validation_skipped is None: + for mask, slices in variant_stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "shape_variant", + "variant": variant, + "variant_packed_tokens": int(variant_pack.tokens.numel()), + "variant_logical_tokens": _logical_tokens(variant_pack), + "cp_planning_ms": variant_plan_ms, + "block_mask_build_ms": variant_mask_ms, + "torch_validation_skipped": variant_torch_validation_skipped, + **_plan_stats(variant_plan), + **_mask_stats(variant_masks), + }, + ) + + print(f"wrote review perf records to {output_jsonl}", flush=True) + + +def _pack_workload( + *, + workload: str, + max_depth: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> SharedPrefixPack: + sequences = ( + _austin_sequences() + if workload == "austin_198k" + else _austin_varied_sequences() + if workload == "austin_varied" + else _regular_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + ) + return pack_shared_prefixes(sequences, max_depth=max_depth) + + +def _austin_sequences() -> tuple[torch.Tensor, ...]: + return tuple( + torch.cat( + ( + _tokens(family * 10_000_019, 5000), + _tokens(family * 10_000_019 + branch * 1009 + 17, 100), + ) + ) + for family in range(30) + for branch in range(16) + ) + + +def _austin_varied_sequences() -> tuple[torch.Tensor, ...]: + sequences: list[torch.Tensor] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + return tuple(sequences) + + +def _regular_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[torch.Tensor, ...]: + sequences = [] + for family in range(max(1, prefix_families)): + family_base = family * 10_000_019 + root = _tokens(family_base, max(1, prefix_len)) + for mid in range(max(1, mid_prefixes_per_family)): + mid_prefix = _tokens( + family_base + 1_000_003 + mid * 100_003, + max(0, mid_prefix_len), + ) + prefix = torch.cat((root, mid_prefix)) + for branch in range(max(1, branches_per_prefix)): + sequences.append( + torch.cat( + ( + prefix, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + max(1, completion_len), + ), + ) + ) + ) + return tuple(sequences) + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _build_cp_plan( + pack: SharedPrefixPack, + spec: object, + topology: ParallelTopology, + config: ContextParallelConfig, +) -> object: + return get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + original_seq_len=int(pack.tokens.numel()), + ) + + +def _build_stage_masks( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, +) -> tuple[tuple[BlockMask, tuple[object, ...]], ...]: + masks = [] + context = prepare_block_mask_context( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + execution_spec = _stage_execution_spec(stage, config) + mask_metadata = execution_spec.mask_metadata or stage.mask_metadata + if mask_metadata is None: + continue + mask = build_block_mask_from_context( + FlexMaskSpec( + q_len=execution_spec.q_len, + k_len=execution_spec.k_len, + block_size=_sparse_block_size(config), + slices=stage.slices, + exact_mask=mask_metadata, + ), + context=context, + device=torch.device("cpu"), + validate=False, + ) + if mask is not None: + masks.append((mask, tuple(stage.slices))) + return tuple(masks) + + +def _flex_records( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, + *, + warmup: int, + repeat: int, + token_cap: int, + heads: int, + head_dim: int, + variants: Sequence[str], +) -> list[dict[str, object]]: + if not torch.cuda.is_available(): + return [{"case": "flex_attention_fwd_bwd", "skipped": "cuda_unavailable"}] + device = torch.device("cuda") + stage_cases = _build_stage_flex_cases( + pack, + plan, + config, + device=device, + ) + if not stage_cases: + return [{"case": "flex_attention_fwd_bwd", "skipped": "no_stage_masks"}] + largest_stage = max(max(case.q_len, case.k_len) for case in stage_cases) + if int(largest_stage) > int(token_cap): + return [ + { + "case": "flex_attention_fwd_bwd", + "skipped": "stage_tokens_exceed_flex_token_cap", + "flex_token_cap": int(token_cap), + "largest_stage_tokens": int(largest_stage), + } + ] + records: list[dict[str, object]] = [] + base_tensors = _stage_tensors( + stage_cases, + heads=heads, + head_dim=head_dim, + device=device, + ) + for variant in variants: + block_masks = [] + try: + block_masks = [ + _stage_variant_block_mask(case, variant, device=device) + for case in stage_cases + ] + except Exception as exc: + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "compile_error": type(exc).__name__, + "compile_error_message": str(exc).splitlines()[0][:500], + "flex_heads": heads, + "flex_head_dim": head_dim, + } + ) + continue + qkv = [ + ( + q.detach().clone().requires_grad_(True), + k.detach().clone().requires_grad_(True), + v.detach().clone().requires_grad_(True), + ) + for q, k, v in base_tensors + ] + + def step() -> None: + loss = torch.zeros((), device=device, dtype=torch.float32) + for (q, k, v), block_mask in zip(qkv, block_masks, strict=True): + q.grad = None + k.grad = None + v.grad = None + out, _aux = sparse_compiled_flex_attention( + q, + k, + v, + block_mask=block_mask, + scale=float(head_dim) ** -0.5, + enable_gqa=False, + return_aux=AuxRequest(lse=True), + ) + loss = loss + out.float().sum() + loss.backward() + + try: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + first_started = time.perf_counter() + step() + torch.cuda.synchronize() + first_call_ms = round((time.perf_counter() - first_started) * 1000.0, 3) + ms = _bench_cuda(step, warmup=warmup, repeat=repeat) + except Exception as exc: + torch.cuda.empty_cache() + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "compile_error": type(exc).__name__, + "compile_error_message": str(exc).splitlines()[0][:500], + "flex_heads": heads, + "flex_head_dim": head_dim, + **_stage_flex_stats(stage_cases), + } + ) + continue + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "first_call_ms": first_call_ms, + "ms": ms, + "packed_tok_s": round(int(pack.tokens.numel()) * 1000.0 / ms, 3), + "flex_heads": heads, + "flex_head_dim": head_dim, + **_stage_flex_stats(stage_cases), + "peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1024**3, 3), + } + ) + return records + + +@dataclass(frozen=True) +class _StageFlexCase: + rank: int + stage_index: int + q_len: int + k_len: int + logical_q_len: int + logical_k_len: int + block_mask: BlockMask + q_abs: np.ndarray + k_abs: np.ndarray + + +def _build_stage_flex_cases( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, + *, + device: torch.device, +) -> tuple[_StageFlexCase, ...]: + cases: list[_StageFlexCase] = [] + context = prepare_block_mask_context( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + execution_spec = _stage_execution_spec(stage, config) + mask_metadata = execution_spec.mask_metadata or stage.mask_metadata + if mask_metadata is None: + continue + mask = build_block_mask_from_context( + FlexMaskSpec( + q_len=execution_spec.q_len, + k_len=execution_spec.k_len, + block_size=_sparse_block_size(config), + slices=stage.slices, + exact_mask=mask_metadata, + ), + context=context, + device=device, + validate=False, + ) + if mask is None: + continue + q_abs = ( + mask_metadata.q_token_indices.detach() + .to(device="cpu", dtype=torch.int64) + .reshape(-1) + .numpy() + ) + k_abs = ( + mask_metadata.k_token_indices.detach() + .to(device="cpu", dtype=torch.int64) + .reshape(-1) + .numpy() + ) + cases.append( + _StageFlexCase( + rank=int(rank_plan.rank), + stage_index=int(stage.stage_index), + q_len=int(execution_spec.q_len), + k_len=int(execution_spec.k_len), + logical_q_len=int(stage.q_len), + logical_k_len=int(stage.k_len), + block_mask=mask, + q_abs=q_abs, + k_abs=k_abs, + ) + ) + return tuple(cases) + + +def _stage_tensors( + cases: Sequence[_StageFlexCase], + *, + heads: int, + head_dim: int, + device: torch.device, +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + generator = torch.Generator(device=device).manual_seed(17) + tensors = [] + for case in cases: + q_shape = (1, int(heads), int(case.q_len), int(head_dim)) + k_shape = (1, int(heads), int(case.k_len), int(head_dim)) + tensors.append( + ( + torch.randn( + q_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + torch.randn( + k_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + torch.randn( + k_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + ) + ) + return tuple(tensors) + + +def _stage_variant_block_mask( + case: _StageFlexCase, + variant: str, + *, + device: torch.device, +) -> BlockMask: + if variant == "current": + return case.block_mask + q_abs = torch.as_tensor(case.q_abs, device=device, dtype=torch.int64) + k_abs = torch.as_tensor(case.k_abs, device=device, dtype=torch.int64) + if variant == "causal_abs_only": + + def mask_mod(batch_idx, head_idx, query_idx, kv_idx): + del batch_idx, head_idx + return q_abs[query_idx] >= k_abs[kv_idx] + + return _replace_block_mask_mod(case.block_mask, mask_mod) + raise ValueError(f"unknown flex_mask_variant {variant!r}") + + +def _stage_flex_stats(cases: Sequence[_StageFlexCase]) -> dict[str, object]: + return { + "flex_stage_count": len(cases), + "flex_stage_q_tokens": sum(case.q_len for case in cases), + "flex_stage_k_tokens": sum(case.k_len for case in cases), + "flex_stage_logical_q_tokens": sum(case.logical_q_len for case in cases), + "flex_stage_logical_k_tokens": sum(case.logical_k_len for case in cases), + "flex_stage_max_q_tokens": max(case.q_len for case in cases), + "flex_stage_max_k_tokens": max(case.k_len for case in cases), + "flex_stage_max_logical_q_tokens": max(case.logical_q_len for case in cases), + "flex_stage_max_logical_k_tokens": max(case.logical_k_len for case in cases), + } + + +def _sparse_block_size(config: ContextParallelConfig) -> tuple[int, int]: + return normalize_sparse_block_size( + config.attention_sparse_block_size or config.block_size + ) + + +def _stage_execution_spec( + stage: StagePlan, + config: ContextParallelConfig, +) -> StageExecutionSpec: + return _build_stage_execution_spec( + stage_plan=stage, + block_size=_sparse_block_size(config), + ) + + +def _replace_block_mask_mod(block_mask: BlockMask, mask_mod: object) -> BlockMask: + return BlockMask( + seq_lengths=block_mask.seq_lengths, + kv_num_blocks=block_mask.kv_num_blocks, + kv_indices=block_mask.kv_indices, + full_kv_num_blocks=block_mask.full_kv_num_blocks, + full_kv_indices=block_mask.full_kv_indices, + q_num_blocks=block_mask.q_num_blocks, + q_indices=block_mask.q_indices, + full_q_num_blocks=block_mask.full_q_num_blocks, + full_q_indices=block_mask.full_q_indices, + BLOCK_SIZE=block_mask.BLOCK_SIZE, + mask_mod=mask_mod, + ) + + +def _bench_cpu( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + before_each: Callable[[], object] | None = None, +) -> tuple[object, float]: + result = None + for _ in range(warmup): + if before_each is not None: + before_each() + result = fn() + elapsed = [] + for _ in range(repeat): + if before_each is not None: + before_each() + start = time.perf_counter() + result = fn() + elapsed.append((time.perf_counter() - start) * 1000.0) + assert result is not None + return result, round(sum(elapsed) / len(elapsed), 3) + + +def _bench_cuda(fn: Callable[[], object], *, warmup: int, repeat: int) -> float: + torch.cuda.reset_peak_memory_stats() + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + stop.record() + torch.cuda.synchronize() + return round(float(start.elapsed_time(stop)) / repeat, 3) + + +def _plan_stats(plan: object) -> dict[str, int]: + stage_count = 0 + remote_stage_count = 0 + mask_stage_count = 0 + for rank_plan in plan: + for stage in rank_plan.stage_plans: + stage_count += 1 + remote_stage_count += int(not stage.is_local_stage) + mask_stage_count += int(stage.mask_metadata is not None) + return { + "rank_count": len(plan), + "stage_count": stage_count, + "remote_stage_count": remote_stage_count, + "mask_stage_count": mask_stage_count, + } + + +def _mask_stats(masks: Sequence[BlockMask]) -> dict[str, int]: + return { + "mask_count": len(masks), + "partial_kv_blocks": sum(_block_count(mask, "kv_num_blocks") for mask in masks), + "full_kv_blocks": sum( + _block_count(mask, "full_kv_num_blocks") for mask in masks + ), + "partial_q_blocks": sum(_block_count(mask, "q_num_blocks") for mask in masks), + "full_q_blocks": sum(_block_count(mask, "full_q_num_blocks") for mask in masks), + } + + +def _block_count(block_mask: BlockMask, name: str) -> int: + counts = getattr(block_mask, name) + return 0 if counts is None else int(counts.sum().item()) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + slices: Sequence[object] = (), +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + _slice_mask_mod(block_mask.mask_mod, slices), + B=int(block_mask.kv_num_blocks.shape[0]), + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + actual = _block_entries(block_mask, counts_name, indices_name) + expected = _block_entries(reference, counts_name, indices_name) + if actual != expected: + raise AssertionError(f"{counts_name}/{indices_name} mismatch") + + +def _slice_mask_mod(mask_mod: object, slices: Sequence[object]) -> object: + if not slices: + return mask_mod + + def sliced_mask_mod( + batch_idx: torch.Tensor, + head_idx: torch.Tensor, + query_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + in_slice = (query_idx < 0) & (kv_idx < 0) + for slice_ in slices: + in_slice |= ( + (query_idx >= int(slice_.q_range.start)) + & (query_idx < int(slice_.q_range.end)) + & (kv_idx >= int(slice_.k_range.start)) + & (kv_idx < int(slice_.k_range.end)) + ) + return in_slice & mask_mod(batch_idx, head_idx, query_idx, kv_idx) + + return sliced_mask_mod + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _logical_tokens(pack: SharedPrefixPack) -> int: + return sum(int(positions.numel()) for positions in pack.positions_by_sequence) + + +def _torch_validation_skip_reason( + *, + validate_torch: bool, + packed_tokens: int, + token_cap: int, +) -> str | None: + if not validate_torch: + return "disabled" + if token_cap > 0 and packed_tokens > token_cap: + return f"packed_tokens>{token_cap}" + return None + + +def _csv_values(value: str) -> tuple[str, ...]: + values = tuple(part.strip() for part in value.split(",") if part.strip()) + if not values: + raise ValueError("CSV option must contain at least one value") + return values + + +def _write(path: Path, payload: dict[str, object]) -> None: + line = json.dumps(payload, sort_keys=True) + with path.open("a", encoding="utf-8") as output: + output.write(line + "\n") + print(line, flush=True) + + +def _check_threshold(name: str, value_ms: float, limit_ms: float | None) -> None: + if limit_ms is not None and float(value_ms) > float(limit_ms): + raise RuntimeError( + f"{name} took {float(value_ms):.3f}ms, exceeding {float(limit_ms):.3f}ms" + ) + + +if __name__ == "__main__": + typer.run(main) diff --git a/pyproject.toml b/pyproject.toml index bfa06e5d1..dd900d326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,6 +199,7 @@ requires-dist = [ [tool.ty.environment] python-version = "3.12" +extra-paths = ["typings"] [tool.ty.rules] # Ignore unused-ignore-comment warnings because they vary depending on whether @@ -229,6 +230,7 @@ allowed-unresolved-imports = [ "peft.**", "pyarrow.**", "torch.**", + "torchvision.**", "torchao.**", "transformers.**", "trl.**", diff --git a/scripts/build-gpu-image.sh b/scripts/build-gpu-image.sh index 909be3afd..e64abd2ce 100755 --- a/scripts/build-gpu-image.sh +++ b/scripts/build-gpu-image.sh @@ -10,10 +10,12 @@ Options: --image-repo REPO Image repository to publish --infra INFRA Kubernetes-backed SkyPilot infra (default: k8s/cks-wb3) --no-cache Disable registry-backed BuildKit cache + --no-prewarm-modal Skip prebuilding the pushed image in Modal --no-prewarm-nodes Skip pre-pulling the pushed image on GPU nodes --pull-image-repo REPO Image repository for cluster pulls/prewarm + --prewarm-modal Require prebuilding the pushed image in Modal --prewarm-timeout DUR Timeout for the prewarm DaemonSet rollout (default: 30m) - --tag TAG Image tag to publish + --tag TAG Image tag to publish (default: latest) --help Show this help EOF } @@ -24,12 +26,13 @@ cluster_name="" infra="${SKY_INFRA:-k8s/cks-wb3}" image_repo="${ART_IMAGE_REPO:-}" pull_image_repo="${ART_PULL_IMAGE_REPO:-}" -image_tag="" +image_tag="${IMAGE_TAG:-latest}" docker_config_path="${DOCKER_CONFIG_PATH:-${HOME}/.docker/config.json}" buildkit_image="${BUILDKIT_IMAGE:-moby/buildkit:v0.29.0-rootless}" buildkit_namespace="${KUBECTL_NAMESPACE:-default}" buildkit_wait_timeout="${BUILDKIT_WAIT_TIMEOUT:-300s}" no_cache="${NO_CACHE:-false}" +prewarm_modal="${PREWARM_MODAL:-auto}" prewarm_nodes="${PREWARM_NODES:-true}" prewarm_namespace="${PREWARM_NAMESPACE:-default}" prewarm_name="${PREWARM_NAME:-art-gpu-image-prewarm}" @@ -65,6 +68,10 @@ while [[ $# -gt 0 ]]; do no_cache=true shift ;; + --no-prewarm-modal) + prewarm_modal=false + shift + ;; --no-prewarm-nodes) prewarm_nodes=false shift @@ -73,6 +80,10 @@ while [[ $# -gt 0 ]]; do pull_image_repo="$2" shift 2 ;; + --prewarm-modal) + prewarm_modal=true + shift + ;; --prewarm-timeout) prewarm_timeout="$2" shift 2 @@ -93,6 +104,14 @@ while [[ $# -gt 0 ]]; do esac done +case "${prewarm_modal}" in + auto|true|false) ;; + *) + echo "PREWARM_MODAL must be one of: auto, true, false" >&2 + exit 1 + ;; +esac + case "${infra}" in k8s/*) kube_context="${infra#k8s/}" @@ -118,10 +137,6 @@ art_sha="$(git -C "${repo_root}" rev-parse HEAD)" art_short_sha="$(git -C "${repo_root}" rev-parse --short=12 HEAD)" timestamp="$(date +%m%d-%H%M%S)" -if [[ -z "${image_tag}" ]]; then - image_tag="skypilot-${art_short_sha}" -fi - if [[ -z "${cluster_name}" ]]; then cluster_name="art-gpu-build-${timestamp}" fi @@ -437,6 +452,38 @@ if [[ -n "${prewarm_refresh_tag_image}" ]]; then esac fi +modal_auth_available=false +if [[ "${prewarm_modal}" != "false" ]]; then + if uv run --with 'modal>=1.5.0' python - <<'PY' >/dev/null 2>&1; then +import modal + +modal.Workspace.from_context().hydrate() +PY + modal_auth_available=true + fi +fi + +if [[ "${prewarm_modal}" == "true" || "${modal_auth_available}" == "true" ]]; then + echo "Prewarming ${image_repo}:${image_tag} in Modal image cache" + MODAL_FORCE_BUILD=1 uv run --with 'modal>=1.5.0' python - "${image_repo}:${image_tag}" <<'PY' +import sys + +import modal + +image = ( + modal.Image.from_registry(sys.argv[1], add_python="3.12") + .apt_install("openssh-server", "sudo", "rsync", "curl", "procps", "patch", "lsof") +) +app = modal.App.lookup("skypilot-modal", create_if_missing=True) +with modal.enable_output(): + image.build(app) +PY +elif [[ "${prewarm_modal}" == "auto" ]]; then + echo "Skipping Modal image prewarm: Modal auth unavailable" +else + echo "Skipping Modal image prewarm" +fi + dump_prewarm_diagnostics() { echo "::group::Prewarm diagnostics" "${kubectl_cmd[@]}" get daemonset -n "${prewarm_namespace}" "${prewarm_name}" -o wide || true diff --git a/src/art/megatron/context_parallel/__init__.py b/src/art/megatron/context_parallel/__init__.py index 995b0c425..fc27c486e 100644 --- a/src/art/megatron/context_parallel/__init__.py +++ b/src/art/megatron/context_parallel/__init__.py @@ -1,20 +1,16 @@ from .builder import build_dense_reference_mask, build_shared_prefix_attention_spec from .layout_index import TokenLayoutIndex -from .runtime import build_context_parallel_token_layout_index from .types import ( ArtContextParallelState, AttnMaskKind, AttnSlice, ContextParallelConfig, - ContextParallelRuntimeKey, - ContextParallelRuntimePlan, DispatchedPackedTensors, FlexMaskSpec, PackedBatchAttentionSpec, PackedRowAttentionSpec, ParallelTopology, PreparedMegatronBatch, - SharedPrefixBuilderConfig, TokenRange, ) @@ -28,13 +24,9 @@ "PackedRowAttentionSpec", "ParallelTopology", "PreparedMegatronBatch", - "SharedPrefixBuilderConfig", "ContextParallelConfig", - "ContextParallelRuntimeKey", - "ContextParallelRuntimePlan", "TokenRange", "TokenLayoutIndex", "build_dense_reference_mask", - "build_context_parallel_token_layout_index", "build_shared_prefix_attention_spec", ] diff --git a/src/art/megatron/context_parallel/block_mask.py b/src/art/megatron/context_parallel/block_mask.py index 91fe2023b..e5ec1eaac 100644 --- a/src/art/megatron/context_parallel/block_mask.py +++ b/src/art/megatron/context_parallel/block_mask.py @@ -1,32 +1,42 @@ from __future__ import annotations +from dataclasses import dataclass + import numpy as np import torch from torch.nn.attention.flex_attention import BlockMask from art.megatron.flex_attn.compiled import normalize_sparse_block_size +from art.megatron.shared_prefix_tree import parse_shared_prefix_row from .types import AttnMaskKind, FlexMaskSpec -_INVALID_Q_GROUP = -(1 << 63) -_INVALID_Q_PARENT = _INVALID_Q_GROUP + 1 -_INVALID_K_GROUP = _INVALID_Q_GROUP + 2 +_INVALID_ABS = -(1 << 63) +_INVALID_ENTER = -1 +_INVALID_EXIT = -1 + +@dataclass(frozen=True, slots=True) +class PreparedBlockMaskContext: + source_len: int + group_enter_np: np.ndarray + group_exit_np: np.ndarray -def _build_exact_mask_mod( + +def _build_interval_mask_mod( *, q_abs: np.ndarray, k_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - k_group: np.ndarray, + q_enter: np.ndarray, + k_enter: np.ndarray, + k_exit: np.ndarray, device: torch.device, ): q_abs_tensor = torch.as_tensor(q_abs, device=device, dtype=torch.int64) k_abs_tensor = torch.as_tensor(k_abs, device=device, dtype=torch.int64) - q_group_tensor = torch.as_tensor(q_group, device=device, dtype=torch.int64) - q_parent_tensor = torch.as_tensor(q_parent, device=device, dtype=torch.int64) - k_group_tensor = torch.as_tensor(k_group, device=device, dtype=torch.int64) + q_enter_tensor = torch.as_tensor(q_enter, device=device, dtype=torch.int64) + k_enter_tensor = torch.as_tensor(k_enter, device=device, dtype=torch.int64) + k_exit_tensor = torch.as_tensor(k_exit, device=device, dtype=torch.int64) def mask_mod( batch_idx: torch.Tensor, @@ -37,9 +47,13 @@ def mask_mod( del batch_idx, head_idx q_abs_local = q_abs_tensor[query_idx] k_abs_local = k_abs_tensor[kv_idx] - same_group = q_group_tensor[query_idx] == k_group_tensor[kv_idx] - parent_prefix = q_parent_tensor[query_idx] == k_group_tensor[kv_idx] - return (q_abs_local >= k_abs_local) & (same_group | parent_prefix) + q_enter_local = q_enter_tensor[query_idx] + k_enter_local = k_enter_tensor[kv_idx] + k_exit_local = k_exit_tensor[kv_idx] + in_key_subtree = (k_enter_local <= q_enter_local) & ( + q_enter_local < k_exit_local + ) + return (q_abs_local >= k_abs_local) & in_key_subtree return mask_mod @@ -49,10 +63,15 @@ def _dense_blocks_to_ordered( *, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - counts = torch.from_numpy(blocks.sum(axis=-1).astype(np.int32)) - indices = torch.from_numpy( - np.argsort(-blocks.astype(np.int32), axis=-1, kind="stable").astype(np.int32) - ) + row_indices, column_indices = np.nonzero(blocks) + counts_np = np.bincount(row_indices, minlength=blocks.shape[0]).astype(np.int32) + indices_np = np.zeros(blocks.shape, dtype=np.int32) + if int(row_indices.size) > 0: + starts = np.concatenate(([0], np.cumsum(counts_np[:-1], dtype=np.int64))) + offsets = np.arange(int(row_indices.size), dtype=np.int64) - starts[row_indices] + indices_np[row_indices, offsets] = column_indices + counts = torch.from_numpy(counts_np) + indices = torch.from_numpy(indices_np) return ( counts.view(1, 1, -1).to(device=device), indices.view(1, 1, blocks.shape[0], blocks.shape[1]).to(device=device), @@ -72,72 +91,252 @@ def _select_with_invalid_np( return selected -def _build_q_block_group_state( +def _refine_interval_blocks( *, + partial_blocks: np.ndarray, + full_blocks: np.ndarray, q_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - q_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], frozenset[int]]: - start = int(block_idx) * q_block - end = min((int(block_idx) + 1) * q_block, int(q_abs.size)) - q = q_abs[start:end] - q_group_block = q_group[start:end] - q_parent_block = q_parent[start:end] - q_min = int(q.min()) if int(q.size) else 0 - max_by_group: dict[int, int] = {} - all_groups: list[int] = [] - for group_value in np.unique(np.concatenate((q_group_block, q_parent_block))): - allowed = (q_group_block == group_value) | (q_parent_block == group_value) - if bool(allowed.any()): - max_by_group[int(group_value)] = int(q[allowed].max()) - if bool(allowed.all()): - all_groups.append(int(group_value)) - return q_min, max_by_group, frozenset(all_groups) - - -def _build_k_block_group_state( - *, k_abs: np.ndarray, - k_group: np.ndarray, + q_enter: np.ndarray, + k_enter: np.ndarray, + k_exit: np.ndarray, + q_block: int, k_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], tuple[int, ...]]: - start = int(block_idx) * k_block - end = min((int(block_idx) + 1) * k_block, int(k_abs.size)) - k = k_abs[start:end] - k_group_block = k_group[start:end] - k_max = int(k.max()) if int(k.size) else 0 - min_by_group: dict[int, int] = {} - for group_value in np.unique(k_group_block): - min_by_group[int(group_value)] = int(k[k_group_block == group_value].min()) - return k_max, min_by_group, tuple(min_by_group) - - -def _exact_block_state( - *, - q_state: tuple[int, dict[int, int], frozenset[int]], - k_state: tuple[int, dict[int, int], tuple[int, ...]], -) -> tuple[bool, bool]: - q_min, q_allowed_max, q_all_allowed = q_state - k_max, k_min, k_groups = k_state - if not any( - q_allowed_max.get(k_group_value, _INVALID_Q_GROUP) >= min_k - for k_group_value, min_k in k_min.items() +) -> None: + if not bool((partial_blocks | full_blocks).any()): + return + + q_abs_blocks = _block_matrix( + q_abs, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ABS, + ) + q_enter_blocks = _block_matrix( + q_enter, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ENTER, + ) + k_abs_blocks = _block_matrix( + k_abs, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ABS, + ) + k_enter_blocks = _block_matrix( + k_enter, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ENTER, + ) + k_exit_blocks = _block_matrix( + k_exit, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_EXIT, + ) + + q_valid = (q_abs_blocks >= 0) & (q_enter_blocks >= 0) + k_valid = ( + (k_abs_blocks >= 0) & (k_enter_blocks >= 0) & (k_exit_blocks > k_enter_blocks) + ) + q_all_valid = q_valid.all(axis=1) + k_all_valid = k_valid.all(axis=1) + q_min_abs = np.where(q_valid, q_abs_blocks, np.iinfo(np.int64).max).min(axis=1) + q_min_enter = np.where( + q_valid, + q_enter_blocks, + np.iinfo(np.int64).max, + ).min(axis=1) + q_max_enter = np.where(q_valid, q_enter_blocks, _INVALID_ENTER).max(axis=1) + k_max_abs = np.where(k_valid, k_abs_blocks, _INVALID_ABS).max(axis=1) + k_max_enter = np.where(k_valid, k_enter_blocks, _INVALID_ENTER).max(axis=1) + k_min_exit = np.where(k_valid, k_exit_blocks, np.iinfo(np.int64).max).min(axis=1) + safe_full = ( + q_all_valid[:, None] + & k_all_valid[None, :] + & (q_min_abs[:, None] >= k_max_abs[None, :]) + & (k_max_enter[None, :] <= q_min_enter[:, None]) + & (q_max_enter[:, None] < k_min_exit[None, :]) + ) + candidate_blocks = partial_blocks | (full_blocks & ~safe_full) + q_indices, k_indices = np.nonzero(candidate_blocks) + if int(q_indices.size) == 0: + return + + rows = np.arange(int(k_valid.shape[0])) + first_valid_offsets = k_valid.argmax(axis=1) + first_enter = k_enter_blocks[rows, first_valid_offsets] + first_exit = k_exit_blocks[rows, first_valid_offsets] + k_single_interval = k_valid.any(axis=1) & ( + (~k_valid) + | ( + (k_enter_blocks == first_enter[:, None]) + & (k_exit_blocks == first_exit[:, None]) + ) + ).all(axis=1) + + single_pair = k_single_interval[k_indices] + if bool(single_pair.any()): + single_q = q_indices[single_pair] + single_k = k_indices[single_pair] + q_abs_selected = q_abs_blocks[single_q] + q_enter_selected = q_enter_blocks[single_q] + in_subtree = ( + q_valid[single_q] + & (q_enter_selected >= first_enter[single_k, None]) + & (q_enter_selected < first_exit[single_k, None]) + ) + max_abs_in_subtree = np.where( + in_subtree, + q_abs_selected, + _INVALID_ABS, + ).max(axis=1) + k_min_abs = np.where(k_valid, k_abs_blocks, np.iinfo(np.int64).max).min(axis=1) + has_any = max_abs_in_subtree >= k_min_abs[single_k] + + is_full = ( + has_any + & q_all_valid[single_q] + & k_all_valid[single_k] + & (q_min_abs[single_q] >= k_max_abs[single_k]) + & (first_enter[single_k] <= q_min_enter[single_q]) + & (q_max_enter[single_q] < first_exit[single_k]) + ) + partial_blocks[single_q, single_k] = has_any & ~is_full + full_blocks[single_q, single_k] = is_full + + intervals_by_k: dict[int, tuple[tuple[int, int, int], ...]] = {} + + def k_intervals(k_idx: int) -> tuple[tuple[int, int, int], ...]: + cached = intervals_by_k.get(k_idx) + if cached is not None: + return cached + min_abs_by_interval: dict[tuple[int, int], int] = {} + for abs_value, enter_value, exit_value in zip( + k_abs_blocks[k_idx, k_valid[k_idx]], + k_enter_blocks[k_idx, k_valid[k_idx]], + k_exit_blocks[k_idx, k_valid[k_idx]], + strict=True, + ): + key = (int(enter_value), int(exit_value)) + prior = min_abs_by_interval.get(key) + min_abs_by_interval[key] = ( + int(abs_value) if prior is None else min(prior, int(abs_value)) + ) + cached = tuple( + (enter, exit, min_abs) + for (enter, exit), min_abs in min_abs_by_interval.items() + ) + intervals_by_k[k_idx] = cached + return cached + + for q_idx, k_idx in zip( + q_indices[~single_pair], + k_indices[~single_pair], + strict=True, ): - return False, False - if int(q_min) < int(k_max): - return True, False - return True, all(k_group_value in q_all_allowed for k_group_value in k_groups) + q_valid_row = q_valid[q_idx] + intervals = k_intervals(int(k_idx)) + has_any = False + if bool(q_valid_row.any()) and intervals: + q_abs_row = q_abs_blocks[q_idx] + q_enter_row = q_enter_blocks[q_idx] + for enter, exit, min_abs in intervals: + in_subtree = q_valid_row & (q_enter_row >= enter) & (q_enter_row < exit) + if bool(in_subtree.any()) and int(q_abs_row[in_subtree].max()) >= int( + min_abs + ): + has_any = True + break + is_full = ( + has_any + and bool(q_all_valid[q_idx]) + and bool(k_all_valid[k_idx]) + and int(q_min_abs[q_idx]) >= int(k_max_abs[k_idx]) + and int(k_max_enter[k_idx]) <= int(q_min_enter[q_idx]) + and int(q_max_enter[q_idx]) < int(k_min_exit[k_idx]) + ) + partial_blocks[q_idx, k_idx] = has_any and not is_full + full_blocks[q_idx, k_idx] = bool(is_full) + + +def _is_strictly_increasing(values: np.ndarray) -> bool: + return int(values.size) <= 1 or bool(np.all(values[1:] > values[:-1])) + + +def _block_min_max( + values: np.ndarray, + starts: np.ndarray, + ends: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + mins = np.empty(starts.shape, dtype=values.dtype) + maxes = np.empty(starts.shape, dtype=values.dtype) + for index, (start, end) in enumerate(zip(starts, ends, strict=True)): + block = values[int(start) : int(end)] + mins[index] = block.min() + maxes[index] = block.max() + return mins, maxes + + +def _block_matrix( + values: np.ndarray, + *, + block_size: int, + block_count: int, + fill_value: int, +) -> np.ndarray: + padded = np.full(block_count * block_size, fill_value, dtype=np.int64) + padded[: int(values.size)] = values + return padded.reshape(block_count, block_size) + + +def _build_group_interval_arrays( + *, + row_tree, + length: int, +) -> tuple[np.ndarray, np.ndarray]: + enter_by_group: dict[int, int] = {} + exit_by_group: dict[int, int] = {} + segment_by_group = {segment.group_id: segment for segment in row_tree.segments} + children_by_group: dict[int, list[int]] = {} + roots: list[int] = [] + for segment in row_tree.segments: + if segment.ancestors: + children_by_group.setdefault(segment.parent_id, []).append(segment.group_id) + else: + roots.append(segment.group_id) + + next_enter = 0 + + def visit(group_id: int) -> None: + nonlocal next_enter + enter_by_group[group_id] = next_enter + next_enter += 1 + children = children_by_group.get(group_id, []) + children.sort(key=lambda child: segment_by_group[child].start) + for child_group_id in children: + visit(child_group_id) + exit_by_group[group_id] = next_enter + + roots.sort(key=lambda root: segment_by_group[root].start) + for root_group_id in roots: + visit(root_group_id) + + enter_by_token = np.full((length,), _INVALID_ENTER, dtype=np.int64) + exit_by_token = np.full((length,), _INVALID_EXIT, dtype=np.int64) + for segment in row_tree.segments: + enter_by_token[segment.start : segment.end] = enter_by_group[segment.group_id] + exit_by_token[segment.start : segment.end] = exit_by_group[segment.group_id] + return enter_by_token, exit_by_token def _build_sparse_block_mask( spec: FlexMaskSpec, *, device: torch.device, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + context: PreparedBlockMaskContext, block_size: tuple[int, int], ) -> BlockMask: q_block, k_block = block_size @@ -156,33 +355,29 @@ def _build_sparse_block_mask( ) q_abs = q_abs_tensor.numpy() k_abs = k_abs_tensor.numpy() - flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - flat_parent_ids = ( - parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - ) - flat_group_ids_np = flat_group_ids.numpy() - flat_parent_ids_np = flat_parent_ids.numpy() - q_group = _select_with_invalid_np( - flat_group_ids_np, + q_abs_sorted = _is_strictly_increasing(q_abs[q_abs >= 0]) + k_abs_sorted = _is_strictly_increasing(k_abs[k_abs >= 0]) + q_enter = _select_with_invalid_np( + context.group_enter_np, q_abs, - invalid_value=_INVALID_Q_GROUP, + invalid_value=_INVALID_ENTER, ) - q_parent = _select_with_invalid_np( - flat_parent_ids_np, - q_abs, - invalid_value=_INVALID_Q_PARENT, + k_enter = _select_with_invalid_np( + context.group_enter_np, + k_abs, + invalid_value=_INVALID_ENTER, ) - k_group = _select_with_invalid_np( - flat_group_ids_np, + k_exit = _select_with_invalid_np( + context.group_exit_np, k_abs, - invalid_value=_INVALID_K_GROUP, + invalid_value=_INVALID_EXIT, ) - mask_mod = _build_exact_mask_mod( + mask_mod = _build_interval_mask_mod( q_abs=q_abs, k_abs=k_abs, - q_group=q_group, - q_parent=q_parent, - k_group=k_group, + q_enter=q_enter, + k_enter=k_enter, + k_exit=k_exit, device=device, ) if not spec.slices: @@ -208,15 +403,11 @@ def _build_sparse_block_mask( if int(q_block_indices.size) == 0 or int(k_block_indices.size) == 0: continue q_block_start = q_block_indices * q_block - q_block_end = np.minimum( - (q_block_indices + 1) * q_block, - int(spec.q_len), - ) + q_block_end_raw = (q_block_indices + 1) * q_block + q_block_end = np.minimum(q_block_end_raw, int(spec.q_len)) k_block_start = k_block_indices * k_block - k_block_end = np.minimum( - (k_block_indices + 1) * k_block, - int(spec.k_len), - ) + k_block_end_raw = (k_block_indices + 1) * k_block + k_block_end = np.minimum(k_block_end_raw, int(spec.k_len)) q_overlap_start = np.maximum( q_block_start, q_start, @@ -233,12 +424,12 @@ def _build_sparse_block_mask( k_block_end, k_end, ) - q_min = q_abs[q_overlap_start] - q_max = q_abs[q_overlap_end - 1] - k_min = k_abs[k_overlap_start] - k_max = k_abs[k_overlap_end - 1] - q_is_full = (q_overlap_start == q_block_start) & (q_overlap_end == q_block_end) - k_is_full = (k_overlap_start == k_block_start) & (k_overlap_end == k_block_end) + q_is_full = (q_overlap_start == q_block_start) & ( + q_overlap_end == q_block_end_raw + ) + k_is_full = (k_overlap_start == k_block_start) & ( + k_overlap_end == k_block_end_raw + ) covers_block = q_is_full[:, None] & k_is_full[None, :] if slice_.mask_kind == AttnMaskKind.FULL: has_any = np.ones( @@ -246,6 +437,16 @@ def _build_sparse_block_mask( ) is_full = covers_block else: + q_min, q_max = ( + (q_abs[q_overlap_start], q_abs[q_overlap_end - 1]) + if q_abs_sorted + else _block_min_max(q_abs, q_overlap_start, q_overlap_end) + ) + k_min, k_max = ( + (k_abs[k_overlap_start], k_abs[k_overlap_end - 1]) + if k_abs_sorted + else _block_min_max(k_abs, k_overlap_start, k_overlap_end) + ) has_any = q_max[:, None] >= k_min[None, :] is_full = covers_block & (q_min[:, None] >= k_max[None, :]) @@ -255,41 +456,24 @@ def _build_sparse_block_mask( partial_blocks[q_slice, k_slice] |= has_any full_blocks[q_slice, k_slice] |= is_full - ambiguous = (touch_counts > 1) & partial_blocks & ~full_blocks - q_state_cache: dict[int, tuple[int, dict[int, int], frozenset[int]]] = {} - k_state_cache: dict[int, tuple[int, dict[int, int], tuple[int, ...]]] = {} - for q_idx, k_idx in np.argwhere(ambiguous): - q_state = q_state_cache.get(int(q_idx)) - if q_state is None: - q_state = _build_q_block_group_state( - q_abs=q_abs, - q_group=q_group, - q_parent=q_parent, - q_block=q_block, - block_idx=int(q_idx), - ) - q_state_cache[int(q_idx)] = q_state - k_state = k_state_cache.get(int(k_idx)) - if k_state is None: - k_state = _build_k_block_group_state( - k_abs=k_abs, - k_group=k_group, - k_block=k_block, - block_idx=int(k_idx), - ) - k_state_cache[int(k_idx)] = k_state - has_any, is_full = _exact_block_state( - q_state=q_state, - k_state=k_state, - ) - partial_blocks[q_idx, k_idx] = False - full_blocks[q_idx, k_idx] = False - if is_full: - full_blocks[q_idx, k_idx] = True - elif has_any: - partial_blocks[q_idx, k_idx] = True - partial_blocks &= ~full_blocks + needs_refine = full_blocks | ((touch_counts > 1) & partial_blocks) + if bool(needs_refine.any()): + refined_partial = partial_blocks & needs_refine + refined_full = full_blocks & needs_refine + _refine_interval_blocks( + partial_blocks=refined_partial, + full_blocks=refined_full, + q_abs=q_abs, + k_abs=k_abs, + q_enter=q_enter, + k_enter=k_enter, + k_exit=k_exit, + q_block=q_block, + k_block=k_block, + ) + partial_blocks = (partial_blocks & ~needs_refine) | refined_partial + full_blocks = (full_blocks & ~needs_refine) | refined_full kv_num_blocks, kv_indices = _dense_blocks_to_ordered( partial_blocks, device=device, @@ -321,7 +505,44 @@ def _build_sparse_block_mask( ) -def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: +def prepare_block_mask_context( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> PreparedBlockMaskContext: + if group_ids.ndim != 1 or parent_ids.ndim != 1: + raise RuntimeError( + "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." + ) + if int(group_ids.numel()) != int(parent_ids.numel()): + raise RuntimeError( + "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." + ) + flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + flat_parent_ids = ( + parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + ) + row_tree = parse_shared_prefix_row( + group_ids=flat_group_ids, + parent_ids=flat_parent_ids, + ) + group_enter_np, group_exit_np = _build_group_interval_arrays( + row_tree=row_tree, + length=int(flat_group_ids.numel()), + ) + return PreparedBlockMaskContext( + source_len=int(flat_group_ids.numel()), + group_enter_np=group_enter_np, + group_exit_np=group_exit_np, + ) + + +def _validate_exact_indices( + indices: torch.Tensor, + *, + name: str, + source_len: int, +) -> int: if indices.ndim != 1: raise RuntimeError(f"{name} exact token indices must be rank 1.") if indices.dtype != torch.int64: @@ -334,52 +555,33 @@ def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: raise RuntimeError( f"{name} exact token indices must use only contiguous tail padding." ) - return indices_cpu[:first_invalid] - return indices_cpu - - -def _validate_exact_indices( - indices: torch.Tensor, - *, - name: str, - source_len: int, -) -> int: - valid = _valid_prefix(indices, name=name) - if int(valid.numel()) == 0: + indices_cpu = indices_cpu[:first_invalid] + if int(indices_cpu.numel()) == 0: return 0 - if bool((valid[1:] <= valid[:-1]).any().item()): - raise RuntimeError(f"{name} exact token indices must be strictly increasing.") - max_index = int(valid[-1].item()) + if int(indices_cpu.unique().numel()) != int(indices_cpu.numel()): + raise RuntimeError(f"{name} exact token indices must not contain duplicates.") + max_index = int(indices_cpu.max().item()) if max_index >= int(source_len): raise RuntimeError( f"{name} exact token index {max_index} exceeds source metadata length {int(source_len)}." ) - return int(valid.numel()) + return int(indices_cpu.numel()) def _validate_supported_mask_spec( spec: FlexMaskSpec, *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + source_len: int, ) -> None: - if group_ids.ndim != 1 or parent_ids.ndim != 1: - raise RuntimeError( - "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." - ) - if int(group_ids.numel()) != int(parent_ids.numel()): - raise RuntimeError( - "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." - ) q_valid_len = _validate_exact_indices( spec.exact_mask.q_token_indices, name="q", - source_len=int(group_ids.numel()), + source_len=source_len, ) k_valid_len = _validate_exact_indices( spec.exact_mask.k_token_indices, name="k", - source_len=int(group_ids.numel()), + source_len=source_len, ) for slice_ in spec.slices: if int(slice_.row_index) != 0: @@ -404,12 +606,12 @@ def _validate_supported_mask_spec( ) -def build_block_mask( +def build_block_mask_from_context( spec: FlexMaskSpec, *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + context: PreparedBlockMaskContext, device: torch.device, + validate: bool = True, ) -> BlockMask | None: if spec.q_len <= 0 or spec.k_len <= 0: return None @@ -423,12 +625,15 @@ def build_block_mask( "Exact stage k-token metadata length mismatch: " f"{int(spec.exact_mask.k_token_indices.numel())} != {int(spec.k_len)}" ) - _validate_supported_mask_spec(spec, group_ids=group_ids, parent_ids=parent_ids) + if validate: + _validate_supported_mask_spec( + spec, + source_len=context.source_len, + ) block_size = normalize_sparse_block_size(spec.block_size) return _build_sparse_block_mask( spec, device=device, - group_ids=group_ids, - parent_ids=parent_ids, + context=context, block_size=block_size, ) diff --git a/src/art/megatron/context_parallel/builder.py b/src/art/megatron/context_parallel/builder.py index 77ac1b623..5396873ab 100644 --- a/src/art/megatron/context_parallel/builder.py +++ b/src/art/megatron/context_parallel/builder.py @@ -2,110 +2,17 @@ import torch +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree + from .types import ( AttnMaskKind, AttnSlice, PackedBatchAttentionSpec, PackedRowAttentionSpec, - SharedPrefixBuilderConfig, TokenRange, ) -def _valid_length( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - ignore_padding_group_id: int, -) -> int: - valid_mask = group_ids != ignore_padding_group_id - valid_count = int(valid_mask.sum().item()) - if valid_count == 0: - return 0 - if not bool(valid_mask[:valid_count].all().item()): - raise RuntimeError("Padding tokens must be a contiguous tail") - return _infer_terminal_padding_length( - group_ids[:valid_count], - parent_ids[:valid_count], - ) - - -def _infer_terminal_padding_length( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> int: - if group_row.numel() == 0: - return 0 - runs = _scan_runs(group_row, parent_row) - if len(runs) < 2: - return int(group_row.numel()) - last_start, _last_end, last_group_id, last_parent_id = runs[-1] - if last_parent_id >= 0: - return int(group_row.numel()) - terminal_pair = (last_group_id, last_parent_id) - if any( - (group_id, parent_id) == terminal_pair - for _start, _end, group_id, parent_id in runs[:-1] - ): - return last_start - return int(group_row.numel()) - - -def _scan_runs( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> list[tuple[int, int, int, int]]: - length = int(group_row.numel()) - if length == 0: - return [] - - group_changes = group_row[1:] != group_row[:-1] - parent_changes = parent_row[1:] != parent_row[:-1] - inconsistent_parent = torch.nonzero( - torch.logical_not(group_changes) & parent_changes, - as_tuple=False, - ).flatten() - if int(inconsistent_parent.numel()) > 0: - mismatch_index = int(inconsistent_parent[0].item()) + 1 - prior_boundaries = torch.nonzero( - group_changes[: mismatch_index - 1], - as_tuple=False, - ).flatten() - start = ( - 0 - if int(prior_boundaries.numel()) == 0 - else int(prior_boundaries[-1].item()) + 1 - ) - group_id = int(group_row[start].item()) - raise RuntimeError( - "Found one group run with inconsistent parent ids: " - f"group_id={group_id}, start={start}, end={mismatch_index}" - ) - - run_starts = torch.cat( - ( - torch.zeros(1, dtype=torch.int64, device=group_row.device), - torch.nonzero(group_changes, as_tuple=False).flatten() + 1, - ) - ) - run_ends = torch.cat( - ( - run_starts[1:], - torch.tensor([length], dtype=torch.int64, device=group_row.device), - ) - ) - starts = run_starts.to(device="cpu").tolist() - ends = run_ends.to(device="cpu").tolist() - group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() - parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() - return [ - (int(start), int(end), int(group_id), int(parent_id)) - for start, end, group_id, parent_id in zip( - starts, ends, group_ids, parent_ids, strict=True - ) - ] - - def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: sorted_slices = sorted( slices, @@ -138,23 +45,11 @@ def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: return tuple(deduped) -def _is_prompt_run( - *, - start: int, - group_id: int, - parent_id: int, - ignore_padding_group_id: int, -) -> bool: - return group_id == parent_id or ( - start == 0 and parent_id == ignore_padding_group_id - ) - - def build_shared_prefix_attention_spec( *, group_ids: torch.Tensor, parent_ids: torch.Tensor, - config: SharedPrefixBuilderConfig = SharedPrefixBuilderConfig(), + ignore_padding_group_id: int = -1, ) -> PackedBatchAttentionSpec: if group_ids.shape != parent_ids.shape: raise RuntimeError( @@ -166,127 +61,49 @@ def build_shared_prefix_attention_spec( "group_ids and parent_ids must be rank-2 packed tensors, got " f"{group_ids.ndim}" ) - if int(group_ids.shape[0]) != 1: - raise RuntimeError( - "ART shared-prefix attention spec currently supports exactly one packed sequence, " - f"got batch={int(group_ids.shape[0])}." - ) - rows: list[PackedRowAttentionSpec] = [] - for row_index in range(group_ids.shape[0]): - group_row = group_ids[row_index] - parent_row = parent_ids[row_index] - valid_tokens = _valid_length( - group_row, - parent_row, - ignore_padding_group_id=config.ignore_padding_group_id, - ) - if valid_tokens == 0: + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ): + if row.valid_tokens == 0: rows.append( - PackedRowAttentionSpec(row_index=row_index, valid_tokens=0, slices=()) + PackedRowAttentionSpec( + row_index=row.row_index, valid_tokens=0, slices=() + ) ) continue - group_row = group_row[:valid_tokens] - parent_row = parent_row[:valid_tokens] - runs = _scan_runs(group_row, parent_row) - - group_run_count: dict[int, int] = {} - prompt_by_group_id: dict[int, tuple[tuple[int, int], int]] = {} - completion_ranges_by_prompt: dict[int, list[tuple[int, int]]] = {} - - for start, end, group_id, parent_id in runs: - group_run_count[group_id] = group_run_count.get(group_id, 0) + 1 - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - if group_id in prompt_by_group_id: - raise RuntimeError( - f"Prompt group_id {group_id} appears more than once in row {row_index}" - ) - family_index = len(prompt_by_group_id) - prompt_by_group_id[group_id] = ( - (start, end), - family_index, - ) - completion_ranges_by_prompt[group_id] = [] - - if config.require_contiguous_group_runs: - repeated_groups = { - group_id: count - for group_id, count in group_run_count.items() - if count > 1 and group_id != config.ignore_padding_group_id - } - if repeated_groups: - raise RuntimeError( - "Shared-prefix builder requires contiguous group runs per row, " - f"found repeats in row {row_index}: {repeated_groups}" - ) - - for start, end, group_id, parent_id in runs: - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - continue - prompt_entry = prompt_by_group_id.get(parent_id) - if prompt_entry is None: - raise RuntimeError( - "Completion run points to a missing prompt run: " - f"row={row_index}, group_id={group_id}, parent_id={parent_id}" - ) - completion_ranges_by_prompt[parent_id].append((start, end)) - + segment_by_group_id = {segment.group_id: segment for segment in row.segments} row_slices: list[AttnSlice] = [] - for prompt_group_id, ( - (prompt_start, prompt_end), - family_index, - ) in prompt_by_group_id.items(): - prompt_range = TokenRange(start=prompt_start, end=prompt_end) - row_slices.append( - AttnSlice( - q_range=prompt_range, - k_range=prompt_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) - ) - for completion_start, completion_end in completion_ranges_by_prompt[ - prompt_group_id - ]: - completion_range = TokenRange( - start=completion_start, - end=completion_end, - ) + for segment in row.segments: + q_range = TokenRange(start=segment.start, end=segment.end) + for ancestor_group_id in segment.ancestors: + ancestor = segment_by_group_id[ancestor_group_id] row_slices.append( AttnSlice( - q_range=completion_range, - k_range=prompt_range, + q_range=q_range, + k_range=TokenRange(start=ancestor.start, end=ancestor.end), mask_kind=AttnMaskKind.FULL, - row_index=row_index, - family_index=family_index, + row_index=row.row_index, + family_index=segment.family_index, ) ) - row_slices.append( - AttnSlice( - q_range=completion_range, - k_range=completion_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) + row_slices.append( + AttnSlice( + q_range=q_range, + k_range=q_range, + mask_kind=AttnMaskKind.CAUSAL, + row_index=row.row_index, + family_index=segment.family_index, ) + ) rows.append( PackedRowAttentionSpec( - row_index=row_index, - valid_tokens=valid_tokens, + row_index=row.row_index, + valid_tokens=row.valid_tokens, slices=_sort_and_dedupe_slices(row_slices), ) ) diff --git a/src/art/megatron/context_parallel/comm.py b/src/art/megatron/context_parallel/comm.py index c1767a4dc..8ea97067d 100644 --- a/src/art/megatron/context_parallel/comm.py +++ b/src/art/megatron/context_parallel/comm.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass from typing import Any, Protocol, cast @@ -141,37 +142,27 @@ def wait_post_process(self) -> tuple[torch.Tensor, torch.Tensor]: if range_.size() > 0 ) - def _apply_reduce() -> None: - dk_reduce = ( - dk_remote - if dk_remote.dtype == self.dk_local.dtype - else dk_remote.to(dtype=self.dk_local.dtype) - ) - dv_reduce = ( - dv_remote - if dv_remote.dtype == self.dv_local.dtype - else dv_remote.to(dtype=self.dv_local.dtype) - ) - reduce_fn = ( - range_reduce_sum_head_major_ - if self.input_layout == "head_major" - else range_reduce_sum_ - ) - reduce_fn( - dk_reduce, - output_tensor=self.dk_local, - ranges=flattened_ranges, - range_meta_cache=self.range_meta_cache, - ) - reduce_fn( - dv_reduce, - output_tensor=self.dv_local, - ranges=flattened_ranges, - range_meta_cache=self.range_meta_cache, - ) - return - - _apply_reduce() + reduce_fn = ( + range_reduce_sum_head_major_ + if self.input_layout == "head_major" + else range_reduce_sum_ + ) + reduce_fn( + dk_remote + if dk_remote.dtype == self.dk_local.dtype + else dk_remote.to(dtype=self.dk_local.dtype), + output_tensor=self.dk_local, + ranges=flattened_ranges, + range_meta_cache=self.range_meta_cache, + ) + reduce_fn( + dv_remote + if dv_remote.dtype == self.dv_local.dtype + else dv_remote.to(dtype=self.dv_local.dtype), + output_tensor=self.dv_local, + ranges=flattened_ranges, + range_meta_cache=self.range_meta_cache, + ) return self.dk_local, self.dv_local @@ -191,6 +182,60 @@ def _get_stream(self, tensor: torch.Tensor) -> torch.cuda.Stream | None: self._streams[device_index] = stream return stream + def _launch_exchange( + self, + *, + tensor: torch.Tensor, + recv_buffer: torch.Tensor, + total_send_rows: int, + make_send_buffer: Callable[[], torch.Tensor], + output_split_sizes: list[int], + input_split_sizes: list[int], + group: Any, + async_op: bool, + input_layout: str, + ) -> tuple[_Waitable | None, torch.Tensor, torch.cuda.Stream | None]: + stream = self._get_stream(tensor) if async_op else None + send_buffer = ( + tensor.new_empty( + _packed_peer_tensor_shape( + tensor=tensor, + total_rows=0, + input_layout=input_layout, + ) + ) + if total_send_rows <= 0 + else make_send_buffer() + ) + if stream is None: + return ( + _launch_peer_exchange( + recv_buffer=recv_buffer, + send_buffer=send_buffer, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ), + send_buffer, + None, + ) + + current_stream = torch.cuda.current_stream(tensor.device) + stream.wait_stream(current_stream) + send_buffer.record_stream(stream) + recv_buffer.record_stream(stream) + with torch.cuda.stream(stream): + handle = _launch_peer_exchange( + recv_buffer=recv_buffer, + send_buffer=send_buffer, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True, + ) + return handle, send_buffer, stream + def launch_kv_fetch( self, *, @@ -230,70 +275,23 @@ def launch_kv_fetch( ) input_split_sizes = [split * 2 for split in plan.send_splits] output_split_sizes = [split * 2 for split in plan.recv_splits] - stream = self._get_stream(k_local) if async_op else None - if stream is not None: - current_stream = torch.cuda.current_stream(k_local.device) - if total_send_rows <= 0: - send_buffer = k_local.new_empty( - _packed_peer_tensor_shape( - tensor=k_local, - total_rows=0, - input_layout=input_layout, - ) - ) - else: - send_buffer = _pack_gathered_tensors_per_peer( - left_tensor=k_local, - right_tensor=v_local, - ranges_by_peer=plan.send_ranges_by_peer, - range_meta_cache=range_meta_cache, - input_layout=input_layout, - ) - stream.wait_stream(current_stream) - send_buffer.record_stream(stream) - recv_packed.record_stream(stream) - with torch.cuda.stream(stream): - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=True, - ) - else: - if total_send_rows <= 0: - send_buffer = k_local.new_empty( - _packed_peer_tensor_shape( - tensor=k_local, - total_rows=0, - input_layout=input_layout, - ) - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - else: - send_buffer = _pack_gathered_tensors_per_peer( - left_tensor=k_local, - right_tensor=v_local, - ranges_by_peer=plan.send_ranges_by_peer, - range_meta_cache=range_meta_cache, - input_layout=input_layout, - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) + handle, send_buffer, stream = self._launch_exchange( + tensor=k_local, + recv_buffer=recv_packed, + total_send_rows=total_send_rows, + make_send_buffer=lambda: _pack_gathered_tensors_per_peer( + left_tensor=k_local, + right_tensor=v_local, + ranges_by_peer=plan.send_ranges_by_peer, + range_meta_cache=range_meta_cache, + input_layout=input_layout, + ), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + input_layout=input_layout, + ) return KvFetchWork( packed_buffer=recv_packed, recv_splits=plan.recv_splits, @@ -333,89 +331,33 @@ def launch_dkv_reduce( total_send_rows = int(sum(plan.send_splits)) recv_total = int(sum(plan.recv_splits)) - recv_packed = ( - dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=recv_total, - input_layout=input_layout, - ) - ) - if recv_total > 0 - else dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) + recv_packed = dk_remote.new_empty( + _packed_peer_tensor_shape( + tensor=dk_remote, + total_rows=recv_total, + input_layout=input_layout, ) ) input_split_sizes = [split * 2 for split in plan.send_splits] output_split_sizes = [split * 2 for split in plan.recv_splits] - stream = self._get_stream(dk_remote) if async_op else None - if stream is not None: - current_stream = torch.cuda.current_stream(dk_remote.device) - if total_send_rows <= 0: - send_buffer = dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) - ) - else: - send_buffer = _pack_split_tensors_by_peer( - left_tensor=dk_remote, - right_tensor=dv_remote, - splits=plan.send_splits, - input_layout=input_layout, - ) - stream.wait_stream(current_stream) - send_buffer.record_stream(stream) - recv_packed.record_stream(stream) - with torch.cuda.stream(stream): - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=True, - ) - else: - if total_send_rows <= 0: - send_buffer = dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - else: - send_buffer = _pack_split_tensors_by_peer( - left_tensor=dk_remote, - right_tensor=dv_remote, - splits=plan.send_splits, - input_layout=input_layout, - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) + handle, send_buffer, stream = self._launch_exchange( + tensor=dk_remote, + recv_buffer=recv_packed, + total_send_rows=total_send_rows, + make_send_buffer=lambda: _pack_split_tensors_by_peer( + left_tensor=dk_remote, + right_tensor=dv_remote, + splits=plan.send_splits, + input_layout=input_layout, + ), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + input_layout=input_layout, + ) return DkvReduceWork( - packed_buffer=recv_packed if recv_total > 0 else None, + packed_buffer=recv_packed, handle=handle, send_buffer=send_buffer, stream=stream, @@ -449,29 +391,6 @@ def range_gather_per_peer( return torch.cat(chunks, dim=0).contiguous() -def _split_tensor_to_peer( - input_tensor: torch.Tensor, - splits: tuple[int, ...], -) -> torch.Tensor: - if int(sum(splits)) == 0: - return input_tensor.new_empty((0, *input_tensor.shape[1:])) - if int(input_tensor.shape[0]) == int(sum(splits)): - return input_tensor.contiguous() - if len([split for split in splits if split > 0]) > 1: - raise RuntimeError( - f"Expected at most one non-zero send split for dKV reduce, got {splits}" - ) - pieces: list[torch.Tensor] = [] - cursor = 0 - for split in splits: - if split == 0: - pieces.append(input_tensor.new_empty((0, *input_tensor.shape[1:]))) - continue - pieces.append(input_tensor[cursor : cursor + split]) - cursor += split - return torch.cat(pieces, dim=0).contiguous() - - def _pack_gathered_tensors_per_peer( *, left_tensor: torch.Tensor, @@ -480,56 +399,16 @@ def _pack_gathered_tensors_per_peer( range_meta_cache: dict[Any, Any] | None = None, input_layout: str = "token_major", ) -> torch.Tensor: - if input_layout == "head_major": - return _pack_gathered_tensors_per_peer_head_major( - left_tensor=left_tensor, - right_tensor=right_tensor, - ranges_by_peer=ranges_by_peer, - range_meta_cache=range_meta_cache, - ) - if input_layout != "token_major": - raise ValueError(f"Unsupported gathered-pack input layout: {input_layout}") - total_rows = sum( - range_.size() for peer_ranges in ranges_by_peer for range_ in peer_ranges - ) - if total_rows == 0: - return left_tensor.new_empty((0, *left_tensor.shape[1:])) - packed = left_tensor.new_empty((total_rows * 2, *left_tensor.shape[1:])) - cursor = 0 - for peer_ranges in ranges_by_peer: - split = sum(range_.size() for range_ in peer_ranges) - if split <= 0: - continue - range_gather( - left_tensor, - peer_ranges, - output=packed[cursor : cursor + split], - range_meta_cache=range_meta_cache, - ) - range_gather( - right_tensor, - peer_ranges, - output=packed[cursor + split : cursor + split * 2], - range_meta_cache=range_meta_cache, - ) - cursor += split * 2 - return packed - - -def _pack_gathered_tensors_per_peer_head_major( - *, - left_tensor: torch.Tensor, - right_tensor: torch.Tensor, - ranges_by_peer: tuple[tuple[TokenRange, ...], ...], - range_meta_cache: dict[Any, Any] | None = None, -) -> torch.Tensor: + _validate_peer_layout(input_layout, context="gathered-pack input") total_rows = sum( range_.size() for peer_ranges in ranges_by_peer for range_ in peer_ranges ) - if total_rows == 0: - return left_tensor.new_empty((0, left_tensor.shape[0], left_tensor.shape[2])) packed = left_tensor.new_empty( - (total_rows * 2, left_tensor.shape[0], left_tensor.shape[2]) + _packed_peer_tensor_shape( + tensor=left_tensor, + total_rows=total_rows, + input_layout=input_layout, + ) ) cursor = 0 for peer_ranges in ranges_by_peer: @@ -537,18 +416,20 @@ def _pack_gathered_tensors_per_peer_head_major( if split <= 0: continue packed[cursor : cursor + split].copy_( - range_gather_head_major( + _gather_peer_rows( left_tensor, peer_ranges, + input_layout=input_layout, range_meta_cache=range_meta_cache, - ).permute(1, 0, 2) + ) ) packed[cursor + split : cursor + split * 2].copy_( - range_gather_head_major( + _gather_peer_rows( right_tensor, peer_ranges, + input_layout=input_layout, range_meta_cache=range_meta_cache, - ).permute(1, 0, 2) + ) ) cursor += split * 2 return packed @@ -561,79 +442,83 @@ def _pack_split_tensors_by_peer( splits: tuple[int, ...], input_layout: str = "token_major", ) -> torch.Tensor: - if input_layout == "head_major": - return _pack_split_tensors_by_peer_head_major( - left_tensor=left_tensor, - right_tensor=right_tensor, - splits=splits, - ) - if input_layout != "token_major": - raise ValueError(f"Unsupported split-pack input layout: {input_layout}") + _validate_peer_layout(input_layout, context="split-pack input") total_rows = int(sum(splits)) - if total_rows == 0: - return left_tensor.new_empty((0, *left_tensor.shape[1:])) - packed = left_tensor.new_empty((total_rows * 2, *left_tensor.shape[1:])) + packed = left_tensor.new_empty( + _packed_peer_tensor_shape( + tensor=left_tensor, + total_rows=total_rows, + input_layout=input_layout, + ) + ) cursor = 0 for split in splits: if split <= 0: continue packed[cursor * 2 : cursor * 2 + split].copy_( - left_tensor[cursor : cursor + split] + _slice_peer_rows(left_tensor, cursor, cursor + split, layout=input_layout) ) packed[cursor * 2 + split : cursor * 2 + split * 2].copy_( - right_tensor[cursor : cursor + split] + _slice_peer_rows(right_tensor, cursor, cursor + split, layout=input_layout) ) cursor += split - if cursor != int(left_tensor.shape[0]) or cursor != int(right_tensor.shape[0]): + left_rows = _peer_row_count(left_tensor, layout=input_layout) + right_rows = _peer_row_count(right_tensor, layout=input_layout) + if cursor != left_rows or cursor != right_rows: raise RuntimeError( "Packed split consumed the wrong number of rows: " - f"consumed={cursor}, left={int(left_tensor.shape[0])}, right={int(right_tensor.shape[0])}" + f"consumed={cursor}, left={left_rows}, right={right_rows}" ) return packed +def _validate_peer_layout(layout: str, *, context: str) -> None: + if layout not in {"token_major", "head_major"}: + raise ValueError(f"Unsupported {context} layout: {layout}") + + def _packed_peer_tensor_shape( *, tensor: torch.Tensor, total_rows: int, input_layout: str, ) -> tuple[int, ...]: + _validate_peer_layout(input_layout, context="peer tensor input") if input_layout == "head_major": return (total_rows * 2, int(tensor.shape[0]), int(tensor.shape[2])) - if input_layout != "token_major": - raise ValueError(f"Unsupported split-pack input layout: {input_layout}") return (total_rows * 2, *tuple(int(dim) for dim in tensor.shape[1:])) -def _pack_split_tensors_by_peer_head_major( +def _peer_row_count(tensor: torch.Tensor, *, layout: str) -> int: + return int(tensor.shape[1] if layout == "head_major" else tensor.shape[0]) + + +def _slice_peer_rows( + tensor: torch.Tensor, + start: int, + end: int, *, - left_tensor: torch.Tensor, - right_tensor: torch.Tensor, - splits: tuple[int, ...], + layout: str, ) -> torch.Tensor: - total_rows = int(sum(splits)) - if total_rows == 0: - return left_tensor.new_empty((0, left_tensor.shape[0], left_tensor.shape[2])) - packed = left_tensor.new_empty( - (total_rows * 2, left_tensor.shape[0], left_tensor.shape[2]) - ) - cursor = 0 - for split in splits: - if split <= 0: - continue - packed[cursor * 2 : cursor * 2 + split].copy_( - left_tensor[:, cursor : cursor + split].permute(1, 0, 2) - ) - packed[cursor * 2 + split : cursor * 2 + split * 2].copy_( - right_tensor[:, cursor : cursor + split].permute(1, 0, 2) - ) - cursor += split - if cursor != int(left_tensor.shape[1]) or cursor != int(right_tensor.shape[1]): - raise RuntimeError( - "Head-major split pack consumed the wrong number of rows: " - f"consumed={cursor}, left={int(left_tensor.shape[1])}, right={int(right_tensor.shape[1])}" - ) - return packed + if layout == "head_major": + return tensor[:, start:end].movedim(1, 0) + return tensor[start:end] + + +def _gather_peer_rows( + tensor: torch.Tensor, + ranges: tuple[TokenRange, ...], + *, + input_layout: str, + range_meta_cache: dict[Any, Any] | None, +) -> torch.Tensor: + if input_layout == "head_major": + return range_gather_head_major( + tensor, + ranges, + range_meta_cache=range_meta_cache, + ).movedim(1, 0) + return range_gather(tensor, ranges, range_meta_cache=range_meta_cache) def _unpack_packed_tensor_per_peer( @@ -642,15 +527,13 @@ def _unpack_packed_tensor_per_peer( *, output_layout: str = "token_major", ) -> tuple[torch.Tensor, torch.Tensor]: - if output_layout == "head_major": - return _unpack_packed_tensor_per_peer_head_major( + _validate_peer_layout(output_layout, context="packed-tensor output") + if int(packed_tensor.shape[0]) == 0: + empty = _new_unpacked_peer_tensor( packed_tensor, - splits, + total_rows=0, + output_layout=output_layout, ) - if output_layout != "token_major": - raise ValueError(f"Unsupported packed-tensor output layout: {output_layout}") - if int(packed_tensor.shape[0]) == 0: - empty = packed_tensor.new_empty((0, *packed_tensor.shape[1:])) return empty, empty total_rows = 0 cursor = 0 @@ -664,62 +547,59 @@ def _unpack_packed_tensor_per_peer( "Packed tensor unpack consumed the wrong number of rows: " f"consumed={cursor}, input={int(packed_tensor.shape[0])}" ) - left = packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) - right = packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) + left = _new_unpacked_peer_tensor( + packed_tensor, + total_rows=total_rows, + output_layout=output_layout, + ) + right = _new_unpacked_peer_tensor( + packed_tensor, + total_rows=total_rows, + output_layout=output_layout, + ) in_cursor = 0 out_cursor = 0 for split in splits: if split <= 0: continue - left[out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor : in_cursor + split] + _copy_from_peer_rows( + left, + out_cursor, + packed_tensor[in_cursor : in_cursor + split], + output_layout=output_layout, ) - right[out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor + split : in_cursor + split * 2] + _copy_from_peer_rows( + right, + out_cursor, + packed_tensor[in_cursor + split : in_cursor + split * 2], + output_layout=output_layout, ) in_cursor += split * 2 out_cursor += split return left, right -def _unpack_packed_tensor_per_peer_head_major( +def _new_unpacked_peer_tensor( packed_tensor: torch.Tensor, - splits: tuple[int, ...], -) -> tuple[torch.Tensor, torch.Tensor]: - if int(packed_tensor.shape[0]) == 0: - empty = packed_tensor.new_empty( - (packed_tensor.shape[1], 0, packed_tensor.shape[2]) - ) - return empty, empty - total_rows = 0 - cursor = 0 - for split in splits: - if split <= 0: - continue - cursor += split * 2 - total_rows += split - if cursor != int(packed_tensor.shape[0]): - raise RuntimeError( - "Packed tensor unpack consumed the wrong number of rows: " - f"consumed={cursor}, input={int(packed_tensor.shape[0])}" - ) - left = packed_tensor.new_empty( - (packed_tensor.shape[1], total_rows, packed_tensor.shape[2]) - ) - right = packed_tensor.new_empty( - (packed_tensor.shape[1], total_rows, packed_tensor.shape[2]) - ) - in_cursor = 0 - out_cursor = 0 - for split in splits: - if split <= 0: - continue - left[:, out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor : in_cursor + split].permute(1, 0, 2) - ) - right[:, out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor + split : in_cursor + split * 2].permute(1, 0, 2) + *, + total_rows: int, + output_layout: str, +) -> torch.Tensor: + if output_layout == "head_major": + return packed_tensor.new_empty( + (packed_tensor.shape[1], total_rows, *packed_tensor.shape[2:]) ) - in_cursor += split * 2 - out_cursor += split - return left, right + return packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) + + +def _copy_from_peer_rows( + output: torch.Tensor, + start: int, + rows: torch.Tensor, + *, + output_layout: str, +) -> None: + if output_layout == "head_major": + output[:, start : start + int(rows.shape[0])].copy_(rows.movedim(0, 1)) + else: + output[start : start + int(rows.shape[0])].copy_(rows) diff --git a/src/art/megatron/context_parallel/executor.py b/src/art/megatron/context_parallel/executor.py index e5e219e72..5beaec9f4 100644 --- a/src/art/megatron/context_parallel/executor.py +++ b/src/art/megatron/context_parallel/executor.py @@ -19,7 +19,7 @@ sparse_compiled_flex_attention, ) -from .block_mask import build_block_mask +from .block_mask import build_block_mask_from_context, prepare_block_mask_context from .comm import A2AVCommunicator from .range_ops import ( range_gather_head_major, @@ -684,17 +684,24 @@ def _build_stage_block_mask( raise RuntimeError( f"Stage {stage_plan.stage_index} is missing exact mask metadata" ) - mask = build_block_mask( + block_mask_context = state.execution_cache.block_mask_context + if block_mask_context is None: + block_mask_context = prepare_block_mask_context( + group_ids=state.group_ids, + parent_ids=state.parent_ids, + ) + state.execution_cache.block_mask_context = block_mask_context + mask = build_block_mask_from_context( FlexMaskSpec( q_len=int(execution_spec.q_len), k_len=int(execution_spec.k_len), block_size=resolved_block_size, slices=stage_plan.slices, - exact_mask=mask_metadata.model_dump(mode="python"), + exact_mask=mask_metadata, ), - group_ids=state.group_ids, - parent_ids=state.parent_ids, + context=block_mask_context, device=device, + validate=False, ) cache[cache_key] = mask return mask @@ -774,30 +781,6 @@ def prepare_context_parallel_execution_state( ) -def _causal_slice_pair_count(slice_: AttnSlice) -> int: - q_start = int(slice_.q_range.start) - q_end = int(slice_.q_range.end) - k_start = int(slice_.k_range.start) - k_end = int(slice_.k_range.end) - if q_end <= q_start or k_end <= k_start: - return 0 - - k_len = k_end - k_start - partial_q_start = max(q_start, k_start) - partial_q_end = min(q_end - 1, k_end - 2) - partial = 0 - if partial_q_start <= partial_q_end: - count = partial_q_end - partial_q_start + 1 - partial = count * (partial_q_start + partial_q_end + 2 - 2 * k_start) // 2 - - full_q_start = max(q_start, k_end - 1) - full_q_end = q_end - 1 - full = 0 - if full_q_start <= full_q_end: - full = (full_q_end - full_q_start + 1) * k_len - return int(partial + full) - - def _validate_stage_block_alignment( *, q_len: int, diff --git a/src/art/megatron/context_parallel/layout_index.py b/src/art/megatron/context_parallel/layout_index.py index 99fb2c35b..9f60550a0 100644 --- a/src/art/megatron/context_parallel/layout_index.py +++ b/src/art/megatron/context_parallel/layout_index.py @@ -1,10 +1,9 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict +from dataclasses import dataclass -class TokenLayoutIndex(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class TokenLayoutIndex: ownership_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] token_counts_by_rank: tuple[int, ...] diff --git a/src/art/megatron/context_parallel/runtime.py b/src/art/megatron/context_parallel/runtime.py index c6eb9fddd..b89c42fd2 100644 --- a/src/art/megatron/context_parallel/runtime.py +++ b/src/art/megatron/context_parallel/runtime.py @@ -1,12 +1,11 @@ from __future__ import annotations from bisect import bisect_left, bisect_right +from dataclasses import dataclass, replace import hashlib import json from typing import Any, cast -import warnings -from pydantic import BaseModel, ConfigDict import torch from art.loss import shift_tensor @@ -19,8 +18,6 @@ AttnMaskKind, AttnSlice, ContextParallelConfig, - ContextParallelRuntimeKey, - ContextParallelRuntimePlan, DispatchedPackedTensors, DkvReducePlan, ExactMaskMetadata, @@ -28,17 +25,12 @@ PackedBatchAttentionSpec, PackedRowAttentionSpec, ParallelTopology, - PlannerProvenance, PreparedMegatronBatch, RankRuntimePlan, StagePlan, TokenRange, ) -_PLANNER_RUNTIME_BACKEND = "art_context_parallel" -_PLANNER_BEST_EFFORT_WARNING_KEYS: set[ - tuple[str, str, int, str, str, tuple[int, ...]] -] = set() _CHUNK_MASK_STATS_TORCH_THRESHOLD = 1024 _CP4_SEARCH_PROBE_CANDIDATE_LIMIT = 2 _CP4_SEARCH_PROBE_IMPROVEMENT_MS = 1.0 @@ -48,17 +40,15 @@ StageSliceKey = tuple[int, int, int, int, int, str, int] -class _PlanningBundle(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +@dataclass(frozen=True) +class _PlanningBundle: spec: PackedBatchAttentionSpec - runtime_key: ContextParallelRuntimeKey - runtime_plan: ContextParallelRuntimePlan + rank_plans: tuple[RankRuntimePlan, ...] gdn_execution_spec: Any | None = None _PLANNING_BUNDLE_CACHE: dict[str, _PlanningBundle] = {} -_RUNTIME_PLAN_CACHE: dict[tuple[str, int], ContextParallelRuntimePlan] = {} +_RUNTIME_PLAN_CACHE: dict[str, tuple[RankRuntimePlan, ...]] = {} _GDN_RANK_PLAN_CACHE: dict[tuple[str, str, int | None, int], Any] = {} @@ -111,174 +101,14 @@ def _planning_bundle_cache_key( { "group_ids": _metadata_tensor_digest(group_ids), "parent_ids": _metadata_tensor_digest(parent_ids), - "topology": topology.model_dump(mode="json"), - "config": config.model_dump(mode="json"), + "topology": _dataclass_payload(topology), + "config": _dataclass_payload(config), "original_seq_len": int(original_seq_len), "build_gdn_execution_spec": bool(build_gdn_execution_spec), } ) -def _rank_plan_cache_key( - *, - planning_key: str, - device: torch.device, - cp_rank: int, -) -> tuple[str, str, int | None, int]: - return (planning_key, device.type, device.index, int(cp_rank)) - - -def _config_for_runtime_cp( - *, - topology: ParallelTopology, - config: ContextParallelConfig, -) -> ContextParallelConfig: - cp_size = max(int(topology.cp), 1) - updates: dict[str, Any] = {} - applied_override = False - for override in config.planner_cp_overrides: - if int(override.cp_size) != cp_size: - continue - override_updates = override.model_dump(mode="python", exclude_none=True) - override_updates.pop("cp_size", None) - updates.update(override_updates) - applied_override = True - if not applied_override: - return config - updates.setdefault("planner_tuned_cp_sizes", (cp_size,)) - return config.model_copy(update=updates) - - -def _normalized_planner_metadata_value(value: str | None) -> str: - if value is None: - return "" - normalized = "".join( - character.lower() if character.isalnum() else " " - for character in str(value).strip() - ) - return " ".join(part for part in normalized.split() if part) - - -def _planner_metadata_matches( - expected: str | None, - actual: str | None, - *, - fuzzy: bool, -) -> bool: - normalized_expected = _normalized_planner_metadata_value(expected) - normalized_actual = _normalized_planner_metadata_value(actual) - if not normalized_expected or not normalized_actual: - return False - if normalized_expected == normalized_actual: - return True - return bool( - fuzzy - and ( - normalized_expected in normalized_actual - or normalized_actual in normalized_expected - ) - ) - - -def _planner_runtime_hardware() -> str | None: - if not torch.cuda.is_available(): - return None - try: - return str(torch.cuda.get_device_name(torch.cuda.current_device())) - except Exception: - return str(torch.cuda.get_device_name(0)) - - -def _planner_best_effort_warning_message(provenance: PlannerProvenance) -> str: - mismatch_reasons: list[str] = [] - if not provenance.backend_match: - mismatch_reasons.append( - f"backend runtime={provenance.runtime_backend!r} tuned={provenance.tuned_backend!r}" - ) - if not provenance.hardware_match: - mismatch_reasons.append( - f"hardware runtime={provenance.runtime_hardware!r} tuned={provenance.tuned_hardware!r}" - ) - if not provenance.cp_size_match: - mismatch_reasons.append( - f"cp_size runtime={int(provenance.runtime_cp_size)} tuned={list(provenance.tuned_cp_sizes)}" - ) - mismatch_text = ( - "; ".join(mismatch_reasons) if mismatch_reasons else "metadata missing" - ) - return ( - "ART context parallel planner coefficients are running in best-effort mode; " - f"{mismatch_text}. The runtime will continue with the configured coefficients." - ) - - -def _planner_provenance( - *, - topology: ParallelTopology, - config: ContextParallelConfig, - warn: bool = True, -) -> PlannerProvenance: - runtime_hardware = _planner_runtime_hardware() - tuned_cp_sizes = tuple( - sorted( - { - int(cp_size) - for cp_size in config.planner_tuned_cp_sizes - if int(cp_size) > 0 - } - ) - ) - provenance = PlannerProvenance( - runtime_backend=_PLANNER_RUNTIME_BACKEND, - runtime_hardware=runtime_hardware, - runtime_cp_size=max(int(topology.cp), 1), - tuned_backend=config.planner_tuned_backend, - tuned_hardware=config.planner_tuned_hardware, - tuned_cp_sizes=tuned_cp_sizes, - backend_match=_planner_metadata_matches( - config.planner_tuned_backend, - _PLANNER_RUNTIME_BACKEND, - fuzzy=False, - ), - hardware_match=_planner_metadata_matches( - config.planner_tuned_hardware, - runtime_hardware, - fuzzy=True, - ), - cp_size_match=bool(tuned_cp_sizes) - and max(int(topology.cp), 1) in tuned_cp_sizes, - using_best_effort=False, - ) - if ( - provenance.backend_match - and provenance.hardware_match - and provenance.cp_size_match - ): - return provenance - - warning_message = _planner_best_effort_warning_message(provenance) - warning_key = ( - _normalized_planner_metadata_value(provenance.runtime_backend), - _normalized_planner_metadata_value(provenance.runtime_hardware), - int(provenance.runtime_cp_size), - _normalized_planner_metadata_value(provenance.tuned_backend), - _normalized_planner_metadata_value(provenance.tuned_hardware), - provenance.tuned_cp_sizes, - ) - warning_emitted = False - if warn and warning_key not in _PLANNER_BEST_EFFORT_WARNING_KEYS: - _PLANNER_BEST_EFFORT_WARNING_KEYS.add(warning_key) - warnings.warn(warning_message, RuntimeWarning, stacklevel=3) - warning_emitted = True - return provenance.model_copy( - update={ - "using_best_effort": True, - "warning_message": warning_message, - "warning_emitted": warning_emitted, - } - ) - - def _normalized_chunk_size( *, valid_tokens: int, @@ -351,7 +181,7 @@ def _search_config_for_chunk_count( return config if all(int(getattr(config, key)) == int(value) for key, value in updates.items()): return config - return config.model_copy(update=updates) + return replace(config, **updates) def _best_improving_move( @@ -387,9 +217,9 @@ def _best_improving_move( candidate = list(current_owners) candidate[chunk_index] = dst_rank candidate_owners = tuple(candidate) - if not _assignment_uses_all_ranks( - candidate_owners, - cp_size=cp_size, + if ( + len(candidate_owners) >= cp_size + and len(set(candidate_owners)) != cp_size ): continue candidate_eval = evaluate_candidate( @@ -410,12 +240,10 @@ def _build_chunk_ranges( valid_tokens: int, chunk_size: int, ) -> tuple[TokenRange, ...]: - ranges: list[TokenRange] = [] - for start in range(0, valid_tokens, chunk_size): - ranges.append( - TokenRange(start=start, end=min(start + chunk_size, valid_tokens)) - ) - return tuple(ranges) + return tuple( + TokenRange(start=start, end=min(start + chunk_size, valid_tokens)) + for start in range(0, valid_tokens, chunk_size) + ) def _indexed_intersections( @@ -445,33 +273,6 @@ def _indexed_intersections( return intersections -def _slice_pair_count( - *, - mask_kind: AttnMaskKind, - q_range: TokenRange, - k_range: TokenRange, -) -> int: - if mask_kind is AttnMaskKind.FULL: - return int(q_range.size()) * int(k_range.size()) - return _causal_piece_pair_count( - q_range=q_range, - k_range=k_range, - ) - - -def _causal_piece_pair_count( - *, - q_range: TokenRange, - k_range: TokenRange, -) -> int: - return _causal_piece_pair_count_from_bounds( - q_start=int(q_range.start), - q_end=int(q_range.end), - k_start=int(k_range.start), - k_end=int(k_range.end), - ) - - def _causal_piece_pair_count_from_bounds( *, q_start: int, @@ -498,91 +299,15 @@ def _causal_piece_pair_count_from_bounds( return int(partial + full) -def _chunk_piece_decomposition( - *, - start: int, - end: int, - chunk_size: int, -) -> tuple[ - int, tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...], int -]: - first = start // chunk_size - last = (end - 1) // chunk_size - piece_starts: list[int] = [] - piece_ends: list[int] = [] - piece_lengths: list[int] = [] - piece_prefix_lengths: list[int] = [] - running_len = 0 - for chunk_index in range(first, last + 1): - piece_start = start if chunk_index == first else chunk_index * chunk_size - piece_end = end if chunk_index == last else (chunk_index + 1) * chunk_size - piece_len = piece_end - piece_start - if piece_len <= 0: - continue - running_len += piece_len - piece_starts.append(piece_start) - piece_ends.append(piece_end) - piece_lengths.append(piece_len) - piece_prefix_lengths.append(running_len) - return ( - first, - tuple(piece_starts), - tuple(piece_ends), - tuple(piece_lengths), - tuple(piece_prefix_lengths), - running_len, - ) - - -def _can_use_shared_prefix_chunk_pair_program( - row_spec: PackedRowAttentionSpec, -) -> bool: - slices = row_spec.slices - index = 0 - while index < len(slices): - prompt_slice = slices[index] - if ( - prompt_slice.family_index is None - or prompt_slice.mask_kind is not AttnMaskKind.CAUSAL - or prompt_slice.q_range != prompt_slice.k_range - ): - return False - prompt_family_index = prompt_slice.family_index - if prompt_family_index is None: - raise RuntimeError("shared-prefix prompt slices must carry family_index") - family_index = int(prompt_family_index) - prompt_start = int(prompt_slice.q_range.start) - prompt_end = int(prompt_slice.q_range.end) - index += 1 - while index < len(slices): - family_value = slices[index].family_index - if family_value is None or int(family_value) != family_index: - break - if index + 1 >= len(slices): - return False - full_slice = slices[index] - causal_slice = slices[index + 1] - if ( - full_slice.family_index != prompt_slice.family_index - or causal_slice.family_index != prompt_slice.family_index - or full_slice.mask_kind is not AttnMaskKind.FULL - or causal_slice.mask_kind is not AttnMaskKind.CAUSAL - or full_slice.q_range != causal_slice.q_range - or causal_slice.q_range != causal_slice.k_range - or int(full_slice.k_range.start) != prompt_start - or int(full_slice.k_range.end) != prompt_end - ): - return False - index += 2 - return True - - -def _build_chunk_pair_program_generic( +def _build_chunk_pair_program( row_spec: PackedRowAttentionSpec, *, - chunk_count: int, - chunk_size: int, + chunk_ranges: tuple[TokenRange, ...], ) -> tuple[torch.Tensor, list[float]]: + chunk_count = len(chunk_ranges) + if chunk_count == 0: + return torch.zeros((0, 0), dtype=torch.int64), [] + chunk_size = int(chunk_ranges[0].size()) pair_rows = [[0 for _ in range(chunk_count)] for _ in range(chunk_count)] q_weights = [0.0 for _ in range(chunk_count)] @@ -681,138 +406,6 @@ def _build_chunk_pair_program_generic( return torch.tensor(pair_rows, dtype=torch.int64), q_weights -def _build_chunk_pair_program( - row_spec: PackedRowAttentionSpec, - *, - chunk_ranges: tuple[TokenRange, ...], -) -> tuple[torch.Tensor, list[float]]: - chunk_count = len(chunk_ranges) - if chunk_count == 0: - return torch.zeros((0, 0), dtype=torch.int64), [] - chunk_size = int(chunk_ranges[0].size()) - if not _can_use_shared_prefix_chunk_pair_program(row_spec): - return _build_chunk_pair_program_generic( - row_spec, - chunk_count=chunk_count, - chunk_size=chunk_size, - ) - - pair_rows = [[0 for _ in range(chunk_count)] for _ in range(chunk_count)] - q_weights = [0.0 for _ in range(chunk_count)] - slices = row_spec.slices - index = 0 - while index < len(slices): - prompt_slice = slices[index] - ( - prompt_first, - prompt_starts, - prompt_ends, - prompt_lengths, - prompt_prefix, - prompt_total, - ) = _chunk_piece_decomposition( - start=int(prompt_slice.q_range.start), - end=int(prompt_slice.q_range.end), - chunk_size=chunk_size, - ) - for offset, q_chunk_index in enumerate( - range(prompt_first, prompt_first + len(prompt_lengths)) - ): - q_piece_len = prompt_lengths[offset] - row = pair_rows[q_chunk_index] - q_total = 0 - if offset > 0: - for k_offset in range(offset): - row[prompt_first + k_offset] += ( - q_piece_len * prompt_lengths[k_offset] - ) - q_total += q_piece_len * prompt_prefix[offset - 1] - pair_count = _causal_piece_pair_count_from_bounds( - q_start=prompt_starts[offset], - q_end=prompt_ends[offset], - k_start=prompt_starts[offset], - k_end=prompt_ends[offset], - ) - if pair_count > 0: - row[q_chunk_index] += pair_count - q_total += pair_count - if q_total > 0: - q_weights[q_chunk_index] += float(q_total) - - prompt_family_index = prompt_slice.family_index - if prompt_family_index is None: - raise RuntimeError("shared-prefix prompt slices must carry family_index") - family_index = int(prompt_family_index) - index += 1 - completion_chunk_indices: list[int] = [] - completion_chunk_totals: list[int] = [] - while index < len(slices): - family_value = slices[index].family_index - if family_value is None or int(family_value) != family_index: - break - full_slice = slices[index] - ( - completion_first, - completion_starts, - completion_ends, - completion_lengths, - completion_prefix, - _, - ) = _chunk_piece_decomposition( - start=int(full_slice.q_range.start), - end=int(full_slice.q_range.end), - chunk_size=chunk_size, - ) - for offset, q_chunk_index in enumerate( - range(completion_first, completion_first + len(completion_lengths)) - ): - q_piece_len = completion_lengths[offset] - if ( - completion_chunk_indices - and completion_chunk_indices[-1] == q_chunk_index - ): - completion_chunk_totals[-1] += q_piece_len - else: - completion_chunk_indices.append(q_chunk_index) - completion_chunk_totals.append(q_piece_len) - - for offset, q_chunk_index in enumerate( - range(completion_first, completion_first + len(completion_lengths)) - ): - q_piece_len = completion_lengths[offset] - row = pair_rows[q_chunk_index] - q_total = 0 - if offset > 0: - for k_offset in range(offset): - row[completion_first + k_offset] += ( - q_piece_len * completion_lengths[k_offset] - ) - q_total += q_piece_len * completion_prefix[offset - 1] - pair_count = _causal_piece_pair_count_from_bounds( - q_start=completion_starts[offset], - q_end=completion_ends[offset], - k_start=completion_starts[offset], - k_end=completion_ends[offset], - ) - if pair_count > 0: - row[q_chunk_index] += pair_count - q_total += pair_count - if q_total > 0: - q_weights[q_chunk_index] += float(q_total) - index += 2 - - for q_chunk_index, total_q_len in zip( - completion_chunk_indices, - completion_chunk_totals, - strict=True, - ): - row = pair_rows[q_chunk_index] - for k_offset, k_piece_len in enumerate(prompt_lengths): - row[prompt_first + k_offset] += total_q_len * k_piece_len - q_weights[q_chunk_index] += float(total_q_len * prompt_total) - return torch.tensor(pair_rows, dtype=torch.int64), q_weights - - def _collect_rank_stage_pieces( row_spec: PackedRowAttentionSpec, *, @@ -962,63 +555,6 @@ def _contiguous_chunk_assignment( return tuple(owners) -def _bucket_chunk_assignment( - *, - q_weights: list[float], - cp_size: int, -) -> tuple[int, ...]: - chunk_count = len(q_weights) - if chunk_count == 0: - return tuple() - if cp_size <= 1: - return tuple(0 for _ in range(chunk_count)) - rank_loads = [0.0 for _ in range(cp_size)] - rank_chunk_counts = [0 for _ in range(cp_size)] - owners = [-1 for _ in range(chunk_count)] - for chunk_index in sorted( - range(chunk_count), - key=lambda index: (-q_weights[index], index), - ): - rank = min( - range(cp_size), - key=lambda candidate: ( - rank_loads[candidate], - rank_chunk_counts[candidate], - candidate, - ), - ) - owners[chunk_index] = rank - rank_loads[rank] += q_weights[chunk_index] - rank_chunk_counts[rank] += 1 - return tuple(int(owner) for owner in owners) - - -def _striped_chunk_assignment( - *, - chunk_count: int, - cp_size: int, - group_size: int, -) -> tuple[int, ...]: - if chunk_count == 0: - return tuple() - if cp_size <= 1: - return tuple(0 for _ in range(chunk_count)) - group_size = max(1, int(group_size)) - return tuple( - ((chunk_index // group_size) % cp_size) for chunk_index in range(chunk_count) - ) - - -def _assignment_uses_all_ranks( - owners: tuple[int, ...], - *, - cp_size: int, -) -> bool: - if len(owners) < cp_size: - return True - return len({int(owner) for owner in owners}) == cp_size - - def _candidate_chunk_indices( *, owners: tuple[int, ...], @@ -1127,31 +663,6 @@ def _chunk_mask_stats( return token_count, range_count -def _merge_chunk_ranges_from_mask( - *, - chunk_ranges: tuple[TokenRange, ...], - chunk_mask: torch.Tensor, -) -> tuple[TokenRange, ...]: - chunk_indices = torch.nonzero(chunk_mask, as_tuple=False).flatten() - if int(chunk_indices.numel()) == 0: - return tuple() - ordered_chunk_indices = chunk_indices.tolist() - first_range = chunk_ranges[int(ordered_chunk_indices[0])] - current_start = int(first_range.start) - current_end = int(first_range.end) - merged: list[TokenRange] = [] - for chunk_index in ordered_chunk_indices[1:]: - range_ = chunk_ranges[int(chunk_index)] - if int(range_.start) <= current_end: - current_end = max(current_end, int(range_.end)) - continue - merged.append(TokenRange(start=current_start, end=current_end)) - current_start = int(range_.start) - current_end = int(range_.end) - merged.append(TokenRange(start=current_start, end=current_end)) - return tuple(merged) - - def _stage_cost_ms( *, pair_count: int, @@ -1521,31 +1032,6 @@ def _evaluate_plan( } -def _evaluate_plan_for_search( - *, - chunk_ranges: tuple[TokenRange, ...], - pair_matrix: list[list[int]] | torch.Tensor, - owners: tuple[int, ...], - wave_assignment: tuple[int, ...], - cp_size: int, - config: ContextParallelConfig, - pair_positive: torch.Tensor | None = None, - chunk_lengths: tuple[int, ...] | None = None, - chunk_lengths_tensor: torch.Tensor | None = None, -) -> dict[str, Any]: - return _evaluate_plan( - chunk_ranges=chunk_ranges, - pair_matrix=pair_matrix, - owners=owners, - wave_assignment=wave_assignment, - cp_size=cp_size, - config=config, - pair_positive=pair_positive, - chunk_lengths=chunk_lengths, - chunk_lengths_tensor=chunk_lengths_tensor, - ) - - def _search_chunk_assignment( *, chunk_ranges: tuple[TokenRange, ...], @@ -1554,7 +1040,6 @@ def _search_chunk_assignment( cp_size: int, config: ContextParallelConfig, ) -> tuple[tuple[int, ...], tuple[int, ...], dict[str, Any]]: - cp_size = int(cp_size) config = _search_config_for_chunk_count( config=config, chunk_count=len(chunk_ranges), @@ -1563,9 +1048,7 @@ def _search_chunk_assignment( 1, min(int(config.planner_max_remote_waves), len(chunk_ranges)) + 1, ) - best_owners: tuple[int, ...] = tuple() - best_waves: tuple[int, ...] = tuple() - best_eval: dict[str, Any] | None = None + best: tuple[tuple[int, ...], tuple[int, ...], dict[str, Any]] | None = None eval_cache: dict[tuple[tuple[int, ...], tuple[int, ...]], dict[str, Any]] = {} pair_counts = torch.as_tensor(pair_matrix, dtype=torch.int64) pair_positive = pair_counts > 0 @@ -1585,7 +1068,7 @@ def _evaluate_candidate( cached = eval_cache.get(cache_key) if cached is not None: return cached - cached = _evaluate_plan_for_search( + cached = _evaluate_plan( chunk_ranges=chunk_ranges, pair_matrix=pair_counts, owners=owners, @@ -1599,86 +1082,35 @@ def _evaluate_candidate( eval_cache[cache_key] = cached return cached - def _best_wave_assignment_for_owners( - owners: tuple[int, ...], - ) -> tuple[tuple[int, ...], dict[str, Any]]: - best_wave_assignment = tuple() - best_eval_local: dict[str, Any] | None = None - for wave_count in wave_count_candidates: - wave_assignment = _wave_assignment( - chunk_count=len(chunk_ranges), - wave_count=wave_count, - ) - candidate_eval = _evaluate_candidate( - owners=owners, - wave_assignment=wave_assignment, - ) - if best_eval_local is None or float(candidate_eval["score"]) + 1e-9 < float( - best_eval_local["score"] - ): - best_wave_assignment = wave_assignment - best_eval_local = candidate_eval - if best_eval_local is None: - raise RuntimeError("Failed to evaluate any wave assignment candidate.") - return best_wave_assignment, best_eval_local - - strategy = str(config.planner_assignment_strategy).strip().lower() - striped_owners = _striped_chunk_assignment( - chunk_count=len(chunk_ranges), - cp_size=cp_size, - group_size=int(config.planner_stripe_group_size), - ) - fixed_owners_by_strategy = { - "contiguous": _contiguous_chunk_assignment( - q_weights=q_weights, cp_size=cp_size - ), - "bucket": _bucket_chunk_assignment(q_weights=q_weights, cp_size=cp_size), - "striped": striped_owners, - } - if strategy in fixed_owners_by_strategy: - owners = fixed_owners_by_strategy[strategy] - best_waves, best_eval = _best_wave_assignment_for_owners(owners) - return owners, best_waves, best_eval - if strategy not in {"search", "search_with_striped_seed"}: - raise ValueError( - "Unsupported planner_assignment_strategy=" - f"{config.planner_assignment_strategy!r}." - ) - contiguous_owners = _contiguous_chunk_assignment( q_weights=q_weights, cp_size=cp_size, ) + if not contiguous_owners: + wave_assignment = _wave_assignment(chunk_count=len(chunk_ranges), wave_count=1) + return ( + contiguous_owners, + wave_assignment, + _evaluate_candidate( + owners=contiguous_owners, + wave_assignment=wave_assignment, + ), + ) + for wave_count in wave_count_candidates: wave_assignment = _wave_assignment( chunk_count=len(chunk_ranges), wave_count=wave_count, ) - initial_candidates = [ - initial_owners - for initial_owners in (contiguous_owners,) - if initial_owners - if _assignment_uses_all_ranks(initial_owners, cp_size=cp_size) - ] - if not initial_candidates: - continue - current_owners = min( - initial_candidates, - key=lambda owners: float( - _evaluate_candidate(owners=owners, wave_assignment=wave_assignment)[ - "score" - ] - ), - ) + current_owners = contiguous_owners current_eval = _evaluate_candidate( owners=current_owners, wave_assignment=wave_assignment, ) - if cp_size >= 8: - search_steps_remaining = 0 - else: - search_steps_remaining = int(config.planner_max_search_steps) + search_steps_remaining = ( + 0 if cp_size >= 8 else int(config.planner_max_search_steps) + ) if cp_size == 4 and search_steps_remaining > 0: probe_move = _best_improving_move( current_owners=current_owners, @@ -1716,27 +1148,13 @@ def _best_wave_assignment_for_owners( break current_owners, current_eval = best_move - if best_eval is None or float(current_eval["score"]) + 1e-9 < float( - best_eval["score"] + if best is None or float(current_eval["score"]) + 1e-9 < float( + best[2]["score"] ): - best_owners = current_owners - best_waves = wave_assignment - best_eval = current_eval - - if best_eval is None: - best_owners = _contiguous_chunk_assignment(q_weights=q_weights, cp_size=cp_size) - best_waves = _wave_assignment(chunk_count=len(chunk_ranges), wave_count=1) - best_eval = _evaluate_candidate( - owners=best_owners, - wave_assignment=best_waves, - ) - return best_owners, best_waves, best_eval - - -def _concatenate_peer_ranges( - ranges_by_peer: list[tuple[TokenRange, ...]] | tuple[tuple[TokenRange, ...], ...], -) -> tuple[tuple[TokenRange, ...], ...]: - return tuple(tuple(ranges) for ranges in ranges_by_peer) + best = (current_owners, wave_assignment, current_eval) + if best is None: + raise RuntimeError("Failed to evaluate any CP planner wave assignment.") + return best def _flatten_ranges_by_peer( @@ -1956,16 +1374,8 @@ def _build_rank_runtime_plan( _remap_subrange(range_, host_local_ranges) for range_ in local_global_k_ranges ), - kv_fetch_plan=KvFetchPlan( - send_splits=tuple(0 for _ in range(cp_size)), - recv_splits=tuple(0 for _ in range(cp_size)), - send_ranges_by_peer=tuple(tuple() for _ in range(cp_size)), - ), - dkv_reduce_plan=DkvReducePlan( - send_splits=tuple(0 for _ in range(cp_size)), - recv_splits=tuple(0 for _ in range(cp_size)), - recv_ranges_by_peer=tuple(tuple() for _ in range(cp_size)), - ), + kv_fetch_plan=None, + dkv_reduce_plan=None, remote_buffer_range=None, block_size=block_size, ) @@ -2063,14 +1473,8 @@ def _build_rank_runtime_plan( token_layout_index=token_layout_index, local_valid_lengths=(local_token_count,), local_row_ranges=local_row_ranges, - local_token_count=local_token_count, stage_plans=tuple(stage_plans), backward_stage_indices=tuple(backward_stage_indices + [0]), - remote_kv_fetch_plan=KvFetchPlan( - send_splits=aggregate_send_splits, - recv_splits=tuple(aggregate_recv_splits), - send_ranges_by_peer=aggregate_send_ranges, - ), remote_dkv_reduce_plan=DkvReducePlan( send_splits=tuple(aggregate_recv_splits), recv_splits=aggregate_send_splits, @@ -2079,58 +1483,6 @@ def _build_rank_runtime_plan( ) -def make_runtime_key( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, -) -> ContextParallelRuntimeKey: - if len(spec.rows) != 1: - raise RuntimeError( - "ART context parallel runtime keys expect exactly one packed sequence, " - f"got {len(spec.rows)} rows." - ) - row_signatures = tuple(_row_signature(row) for row in spec.rows) - return ContextParallelRuntimeKey( - topology=topology, - config=config, - row_signatures=row_signatures, - ) - - -def build_context_parallel_token_layout_index( - *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - topology: ParallelTopology, - config: ContextParallelConfig, - original_seq_len: int, -) -> TokenLayoutIndex: - """Return the token ownership chosen by the real CP attention planner.""" - - spec = build_shared_prefix_attention_spec( - group_ids=group_ids, parent_ids=parent_ids - ) - if int(topology.cp) <= 1: - valid_tokens = int(spec.rows[0].valid_tokens) if spec.rows else 0 - return TokenLayoutIndex( - ownership_ranges_by_rank=(((0, valid_tokens, 0),) if valid_tokens else (),), - token_counts_by_rank=(valid_tokens,), - ) - runtime_config = _config_for_runtime_cp(topology=topology, config=config) - _row_spec, chunk_ranges, owners, _wave_assignment = _runtime_plan_assignment( - spec, - topology=topology, - config=runtime_config, - ) - del original_seq_len - return _build_runtime_token_layout_index( - chunk_ranges=chunk_ranges, - owners=owners, - cp_size=max(int(topology.cp), 1), - ) - - def prepare_cp_micro( *, micro: PackedTensors, @@ -2172,7 +1524,7 @@ def prepare_cp_micro( ref_logprobs=ref_logprobs, ) if tensors.token_uids is not None: - state = state.model_copy(update={"trace_token_uids": tensors.token_uids}) + state = replace(state, trace_token_uids=tensors.token_uids) if prepare_execution_state: from .executor import prepare_context_parallel_execution_state @@ -2222,12 +1574,11 @@ def prepare_megatron_context_parallel_state( ) group_ids_cpu = _planning_metadata_cpu(micro["group_ids"]) parent_ids_cpu = _planning_metadata_cpu(micro["parent_ids"]) - runtime_config = _config_for_runtime_cp(topology=topology, config=config) planning_key = _planning_bundle_cache_key( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, topology=topology, - config=runtime_config, + config=config, original_seq_len=int(micro["tokens"].shape[1]), build_gdn_execution_spec=build_gdn_execution_spec, ) @@ -2237,12 +1588,10 @@ def prepare_megatron_context_parallel_state( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, ) - runtime_key = make_runtime_key(spec, topology=topology, config=runtime_config) runtime_plan = get_or_build_runtime_plan( spec, topology=topology, - config=runtime_config, - runtime_key=runtime_key, + config=config, original_seq_len=int(micro["tokens"].shape[1]), ) gdn_execution_spec = None @@ -2252,18 +1601,15 @@ def prepare_megatron_context_parallel_state( ) gdn_execution_spec = parse_gdn_shared_prefix_segments( - group_ids_cpu, - parent_ids_cpu, - min_completions_per_family=0, + group_ids_cpu, parent_ids_cpu ) bundle = _PlanningBundle( spec=spec, - runtime_key=runtime_key, - runtime_plan=runtime_plan, + rank_plans=runtime_plan, gdn_execution_spec=gdn_execution_spec, ) _cache_put(_PLANNING_BUNDLE_CACHE, planning_key, bundle) - rank_plan = bundle.runtime_plan.rank_plans[int(cp_rank)] + rank_plan = bundle.rank_plans[int(cp_rank)] gdn_execution_plan = None if build_gdn_execution_spec: if bundle.gdn_execution_spec is None: @@ -2271,10 +1617,11 @@ def prepare_megatron_context_parallel_state( gdn_plan_device = ( target_device if target_device is not None else micro["tokens"].device ) - rank_gdn_key = _rank_plan_cache_key( - planning_key=planning_key, - device=gdn_plan_device, - cp_rank=int(cp_rank), + rank_gdn_key = ( + planning_key, + gdn_plan_device.type, + gdn_plan_device.index, + int(cp_rank), ) gdn_execution_plan = _GDN_RANK_PLAN_CACHE.get(rank_gdn_key) if gdn_execution_plan is None: @@ -2290,22 +1637,15 @@ def prepare_megatron_context_parallel_state( attention_token_layout_index=rank_plan.token_layout_index, ) _cache_put(_GDN_RANK_PLAN_CACHE, rank_gdn_key, gdn_execution_plan) - planner_provenance = _planner_provenance( - topology=topology, - config=runtime_config, - warn=int(cp_rank) == 0, - ) pad_multiple = int(topology.tp) if bool(topology.sp) and int(topology.tp) > 1 else 1 state = ArtContextParallelState( - runtime_key=bundle.runtime_key, rank_plan=rank_plan, cp_group=cp_group, - config=runtime_config, + config=config, group_ids=group_ids_cpu[0].contiguous(), parent_ids=parent_ids_cpu[0].contiguous(), gdn_execution_spec=bundle.gdn_execution_spec, gdn_execution_plan=gdn_execution_plan, - planner_provenance=planner_provenance, trace_token_uids=None, ) return state, rank_plan, bundle.spec, pad_multiple @@ -2353,111 +1693,43 @@ def dispatch_megatron_context_parallel_training_tensors( if trace_token_uids else None ) - local_tokens = _dispatch_tensor( - micro["tokens"], - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_labels = _dispatch_tensor( - labels, - rank_plan=rank_plan, - pad_value=-100, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_input_pos = _dispatch_tensor( - micro["input_pos"], - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_assistant_mask = _dispatch_tensor( - assistant_mask, - rank_plan=rank_plan, - pad_value=False, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ).to(dtype=torch.bool) - local_group_ids = _dispatch_tensor( - shifted_group_ids, - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_old_logprobs = _dispatch_tensor( - old_logprobs, - rank_plan=rank_plan, - pad_value=float("nan"), - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_original_logprobs = ( - None - if original_logprobs is None - else _dispatch_tensor( - original_logprobs, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - ) - local_ref_logprobs = ( - None - if ref_logprobs is None - else _dispatch_tensor( - ref_logprobs, + + def dispatch( + tensor: torch.Tensor, + pad_value: int | float | bool, + *, + move_to_target: bool = True, + ) -> torch.Tensor: + local = _dispatch_tensor( + tensor, rank_plan=rank_plan, - pad_value=float("nan"), + pad_value=pad_value, pad_multiple=pad_multiple, dispatch_meta_cache=dispatch_meta_cache, ) - ) - local_advantages = _dispatch_tensor( - advantages, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_weights = _dispatch_tensor( - weights, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) + return _to_target_device(local, target_device) if move_to_target else local + + def maybe_dispatch( + tensor: torch.Tensor | None, + pad_value: int | float | bool, + ) -> torch.Tensor | None: + return None if tensor is None else dispatch(tensor, pad_value) + local_token_uids = ( - None - if token_uids is None - else _dispatch_tensor( - token_uids, - rank_plan=rank_plan, - pad_value=-1, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) + None if token_uids is None else dispatch(token_uids, -1, move_to_target=False) ) return DispatchedPackedTensors( - tokens=_to_target_device(local_tokens, target_device), - labels=_to_target_device(local_labels, target_device), - input_pos=_to_target_device(local_input_pos, target_device), - assistant_mask=_to_target_device(local_assistant_mask, target_device), - group_ids=_to_target_device(local_group_ids, target_device), - old_logprobs=_to_target_device(local_old_logprobs, target_device), - advantages=_to_target_device(local_advantages, target_device), - weights=_to_target_device(local_weights, target_device), + tokens=dispatch(micro["tokens"], 0), + labels=dispatch(labels, -100), + input_pos=dispatch(micro["input_pos"], 0), + assistant_mask=dispatch(assistant_mask, False).to(dtype=torch.bool), + group_ids=dispatch(shifted_group_ids, 0), + old_logprobs=dispatch(old_logprobs, float("nan")), + advantages=dispatch(advantages, 0.0), + weights=dispatch(weights, 0.0), valid_lengths=rank_plan.local_valid_lengths, - original_logprobs=None - if local_original_logprobs is None - else _to_target_device(local_original_logprobs, target_device), - ref_logprobs=None - if local_ref_logprobs is None - else _to_target_device(local_ref_logprobs, target_device), + original_logprobs=maybe_dispatch(original_logprobs, 0.0), + ref_logprobs=maybe_dispatch(ref_logprobs, float("nan")), loss_all_reduce_group=cp_group, token_uids=None if local_token_uids is None else local_token_uids.contiguous(), ) @@ -2468,12 +1740,13 @@ def get_or_build_runtime_plan( *, topology: ParallelTopology, config: ContextParallelConfig, - runtime_key: ContextParallelRuntimeKey, original_seq_len: int, -) -> ContextParallelRuntimePlan: - key = ( - _json_cache_key(runtime_key.model_dump(mode="json")), - int(original_seq_len), +) -> tuple[RankRuntimePlan, ...]: + key = _runtime_plan_cache_key( + spec, + topology=topology, + config=config, + original_seq_len=original_seq_len, ) cached = _RUNTIME_PLAN_CACHE.get(key) if cached is not None: @@ -2488,25 +1761,6 @@ def get_or_build_runtime_plan( return plan -def get_or_build_rank_runtime_plan( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, - runtime_key: ContextParallelRuntimeKey, - original_seq_len: int, - target_rank: int, -) -> RankRuntimePlan: - del runtime_key - return _build_rank_runtime_plan_for_spec( - spec, - topology=topology, - config=config, - original_seq_len=original_seq_len, - target_rank=target_rank, - ) - - def _runtime_plan_assignment( spec: PackedBatchAttentionSpec, *, @@ -2552,45 +1806,13 @@ def _runtime_plan_assignment( return row_spec, chunk_ranges, owners, wave_assignment -def _build_rank_runtime_plan_for_spec( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, - original_seq_len: int, - target_rank: int, -) -> RankRuntimePlan: - row_spec, chunk_ranges, owners, wave_assignment = _runtime_plan_assignment( - spec, - topology=topology, - config=config, - ) - cp_size = max(int(topology.cp), 1) - token_layout_index = _build_runtime_token_layout_index( - chunk_ranges=chunk_ranges, - owners=owners, - cp_size=cp_size, - ) - return _build_rank_runtime_plan( - row_spec=row_spec, - chunk_ranges=chunk_ranges, - owners=owners, - wave_assignment=wave_assignment, - token_layout_index=token_layout_index, - cp_size=cp_size, - original_seq_len=original_seq_len, - target_rank=int(target_rank), - block_size=int(config.block_size), - ) - - def _build_runtime_plan( spec: PackedBatchAttentionSpec, *, topology: ParallelTopology, config: ContextParallelConfig, original_seq_len: int, -) -> ContextParallelRuntimePlan: +) -> tuple[RankRuntimePlan, ...]: row_spec, chunk_ranges, owners, wave_assignment = _runtime_plan_assignment( spec, topology=topology, @@ -2602,7 +1824,7 @@ def _build_runtime_plan( owners=owners, cp_size=cp_size, ) - rank_plans = [ + return tuple( _build_rank_runtime_plan( row_spec=row_spec, chunk_ranges=chunk_ranges, @@ -2615,12 +1837,6 @@ def _build_runtime_plan( block_size=int(config.block_size), ) for rank in range(cp_size) - ] - return ContextParallelRuntimePlan( - topology=topology, - config=config, - token_layout_index=token_layout_index, - rank_plans=tuple(rank_plans), ) @@ -2648,11 +1864,42 @@ def _build_runtime_token_layout_index( def _row_signature(row_spec: PackedRowAttentionSpec) -> str: payload = { "valid_tokens": row_spec.valid_tokens, - "slices": [slice_.model_dump(mode="json") for slice_ in row_spec.slices], + "slices": [_attn_slice_payload(slice_) for slice_ in row_spec.slices], } return json.dumps(payload, sort_keys=True) +def _runtime_plan_cache_key( + spec: PackedBatchAttentionSpec, + *, + topology: ParallelTopology, + config: ContextParallelConfig, + original_seq_len: int, +) -> str: + return _json_cache_key( + { + "topology": _dataclass_payload(topology), + "config": _dataclass_payload(config), + "row_signatures": tuple(_row_signature(row) for row in spec.rows), + "original_seq_len": int(original_seq_len), + } + ) + + +def _dataclass_payload(value: Any) -> dict[str, Any]: + return dict(value.__dict__) + + +def _attn_slice_payload(slice_: AttnSlice) -> dict[str, Any]: + return { + "q_range": _dataclass_payload(slice_.q_range), + "k_range": _dataclass_payload(slice_.k_range), + "mask_kind": slice_.mask_kind.value, + "row_index": slice_.row_index, + "family_index": slice_.family_index, + } + + def _range_key(range_: TokenRange) -> tuple[int, int]: return (int(range_.start), int(range_.end)) @@ -2694,101 +1941,6 @@ def _set_stage_token_indices( current_indices.copy_(source_indices) -def _token_costs(row_spec: PackedRowAttentionSpec) -> list[float]: - costs = [0.0] * row_spec.valid_tokens - for slice_ in row_spec.slices: - q_range = slice_.q_range - k_range = slice_.k_range - if slice_.mask_kind is AttnMaskKind.FULL: - cost = float(k_range.size()) - for q_idx in range(q_range.start, q_range.end): - costs[q_idx] += cost - continue - if q_range.size() != k_range.size(): - raise RuntimeError( - "The current planner only supports causal slices with matched q/k sizes, got " - f"{q_range} vs {k_range}" - ) - for q_idx in range(q_range.start, q_range.end): - costs[q_idx] += float(q_idx - q_range.start + 1) - return costs - - -def _split_row_by_cost( - row_spec: PackedRowAttentionSpec, - *, - cp_size: int, - block_size: int, -) -> tuple[TokenRange | None, ...]: - if cp_size == 1: - return (TokenRange(start=0, end=row_spec.valid_tokens),) - if row_spec.valid_tokens == 0: - return tuple(None for _ in range(cp_size)) - - costs = _token_costs(row_spec) - prefix = [0.0] - for cost in costs: - prefix.append(prefix[-1] + cost) - total_cost = prefix[-1] - boundaries = [0] - block_aligned_split = int(block_size) > 1 and row_spec.valid_tokens >= ( - cp_size * int(block_size) - ) - for split_index in range(1, cp_size): - remaining_ranks = cp_size - split_index - min_boundary = boundaries[-1] - max_boundary = row_spec.valid_tokens - remaining_ranks - if max_boundary <= min_boundary: - boundaries.append(min_boundary) - continue - target = ( - total_cost * split_index / cp_size - if total_cost > 0.0 - else row_spec.valid_tokens * split_index / cp_size - ) - best_boundary = min_boundary + 1 - best_error = float("inf") - candidate_boundaries = range(min_boundary + 1, max_boundary + 1) - if block_aligned_split: - aligned_start = ( - (min_boundary + 1 + block_size - 1) // block_size - ) * block_size - aligned_end = (max_boundary // block_size) * block_size - if aligned_start <= aligned_end: - candidate_boundaries = range(aligned_start, aligned_end + 1, block_size) - for boundary in candidate_boundaries: - current = prefix[boundary] if total_cost > 0.0 else float(boundary) - error = abs(current - target) - if error < best_error: - best_error = error - best_boundary = boundary - boundaries.append(best_boundary) - boundaries.append(row_spec.valid_tokens) - - ranges: list[TokenRange | None] = [] - for start, end in zip(boundaries[:-1], boundaries[1:]): - if end <= start: - ranges.append(None) - else: - ranges.append(TokenRange(start=start, end=end)) - return tuple(ranges) - - -def _intersections( - base_range: TokenRange, - owner_ranges: tuple[TokenRange | None, ...], -) -> list[tuple[int, TokenRange]]: - intersections: list[tuple[int, TokenRange]] = [] - for rank, owner_range in enumerate(owner_ranges): - if owner_range is None: - continue - start = max(base_range.start, owner_range.start) - end = min(base_range.end, owner_range.end) - if end > start: - intersections.append((rank, TokenRange(start=start, end=end))) - return intersections - - def _resolve_stage_mask_kind( *, mask_kind: AttnMaskKind, diff --git a/src/art/megatron/context_parallel/types.py b/src/art/megatron/context_parallel/types.py index 5cc874d09..bf52ffddc 100644 --- a/src/art/megatron/context_parallel/types.py +++ b/src/art/megatron/context_parallel/types.py @@ -1,10 +1,11 @@ from __future__ import annotations +from dataclasses import dataclass, field from enum import Enum from typing import Any from megatron.core.packed_seq_params import PackedSeqParams -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict import torch from .layout_index import TokenLayoutIndex @@ -16,22 +17,17 @@ class AttnMaskKind(str, Enum): CAUSAL = "causal" -class TokenRange(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class TokenRange: start: int end: int def size(self) -> int: return self.end - self.start - def is_empty(self) -> bool: - return self.end <= self.start - - -class AttnSlice(BaseModel): - model_config = ConfigDict(frozen=True) +@dataclass(frozen=True) +class AttnSlice: q_range: TokenRange k_range: TokenRange mask_kind: AttnMaskKind @@ -39,68 +35,25 @@ class AttnSlice(BaseModel): family_index: int | None = None -class PackedRowAttentionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class PackedRowAttentionSpec: row_index: int valid_tokens: int slices: tuple[AttnSlice, ...] -class PackedBatchAttentionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class PackedBatchAttentionSpec: rows: tuple[PackedRowAttentionSpec, ...] -class SharedPrefixBuilderConfig(BaseModel): - model_config = ConfigDict(frozen=True) - - ignore_padding_group_id: int = -1 - require_contiguous_group_runs: bool = True - - -class PlannerCpOverride(BaseModel): - model_config = ConfigDict(frozen=True) - - cp_size: int - block_size: int | None = None - planner_chunk_size: int | None = None - planner_chunk_budget_base: int | None = None - planner_chunk_budget_per_cp_rank: int | None = None - planner_assignment_strategy: str | None = None - planner_stripe_group_size: int | None = None - planner_max_search_steps: int | None = None - planner_candidate_chunk_limit: int | None = None - planner_max_remote_waves: int | None = None - planner_stage_overhead_ms: float | None = None - planner_comm_stage_overhead_ms: float | None = None - planner_interval_overhead_ms: float | None = None - planner_merge_q_token_ms: float | None = None - planner_fetch_token_ms: float | None = None - planner_reduce_token_ms: float | None = None - planner_local_pair_ms: float | None = None - planner_remote_pair_ms: float | None = None - planner_local_backward_pair_ms: float | None = None - planner_remote_backward_pair_ms: float | None = None - planner_remote_stage_token_floor: int | None = None - planner_remote_stage_pair_floor: int | None = None - planner_remote_stage_underfill_ms: float | None = None - planner_tuned_backend: str | None = None - planner_tuned_hardware: str | None = None - planner_tuned_cp_sizes: tuple[int, ...] | None = None - - -class ContextParallelConfig(BaseModel): - model_config = ConfigDict(frozen=True, extra="forbid") - +@dataclass(frozen=True) +class ContextParallelConfig: block_size: int = 128 attention_sparse_block_size: tuple[int, int] | None = None planner_chunk_size: int = 512 planner_chunk_budget_base: int = 128 planner_chunk_budget_per_cp_rank: int = 16 - planner_assignment_strategy: str = "search" - planner_stripe_group_size: int = 16 planner_max_search_steps: int = 8 planner_candidate_chunk_limit: int = 8 planner_max_remote_waves: int = 4 @@ -117,15 +70,10 @@ class ContextParallelConfig(BaseModel): planner_remote_stage_token_floor: int = 4096 planner_remote_stage_pair_floor: int = 4_000_000 planner_remote_stage_underfill_ms: float = 0.287151 - planner_tuned_backend: str | None = "art_context_parallel" - planner_tuned_hardware: str | None = "NVIDIA H200" - planner_tuned_cp_sizes: tuple[int, ...] = (2,) - planner_cp_overrides: tuple[PlannerCpOverride, ...] = () -class ParallelTopology(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class ParallelTopology: tp: int = 1 cp: int = 1 dp: int = 1 @@ -133,73 +81,50 @@ class ParallelTopology(BaseModel): sp: bool = False -class ContextParallelRuntimeKey(BaseModel): - model_config = ConfigDict(frozen=True) - - topology: ParallelTopology - config: ContextParallelConfig - row_signatures: tuple[str, ...] - - -class KvFetchPlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class KvFetchPlan: send_splits: tuple[int, ...] recv_splits: tuple[int, ...] send_ranges_by_peer: tuple[tuple[TokenRange, ...], ...] -class DkvReducePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class DkvReducePlan: send_splits: tuple[int, ...] recv_splits: tuple[int, ...] recv_ranges_by_peer: tuple[tuple[TokenRange, ...], ...] -class StagePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class StagePlan: stage_index: int source_rank: int - source_ranks: tuple[int, ...] = () is_local_stage: bool - wave_index: int | None = None slices: tuple[AttnSlice, ...] - global_q_ranges: tuple[TokenRange, ...] = () - global_k_ranges: tuple[TokenRange, ...] = () owner_local_q_ranges: tuple[TokenRange, ...] owner_local_k_ranges: tuple[TokenRange, ...] - mask_metadata: "ExactMaskMetadata | None" = None - remote_buffer_range: TokenRange | None = None q_len: int k_len: int + source_ranks: tuple[int, ...] = () + wave_index: int | None = None + global_q_ranges: tuple[TokenRange, ...] = () + global_k_ranges: tuple[TokenRange, ...] = () + mask_metadata: "ExactMaskMetadata | None" = None + remote_buffer_range: TokenRange | None = None kv_fetch_plan: KvFetchPlan | None = None dkv_reduce_plan: DkvReducePlan | None = None -class RankRuntimePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class RankRuntimePlan: rank: int original_seq_len: int token_layout_index: TokenLayoutIndex local_valid_lengths: tuple[int, ...] local_row_ranges: tuple[TokenRange | None, ...] - local_token_count: int stage_plans: tuple[StagePlan, ...] - backward_stage_indices: tuple[int, ...] = () - remote_kv_fetch_plan: KvFetchPlan remote_dkv_reduce_plan: DkvReducePlan - - -class ContextParallelRuntimePlan(BaseModel): - model_config = ConfigDict(frozen=True) - - topology: ParallelTopology - config: ContextParallelConfig - token_layout_index: TokenLayoutIndex - rank_plans: tuple[RankRuntimePlan, ...] + backward_stage_indices: tuple[int, ...] = () class DispatchedPackedTensors(ContextParallelLossInputs): @@ -220,47 +145,27 @@ class DispatchedPackedTensors(ContextParallelLossInputs): token_uids: torch.Tensor | None = None -class ContextParallelExecutionCache(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - block_masks: dict[Any, Any] = Field(default_factory=dict) - range_indices: dict[Any, torch.Tensor] = Field(default_factory=dict) - range_meta: dict[Any, tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = Field( +@dataclass +class ContextParallelExecutionCache: + block_mask_context: Any | None = None + block_masks: dict[Any, Any] = field(default_factory=dict) + range_indices: dict[Any, torch.Tensor] = field(default_factory=dict) + range_meta: dict[Any, tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = field( default_factory=dict ) - stage_execution_specs: dict[Any, "StageExecutionSpec"] = Field(default_factory=dict) + stage_execution_specs: dict[Any, "StageExecutionSpec"] = field(default_factory=dict) -class StageExecutionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class StageExecutionSpec: q_len: int k_len: int compile_key: str mask_metadata: "ExactMaskMetadata | None" = None -class PlannerProvenance(BaseModel): - model_config = ConfigDict(frozen=True) - - runtime_backend: str - runtime_hardware: str | None = None - runtime_cp_size: int - tuned_backend: str | None = None - tuned_hardware: str | None = None - tuned_cp_sizes: tuple[int, ...] = () - backend_match: bool - hardware_match: bool - cp_size_match: bool - using_best_effort: bool - warning_message: str | None = None - warning_emitted: bool = False - - -class ArtContextParallelState(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - runtime_key: ContextParallelRuntimeKey +@dataclass +class ArtContextParallelState: rank_plan: RankRuntimePlan cp_group: Any config: ContextParallelConfig @@ -272,31 +177,28 @@ class ArtContextParallelState(BaseModel): gdn_input_layout: str | None = None gdn_output_layout: str | None = None gdn_attention_original_shape: tuple[int, int, int] | None = None - gdn_attention_original_shapes: dict[int, tuple[int, int, int]] = Field( + gdn_attention_original_shapes: dict[int, tuple[int, int, int]] = field( default_factory=dict ) gdn_attention_token_uids: torch.Tensor | None = None gdn_active_module: Any | None = None - planner_provenance: PlannerProvenance trace_token_uids: torch.Tensor | None = None - execution_cache: ContextParallelExecutionCache = Field( + execution_cache: ContextParallelExecutionCache = field( default_factory=ContextParallelExecutionCache ) -class PreparedMegatronBatch(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +@dataclass +class PreparedMegatronBatch: tensors: DispatchedPackedTensors - packed_seq_params: PackedSeqParams | None = None attention_state: Any + packed_seq_params: PackedSeqParams | None = None rank_plan: RankRuntimePlan | None = None pad_multiple: int = 1 -class FlexMaskSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class FlexMaskSpec: q_len: int k_len: int block_size: int | tuple[int, int] @@ -304,9 +206,8 @@ class FlexMaskSpec(BaseModel): exact_mask: "ExactMaskMetadata" -class ExactMaskMetadata(BaseModel): - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - +@dataclass(frozen=True) +class ExactMaskMetadata: q_token_indices: torch.Tensor k_token_indices: torch.Tensor cache_key: str diff --git a/src/art/megatron/gdn/__init__.py b/src/art/megatron/gdn/__init__.py index cd3a0873a..1dc629403 100644 --- a/src/art/megatron/gdn/__init__.py +++ b/src/art/megatron/gdn/__init__.py @@ -3,12 +3,10 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnPackedFamilySpec, GdnPlannerConfig, GdnRankExecutionPlan, GdnSegmentBucketPlan, GdnSegmentSpec, - build_gdn_cp_segment_schedule, build_gdn_rank_execution_plan, move_gdn_rank_execution_plan_to_device, parse_gdn_shared_prefix_segments, @@ -19,12 +17,10 @@ __all__ = [ "chunk_gated_delta_rule_native_cp", "GdnPackedExecutionSpec", - "GdnPackedFamilySpec", "GdnPlannerConfig", "GdnRankExecutionPlan", "GdnSegmentSpec", "GdnSegmentBucketPlan", - "build_gdn_cp_segment_schedule", "build_gdn_rank_execution_plan", "exchange_rank_tensor_all_to_all", "move_gdn_rank_execution_plan_to_device", diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index 3fb693891..f4bc02ba7 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -1,152 +1,76 @@ from __future__ import annotations from bisect import bisect_left -from typing import Any, Literal, TypeVar +from dataclasses import dataclass, replace +from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field import torch from art.megatron.context_parallel.layout_index import TokenLayoutIndex +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree GdnSegmentKind = Literal["prefix", "completion"] -GdnSegmentDecisionKey = tuple[int, int, int] # FLA's public chunk_gated_delta_rule hard-codes 64-token WY chunks. FLA_CHUNK_SIZE = 64 -_PydanticModelT = TypeVar("_PydanticModelT", bound=BaseModel) -class GdnSegmentSpec(BaseModel): +@dataclass(frozen=True) +class GdnSegmentSpec: """Contiguous logical GDN segment in one packed row.""" - model_config = ConfigDict(frozen=True) - - row_index: int = Field(ge=0) - family_index: int = Field(ge=0) + row_index: int + family_index: int group_id: int parent_id: int - start: int = Field(ge=0) - end: int = Field(ge=1) + start: int + end: int kind: GdnSegmentKind - child_index: int | None = Field(default=None, ge=0) + child_index: int | None = None @property def length(self) -> int: return self.end - self.start - def linear_indices(self, sequence_length: int) -> tuple[int, ...]: - base = self.row_index * sequence_length - return tuple(range(base + self.start, base + self.end)) - - -class GdnPackedFamilySpec(BaseModel): - """One shared-prefix family plus child completion segments.""" - - model_config = ConfigDict(frozen=True) - - row_index: int = Field(ge=0) - family_index: int = Field(ge=0) - prefix: GdnSegmentSpec - completions: tuple[GdnSegmentSpec, ...] - - @property - def completion_count(self) -> int: - return len(self.completions) - - @property - def token_count(self) -> int: - return self.prefix.length + sum(segment.length for segment in self.completions) - -class GdnPackedExecutionSpec(BaseModel): +@dataclass(frozen=True) +class GdnPackedExecutionSpec: """Parsed shared-prefix GDN execution metadata for a packed batch.""" - model_config = ConfigDict(frozen=True) - - batch_size: int = Field(ge=1) - sequence_length: int = Field(ge=1) + batch_size: int + sequence_length: int valid_lengths: tuple[int, ...] - families: tuple[GdnPackedFamilySpec, ...] + tree_segments: tuple[GdnSegmentSpec, ...] + tree_parent_indices: tuple[int, ...] + tree_depths: tuple[int, ...] @property def family_count(self) -> int: - return len(self.families) - - @property - def completion_count(self) -> int: - return sum(family.completion_count for family in self.families) + return len(self.tree_segments) @property def real_token_count(self) -> int: return sum(self.valid_lengths) - @property - def max_segment_length(self) -> int: - lengths = [ - segment.length - for family in self.families - for segment in (family.prefix, *family.completions) - ] - return max(lengths, default=0) - - def segments(self) -> tuple[GdnSegmentSpec, ...]: - return tuple( - segment - for family in self.families - for segment in (family.prefix, *family.completions) - ) - - -_GDN_SEGMENT_SPEC_FIELDS = frozenset( - { - "row_index", - "family_index", - "group_id", - "parent_id", - "start", - "end", - "kind", - "child_index", - } -) -_GDN_PACKED_FAMILY_SPEC_FIELDS = frozenset( - { - "row_index", - "family_index", - "prefix", - "completions", - } -) - - -def _trusted_pydantic_construct( - model_type: type[_PydanticModelT], - fields_set: frozenset[str], - **values: Any, -) -> _PydanticModelT: - model = model_type.__new__(model_type) - object.__setattr__(model, "__dict__", values) - object.__setattr__(model, "__pydantic_fields_set__", fields_set) - object.__setattr__(model, "__pydantic_extra__", None) - object.__setattr__(model, "__pydantic_private__", None) - return model - -class GdnSegmentBucketPlan(BaseModel): +@dataclass(frozen=True) +class GdnSegmentBucketPlan: """Device-local index tensors for a variable-length GDN segment batch.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - length: int = Field(ge=1) + length: int lengths: torch.Tensor lengths_cpu: torch.Tensor - lengths_by_rank_cpu: torch.Tensor | None = None real_mask: torch.Tensor cu_seqlens: torch.Tensor cu_seqlens_cpu: torch.Tensor row_indices: torch.Tensor position_indices: torch.Tensor family_indices: torch.Tensor - real_token_count_static: int = Field(ge=0) + real_token_count_static: int + lengths_by_rank_cpu: torch.Tensor | None = None + family_indices_cpu: torch.Tensor | None = None + parent_indices: torch.Tensor | None = None + parent_indices_cpu: torch.Tensor | None = None + needs_final_state: bool = True output_mask: torch.Tensor | None = None @property @@ -158,89 +82,54 @@ def real_token_count(self) -> int: return self.real_token_count_static -class GdnParentStateTransferPlan(BaseModel): - """Prefix-state rows transferred from one CP rank to another.""" - - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) +@dataclass(frozen=True) +class GdnStateExchangePlan: + """Sparse CP exchange for tree parent states needed by remote children.""" - source_rank: int = Field(ge=0) - dest_rank: int = Field(ge=0) - family_indices: tuple[int, ...] - family_indices_tensor: torch.Tensor | None = None + source_family_indices: tuple[int, ...] + dest_family_indices: tuple[int, ...] + exchange: Any + reverse_exchange: Any -class GdnPlannerConfig(BaseModel): +@dataclass(frozen=True) +class GdnPlannerConfig: """Tunable cost coefficients for one packed-row GDN execution plan.""" - model_config = ConfigDict(frozen=True) - - max_padding_ratio: float = Field(default=2.0, gt=1.0) - max_segments_per_batch: int = Field(default=4096, ge=1) - cp_chain_min_tokens_per_rank: int = Field(default=32, ge=1) - cp_chain_min_total_tokens: int = Field(default=32768, ge=1) - cp_chain_min_prefix_only_tokens: int = Field(default=32768, ge=1) - local_fork_launch_penalty_tokens: int = Field(default=256, ge=0) - cp_collective_latency_tokens: int = Field(default=512, ge=0) - parent_state_exchange_penalty_tokens: int = Field(default=16384, ge=0) - layout_cross_rank_token_cost: float = Field(default=6.0, ge=0.0) - rank_idle_token_cost: float = Field(default=1.0, ge=0.0) - empty_rank_penalty_tokens: int = Field(default=65536, ge=0) - max_zero_exchange_load_imbalance: float = Field(default=1.5, ge=1.0) - local_completion_rebalance_min_imbalance: float = Field(default=1.08, ge=1.0) - cp_chain_beam_width: int = Field(default=2, ge=1) - cp_chain_beam_branch_factor: int = Field(default=4, ge=1) - cp_chain_beam_candidate_limit: int = Field(default=16, ge=1) - cp_chain_beam_max_steps: int = Field(default=4, ge=0) - cp_chain_beam_min_score_delta_tokens: float = Field(default=512.0, ge=0.0) - cp_chain_min_score_delta_ms: float = Field(default=0.25, ge=0.0) - planner_local_token_ms: float = Field(default=0.00065, ge=0.0) - planner_chain_token_ms: float = Field(default=0.00055, ge=0.0) - planner_local_bucket_ms: float = Field(default=0.25, ge=0.0) - planner_chain_bucket_ms: float = Field(default=22.0, ge=0.0) - planner_local_segment_ms: float = Field(default=0.010, ge=0.0) - planner_layout_cross_rank_token_ms: float = Field(default=0.00008, ge=0.0) - planner_parent_state_exchange_base_ms: float = Field(default=40.0, ge=0.0) - planner_parent_state_exchange_ms: float = Field(default=0.5, ge=0.0) - planner_empty_rank_ms: float = Field(default=32.0, ge=0.0) - - -class GdnRankExecutionPlan(BaseModel): + max_padding_ratio: float = 2.0 + max_segments_per_batch: int = 4096 + cp_chain_min_tokens_per_rank: int = 32 + cp_chain_min_total_tokens: int = 32768 + cp_chain_min_prefix_only_tokens: int = 32768 + cp_tree_chain_min_total_tokens: int = 8192 + cp_tree_chain_min_prefix_only_tokens: int = 8192 + rank_idle_token_cost: float = 1.0 + max_zero_exchange_load_imbalance: float = 1.5 + planner_local_token_ms: float = 0.00065 + planner_layout_cross_rank_token_ms: float = 0.00008 + planner_empty_rank_ms: float = 32.0 + + +@dataclass(frozen=True) +class GdnRankExecutionPlan: """Rank-local planned execution metadata for shared-prefix GDN.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - cp_rank: int = Field(ge=0) - cp_size: int = Field(ge=1) - batch_size: int = Field(ge=1) - sequence_length: int = Field(ge=0) - packed_batch_size: int | None = Field(default=None, ge=1) - packed_sequence_length: int | None = Field(default=None, ge=1) + cp_rank: int + cp_size: int + batch_size: int + sequence_length: int real_token_mask: torch.Tensor - family_count: int = Field(ge=0) - completion_count: int = Field(ge=0) - local_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - ready_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_table_is_dense_ordered: bool + packed_batch_size: int | None = None + packed_sequence_length: int | None = None attention_to_gdn: Any | None = None gdn_to_attention: Any | None = None attention_token_ranges: tuple[tuple[int, int, int], ...] = () gdn_token_ranges: tuple[tuple[int, int, int], ...] = () - attention_token_count: int = Field(default=0, ge=0) - gdn_token_count: int = Field(default=0, ge=0) - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - prefix_boundary_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_exchange: Any | None = None - remote_prefix_tail_backward_exchange: Any | None = None - remote_prefix_tail_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () + attention_token_count: int = 0 + gdn_token_count: int = 0 + tree_segment_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_chain_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_state_exchanges_by_depth: tuple[GdnStateExchangePlan | None, ...] = () @property def attention_token_indices(self) -> tuple[int, ...]: @@ -251,74 +140,12 @@ def gdn_token_indices(self) -> tuple[int, ...]: return _tokens_from_rank_ranges(self.gdn_token_ranges) -class GdnCpSegmentSchedule(BaseModel): - """CPU-side ownership and bucket schedule for one CP GDN plan.""" - - model_config = ConfigDict(frozen=True) - - gdn_token_counts_by_rank: tuple[int, ...] - gdn_token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] = () - cross_rank_token_count: int = Field(ge=0) - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - chain_completion_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - - -class _GdnCpSegmentSearchDecision(BaseModel): - model_config = ConfigDict(frozen=True) - - chain_segment_keys: frozenset[GdnSegmentDecisionKey] - co_locate_local_families: bool - score: float - - -class _ExplicitBucketColumn(BaseModel): - model_config = ConfigDict(frozen=True) - - row_index: int - family_index: int - positions: tuple[int, ...] - output_mask: tuple[bool, ...] - - @property - def length(self) -> int: - return len(self.positions) - - -def _explicit_bucket_column( - *, - row_index: int, - family_index: int, - positions: tuple[int, ...], - output_mask: tuple[bool, ...], -) -> _ExplicitBucketColumn: - return _ExplicitBucketColumn.model_construct( - row_index=row_index, - family_index=family_index, - positions=positions, - output_mask=output_mask, - ) - - -class _AttentionLayoutIndex(BaseModel): +@dataclass(frozen=True) +class _AttentionLayoutIndex: """Counting index for CP attention token ownership.""" - model_config = ConfigDict(frozen=True) - token_ranges_by_rank: tuple[tuple[tuple[int, int], ...], ...] token_range_ends_by_rank: tuple[tuple[int, ...], ...] - range_count: int = Field(ge=0) - - -def _layout_cp_size(layout: TokenLayoutIndex) -> int: - return len(layout.token_counts_by_rank) - - -def _layout_token_count(layout: TokenLayoutIndex) -> int: - return sum(int(count) for count in layout.token_counts_by_rank) def _tokens_from_rank_ranges( @@ -327,21 +154,6 @@ def _tokens_from_rank_ranges( return tuple(token for start, end, _ in ranges for token in range(start, end)) -def _token_layout_from_rank_ranges( - ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...], -) -> TokenLayoutIndex: - return TokenLayoutIndex( - ownership_ranges_by_rank=ranges_by_rank, - token_counts_by_rank=tuple( - _ranges_token_count(ranges) for ranges in ranges_by_rank - ), - ) - - -def _ranges_token_count(ranges: tuple[tuple[int, int, int], ...]) -> int: - return sum(int(end) - int(start) for start, end, _ in ranges) - - def build_gdn_rank_execution_plan( spec: GdnPackedExecutionSpec, *, @@ -349,7 +161,6 @@ def build_gdn_rank_execution_plan( cp_rank: int = 0, cp_size: int = 1, attention_token_layout_index: TokenLayoutIndex | None = None, - cp_segment_schedule: GdnCpSegmentSchedule | None = None, planner_config: GdnPlannerConfig | None = None, ) -> GdnRankExecutionPlan: """Build rank-local tensor metadata from a parsed shared-prefix DAG. @@ -368,185 +179,20 @@ def build_gdn_rank_execution_plan( cp_rank=cp_rank, cp_size=cp_size, attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, planner_config=planner_config, ) return move_gdn_rank_execution_plan_to_device(cpu_plan, target_device) - if cp_size != 1 or cp_rank != 0: - return _build_cp_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, - planner_config=planner_config, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_cp1_bucket_plans( + return _build_tree_rank_execution_plan( spec, device=device, - planner_config=planner_config, - ) - valid_lengths = torch.tensor( - spec.valid_lengths, - device=device, - dtype=torch.long, - ) - positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) - local_range_list: list[tuple[int, int, int]] = [] - local_position = 0 - for row_index, length in enumerate(spec.valid_lengths): - if length: - start = row_index * spec.sequence_length - local_range_list.append((start, start + length, local_position)) - local_position += length - local_ranges = tuple(local_range_list) - return GdnRankExecutionPlan.model_construct( cp_rank=cp_rank, cp_size=cp_size, - batch_size=spec.batch_size, - sequence_length=spec.sequence_length, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=positions.unsqueeze(0) < valid_lengths.unsqueeze(1), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_token_ranges=local_ranges, - gdn_token_ranges=local_ranges, - attention_token_count=spec.real_token_count, - gdn_token_count=spec.real_token_count, - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - ) - - -def move_gdn_rank_execution_plan_to_device( - plan: GdnRankExecutionPlan, - device: torch.device | str, -) -> GdnRankExecutionPlan: - """Move planner tensors to the execution device after CPU planning.""" - - from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - - return GdnRankExecutionPlan.model_construct( - cp_rank=plan.cp_rank, - cp_size=plan.cp_size, - batch_size=plan.batch_size, - sequence_length=plan.sequence_length, - packed_batch_size=plan.packed_batch_size, - packed_sequence_length=plan.packed_sequence_length, - real_token_mask=_move_planner_tensor(plan.real_token_mask, device), - family_count=plan.family_count, - completion_count=plan.completion_count, - local_prefix_buckets=_move_bucket_plans(plan.local_prefix_buckets, device), - local_completion_buckets=_move_bucket_plans( - plan.local_completion_buckets, device - ), - ready_local_completion_buckets=_move_bucket_plans( - plan.ready_local_completion_buckets, device - ), - remote_local_completion_buckets=_move_bucket_plans( - plan.remote_local_completion_buckets, device - ), - chain_prefix_buckets=_move_bucket_plans(plan.chain_prefix_buckets, device), - chain_completion_buckets=_move_bucket_plans( - plan.chain_completion_buckets, device - ), - prefix_table_is_dense_ordered=plan.prefix_table_is_dense_ordered, - attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), - gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), - attention_token_ranges=plan.attention_token_ranges, - gdn_token_ranges=plan.gdn_token_ranges, - attention_token_count=plan.attention_token_count, - gdn_token_count=plan.gdn_token_count, - parent_state_exchange_family_indices=plan.parent_state_exchange_family_indices, - parent_state_transfers=_move_parent_state_transfers( - plan.parent_state_transfers, device - ), - prefix_boundary_buckets=_move_bucket_plans( - plan.prefix_boundary_buckets, device - ), - prefix_tail_buckets=_move_bucket_plans(plan.prefix_tail_buckets, device), - completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.completion_with_prefix_tail_buckets, device - ), - remote_prefix_tail_buckets=_move_bucket_plans( - plan.remote_prefix_tail_buckets, device - ), - remote_completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.remote_completion_with_prefix_tail_buckets, device - ), - remote_prefix_tail_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_exchange, device - ), - remote_prefix_tail_backward_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_backward_exchange, device - ), - remote_prefix_tail_state_transfers=_move_parent_state_transfers( - plan.remote_prefix_tail_state_transfers, device - ), - ) - - -def _move_bucket_plans( - buckets: tuple[GdnSegmentBucketPlan, ...], - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - GdnSegmentBucketPlan.model_construct( - length=bucket.length, - lengths=_move_planner_tensor(bucket.lengths, device), - lengths_cpu=bucket.lengths_cpu, - lengths_by_rank_cpu=bucket.lengths_by_rank_cpu, - real_mask=_move_planner_tensor(bucket.real_mask, device), - cu_seqlens=_move_planner_tensor(bucket.cu_seqlens, device), - cu_seqlens_cpu=bucket.cu_seqlens_cpu, - row_indices=_move_planner_tensor(bucket.row_indices, device), - position_indices=_move_planner_tensor(bucket.position_indices, device), - family_indices=_move_planner_tensor(bucket.family_indices, device), - real_token_count_static=bucket.real_token_count, - output_mask=( - _move_planner_tensor(bucket.output_mask, device) - if bucket.output_mask is not None - else None - ), - ) - for bucket in buckets - ) - - -def _move_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - GdnParentStateTransferPlan.model_construct( - source_rank=transfer.source_rank, - dest_rank=transfer.dest_rank, - family_indices=transfer.family_indices, - family_indices_tensor=( - _move_planner_tensor(transfer.family_indices_tensor, device) - if transfer.family_indices_tensor is not None - else None - ), - ) - for transfer in transfers + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, ) -def _build_local_attention_layout_rank_execution_plan( +def _build_tree_rank_execution_plan( spec: GdnPackedExecutionSpec, *, device: torch.device | str, @@ -554,14 +200,17 @@ def _build_local_attention_layout_rank_execution_plan( cp_size: int, attention_token_layout_index: TokenLayoutIndex | None, planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - if any( - _has_chainable_segment(family, cp_size=cp_size, planner_config=planner_config) - for family in spec.families - ): - return None +) -> GdnRankExecutionPlan: + if cp_size < 1: + raise ValueError(f"cp_size must be >= 1, got {cp_size}") + if cp_rank < 0 or cp_rank >= cp_size: + raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") + if not spec.tree_segments: + raise ValueError("tree GDN planning requires tree segments") + if len(spec.tree_parent_indices) != len(spec.tree_segments): + raise ValueError("tree parent metadata length must match tree segments") + if len(spec.tree_depths) != len(spec.tree_segments): + raise ValueError("tree depth metadata length must match tree segments") from art.megatron.gdn.layout import ( _reverse_exchange_plan, @@ -575,2746 +224,397 @@ def _build_local_attention_layout_rank_execution_plan( planner_config=planner_config, ) attention_layout_index = _build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max(1, 2 * spec.real_token_count // len(tuple(spec.segments()))), + source_layout ) segment_attention_counts = _segment_attention_rank_counts( spec, cp_size=cp_size, attention_layout_index=attention_layout_index, ) - best = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=False, - planner_config=planner_config, - ) - co_located = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=True, - planner_config=planner_config, - ) - if co_located[4] < best[4]: - best = co_located - ( - prefix_owner_by_family, - completion_owners_by_family, - _, - cross_rank_token_count, - _, - ) = best - - local_prefix_segments: list[GdnSegmentSpec] = [] - local_completion_segments: list[GdnSegmentSpec] = [] - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - - def append_segment(rank: int, segment: GdnSegmentSpec) -> None: - token_start = _segment_token_start(segment, spec.sequence_length) - position_start = rank_loads[rank] - gdn_ranges_by_rank[rank].append( - (token_start, token_start + segment.length, position_start) - ) - rank_loads[rank] += segment.length - - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - if prefix_owner == cp_rank: - local_prefix_segments.append(family.prefix) - prefix_segments_by_rank[prefix_owner].append(family.prefix) - append_segment(prefix_owner, family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - if completion_owner == cp_rank: - local_completion_segments.append(completion) - completion_segments_by_rank[completion_owner].append(completion) - append_segment(completion_owner, completion) - if completion_owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - local_token_ranges = tuple(gdn_ranges_by_rank[cp_rank]) - local_token_count = rank_loads[cp_rank] - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - if parent_state_transfer_families: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - dest_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - local_rank=cp_rank, - cross_rank_token_count=cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), - ) - ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - prefix_family_order = tuple( - segment.family_index for bucket in local_prefix_buckets for segment in bucket - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - tuple(local_prefix_segments), - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=( - ready_completion_bucket_plans + remote_completion_bucket_plans - ), - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families - remote_prefix_tail_families) - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - -def _assign_local_attention_segments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[int, ...], - tuple[tuple[int, ...], ...], - tuple[int, ...], - int, - float, -]: + depth_count = max(spec.tree_depths, default=0) + 1 rank_loads = [0] * cp_size - has_prefix = [False] * cp_size - has_completion = [False] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - parent_state_exchange_families: set[int] = set() + owner_by_node = [-1] * len(spec.tree_segments) + chained_nodes = [False] * len(spec.tree_segments) + tree_has_children = [False] * len(spec.tree_segments) + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + tree_has_children[parent_index] = True + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + segments_by_rank_depth: list[list[list[GdnSegmentSpec]]] = [ + [[] for _ in range(depth_count)] for _ in range(cp_size) + ] + chain_segments_by_depth: list[list[GdnSegmentSpec]] = [ + [] for _ in range(depth_count) + ] cross_rank_token_count = 0 - def append_owner(rank: int, segment: GdnSegmentSpec) -> None: + children_by_node: list[list[int]] = [[] for _ in spec.tree_segments] + root_indices: list[int] = [] + for node_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0: + root_indices.append(node_index) + else: + children_by_node[parent_index].append(node_index) + + def subtree_indices(root_index: int) -> tuple[int, ...]: + ordered: list[int] = [] + stack = [root_index] + while stack: + node_index = stack.pop() + ordered.append(node_index) + stack.extend(reversed(children_by_node[node_index])) + return tuple(ordered) + + def assign_local_group(node_indices: tuple[int, ...]) -> None: nonlocal cross_rank_token_count - rank_loads[rank] += segment.length - cross_rank_token_count += ( - segment.length - segment_attention_counts[_segment_key(segment)][rank] - ) - - for family in spec.families: - if co_locate_local_families: - owner = _best_segment_owner( - (family.prefix, *family.completions), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - prefix_owner_by_family.append(owner) - completion_owners = tuple(owner for _ in family.completions) - completion_owners_by_family.append(completion_owners) - has_prefix[owner] = True - for segment in (family.prefix, *family.completions): - append_owner(owner, segment) - if family.completions: - has_completion[owner] = True - continue - - prefix_owner = _best_segment_owner( - (family.prefix,), + segments = tuple(spec.tree_segments[index] for index in node_indices) + owner = _best_segment_owner( + segments, rank_loads, segment_attention_counts=segment_attention_counts, planner_config=planner_config, ) - prefix_owner_by_family.append(prefix_owner) - has_prefix[prefix_owner] = True - append_owner(prefix_owner, family.prefix) - completion_owners = [] - for completion in family.completions: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - completion_owners.append(owner) - has_completion[owner] = True - append_owner(owner, completion) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - completion_owners_by_family.append(tuple(completion_owners)) - - del has_prefix, has_completion - score = _score_local_segment_assignment( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - rank_loads=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - planner_config=planner_config, - ) - return ( - tuple(prefix_owner_by_family), - tuple(completion_owners_by_family), - tuple(sorted(parent_state_exchange_families)), - cross_rank_token_count, - score, - ) - - -def _score_local_segment_assignment( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - rank_loads: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - planner_config: GdnPlannerConfig, -) -> float: - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - local_completion_segments_by_rank[completion_owner].append(completion) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=tuple(0 for _ in range(cp_size)), - rank_real_tokens=rank_loads, - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=parent_state_exchange_family_count, - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=0, - planner_config=planner_config, - ) - - -def _can_zero_exchange_colocate_families( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> bool: - for family in spec.families: - family_rank_counts = [0] * cp_size - for segment in (family.prefix, *family.completions): - segment_counts = segment_attention_counts[_segment_key(segment)] - for rank in range(cp_size): - family_rank_counts[rank] += segment_counts[rank] - if max(family_rank_counts, default=0) != family.token_count: - return False - return True - - -def parse_gdn_shared_prefix_segments( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - min_completions_per_family: int = 0, -) -> GdnPackedExecutionSpec: - """Parse ART packed shared-prefix metadata into a GDN segment DAG. - - The parser is intentionally strict: GDN state routing depends on prompt-family - boundaries, so malformed metadata should fail before execution can silently - leak recurrent or conv state across siblings or independent families. - """ - - groups = _rank2_long_cpu("group_ids", group_ids) - parents = _rank2_long_cpu("parent_ids", parent_ids) - if tuple(groups.shape) != tuple(parents.shape): - raise ValueError( - "group_ids and parent_ids must have the same shape, got " - f"{tuple(groups.shape)} and {tuple(parents.shape)}" - ) - - batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) - valid_lengths: list[int] = [] - families: list[GdnPackedFamilySpec] = [] - for row_index in range(batch_size): - row_group_ids = groups[row_index] - row_parent_ids = parents[row_index] - valid_length = _validate_padding_tensor( - row_index, row_group_ids, row_parent_ids - ) - valid_lengths.append(valid_length) - if valid_length == 0: - continue - families.extend( - _parse_row_tensor( - row_index=row_index, - group_ids=row_group_ids, - parent_ids=row_parent_ids, - valid_length=valid_length, - first_family_index=len(families), - min_completions_per_family=min_completions_per_family, - ) - ) - - return GdnPackedExecutionSpec( - batch_size=batch_size, - sequence_length=sequence_length, - valid_lengths=tuple(valid_lengths), - families=tuple(families), - ) - - -def _build_segment_bucket_plans( - segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_segment_bucket_plan(bucket[0].length, bucket, device=device) - for bucket in segment_buckets - ) - - -def _build_chunk_aligned_cp1_bucket_plans( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for family in spec.families: - prefix = family.prefix - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - prefix_tail_positions = tuple(range(boundary_end, prefix.end)) - if prefix_tail_positions and not family.completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family.completions): - completion_positions = prefix_tail_positions + tuple( - range(completion.start, completion.end) - ) - completion_columns.append( - _explicit_bucket_column( - row_index=completion.row_index, - family_index=completion.family_index, - positions=completion_positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * completion.length - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_segment_bucket_plans(boundary_buckets, device=device), - _build_segment_bucket_plans(tail_buckets, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_chunk_aligned_position_bucket_plans( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - local_token_ranges: tuple[tuple[int, int, int], ...], - *, - sequence_length: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) - local_range_positions = { - (token_start, token_end): position_start - for token_start, token_end, position_start in local_token_ranges - } - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - prefix_tail_positions = _local_positions_for_span( - prefix.row_index, - boundary_end, - prefix.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - if prefix_tail_positions and not family_completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family_completions): - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - positions = prefix_tail_positions + completion_positions - completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=completion.family_index, - positions=positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * len(completion_positions) - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_position_bucket_plans( - boundary_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_position_bucket_plans( - tail_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_remote_prefix_tail_plans( - spec: GdnPackedExecutionSpec, - schedule: GdnCpSegmentSchedule, - *, - cp_rank: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - from art.megatron.gdn.layout import ( - GdnCpExchangePlan, - GdnCpPeerTransfer, - _reverse_exchange_plan, - ) - - family_by_index = {family.family_index: family for family in spec.families} - prefix_owner_by_family = _prefix_owner_by_family(schedule) - source_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_counts = [0 for _ in schedule.gdn_token_counts_by_rank] - state_transfer_families: dict[tuple[int, int], set[int]] = {} - remote_tail_family_indices: set[int] = set() - local_tail_columns: list[_ExplicitBucketColumn] = [] - local_completion_columns: list[_ExplicitBucketColumn] = [] - tail_positions_by_dest_family: dict[tuple[int, int], tuple[int, ...]] = {} - local_tail_column_families: set[int] = set() - rank_ranges = schedule.gdn_token_ranges_by_rank - rank_range_ends = tuple( - tuple(end for _, end, _ in ranges) for ranges in rank_ranges - ) - rank_range_positions = tuple( - { - (token_start, token_end): position_start - for token_start, token_end, position_start in ranges - } - for ranges in rank_ranges - ) - - for dest_rank, completions in enumerate(schedule.local_completion_segments_by_rank): - for completion in completions: - source_rank = prefix_owner_by_family.get(completion.family_index) - if source_rank is None or source_rank == dest_rank: - continue - family = family_by_index[completion.family_index] - boundary_end = _prefix_chunk_boundary_end(family.prefix) - if boundary_end == family.prefix.end: - continue - dest_family = (dest_rank, family.family_index) - dest_positions = tail_positions_by_dest_family.get(dest_family) - if dest_positions is None: - source_positions = _local_positions_for_span( - family.prefix.row_index, - boundary_end, - family.prefix.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[source_rank], - local_range_ends=rank_range_ends[source_rank], - local_range_positions=rank_range_positions[source_rank], - ) - if len(source_positions) != family.prefix.end - boundary_end: - raise ValueError( - "remote prefix-tail exchange could not locate all source tokens " - f"for family {family.family_index}" - ) - dest_start = dest_counts[dest_rank] - dest_positions = tuple( - range(dest_start, dest_start + len(source_positions)) - ) - tail_positions_by_dest_family[dest_family] = dest_positions - dest_counts[dest_rank] += len(source_positions) - pair = (source_rank, dest_rank) - source_positions_by_pair.setdefault(pair, []).extend(source_positions) - dest_positions_by_pair.setdefault(pair, []).extend(dest_positions) - state_transfer_families.setdefault(pair, set()).add(family.family_index) - remote_tail_family_indices.add(family.family_index) - - if dest_rank != cp_rank: - continue - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[dest_rank], - local_range_ends=rank_range_ends[dest_rank], - local_range_positions=rank_range_positions[dest_rank], - ) - if len(completion_positions) != completion.length: - raise ValueError( - "remote prefix-tail bucket could not locate all completion tokens " - f"for family {family.family_index}" - ) - remote_base = int(schedule.gdn_token_counts_by_rank[dest_rank]) - if ( - len(dest_positions) > 0 - and family.family_index not in local_tail_column_families - ): - local_tail_column_families.add(family.family_index) - local_tail_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=tuple(remote_base + pos for pos in dest_positions), - output_mask=(False,) * len(dest_positions), - ) - ) - local_completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=completion_positions, - output_mask=(True,) * len(completion_positions), - ) - ) - - if not source_positions_by_pair: - return (), (), None, None, (), frozenset() - - transfers = tuple( - GdnCpPeerTransfer.model_construct( - source_rank=source_rank, - dest_rank=dest_rank, - token_count=len(source_positions), - source_positions_tensor=_move_planner_tensor( - torch.tensor(source_positions, dtype=torch.long), device - ), - dest_positions_tensor=_move_planner_tensor( - torch.tensor( - dest_positions_by_pair[(source_rank, dest_rank)], - dtype=torch.long, - ), - device, - ), - ) - for (source_rank, dest_rank), source_positions in sorted( - source_positions_by_pair.items() - ) - ) - exchange = GdnCpExchangePlan.model_construct( - cp_size=len(schedule.gdn_token_counts_by_rank), - source_token_counts_by_rank=schedule.gdn_token_counts_by_rank, - dest_token_counts_by_rank=tuple(dest_counts), - transfers=transfers, - cross_rank_token_count_override=sum(dest_counts), - ) - tail_column_batches = _batch_explicit_bucket_columns( - tuple(local_tail_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(local_completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_explicit_bucket_plans(tail_column_batches, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - exchange, - _reverse_exchange_plan(exchange), - _transfer_plans_to_device( - _build_parent_state_transfer_plans(state_transfer_families), - device=device, - ), - frozenset(remote_tail_family_indices), - ) - - -def _empty_remote_prefix_tail_plans() -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - return (), (), None, None, (), frozenset() - - -def _prefix_owner_by_family(schedule: GdnCpSegmentSchedule) -> dict[int, int]: - owners: dict[int, int] = {} - for rank, segments in enumerate(schedule.local_prefix_segments_by_rank): - for segment in segments: - owners[segment.family_index] = rank - return owners - - -def _filter_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - excluded_families: frozenset[int], - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - if not excluded_families: - return _transfer_plans_to_device(transfers, device=device) - kept: dict[tuple[int, int], set[int]] = {} - for transfer in transfers: - families = set(transfer.family_indices) - excluded_families - if families: - kept.setdefault((transfer.source_rank, transfer.dest_rank), set()).update( - families - ) - return _transfer_plans_to_device( - _build_parent_state_transfer_plans(kept), device=device - ) - - -def _local_positions_for_span( - row_index: int, - start: int, - end: int, - *, - sequence_length: int, - local_token_ranges: tuple[tuple[int, int, int], ...], - local_range_ends: tuple[int, ...], - local_range_positions: dict[tuple[int, int], int] | None = None, -) -> tuple[int, ...]: - if start == end: - return () - token_start = row_index * sequence_length + start - token_end = row_index * sequence_length + end - if local_range_positions is not None: - position_start = local_range_positions.get((token_start, token_end)) - if position_start is not None: - return tuple(range(position_start, position_start + end - start)) - range_index = bisect_left(local_range_ends, token_start + 1) - if range_index < len(local_token_ranges): - range_start, range_end, position_start = local_token_ranges[range_index] - if range_start <= token_start and token_end <= range_end: - local_start = position_start + token_start - range_start - return tuple(range(local_start, local_start + end - start)) - segment = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=0, - group_id=0, - parent_id=0, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - return tuple( - int(position) - for position in _local_positions_for_segment( - segment, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - ).tolist() - ) - - -def _prefix_chunk_boundary_end(prefix: GdnSegmentSpec) -> int: - aligned_length = (prefix.length // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE - return prefix.start + aligned_length - - -def _segment_with_bounds( - segment: GdnSegmentSpec, start: int, end: int -) -> GdnSegmentSpec: - return _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=segment.row_index, - family_index=segment.family_index, - group_id=segment.group_id, - parent_id=segment.parent_id, - start=start, - end=end, - kind=segment.kind, - child_index=segment.child_index, - ) - - -def _batch_explicit_bucket_columns( - columns: tuple[_ExplicitBucketColumn, ...], - *, - max_padding_ratio: float = 1.25, - max_segments_per_batch: int = 128, -) -> tuple[tuple[_ExplicitBucketColumn, ...], ...]: - if not columns: - return () - ordered = sorted( - columns, - key=lambda column: (column.length, column.family_index, column.row_index), - ) - batches: list[list[_ExplicitBucketColumn]] = [] - current: list[_ExplicitBucketColumn] = [] - current_tokens = 0 - current_max = 0 - for column in ordered: - next_count = len(current) + 1 - next_tokens = current_tokens + column.length - next_max = max(current_max, column.length) - padded = next_max * next_count - can_extend = not current or ( - next_count <= max_segments_per_batch - and padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - batches.append(current) - current = [] - current_tokens = 0 - current_max = 0 - current.append(column) - current_tokens += column.length - current_max = max(current_max, column.length) - if current: - batches.append(current) - return tuple(tuple(batch) for batch in batches) - - -def _build_explicit_bucket_plans( - bucket_columns: tuple[tuple[_ExplicitBucketColumn, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_explicit_bucket_plan(columns, device=device) - for columns in bucket_columns - ) - - -def _build_explicit_bucket_plan( - columns: tuple[_ExplicitBucketColumn, ...], - *, - device: torch.device | str, -) -> GdnSegmentBucketPlan: - max_length = max(column.length for column in columns) - column_count = len(columns) - lengths = [column.length for column in columns] - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - padded_element_count = max_length * column_count - row_indices = [0] * padded_element_count - position_indices = [0] * padded_element_count - output_mask = [False] * padded_element_count - for column_index, column in enumerate(columns): - length = column.length - column_slice = slice(column_index, length * column_count, column_count) - row_indices[column_slice] = [column.row_index] * length - position_indices[column_slice] = column.positions - output_mask[column_slice] = column.output_mask - row_indices_cpu = torch.tensor(row_indices, dtype=torch.long).reshape( - max_length, column_count - ) - position_indices_cpu = torch.tensor(position_indices, dtype=torch.long).reshape( - max_length, column_count - ) - output_mask_cpu = torch.tensor(output_mask, dtype=torch.bool).reshape( - max_length, column_count - ) - family_indices_cpu = torch.tensor( - [column.family_index for column in columns], dtype=torch.long - ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), - lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=None, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=int(lengths_cpu.sum().item()), - output_mask=_move_planner_tensor(output_mask_cpu, device), - ) - - -def _attention_source_layout( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - planner_config: GdnPlannerConfig, -) -> TokenLayoutIndex: - if attention_token_layout_index is not None: - if _layout_cp_size(attention_token_layout_index) != cp_size: - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - if _layout_token_count(attention_token_layout_index) != spec.real_token_count: - raise ValueError( - "attention token layout index token count must match GDN real token " - f"count, got {_layout_token_count(attention_token_layout_index)} and " - f"{spec.real_token_count}" - ) - return attention_token_layout_index - return _token_layout_from_rank_ranges( - _default_attention_layout_ranges( - spec, - cp_size=cp_size, - planner_config=planner_config, - ) - ) - - -def _build_cp_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - cp_segment_schedule: GdnCpSegmentSchedule | None, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan: - if cp_size < 1: - raise ValueError(f"cp_size must be >= 1, got {cp_size}") - if cp_rank < 0 or cp_rank >= cp_size: - raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") - if ( - attention_token_layout_index is not None - and _layout_cp_size(attention_token_layout_index) != cp_size - ): - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - - from art.megatron.gdn.layout import ( - _reverse_exchange_plan, - build_local_rank_cp_exchange_plan_from_dest_ranges, - ) - - has_explicit_attention_layout = attention_token_layout_index is not None - if cp_segment_schedule is None and not has_explicit_attention_layout: - local_family_plan = _build_local_family_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - planner_config=planner_config, - ) - if local_family_plan is not None: - return local_family_plan - if cp_segment_schedule is None and has_explicit_attention_layout: - local_layout_plan = _build_local_attention_layout_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if local_layout_plan is not None: - return local_layout_plan - - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if cp_segment_schedule is None: - schedule = _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, - (2 * spec.real_token_count) // max(1, len(spec.segments())), - ), - ), - planner_config=planner_config, - ) - else: - schedule = cp_segment_schedule - if len(schedule.gdn_token_counts_by_rank) != cp_size: - raise ValueError(f"CP GDN schedule must contain {cp_size} ranks") - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - local_rank=cp_rank, - dest_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - cross_rank_token_count=schedule.cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_token_ranges = schedule.gdn_token_ranges_by_rank[cp_rank] - local_gdn_token_count = schedule.gdn_token_counts_by_rank[cp_rank] - if schedule.parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - - chain_prefix_buckets = tuple( - bucket for bucket in schedule.chain_prefix_buckets if bucket - ) - chain_completion_buckets = tuple( - bucket for bucket in schedule.chain_completion_buckets if bucket - ) - local_prefix_segments = tuple(schedule.local_prefix_segments_by_rank[cp_rank]) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - () if local_prefix_segments else (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_segments = tuple( - schedule.local_completion_segments_by_rank[cp_rank] - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=chain_prefix_buckets, - ) - ) - ready_local_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_local_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_buckets = ( - ready_local_completion_buckets + remote_local_completion_buckets - ) - prefix_family_order = tuple( - segment.family_index - for bucket in ( - *chain_prefix_buckets, - *local_prefix_buckets, - ) - for segment in bucket - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - local_prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_gdn_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_gdn_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=_build_position_bucket_plans( - local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - ready_local_completion_buckets=_build_position_bucket_plans( - ready_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - remote_local_completion_buckets=_build_position_bucket_plans( - remote_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - chain_prefix_buckets=_build_position_bucket_plans( - chain_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - chain_completion_buckets=_build_position_bucket_plans( - chain_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_gdn_token_count, - parent_state_exchange_family_indices=( - tuple( - family_index - for family_index in schedule.parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ) - ), - parent_state_transfers=_filter_parent_state_transfers( - schedule.parent_state_transfers, - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - - -def build_gdn_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None = None, - planner_config: GdnPlannerConfig | None = None, -) -> GdnCpSegmentSchedule: - planner_config = planner_config or GdnPlannerConfig() - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - return _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, (2 * spec.real_token_count) // max(1, len(spec.segments())) - ), - ), - planner_config=planner_config, - ) - - -def _build_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - segment_attention_counts = _segment_attention_rank_counts( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - legal_chain_segments = tuple( - segment - for family in spec.families - for segment in (family.prefix, *family.completions) - if ( - _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - if segment.kind == "prefix" - else _can_chain_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - ) - ) - decision = _beam_search_cp_segment_schedule_decision( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - legal_chain_segments=legal_chain_segments, - planner_config=planner_config, - ) - return _materialize_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - chain_segment_keys=decision.chain_segment_keys, - co_locate_local_families=decision.co_locate_local_families, - planner_config=planner_config, - ) - - -def _beam_search_cp_segment_schedule_decision( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - legal_chain_segments: tuple[GdnSegmentSpec, ...], - planner_config: GdnPlannerConfig, -) -> _GdnCpSegmentSearchDecision: - legal_chain_keys = frozenset( - _segment_key(segment) for segment in legal_chain_segments - ) - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]] = {} - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int] = {} - for segment in legal_chain_segments: - key = _segment_key(segment) - ( - chain_rank_counts_by_key[key], - chain_cross_rank_tokens_by_key[key], - ) = _chain_segment_rank_counts_and_cross_rank_tokens( - segment, - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - - score_cache: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - - def decision_for( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - ) -> _GdnCpSegmentSearchDecision: - cached = score_cache.get(chain_segment_keys) - if cached is not None: - return cached - non_colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=False, - planner_config=planner_config, - ) - colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=True, - planner_config=planner_config, - ) - co_locate = colocated_score < non_colocated_score - decision = _GdnCpSegmentSearchDecision.model_construct( - chain_segment_keys=chain_segment_keys, - co_locate_local_families=co_locate, - score=colocated_score if co_locate else non_colocated_score, - ) - score_cache[chain_segment_keys] = decision - return decision - - best = decision_for(frozenset()) - beam_by_keys = {best.chain_segment_keys: best} - if legal_chain_keys: - all_chain = decision_for(legal_chain_keys) - beam_by_keys[all_chain.chain_segment_keys] = all_chain - if best.score - all_chain.score > planner_config.cp_chain_min_score_delta_ms: - best = all_chain - candidate_groups = _bounded_chain_candidate_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - beam = _best_cp_segment_search_decisions( - beam_by_keys.values(), - limit=planner_config.cp_chain_beam_width, - ) - stale_steps = 0 - for _ in range(planner_config.cp_chain_beam_max_steps): - if not candidate_groups: - break - expanded: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - for decision in beam: - neighbors = [] - for segment_keys in _chain_beam_neighbor_groups( - decision.chain_segment_keys, - candidate_groups=candidate_groups, - branch_factor=planner_config.cp_chain_beam_branch_factor, - ): - if segment_keys.issubset(decision.chain_segment_keys): - next_keys = decision.chain_segment_keys - segment_keys - else: - next_keys = decision.chain_segment_keys | segment_keys - neighbors.append(decision_for(frozenset(next_keys))) - for neighbor in _best_cp_segment_search_decisions( - neighbors, - limit=planner_config.cp_chain_beam_branch_factor, - ): - expanded[neighbor.chain_segment_keys] = neighbor - if not expanded: - break - beam = _best_cp_segment_search_decisions( - (*beam, *expanded.values()), - limit=planner_config.cp_chain_beam_width, - ) - step_best = beam[0] - if best.score - step_best.score > planner_config.cp_chain_min_score_delta_ms: - best = step_best - stale_steps = 0 - else: - stale_steps += 1 - if stale_steps >= 2: - break - return best - - -def _chain_beam_neighbor_groups( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - *, - candidate_groups: tuple[frozenset[GdnSegmentDecisionKey], ...], - branch_factor: int, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - selected: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in candidate_groups: - if group and not group.issubset(chain_segment_keys): - selected.append(group) - if len(selected) >= branch_factor: - return tuple(selected) - for group in reversed(candidate_groups): - if group and group.intersection(chain_segment_keys) and group not in selected: - selected.append(group) - if len(selected) >= branch_factor: - break - return tuple(selected) - - -def _best_cp_segment_search_decisions( - decisions: Any, - *, - limit: int, -) -> tuple[_GdnCpSegmentSearchDecision, ...]: - return tuple( - sorted( - decisions, - key=lambda decision: ( - decision.score, - len(decision.chain_segment_keys), - tuple(sorted(decision.chain_segment_keys)), - ), - )[:limit] - ) - - -def _bounded_chain_candidate_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - legal_key_set = frozenset(_segment_key(segment) for segment in legal_chain_segments) - if not legal_key_set: - return () - prefix_keys = frozenset( - _segment_key(family.prefix) - for family in spec.families - if _segment_key(family.prefix) in legal_key_set - ) - completion_keys = legal_key_set - prefix_keys - groups: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in (legal_key_set, prefix_keys, completion_keys): - if group and group not in groups: - groups.append(group) - for group in _ranked_chain_beam_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ): - if group and group not in groups: - groups.append(group) - return tuple(groups[: planner_config.cp_chain_beam_candidate_limit]) - - -def _ranked_chain_beam_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - if not legal_chain_segments: - return () - priority_by_key = { - _segment_key(segment): _chain_beam_segment_priority( - segment, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - ) - for segment in legal_chain_segments - } - legal_key_set = frozenset(priority_by_key) - groups: set[frozenset[GdnSegmentDecisionKey]] = { - frozenset((key,)) for key in legal_key_set - } - for family in spec.families: - completion_keys = frozenset( - _segment_key(completion) - for completion in family.completions - if _segment_key(completion) in legal_key_set - ) - if len(completion_keys) > 1: - groups.add(completion_keys) - family_keys = completion_keys - prefix_key = _segment_key(family.prefix) - if prefix_key in legal_key_set: - family_keys = family_keys | frozenset((prefix_key,)) - if len(family_keys) > 1: - groups.add(family_keys) - ranked = tuple( - sorted( - groups, - key=lambda group: _chain_beam_group_priority( - group, priority_by_key=priority_by_key - ), - reverse=True, - ) - ) - limit = planner_config.cp_chain_beam_candidate_limit - if len(ranked) <= limit: - return ranked - high_count = (limit + 1) // 2 - low_count = limit - high_count - selected = [*ranked[:high_count]] - for group in ranked[-low_count:]: - if group not in selected: - selected.append(group) - return tuple(selected) - - -def _chain_beam_group_priority( - group: frozenset[GdnSegmentDecisionKey], - *, - priority_by_key: dict[GdnSegmentDecisionKey, tuple[int, int, int, int]], -) -> tuple[int, int, int, int, int]: - priorities = tuple(priority_by_key[key] for key in group) - return ( - sum(priority[0] for priority in priorities), - sum(priority[1] for priority in priorities), - max((priority[2] for priority in priorities), default=0), - sum(priority[3] for priority in priorities), - len(group), - ) - - -def _chain_beam_segment_priority( - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], -) -> tuple[int, int, int, int]: - key = _segment_key(segment) - chain_max_load = max(chain_rank_counts_by_key[key], default=0) - best_attention_locality = max(segment_attention_counts[key], default=0) - chain_load_relief = segment.length - chain_max_load - minimum_local_exchange = segment.length - best_attention_locality - return ( - chain_load_relief, - segment.length, - best_attention_locality, - -minimum_local_exchange, - ) - - -def _score_cp_segment_decisions( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> float: - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - family.prefix, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _add_local_search_load( - rank_loads, - prefix_owner, - family.prefix, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - completion_key = _segment_key(completion) - if completion_key in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - completion, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - if not chain_prefix: - parent_state_exchange_families.add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _add_local_search_load( - rank_loads, - owner, - completion, - segment_attention_counts=segment_attention_counts, - ) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - chain_work_by_rank, chain_bucket_count = _estimate_chain_rank_kernel_work( - cp_size=cp_size, - chain_prefix_segments=tuple(chain_prefix_segments), - chain_completion_segments=tuple(chain_completion_segments), - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=chain_work_by_rank, - rank_real_tokens=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=chain_bucket_count, - planner_config=planner_config, - ) - - -def _estimate_local_rank_kernel_work( - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int, int]: - rank_work: list[int] = [] - rank_bucket_counts: list[int] = [] - rank_segment_counts: list[int] = [] - for prefix_segments, completion_segments in zip( - local_prefix_segments_by_rank, - local_completion_segments_by_rank, - strict=True, - ): - prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in prefix_family_indices - ) - chunk_work, chunk_bucket_count = _estimate_chunk_aligned_local_work( - prefix_segments, - chunk_local_completion_segments, - planner_config=planner_config, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(segment.length for segment in plain_local_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - rank_work.append(chunk_work + completion_work) - rank_bucket_counts.append(chunk_bucket_count + completion_bucket_count) - rank_segment_counts.append(len(prefix_segments) + len(completion_segments)) - return ( - tuple(rank_work), - max(rank_bucket_counts, default=0), - max(rank_segment_counts, default=0), - ) - - -def _estimate_chunk_aligned_local_work( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[int, int]: - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_lengths: list[int] = [] - tail_lengths: list[int] = [] - completion_column_lengths: list[int] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - boundary_length = boundary_end - prefix.start - if boundary_length > 0: - boundary_lengths.append(boundary_length) - tail_length = prefix.end - boundary_end - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - if tail_length > 0 and not family_completions: - tail_lengths.append(tail_length) - for completion in family_completions: - completion_column_lengths.append(tail_length + completion.length) - boundary_work, boundary_bucket_count = _padded_work_from_lengths( - tuple(boundary_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_work, tail_bucket_count = _padded_work_from_lengths( - tuple(tail_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(completion_column_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - boundary_work + tail_work + completion_work, - boundary_bucket_count + tail_bucket_count + completion_bucket_count, - ) - - -def _estimate_chain_rank_kernel_work( - *, - cp_size: int, - chain_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_completion_segments: tuple[GdnSegmentSpec, ...], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int]: - rank_work = [0] * cp_size - bucket_count = 0 - for segments in (chain_prefix_segments, chain_completion_segments): - buckets = _batch_segments_by_padded_work( - segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - bucket_count += len(buckets) - for bucket in buckets: - for rank in range(cp_size): - lengths = tuple( - chain_rank_counts_by_key[_segment_key(segment)][rank] - for segment in bucket - ) - rank_work[rank] += max(lengths, default=0) * len(lengths) - return tuple(rank_work), bucket_count - - -def _padded_work_from_lengths( - lengths: tuple[int, ...], - *, - max_padding_ratio: float, - max_segments_per_batch: int, -) -> tuple[int, int]: - if not lengths: - return 0, 0 - ordered = sorted(length for length in lengths if length > 0) - if not ordered: - return 0, 0 - bucket_count = 0 - padded_work = 0 - current_count = 0 - current_tokens = 0 - current_max = 0 - for length in ordered: - next_count = current_count + 1 - next_tokens = current_tokens + length - next_max = max(current_max, length) - next_padded = next_max * next_count - can_extend = current_count == 0 or ( - next_count <= max_segments_per_batch - and next_padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - bucket_count += 1 - padded_work += current_max * current_count - current_count = 0 - current_tokens = 0 - current_max = 0 - current_count += 1 - current_tokens += length - current_max = max(current_max, length) - if current_count: - bucket_count += 1 - padded_work += current_max * current_count - return padded_work, bucket_count - - -def _add_chain_search_load( - rank_loads: list[int], - segment: GdnSegmentSpec, - *, - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], -) -> int: - key = _segment_key(segment) - for rank, token_count in enumerate(chain_rank_counts_by_key[key]): - rank_loads[rank] += token_count - return chain_cross_rank_tokens_by_key[key] - - -def _add_local_search_load( - rank_loads: list[int], - rank: int, - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> int: - rank_loads[rank] += segment.length - return segment.length - segment_attention_counts[_segment_key(segment)][rank] - - -def _chain_segment_rank_counts_and_cross_rank_tokens( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, -) -> tuple[tuple[int, ...], int]: - token_start = _segment_token_start(segment, spec.sequence_length) - attention_shards = _attention_contiguous_chain_shards( - token_start, - segment.length, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - if attention_shards is not None: - return tuple(len(shard) for shard in attention_shards), 0 - shard_lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - cross_rank_tokens = 0 - start = 0 - for rank, shard_length in enumerate(shard_lengths): - end = start + shard_length - shard_start = token_start + start - cross_rank_tokens += shard_length - _attention_overlap_count( - attention_layout_index, - rank, - shard_start, - shard_start + shard_length, - ) - start = end - return shard_lengths, cross_rank_tokens - - -def _materialize_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - family.prefix, - spec, - attention_layout_index=attention_layout_index, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, - rank_loads, - prefix_owner, - family.prefix, - spec, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - if _segment_key(completion) in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - completion, - spec, - attention_layout_index=attention_layout_index, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local-prefix/chained-completion planning lost the prefix owner" - ) - parent_state_exchange_families.add(family.family_index) - for dest_rank in range(cp_size): - if dest_rank == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, dest_rank), set() - ).add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, owner), set() - ).add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, + for segment in segments: + owner_by_node[segment.family_index] = owner + segments_by_rank_depth[owner][ + spec.tree_depths[segment.family_index] + ].append(segment) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, rank_loads, owner, - completion, - spec, - segment_attention_counts=segment_attention_counts, - ) - - return GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=_batch_segments_by_padded_work( - tuple(chain_prefix_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - chain_completion_buckets=_batch_segments_by_padded_work( - tuple(chain_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in local_prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in local_completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - - -def _build_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - target_rank_load = spec.real_token_count / cp_size - loads = [0] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - for family in spec.families: - if _has_chainable_segment( - family, cp_size=cp_size, planner_config=planner_config - ): - return None - prefix_locality_limit = max( - planner_config.max_zero_exchange_load_imbalance * target_rank_load, - min(64.0, float(spec.real_token_count)), - ) - if family.prefix.length > prefix_locality_limit: - return None - owner = _least_loaded_rank(loads) - prefix_owner_by_family.append(owner) - completion_owners_by_family.append(tuple(owner for _ in family.completions)) - loads[owner] += family.token_count - - if max(loads, default=0) > ( - planner_config.local_completion_rebalance_min_imbalance * target_rank_load - ): - completion_owners_by_family = list( - _rebalance_local_completion_segments( + segment, spec, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - initial_loads=tuple(loads), + segment_attention_counts=segment_attention_counts, + ) + + subtree_token_counts = [segment.length for segment in spec.tree_segments] + for node_index in reversed(range(len(spec.tree_segments))): + for child_index in children_by_node[node_index]: + subtree_token_counts[node_index] += subtree_token_counts[child_index] + target_rank_load = spec.real_token_count / max(1, cp_size) + max_local_group_tokens = max(1, int(target_rank_load)) + + def assign_tree(root_index: int) -> None: + nonlocal cross_rank_token_count + root = spec.tree_segments[root_index] + if ( + spec.tree_parent_indices[root_index] < 0 + and cp_size > 1 + and _can_chain_tree_segment( + root, + cp_size=cp_size, planner_config=planner_config, ) - ) - rank_assignments = _materialize_local_family_rank_assignments( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - ) - local_token_count, local_token_ranges, prefix_segments, completion_segments = ( - rank_assignments[cp_rank] - ) - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - completion_owners = completion_owners_by_family[family.family_index] - for completion_owner in sorted(set(completion_owners)): - if completion_owner == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - from art.megatron.gdn.layout import GdnCpExchangePlan, GdnCpPeerTransfer - - token_counts_by_rank = tuple(assignment[0] for assignment in rank_assignments) - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=token_counts_by_rank, - dest_token_counts_by_rank=token_counts_by_rank, - transfers=tuple( - GdnCpPeerTransfer.model_construct( - source_rank=rank, - dest_rank=rank, - token_count=token_count, - source_positions_tensor=None, - dest_positions_tensor=None, + ): + chained_nodes[root.family_index] = True + chain_segments_by_depth[spec.tree_depths[root.family_index]].append(root) + cross_rank_token_count += _append_chain_segment( + gdn_ranges_by_rank, + rank_loads, + root, + spec, + attention_layout_index=attention_layout_index, ) - for rank, token_count in enumerate(token_counts_by_rank) - if token_count - ), + for child_index in children_by_node[root_index]: + assign_tree(child_index) + return + + if subtree_token_counts[root_index] <= max_local_group_tokens: + assign_local_group(subtree_indices(root_index)) + return + + assign_local_group((root_index,)) + for child_index in children_by_node[root_index]: + assign_tree(child_index) + + for root_index in root_indices: + assign_tree(root_index) + + gdn_ranges_by_rank_by_position = tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank ) - parent_state_exchange_family_indices = tuple( - sorted( - family_index - for family_indices in parent_state_transfer_families.values() - for family_index in family_indices - ) + gdn_ranges_by_rank_by_source = tuple( + tuple(sorted(ranges)) for ranges in gdn_ranges_by_rank ) - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=token_counts_by_rank, - gdn_token_ranges_by_rank=tuple( - assignment[1] for assignment in rank_assignments - ), - cross_rank_token_count=0, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - assignment[2] for assignment in rank_assignments - ), - local_completion_segments_by_rank=tuple( - assignment[3] for assignment in rank_assignments - ), - parent_state_exchange_family_indices=parent_state_exchange_family_indices, - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), + + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, + device=device, + local_rank=cp_rank, + dest_ranges_by_rank=gdn_ranges_by_rank_by_position, + cross_rank_token_count=cross_rank_token_count, ) - if parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, + local_token_ranges = gdn_ranges_by_rank_by_source[cp_rank] + tree_segment_buckets_by_depth = tuple( + _build_tree_bucket_plans( + tuple(segments_by_rank_depth[cp_rank][depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges=None if cp_size == 1 else local_token_ranges, + sequence_length=spec.sequence_length, device=device, planner_config=planner_config, ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - local_prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in local_prefix_family_indices - ) - suffix_only_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families + for depth in range(depth_count) ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - suffix_only_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), + tree_chain_buckets_by_depth = ( + tuple( + _build_tree_bucket_plans( + tuple(chain_segments_by_depth[depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges=local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + token_ranges_by_rank=tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank_by_source + ), + split_by_final_state=False, + ) + for depth in range(depth_count) ) + if cp_size > 1 + else tuple(() for _ in range(depth_count)) ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - local_completion_bucket_plans = ( - ready_completion_bucket_plans + remote_completion_bucket_plans - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, + tree_state_exchanges_by_depth = _build_tree_state_exchanges_by_depth( + spec, + owner_by_node=tuple(owner_by_node), + chained_nodes=tuple(chained_nodes), + cp_rank=cp_rank, + cp_size=cp_size, + depth_count=depth_count, device=device, - planner_config=planner_config, ) - return GdnRankExecutionPlan.model_construct( + if cp_size == 1: + valid_lengths = torch.tensor( + spec.valid_lengths, device=device, dtype=torch.long + ) + positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) + real_token_mask = positions.unsqueeze(0) < valid_lengths.unsqueeze(1) + else: + real_token_mask = torch.ones( + 1, + rank_loads[cp_rank], + device=device, + dtype=torch.bool, + ) + + return GdnRankExecutionPlan( cp_rank=cp_rank, cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, + batch_size=1 if cp_size > 1 else spec.batch_size, + sequence_length=rank_loads[cp_rank] if cp_size > 1 else spec.sequence_length, packed_batch_size=spec.batch_size, packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=local_completion_bucket_plans, - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - tuple(segment.family_index for segment in prefix_segments) - == tuple(range(spec.family_count)) - ), - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=local_token_ranges, - gdn_token_ranges=local_token_ranges, - attention_token_count=local_token_count, - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - family_index - for family_index in parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, + real_token_mask=real_token_mask, + attention_to_gdn=attention_to_gdn, + gdn_to_attention=_reverse_exchange_plan(attention_to_gdn), + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=gdn_ranges_by_rank_by_position[cp_rank], + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=rank_loads[cp_rank], + tree_segment_buckets_by_depth=tree_segment_buckets_by_depth, + tree_chain_buckets_by_depth=tree_chain_buckets_by_depth, + tree_state_exchanges_by_depth=tree_state_exchanges_by_depth, ) -def _rebalance_local_completion_segments( - spec: GdnPackedExecutionSpec, - *, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - initial_loads: tuple[int, ...], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], ...]: - owners = [list(family_owners) for family_owners in completion_owners_by_family] - loads = list(initial_loads) - remote_owners_by_family = [ - { - owner - for owner in family_owners - if owner != prefix_owner_by_family[family_index] - } - for family_index, family_owners in enumerate(owners) - ] - transfer_count = sum( - len(remote_owners) for remote_owners in remote_owners_by_family - ) +def move_gdn_rank_execution_plan_to_device( + plan: GdnRankExecutionPlan, + device: torch.device | str, +) -> GdnRankExecutionPlan: + """Move planner tensors to the execution device after CPU planning.""" - def score(candidate_loads: list[int], candidate_transfer_count: int) -> float: - max_load = max(candidate_loads, default=0) - idle_tokens = sum(max_load - load for load in candidate_loads) - return ( - max_load - + planner_config.rank_idle_token_cost * idle_tokens - + planner_config.parent_state_exchange_penalty_tokens - * candidate_transfer_count - ) + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - best_score = score(loads, transfer_count) - while True: - best_move: ( - tuple[int, int, int, tuple[int, ...], list[int], int, float] | None - ) = None - for family in spec.families: - family_owners = owners[family.family_index] - prefix_owner = prefix_owner_by_family[family.family_index] - original_remote_owners = remote_owners_by_family[family.family_index] - for source in sorted(set(family_owners)): - source_children = [ - child_index - for child_index, owner in enumerate(family_owners) - if owner == source - ] - ordered_children = sorted( - source_children, - key=lambda child_index: family.completions[child_index].length, - reverse=True, - ) - for dest in range(len(loads)): - if dest == source: - continue - moved_tokens = 0 - moved_children = [] - for child_index in ordered_children: - moved_tokens += family.completions[child_index].length - moved_children.append(child_index) - candidate_loads = list(loads) - candidate_loads[source] -= moved_tokens - candidate_loads[dest] += moved_tokens - candidate_remote_owners = set(original_remote_owners) - if source != prefix_owner and len(moved_children) == len( - source_children - ): - candidate_remote_owners.discard(source) - if dest != prefix_owner: - candidate_remote_owners.add(dest) - candidate_transfer_count = ( - transfer_count - - len(original_remote_owners) - + len(candidate_remote_owners) - ) - candidate_score = score( - candidate_loads, candidate_transfer_count - ) - if candidate_score >= best_score: - continue - if best_move is None or candidate_score < best_move[-1]: - best_move = ( - family.family_index, - source, - dest, - tuple(moved_children), - candidate_loads, - candidate_transfer_count, - candidate_score, - ) - if best_move is None: - return tuple(tuple(item) for item in owners) - ( - family_index, - _source, - dest, - moved_children, - loads, - transfer_count, - best_score, - ) = best_move - for child_index in moved_children: - owners[family_index][child_index] = dest - prefix_owner = prefix_owner_by_family[family_index] - remote_owners_by_family[family_index] = { - owner for owner in set(owners[family_index]) if owner != prefix_owner - } - - -def _materialize_local_family_rank_assignments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], -) -> tuple[ - tuple[ - int, - tuple[tuple[int, int, int], ...], - tuple[GdnSegmentSpec, ...], - tuple[GdnSegmentSpec, ...], - ], - ..., -]: - token_ranges_by_rank: list[list[tuple[int, int, int]]] = [ - [] for _ in range(cp_size) - ] - token_counts_by_rank = [0] * cp_size - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - sequence_length = spec.sequence_length - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - prefix_segments_by_rank[prefix_owner].append(family.prefix) - prefix_token_start = ( - family.prefix.row_index * sequence_length + family.prefix.start - ) - prefix_position_start = token_counts_by_rank[prefix_owner] - token_ranges_by_rank[prefix_owner].append( - ( - prefix_token_start, - prefix_token_start + family.prefix.length, - prefix_position_start, - ) - ) - token_counts_by_rank[prefix_owner] = ( - prefix_position_start + family.prefix.length - ) - for completion, completion_owner in zip( - family.completions, - completion_owners_by_family[family.family_index], - strict=True, - ): - completion_segments_by_rank[completion_owner].append(completion) - completion_token_start = ( - completion.row_index * sequence_length + completion.start - ) - completion_position_start = token_counts_by_rank[completion_owner] - token_ranges_by_rank[completion_owner].append( - ( - completion_token_start, - completion_token_start + completion.length, - completion_position_start, - ) - ) - token_counts_by_rank[completion_owner] = ( - completion_position_start + completion.length - ) - return tuple( - ( - token_counts_by_rank[rank], - tuple(token_ranges_by_rank[rank]), - tuple(prefix_segments_by_rank[rank]), - tuple(completion_segments_by_rank[rank]), - ) - for rank in range(cp_size) + return replace( + plan, + real_token_mask=_move_planner_tensor(plan.real_token_mask, device), + attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), + gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), + tree_segment_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_segment_buckets_by_depth + ), + tree_chain_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_chain_buckets_by_depth + ), + tree_state_exchanges_by_depth=tuple( + _move_state_exchange_plan(exchange, device) + for exchange in plan.tree_state_exchanges_by_depth + ), ) -def _empty_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, +def _move_state_exchange_plan( + exchange: GdnStateExchangePlan | None, device: torch.device | str, - cp_rank: int, - cp_size: int, -) -> GdnRankExecutionPlan: - from art.megatron.gdn.layout import GdnCpExchangePlan +) -> GdnStateExchangePlan | None: + if exchange is None: + return None + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - dest_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - transfers=(), - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=0, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones(1, 0, device=device, dtype=torch.bool), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=(), - gdn_token_ranges=(), - attention_token_count=0, - gdn_token_count=0, - parent_state_exchange_family_indices=(), - parent_state_transfers=(), + return replace( + exchange, + exchange=move_cp_exchange_plan_to_device(exchange.exchange, device), + reverse_exchange=move_cp_exchange_plan_to_device( + exchange.reverse_exchange, device + ), ) -def _can_chain_segment( - segment: GdnSegmentSpec, - *, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> bool: - min_tokens = ( - planner_config.cp_chain_min_prefix_only_tokens - if segment.kind == "prefix" - else planner_config.cp_chain_min_total_tokens - ) - if segment.length < min_tokens: - return False - if segment.length < cp_size: - return False - if segment.length // FLA_CHUNK_SIZE < cp_size: - return False - per_rank = segment.length / cp_size - if per_rank < planner_config.cp_chain_min_tokens_per_rank: - return False - return True - - -def _build_parent_state_transfer_plans( - families_by_peer: dict[tuple[int, int], set[int]], -) -> tuple[GdnParentStateTransferPlan, ...]: +def _move_bucket_plans( + buckets: tuple[GdnSegmentBucketPlan, ...], + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: return tuple( - GdnParentStateTransferPlan( - source_rank=source_rank, - dest_rank=dest_rank, - family_indices=tuple(sorted(family_indices)), + replace( + bucket, + lengths=_move_planner_tensor(bucket.lengths, device), + real_mask=_move_planner_tensor(bucket.real_mask, device), + cu_seqlens=_move_planner_tensor(bucket.cu_seqlens, device), + row_indices=_move_planner_tensor(bucket.row_indices, device), + position_indices=_move_planner_tensor(bucket.position_indices, device), + family_indices=_move_planner_tensor(bucket.family_indices, device), + parent_indices=( + _move_planner_tensor(bucket.parent_indices, device) + if bucket.parent_indices is not None + else None + ), + output_mask=( + _move_planner_tensor(bucket.output_mask, device) + if bucket.output_mask is not None + else None + ), ) - for (source_rank, dest_rank), family_indices in sorted(families_by_peer.items()) - if source_rank != dest_rank and family_indices + for bucket in buckets ) -def _split_ready_and_remote_completion_segments( - completion_segments: tuple[GdnSegmentSpec, ...], - *, - local_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], -) -> tuple[tuple[GdnSegmentSpec, ...], tuple[GdnSegmentSpec, ...]]: - ready_family_indices = { - segment.family_index for segment in local_prefix_segments - } | {segment.family_index for bucket in chain_prefix_buckets for segment in bucket} - ready = [] - remote = [] - for segment in completion_segments: - if segment.family_index in ready_family_indices: - ready.append(segment) - else: - remote.append(segment) - return tuple(ready), tuple(remote) +def parse_gdn_shared_prefix_segments( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> GdnPackedExecutionSpec: + """Parse ART packed shared-prefix metadata into generic GDN tree nodes.""" + groups = _rank2_long_cpu("group_ids", group_ids) + parents = _rank2_long_cpu("parent_ids", parent_ids) + if tuple(groups.shape) != tuple(parents.shape): + raise ValueError( + "group_ids and parent_ids must have the same shape, got " + f"{tuple(groups.shape)} and {tuple(parents.shape)}" + ) -def _transfer_plans_to_device( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - transfer.model_copy( - update={ - "family_indices_tensor": _move_planner_tensor( - torch.tensor(transfer.family_indices, dtype=torch.long), - device, + batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) + rows = parse_shared_prefix_tree(group_ids=groups, parent_ids=parents) + tree_segments: list[GdnSegmentSpec] = [] + tree_parent_indices: list[int] = [] + tree_depths: list[int] = [] + valid_lengths: list[int] = [] + node_by_row_group: dict[tuple[int, int], int] = {} + child_counts_by_parent: dict[int, int] = {} + + for row in rows: + valid_lengths.append(row.valid_tokens) + for segment in row.segments: + node_index = len(tree_segments) + is_root = segment.depth == 0 + parent_node_index = ( + -1 if is_root else node_by_row_group[(row.row_index, segment.parent_id)] + ) + child_index = None + if not is_root: + child_index = child_counts_by_parent.get(parent_node_index, 0) + child_counts_by_parent[parent_node_index] = child_index + 1 + tree_segments.append( + GdnSegmentSpec( + row_index=row.row_index, + family_index=node_index, + group_id=segment.group_id, + parent_id=segment.parent_id, + start=segment.start, + end=segment.end, + kind="prefix" if is_root else "completion", + child_index=child_index, ) - } - ) - for transfer in transfers + ) + tree_parent_indices.append(parent_node_index) + tree_depths.append(segment.depth) + node_by_row_group[(row.row_index, segment.group_id)] = node_index + + return GdnPackedExecutionSpec( + batch_size=batch_size, + sequence_length=sequence_length, + valid_lengths=tuple(valid_lengths), + tree_segments=tuple(tree_segments), + tree_parent_indices=tuple(tree_parent_indices), + tree_depths=tuple(tree_depths), ) -def _has_chainable_segment( - family: GdnPackedFamilySpec, +def _attention_source_layout( + spec: GdnPackedExecutionSpec, *, cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, planner_config: GdnPlannerConfig, -) -> bool: - return _can_chain_prefix_segment( - family.prefix, cp_size=cp_size, planner_config=planner_config - ) or any( - _can_chain_segment(completion, cp_size=cp_size, planner_config=planner_config) - for completion in family.completions +) -> TokenLayoutIndex: + if attention_token_layout_index is not None: + layout_cp_size = len(attention_token_layout_index.token_counts_by_rank) + layout_token_count = sum( + int(count) for count in attention_token_layout_index.token_counts_by_rank + ) + if layout_cp_size != cp_size: + raise ValueError( + "attention token layout index cp_size must match GDN cp_size, got " + f"{layout_cp_size} and {cp_size}" + ) + if layout_token_count != spec.real_token_count: + raise ValueError( + "attention token layout index token count must match GDN real token " + f"count, got {layout_token_count} and {spec.real_token_count}" + ) + return attention_token_layout_index + ranges_by_rank = _default_attention_layout_ranges( + spec, + cp_size=cp_size, + planner_config=planner_config, + ) + return TokenLayoutIndex( + ownership_ranges_by_rank=ranges_by_rank, + token_counts_by_rank=tuple( + sum(int(end) - int(start) for start, end, _ in ranges) + for ranges in ranges_by_rank + ), ) -def _can_chain_prefix_segment( +def _can_chain_tree_segment( segment: GdnSegmentSpec, *, cp_size: int, planner_config: GdnPlannerConfig, ) -> bool: - return _can_chain_segment(segment, cp_size=cp_size, planner_config=planner_config) - - -def _score_cp_segment_stats( - *, - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - rank_real_tokens: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - local_bucket_count: int, - local_segment_count: int, - chain_bucket_count: int, - planner_config: GdnPlannerConfig, -) -> float: - empty_rank_count = sum(1 for token_count in rank_real_tokens if token_count == 0) - return ( - _rank_kernel_ms( - rank_local_work, - rank_chain_work, - local_token_ms=planner_config.planner_local_token_ms, - chain_token_ms=planner_config.planner_chain_token_ms, + min_total_tokens = ( + min( + planner_config.cp_tree_chain_min_prefix_only_tokens, + planner_config.cp_chain_min_prefix_only_tokens, ) - + planner_config.planner_local_bucket_ms * local_bucket_count - + planner_config.planner_chain_bucket_ms * chain_bucket_count - + planner_config.planner_local_segment_ms * local_segment_count - + planner_config.planner_layout_cross_rank_token_ms * cross_rank_token_count - + ( - planner_config.planner_parent_state_exchange_base_ms - + planner_config.planner_parent_state_exchange_ms - * parent_state_exchange_family_count - if parent_state_exchange_family_count - else 0.0 + if segment.kind == "prefix" + else min( + planner_config.cp_tree_chain_min_total_tokens, + planner_config.cp_chain_min_total_tokens, ) - + planner_config.planner_empty_rank_ms * empty_rank_count ) - - -def _rank_kernel_ms( - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - *, - local_token_ms: float, - chain_token_ms: float, -) -> float: - return max( - ( - local_work * local_token_ms + chain_work * chain_token_ms - for local_work, chain_work in zip( - rank_local_work, rank_chain_work, strict=True - ) - ), - default=0.0, + return ( + segment.length >= min_total_tokens + and segment.length >= cp_size + and segment.length // FLA_CHUNK_SIZE >= cp_size + and segment.length / cp_size >= planner_config.cp_chain_min_tokens_per_rank ) @@ -3336,11 +636,16 @@ def _best_segment_owner( for rank in range(rank_count): counts_by_rank[rank] += segment_counts[rank] on_rank_tokens = tuple(counts_by_rank) - best: tuple[float, int, int, int, int] | None = None + best: tuple[float, float, int, int, int, int] | None = None for rank, tokens in enumerate(on_rank_tokens): projected_loads = list(rank_loads) projected_loads[rank] += segment_length max_load = max(projected_loads, default=0) + target_load = sum(projected_loads) / max(1, len(projected_loads)) + overload = max( + 0.0, + max_load - planner_config.max_zero_exchange_load_imbalance * target_load, + ) idle_tokens = sum(max_load - load for load in projected_loads) cross_rank_tokens = segment_length - int(tokens) empty_rank_count = sum(1 for load in projected_loads if load == 0) @@ -3353,6 +658,7 @@ def _best_segment_owner( + empty_rank_count * planner_config.planner_empty_rank_ms ) candidate = ( + overload, score, max_load, cross_rank_tokens, @@ -3366,23 +672,130 @@ def _best_segment_owner( return best[-1] +def _build_tree_state_exchanges_by_depth( + spec: GdnPackedExecutionSpec, + *, + owner_by_node: tuple[int, ...], + chained_nodes: tuple[bool, ...], + cp_rank: int, + cp_size: int, + depth_count: int, + device: torch.device | str, +) -> tuple[GdnStateExchangePlan | None, ...]: + if cp_size <= 1: + return tuple(None for _ in range(depth_count)) + + from art.megatron.gdn.layout import ( + GdnCpExchangePlan, + _make_peer_transfer, + _reverse_exchange_plan, + ) + + families_by_depth_pair: list[dict[tuple[int, int], set[int]]] = [ + {} for _ in range(depth_count) + ] + for child_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or chained_nodes[parent_index]: + continue + source_rank = owner_by_node[parent_index] + dest_rank = owner_by_node[child_index] + if source_rank < 0 or dest_rank < 0: + raise ValueError("tree state exchange requires every node to have an owner") + if source_rank == dest_rank: + continue + depth = spec.tree_depths[child_index] + families_by_depth_pair[depth].setdefault((source_rank, dest_rank), set()).add( + parent_index + ) + + state_exchanges: list[GdnStateExchangePlan | None] = [] + for pair_families in families_by_depth_pair: + if not pair_families: + state_exchanges.append(None) + continue + source_families_by_rank = [set[int]() for _ in range(cp_size)] + dest_families_by_rank = [set[int]() for _ in range(cp_size)] + for (source_rank, dest_rank), parent_indices in pair_families.items(): + source_families_by_rank[source_rank].update(parent_indices) + dest_families_by_rank[dest_rank].update(parent_indices) + source_families = tuple( + tuple(sorted(families)) for families in source_families_by_rank + ) + dest_families = tuple( + tuple(sorted(families)) for families in dest_families_by_rank + ) + source_positions = ( + {family: index for index, family in enumerate(families)} + for families in source_families + ) + dest_positions = ( + {family: index for index, family in enumerate(families)} + for families in dest_families + ) + source_position_by_rank = tuple(source_positions) + dest_position_by_rank = tuple(dest_positions) + transfers = [] + transfer_count = 0 + for (source_rank, dest_rank), parent_indices in sorted(pair_families.items()): + ordered = tuple(sorted(parent_indices)) + transfer_count += len(ordered) + transfers.append( + _make_peer_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=torch.tensor( + [ + source_position_by_rank[source_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + dest_positions=torch.tensor( + [ + dest_position_by_rank[dest_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + source_count=len(source_families[source_rank]), + dest_count=len(dest_families[dest_rank]), + device=device, + ) + ) + exchange = GdnCpExchangePlan( + cp_size=cp_size, + source_token_counts_by_rank=tuple( + len(families) for families in source_families + ), + dest_token_counts_by_rank=tuple( + len(families) for families in dest_families + ), + transfers=tuple(transfers), + cross_rank_token_count_override=transfer_count, + ) + state_exchanges.append( + GdnStateExchangePlan( + source_family_indices=source_families[cp_rank], + dest_family_indices=dest_families[cp_rank], + exchange=exchange, + reverse_exchange=_reverse_exchange_plan(exchange), + ) + ) + return tuple(state_exchanges) + + def _build_attention_layout_index_from_token_layout( layout: TokenLayoutIndex, - *, - max_ranges: int, ) -> _AttentionLayoutIndex: - del max_ranges ranges_by_rank = tuple( tuple(sorted((int(start), int(end)) for start, end, _ in rank_ranges)) for rank_ranges in layout.ownership_ranges_by_rank ) - range_count = sum(len(ranges) for ranges in ranges_by_rank) - return _AttentionLayoutIndex.model_construct( + return _AttentionLayoutIndex( token_ranges_by_rank=ranges_by_rank, token_range_ends_by_rank=tuple( tuple(end for _, end in ranges) for ranges in ranges_by_rank ), - range_count=range_count, ) @@ -3393,7 +806,7 @@ def _segment_attention_rank_counts( attention_layout_index: _AttentionLayoutIndex, ) -> dict[tuple[int, int, int], tuple[int, ...]]: del cp_size - segments = tuple(spec.segments()) + segments = spec.tree_segments if not segments: return {} starts = torch.tensor( @@ -3471,74 +884,25 @@ def should_split_segment(segment: GdnSegmentSpec) -> bool: if segment.length <= planner_config.max_zero_exchange_load_imbalance * ( target_rank_load ): - return False - if segment.kind == "prefix": - return _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - return _can_chain_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - - for family in spec.families: - has_split_segment = any( - should_split_segment(segment) - for segment in (family.prefix, *family.completions) - ) - if not has_split_segment: - if _should_co_locate_non_chain_family( - family, - total_real_tokens=spec.real_token_count, - cp_size=cp_size, - planner_config=planner_config, - ): - owner = _least_loaded_rank(loads) - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - append_segment(owner, token_start, segment.length) - continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) - continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - if should_split_segment(segment): - _append_split_default_attention_segment( - ranks, loads, token_start, segment.length - ) - continue - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) - return tuple(tuple(ranges) for ranges in ranks) - - -def _should_co_locate_non_chain_family( - family: GdnPackedFamilySpec, - *, - total_real_tokens: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> bool: - target_rank_load = total_real_tokens / cp_size - return family.token_count <= ( - planner_config.max_zero_exchange_load_imbalance * target_rank_load - ) - + return False + return _can_chain_tree_segment( + segment, cp_size=cp_size, planner_config=planner_config + ) -def _append_split_default_attention_segment( - ranks: list[list[tuple[int, int, int]]], - loads: list[int], - token_start: int, - token_count: int, -) -> None: - cp_size = len(ranks) - for rank in range(cp_size): - start = (token_count * rank) // cp_size - end = (token_count * (rank + 1)) // cp_size - ranks[rank].append((token_start + start, token_start + end, loads[rank])) - loads[rank] += end - start + for segment in spec.tree_segments: + token_start = _segment_token_start(segment, spec.sequence_length) + if should_split_segment(segment): + for rank in range(cp_size): + start = (segment.length * rank) // cp_size + end = (segment.length * (rank + 1)) // cp_size + ranks[rank].append( + (token_start + start, token_start + end, loads[rank]) + ) + loads[rank] += end - start + continue + owner = _least_loaded_rank(loads) + append_segment(owner, token_start, segment.length) + return tuple(tuple(ranges) for ranges in ranks) def _append_chain_segment( @@ -3581,36 +945,16 @@ def _append_chain_segment( ) rank_loads[rank] += shard_length if attention_layout_index is not None: - cross_rank_tokens += shard_length - _attention_overlap_count( - attention_layout_index, - rank, + cross_rank_tokens += shard_length - _range_overlap_count( shard_start, shard_start + shard_length, + attention_layout_index.token_ranges_by_rank[rank], + attention_layout_index.token_range_ends_by_rank[rank], ) start = end return cross_rank_tokens -def _chain_rank_token_indices( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_rank: int, - cp_size: int, -) -> range: - token_start = _segment_token_start(segment, spec.sequence_length) - lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - start = sum(lengths[:cp_rank]) - end = start + lengths[cp_rank] - if start >= end: - raise ValueError( - "CP chain planning requires non-empty shards; " - f"segment={segment.kind}:{segment.family_index} " - f"length={segment.length} cp_size={cp_size}" - ) - return range(token_start + start, token_start + end) - - def _fla_aligned_chain_shard_lengths(length: int, *, cp_size: int) -> tuple[int, ...]: full_chunks = int(length) // FLA_CHUNK_SIZE if full_chunks < int(cp_size): @@ -3641,15 +985,14 @@ def _attention_contiguous_chain_shards( shards: list[range] = [] cursor = token_start for rank in range(cp_size): - overlap = _attention_single_contiguous_overlap( - attention_layout_index, - rank, + overlaps = _range_overlaps( token_start, segment_end, + attention_layout_index.token_ranges_by_rank[rank], ) - if overlap is None: + if len(overlaps) != 1: return None - start, end = overlap + start, end = overlaps[0] if start != cursor or end <= start: return None shards.append(range(start, end)) @@ -3661,18 +1004,6 @@ def _attention_contiguous_chain_shards( return tuple(shards) -def _attention_single_contiguous_overlap( - index: _AttentionLayoutIndex, - rank: int, - start: int, - end: int, -) -> tuple[int, int] | None: - overlaps = _range_overlaps(start, end, index.token_ranges_by_rank[rank]) - if len(overlaps) != 1: - return None - return overlaps[0] - - def _append_local_segment( gdn_ranges_by_rank: list[list[tuple[int, int, int]]], rank_loads: list[int], @@ -3695,162 +1026,154 @@ def _least_loaded_rank(rank_loads: list[int]) -> int: return min(range(len(rank_loads)), key=lambda rank: (rank_loads[rank], rank)) -def _owner_rank( - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]], - prefix: GdnSegmentSpec, -) -> int: - for rank, segments in enumerate(local_prefix_segments_by_rank): - if prefix in segments: - return rank - raise RuntimeError("local prefix owner was not recorded") - - -def _build_position_bucket_plans( - segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], - local_token_ranges: tuple[tuple[int, int, int], ...], +def _build_tree_bucket_plans( + segments: tuple[GdnSegmentSpec, ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], *, + local_token_ranges: tuple[tuple[int, int, int], ...] | None, sequence_length: int, device: torch.device | str, + planner_config: GdnPlannerConfig, token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, + split_by_final_state: bool = True, ) -> tuple[GdnSegmentBucketPlan, ...]: + segment_buckets = ( + _batch_tree_segments_by_padded_work( + segments, + tree_has_children, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + if split_by_final_state + else _batch_segments_by_padded_work( + segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ) return tuple( - _build_position_bucket_plan( + _bucket_with_tree_parent_indices( + ( + _build_segment_bucket_plan(bucket, device=device) + if local_token_ranges is None + else _build_position_bucket_plan( + bucket, + local_token_ranges, + sequence_length=sequence_length, + device=device, + token_ranges_by_rank=token_ranges_by_rank, + ) + ), bucket, - local_token_ranges, - sequence_length=sequence_length, + tree_parent_indices, + tree_has_children, device=device, - token_ranges_by_rank=token_ranges_by_rank, ) for bucket in segment_buckets ) -def _build_position_bucket_plan( +def _bucket_with_tree_parent_indices( + plan: GdnSegmentBucketPlan, segments: tuple[GdnSegmentSpec, ...], - local_token_ranges: tuple[tuple[int, int, int], ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], *, - sequence_length: int, device: torch.device | str, - token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, ) -> GdnSegmentBucketPlan: - exact_plan = _build_exact_range_position_bucket_plan( - segments, - local_token_ranges, - sequence_length=sequence_length, - device=device, - token_ranges_by_rank=token_ranges_by_rank, - ) - if exact_plan is not None: - return exact_plan - local_positions_by_segment = [] - lengths = [] - local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) - for segment in segments: - positions = _local_positions_for_segment( - segment, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - ) - length = int(positions.numel()) - if not length: - raise ValueError( - "planned GDN bucket contains a segment with no local tokens; " - f"family={segment.family_index} kind={segment.kind}" - ) - local_positions_by_segment.append(positions) - lengths.append(length) - max_length = max(lengths) - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - position_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - for column, positions in enumerate(local_positions_by_segment): - position_indices_cpu[: int(positions.numel()), column] = positions - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - lengths_by_rank_cpu = _bucket_lengths_by_rank_cpu( - segments, - token_ranges_by_rank, - sequence_length=sequence_length, - ) - row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - family_indices_cpu = torch.tensor( - [segment.family_index for segment in segments], + parent_indices = torch.tensor( + [tree_parent_indices[segment.family_index] for segment in segments], dtype=torch.long, ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), - lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=lengths_by_rank_cpu, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(lengths), + return replace( + plan, + parent_indices=_move_planner_tensor(parent_indices, device), + parent_indices_cpu=parent_indices, + needs_final_state=any( + tree_has_children[segment.family_index] for segment in segments + ), ) -def _build_exact_range_position_bucket_plan( +def _build_position_bucket_plan( segments: tuple[GdnSegmentSpec, ...], local_token_ranges: tuple[tuple[int, int, int], ...], *, sequence_length: int, device: torch.device | str, token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, -) -> GdnSegmentBucketPlan | None: +) -> GdnSegmentBucketPlan: range_positions = { (start, end): position for start, end, position in local_token_ranges } - starts = [] - lengths = [] + starts: list[int] = [] + lengths: list[int] = [] for segment in segments: token_start = _segment_token_start(segment, sequence_length) token_end = token_start + segment.length position_start = range_positions.get((token_start, token_end)) if position_start is None: - return None + break starts.append(position_start) lengths.append(segment.length) - max_length = max(lengths) - starts_cpu = torch.tensor(starts, dtype=torch.long) - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - position_indices_cpu = torch.where( - real_mask_cpu, - starts_cpu.unsqueeze(0) + offsets_cpu, - torch.zeros_like(offsets_cpu), - ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - lengths_by_rank_cpu = _bucket_lengths_by_rank_cpu( - segments, - token_ranges_by_rank, - sequence_length=sequence_length, - ) - row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - family_indices_cpu = torch.tensor( - [segment.family_index for segment in segments], + else: + starts_cpu = torch.tensor(starts, dtype=torch.long) + lengths_cpu = torch.tensor(lengths, dtype=torch.long) + offsets_cpu = torch.arange(max(lengths), dtype=torch.long).unsqueeze(1) + position_indices_cpu = torch.where( + offsets_cpu < lengths_cpu.unsqueeze(0), + starts_cpu.unsqueeze(0) + offsets_cpu, + torch.zeros_like(offsets_cpu), + ) + return _build_bucket_plan( + segments, + lengths_cpu=lengths_cpu, + row_indices_cpu=torch.zeros_like(position_indices_cpu), + position_indices_cpu=position_indices_cpu, + lengths_by_rank_cpu=_bucket_lengths_by_rank_cpu( + segments, + token_ranges_by_rank, + sequence_length=sequence_length, + ), + device=device, + ) + + local_positions_by_segment: list[torch.Tensor] = [] + local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) + for segment in segments: + positions = _local_positions_for_segment( + segment, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ) + if not int(positions.numel()): + raise ValueError( + "planned GDN bucket contains a segment with no local tokens; " + f"family={segment.family_index} kind={segment.kind}" + ) + local_positions_by_segment.append(positions) + + lengths_cpu = torch.tensor( + [int(positions.numel()) for positions in local_positions_by_segment], dtype=torch.long, ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), + max_length = int(lengths_cpu.max().item()) + position_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) + for column, positions in enumerate(local_positions_by_segment): + position_indices_cpu[: int(positions.numel()), column] = positions + return _build_bucket_plan( + segments, lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=lengths_by_rank_cpu, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(lengths), + row_indices_cpu=torch.zeros_like(position_indices_cpu), + position_indices_cpu=position_indices_cpu, + lengths_by_rank_cpu=_bucket_lengths_by_rank_cpu( + segments, + token_ranges_by_rank, + sequence_length=sequence_length, + ), + device=device, ) @@ -3927,41 +1250,86 @@ def _batch_segments_by_padded_work( return tuple(tuple(batch) for batch in batches) +def _batch_tree_segments_by_padded_work( + segments: tuple[GdnSegmentSpec, ...], + tree_has_children: tuple[bool, ...], + *, + max_padding_ratio: float = 1.25, + max_segments_per_batch: int = 128, +) -> tuple[tuple[GdnSegmentSpec, ...], ...]: + stateful = tuple( + segment for segment in segments if tree_has_children[segment.family_index] + ) + stateless = tuple( + segment for segment in segments if not tree_has_children[segment.family_index] + ) + return ( + *_batch_segments_by_padded_work( + stateful, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + *_batch_segments_by_padded_work( + stateless, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + ) + + def _build_segment_bucket_plan( - length: int, segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str + segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str ) -> GdnSegmentBucketPlan: - max_length = max(segment.length for segment in segments) lengths_cpu = torch.tensor( [segment.length for segment in segments], dtype=torch.long ) + max_length = int(lengths_cpu.max().item()) starts_cpu = torch.tensor([segment.start for segment in segments], dtype=torch.long) rows_cpu = torch.tensor( [segment.row_index for segment in segments], dtype=torch.long ) offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) + return _build_bucket_plan( + segments, + lengths_cpu=lengths_cpu, + row_indices_cpu=rows_cpu.unsqueeze(0).expand(max_length, -1).contiguous(), + position_indices_cpu=starts_cpu.unsqueeze(0) + offsets_cpu, + device=device, + ) + + +def _build_bucket_plan( + segments: tuple[GdnSegmentSpec, ...], + *, + lengths_cpu: torch.Tensor, + row_indices_cpu: torch.Tensor, + position_indices_cpu: torch.Tensor, + device: torch.device | str, + lengths_by_rank_cpu: torch.Tensor | None = None, +) -> GdnSegmentBucketPlan: + max_length = int(lengths_cpu.max().item()) + offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - positions_cpu = starts_cpu.unsqueeze(0) + offsets_cpu + cu_seqlens_cpu = torch.cat( + [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] + ) family_indices_cpu = torch.tensor( [segment.family_index for segment in segments], dtype=torch.long, ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - return GdnSegmentBucketPlan.model_construct( + return GdnSegmentBucketPlan( length=max_length, lengths=_move_planner_tensor(lengths_cpu, device), lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=None, + lengths_by_rank_cpu=lengths_by_rank_cpu, real_mask=_move_planner_tensor(real_mask_cpu, device), cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor( - rows_cpu.unsqueeze(0).expand(max_length, -1).contiguous(), device - ), - position_indices=_move_planner_tensor(positions_cpu, device), + row_indices=_move_planner_tensor(row_indices_cpu, device), + position_indices=_move_planner_tensor(position_indices_cpu, device), family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(segment.length for segment in segments), + family_indices_cpu=family_indices_cpu, + real_token_count_static=int(lengths_cpu.sum().item()), ) @@ -3969,20 +1337,6 @@ def _segment_token_start(segment: GdnSegmentSpec, sequence_length: int) -> int: return segment.row_index * sequence_length + segment.start -def _attention_overlap_count( - index: _AttentionLayoutIndex, - rank: int, - start: int, - end: int, -) -> int: - return _range_overlap_count( - start, - end, - index.token_ranges_by_rank[rank], - index.token_range_ends_by_rank[rank], - ) - - def _range_overlap_count( start: int, end: int, @@ -4012,27 +1366,6 @@ def _range_overlaps( return overlaps -def _local_token_ranges( - local_gdn_tokens: tuple[int, ...], -) -> tuple[tuple[int, int, int], ...]: - if not local_gdn_tokens: - return () - ranges = [] - token_start = local_gdn_tokens[0] - token_end = token_start + 1 - position_start = 0 - for position, token in enumerate(local_gdn_tokens[1:], start=1): - if token == token_end: - token_end += 1 - continue - ranges.append((token_start, token_end, position_start)) - token_start = token - token_end = token + 1 - position_start = position - ranges.append((token_start, token_end, position_start)) - return tuple(ranges) - - def _local_positions_for_segment( segment: GdnSegmentSpec, *, @@ -4079,285 +1412,3 @@ def _rank2_long_cpu(name: str, tensor: torch.Tensor) -> torch.Tensor: ): raise TypeError(f"{name} must contain integer ids, got dtype={tensor.dtype}") return tensor.detach().to(device="cpu", dtype=torch.long) - - -def _validate_padding_tensor( - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, -) -> int: - padding_positions = torch.nonzero(group_ids == -1, as_tuple=False) - valid_length = ( - int(padding_positions[0].item()) - if int(padding_positions.numel()) > 0 - else int(group_ids.numel()) - ) - if valid_length == 0: - if bool(torch.any(parent_ids != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if bool(torch.any(group_ids[valid_length:] != -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if bool(torch.any(parent_ids[:valid_length] == -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if bool(torch.any(parent_ids[valid_length:] != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _validate_padding( - row_index: int, - group_ids: list[int], - parent_ids: list[int], -) -> int: - valid_length = 0 - for group_id in group_ids: - if group_id == -1: - break - valid_length += 1 - if valid_length == 0: - if any(parent_id != -1 for parent_id in parent_ids): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if any(group_id != -1 for group_id in group_ids[valid_length:]): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if any(parent_id == -1 for parent_id in parent_ids[:valid_length]): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if any(parent_id != -1 for parent_id in parent_ids[valid_length:]): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _parse_row_tensor( - *, - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - valid_groups = group_ids[:valid_length] - valid_parents = parent_ids[:valid_length] - if valid_length > 1: - same_group = valid_groups[1:] == valid_groups[:-1] - parent_changed = same_group & (valid_parents[1:] != valid_parents[:-1]) - if bool(torch.any(parent_changed).item()): - position = int(torch.nonzero(parent_changed, as_tuple=False)[0].item()) + 1 - group_id = int(valid_groups[position].item()) - previous_parent = int(valid_parents[position - 1].item()) - current_parent = int(valid_parents[position].item()) - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{previous_parent} to {current_parent}" - ) - boundaries = torch.nonzero(~same_group, as_tuple=False).flatten() + 1 - starts_tensor = torch.cat( - (valid_groups.new_zeros(1), boundaries.to(valid_groups.dtype)) - ) - ends_tensor = torch.cat( - ( - boundaries.to(valid_groups.dtype), - valid_groups.new_tensor([valid_length]), - ) - ) - else: - starts_tensor = valid_groups.new_zeros(1) - ends_tensor = valid_groups.new_tensor([valid_length]) - - starts = tuple(int(value) for value in starts_tensor.tolist()) - ends = tuple(int(value) for value in ends_tensor.tolist()) - segment_group_ids = tuple(int(valid_groups[start].item()) for start in starts) - segment_parent_ids = tuple(int(valid_parents[start].item()) for start in starts) - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - segment_cursor = 0 - while segment_cursor < len(starts): - group_id = segment_group_ids[segment_cursor] - parent_id = segment_parent_ids[segment_cursor] - start = starts[segment_cursor] - end = ends[segment_cursor] - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - segment_cursor += 1 - completions: list[GdnSegmentSpec] = [] - while segment_cursor < len(starts): - child_group_id = segment_group_ids[segment_cursor] - child_parent_id = segment_parent_ids[segment_cursor] - child_start = starts[segment_cursor] - child_end = ends[segment_cursor] - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - segment_cursor += 1 - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - _trusted_pydantic_construct( - GdnPackedFamilySpec, - _GDN_PACKED_FAMILY_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _parse_row( - *, - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - cursor = 0 - while cursor < valid_length: - group_id, parent_id, start, end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - ) - cursor = end - completions: list[GdnSegmentSpec] = [] - while cursor < valid_length: - child_group_id, child_parent_id, child_start, child_end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - cursor = child_end - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - GdnPackedFamilySpec( - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _read_segment( - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - cursor: int, -) -> tuple[int, int, int, int]: - group_id = int(group_ids[cursor]) - parent_id = int(parent_ids[cursor]) - if group_id < 0 or parent_id < 0: - raise ValueError(f"row {row_index}: segment ids must be non-negative") - start = cursor - cursor += 1 - while cursor < valid_length and int(group_ids[cursor]) == group_id: - current_parent = int(parent_ids[cursor]) - if current_parent != parent_id: - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{parent_id} to {current_parent}" - ) - cursor += 1 - return group_id, parent_id, start, cursor diff --git a/src/art/megatron/gdn/layout.py b/src/art/megatron/gdn/layout.py index c3469a451..7119218f6 100644 --- a/src/art/megatron/gdn/layout.py +++ b/src/art/megatron/gdn/layout.py @@ -1,9 +1,9 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass, replace from typing import Any -from pydantic import BaseModel, ConfigDict, Field, model_validator import torch from torch import Tensor from torch.distributed import ( @@ -20,47 +20,47 @@ from art.megatron.context_parallel.layout_index import TokenLayoutIndex -class GdnCpPeerTransfer(BaseModel): +@dataclass(frozen=True) +class GdnCpPeerTransfer: """Token rows sent from one source rank to one destination rank.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - source_rank: int = Field(ge=0) - dest_rank: int = Field(ge=0) - token_count: int = Field(ge=0) + source_rank: int + dest_rank: int + token_count: int + source_positions_cpu: tuple[int, ...] | None = None + dest_positions_cpu: tuple[int, ...] | None = None source_positions_tensor: Tensor | None = None dest_positions_tensor: Tensor | None = None - @model_validator(mode="after") - def _same_lengths(self) -> "GdnCpPeerTransfer": + def __post_init__(self) -> None: lengths = {int(self.token_count)} + if self.source_positions_cpu is not None: + lengths.add(len(self.source_positions_cpu)) + if self.dest_positions_cpu is not None: + lengths.add(len(self.dest_positions_cpu)) if self.source_positions_tensor is not None: lengths.add(int(self.source_positions_tensor.numel())) if self.dest_positions_tensor is not None: lengths.add(int(self.dest_positions_tensor.numel())) if len(lengths) != 1: raise ValueError("token, source, and destination position counts differ") - return self -class GdnCpExchangePlan(BaseModel): +@dataclass(frozen=True) +class GdnCpExchangePlan: """Permutation/all-to-all metadata between two distributed token layouts.""" - model_config = ConfigDict(frozen=True) - - cp_size: int = Field(ge=1) + cp_size: int source_token_counts_by_rank: tuple[int, ...] dest_token_counts_by_rank: tuple[int, ...] transfers: tuple[GdnCpPeerTransfer, ...] - cross_rank_token_count_override: int | None = Field(default=None, ge=0) + cross_rank_token_count_override: int | None = None - @model_validator(mode="after") - def _rank_counts(self) -> "GdnCpExchangePlan": + def __post_init__(self) -> None: if len(self.source_token_counts_by_rank) != self.cp_size: raise ValueError("source token count length must equal cp_size") if len(self.dest_token_counts_by_rank) != self.cp_size: raise ValueError("destination token count length must equal cp_size") - return self @property def cross_rank_token_count(self) -> int: @@ -73,11 +73,10 @@ def cross_rank_token_count(self) -> int: ) -class GdnSpExchangePlan(BaseModel): +@dataclass(frozen=True) +class GdnSpExchangePlan: """Sequence-parallel view of an existing CP exchange plan.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - plan: GdnCpExchangePlan rank: int @@ -206,7 +205,7 @@ def build_local_rank_cp_exchange_plan_from_dest_ranges( device=device, ) ) - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=cp_size, source_token_counts_by_rank=source_layout.token_counts_by_rank, dest_token_counts_by_rank=dest_counts, @@ -238,23 +237,33 @@ def _make_peer_transfer( source_count=source_count, dest_count=dest_count, ): + source_cpu = None + dest_cpu = None source_tensor = None dest_tensor = None else: + source_cpu = _tensor_positions_tuple(source_positions) + dest_cpu = _tensor_positions_tuple(dest_positions) target = torch.device(device) if device is not None else torch.device("cpu") source_tensor = source_positions.to( device=target, dtype=torch.long ).contiguous() dest_tensor = dest_positions.to(device=target, dtype=torch.long).contiguous() - return GdnCpPeerTransfer.model_construct( + return GdnCpPeerTransfer( source_rank=source_rank, dest_rank=dest_rank, token_count=token_count, + source_positions_cpu=source_cpu, + dest_positions_cpu=dest_cpu, source_positions_tensor=source_tensor, dest_positions_tensor=dest_tensor, ) +def _tensor_positions_tuple(tensor: Tensor) -> tuple[int, ...]: + return tuple(int(value) for value in tensor.detach().cpu().tolist()) + + def _is_full_identity_transfer( *, source_rank: int, @@ -277,16 +286,18 @@ def _is_full_identity_transfer( def _reverse_exchange_plan(plan: GdnCpExchangePlan) -> GdnCpExchangePlan: - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=plan.cp_size, source_token_counts_by_rank=_dest_counts_by_rank(plan), dest_token_counts_by_rank=_source_counts_by_rank(plan), cross_rank_token_count_override=plan.cross_rank_token_count_override, transfers=tuple( - GdnCpPeerTransfer.model_construct( + GdnCpPeerTransfer( source_rank=transfer.dest_rank, dest_rank=transfer.source_rank, token_count=_transfer_token_count(transfer), + source_positions_cpu=transfer.dest_positions_cpu, + dest_positions_cpu=transfer.source_positions_cpu, source_positions_tensor=transfer.dest_positions_tensor, dest_positions_tensor=transfer.source_positions_tensor, ) @@ -485,15 +496,11 @@ def move_cp_exchange_plan_to_device( if plan is None: return None target = torch.device(device) - return GdnCpExchangePlan.model_construct( - cp_size=plan.cp_size, - source_token_counts_by_rank=_source_counts_by_rank(plan), - dest_token_counts_by_rank=_dest_counts_by_rank(plan), + return replace( + plan, transfers=tuple( - GdnCpPeerTransfer.model_construct( - source_rank=transfer.source_rank, - dest_rank=transfer.dest_rank, - token_count=transfer.token_count, + replace( + transfer, source_positions_tensor=_move_optional_index_tensor( transfer.source_positions_tensor, target ), @@ -503,7 +510,6 @@ def move_cp_exchange_plan_to_device( ) for transfer in plan.transfers ), - cross_rank_token_count_override=plan.cross_rank_token_count_override, ) @@ -532,7 +538,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( """ if tp_size <= 1: - return GdnSpExchangePlan.model_construct(plan=plan, rank=cp_rank) + return GdnSpExchangePlan(plan=plan, rank=cp_rank) _check_rank(plan, cp_rank) if tp_rank < 0 or tp_rank >= tp_size: raise ValueError(f"tp_rank must be in [0, {tp_size}), got {tp_rank}") @@ -603,7 +609,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( # A CP-local reorder can still move rows between TP ranks, and local CP plans do # not contain enough global TP information for every rank to independently # prove that no peer exchange is needed. - sp_plan = GdnCpExchangePlan.model_construct( + sp_plan = GdnCpExchangePlan( cp_size=world_size, source_token_counts_by_rank=source_counts, dest_token_counts_by_rank=dest_counts, @@ -612,7 +618,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( ), cross_rank_token_count_override=1, ) - return GdnSpExchangePlan.model_construct(plan=sp_plan, rank=composite_rank) + return GdnSpExchangePlan(plan=sp_plan, rank=composite_rank) def recv_split_sizes_for_rank(plan: GdnCpExchangePlan, rank: int) -> tuple[int, ...]: @@ -750,10 +756,15 @@ def _is_implicit_full_identity_transfer( ) -def _transfer_positions_tuple(tensor: Tensor | None) -> tuple[int, ...]: +def _transfer_positions_tuple( + positions: tuple[int, ...] | None, + tensor: Tensor | None, +) -> tuple[int, ...]: + if positions is not None: + return positions if tensor is None: return () - return tuple(int(value) for value in tensor.detach().cpu().tolist()) + return _tensor_positions_tuple(tensor) def _transfer_index_tensor( @@ -895,17 +906,6 @@ def _exchange_rank_tensor_local( ) -def _copy_rank_self_transfers( - local_tensor: Tensor, - plan: GdnCpExchangePlan, - *, - rank: int, -) -> Tensor: - return _init_rank_exchange_output( - local_tensor, plan, rank=rank, accumulate=False, zero_init=False - ) - - def _init_rank_exchange_output( local_tensor: Tensor, plan: GdnCpExchangePlan, @@ -1028,7 +1028,10 @@ def _transfer_dest_positions_for_duplicate_check( dest_count=_dest_count_for_rank(plan, transfer.dest_rank), ): return tuple(range(token_count)) - positions = _transfer_positions_tuple(transfer.dest_positions_tensor) + positions = _transfer_positions_tuple( + transfer.dest_positions_cpu, + transfer.dest_positions_tensor, + ) if len(positions) != token_count: raise ValueError("GDN CP transfer destination positions must match token_count") return positions diff --git a/src/art/megatron/gdn/operator.py b/src/art/megatron/gdn/operator.py index e8a122f5c..3dbed83d9 100644 --- a/src/art/megatron/gdn/operator.py +++ b/src/art/megatron/gdn/operator.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import MethodType -from typing import Any, Callable, Literal, NamedTuple, Sequence, cast +from typing import Any, Callable, Iterable, Literal, NamedTuple, Sequence, cast import torch from torch import Tensor @@ -12,9 +12,9 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnParentStateTransferPlan, GdnRankExecutionPlan, GdnSegmentBucketPlan, + GdnStateExchangePlan, build_gdn_rank_execution_plan, parse_gdn_shared_prefix_segments, ) @@ -301,6 +301,7 @@ def _empty_safe_norm_forward( return original_forward(input_, *args, **kwargs) +@torch.compiler.disable def _shared_prefix_forward( self: Any, hidden_states: Tensor, @@ -463,9 +464,7 @@ def run_gdn_layer( ) if execution_spec is None and execution_plan is None: - execution_spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + execution_spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) if ( execution_spec is not None and requested_cp_size == 1 @@ -510,136 +509,396 @@ def run_gdn_layer( ) if input_layout != "attention" or output_layout != "attention": raise ValueError("GDN layout controls require a CP execution plan") - return _run_planned_prefixes_and_completions(gdn, hidden_states, execution_plan) + return _run_tree_prefixes(gdn, hidden_states, execution_plan) -def _run_planned_prefixes_and_completions( +def _run_tree_prefixes( gdn: Any, hidden_states: Tensor, plan: GdnRankExecutionPlan, ) -> tuple[Tensor, Tensor | None]: - if _has_chunk_aligned_local_plan(plan): - return _run_chunk_aligned_prefixes_and_completions(gdn, hidden_states, plan) - raise ValueError( - "shared-prefix GDN requires a chunk-aligned execution plan; " - "prefix/completion bucket execution has been removed" + qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) + gate = gate.clone() + recurrent_output = torch.zeros_like(gate) + recurrent_output, _cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=hidden_states, ) + return _project_gdn_output(gdn, recurrent_output, gate, plan) -def _has_chunk_aligned_local_plan(plan: GdnRankExecutionPlan) -> bool: - return bool( - plan.prefix_boundary_buckets - or plan.prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets +def _run_tree_depth_buckets( + gdn: Any, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + plan: GdnRankExecutionPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, +) -> tuple[Tensor, Tensor | None]: + state_cache = _TreeStateChunkCache( + device=state_reference.device, ) + for depth, buckets in enumerate(plan.tree_segment_buckets_by_depth): + if depth < len(plan.tree_state_exchanges_by_depth): + cp_dependency = state_cache.exchange_remote_parent_states( + gdn, + plan.tree_state_exchanges_by_depth[depth], + state_reference=state_reference, + rank=plan.cp_rank, + group=group, + cp_dependency=cp_dependency, + ) + if depth < len(plan.tree_chain_buckets_by_depth): + for bucket in plan.tree_chain_buckets_by_depth[depth]: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + group=group, + cp_dependency=cp_dependency, + recurrent_cp=True, + scale_parent_state_gradient=1.0 / plan.cp_size, + ) + + for bucket in buckets: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + cp_dependency=cp_dependency, + ) + + return recurrent_output, cp_dependency -def _run_chunk_aligned_prefixes_and_completions( + +def _run_tree_bucket( gdn: Any, - hidden_states: Tensor, - plan: GdnRankExecutionPlan, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + state_cache: "_TreeStateChunkCache", + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, + recurrent_cp: bool = False, + scale_parent_state_gradient: float | None = None, ) -> tuple[Tensor, Tensor | None]: - qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) - gate = gate.clone() - recurrent_output = torch.zeros_like(gate) - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] - - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket + parent_conv, parent_rec = state_cache.parent_states( + gdn, + bucket, + state_reference=state_reference, + ) + if _bucket_has_parent_state(bucket): + parent_conv, parent_rec = _couple_parent_states(parent_conv, parent_rec) + if scale_parent_state_gradient is not None: + parent_conv = _scale_state_gradient( + parent_conv, + scale_parent_state_gradient, + ) + parent_rec = _scale_state_gradient(parent_rec, scale_parent_state_gradient) + segment_qkv, segment_beta, segment_g = _gather_bucket_streams( + qkv, + beta, + recurrent_g, + bucket, + ) + if cp_dependency is not None: + segment_qkv = _add_autograd_dependency(segment_qkv, cp_dependency) + segment_beta = _add_autograd_dependency(segment_beta, cp_dependency) + segment_g = _add_autograd_dependency(segment_g, cp_dependency) + parent_conv = _add_autograd_dependency(parent_conv, cp_dependency) + parent_rec = _add_autograd_dependency(parent_rec, cp_dependency) + segment_out, segment_conv, segment_rec = run_gdn_bucket( + bucket, + (segment_qkv, segment_beta, segment_g), + (parent_conv, parent_rec), + gdn=gdn, + group=group, + recurrent_cp=recurrent_cp, + output_final_state=bucket.needs_final_state or recurrent_cp, + ) + if bucket.needs_final_state and (segment_conv is None or segment_rec is None): + raise RuntimeError("tree GDN execution must return final states") + if ( + bucket.needs_final_state + and segment_conv is not None + and segment_rec is not None + ): + cp_dependency = _make_autograd_dependency( + segment_out, segment_conv, segment_rec ) - zero_conv = _zero_conv_state( - gdn, hidden_states, batch_size=bucket.segment_count + else: + cp_dependency = _make_autograd_dependency(segment_out) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, + bucket, + segment_out, + ) + if bucket.needs_final_state: + state_cache.append( + bucket, + cast(Tensor, segment_conv), + cast(Tensor, segment_rec), + ) + return recurrent_output, cp_dependency + + +class _TreeStateChunkCache: + def __init__(self, *, device: torch.device) -> None: + self._device = device + self._conv_chunks: list[Tensor] = [] + self._rec_chunks: list[Tensor] = [] + self._source_by_family: list[tuple[int, int] | None] = [] + + def append(self, bucket: GdnSegmentBucketPlan, conv: Tensor, rec: Tensor) -> None: + self.append_families(_bucket_family_indices_cpu(bucket), conv, rec) + + def append_families( + self, family_indices: Sequence[int], conv: Tensor, rec: Tensor + ) -> None: + if len(family_indices) == 0: + return + if int(conv.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache conv batch must match family count, got " + f"{tuple(conv.shape)} and {len(family_indices)} families" + ) + if int(rec.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache recurrent batch must match family count, got " + f"{tuple(rec.shape)} and {len(family_indices)} families" + ) + chunk_index = len(self._conv_chunks) + self._conv_chunks.append(conv) + self._rec_chunks.append(rec) + max_family = max(int(index) for index in family_indices) + if max_family >= len(self._source_by_family): + self._source_by_family.extend( + None for _ in range(max_family + 1 - len(self._source_by_family)) + ) + for source_row, family_index in enumerate(family_indices): + self._source_by_family[int(family_index)] = (chunk_index, source_row) + + def exchange_remote_parent_states( + self, + gdn: Any, + exchange: GdnStateExchangePlan | None, + *, + state_reference: Tensor, + rank: int, + group: Any | None, + cp_dependency: Tensor | None, + ) -> Tensor | None: + if exchange is None: + return cp_dependency + from .layout import exchange_rank_tensor_all_to_all + + source_conv, source_rec = self.states_for_families( + gdn, + exchange.source_family_indices, + state_reference=state_reference, + ) + if cp_dependency is not None: + source_conv = _add_autograd_dependency(source_conv, cp_dependency) + source_rec = _add_autograd_dependency(source_rec, cp_dependency) + remote_conv = exchange_rank_tensor_all_to_all( + source_conv, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - zero_rec = _zero_recurrent_state( - gdn, hidden_states, batch_size=bucket.segment_count + remote_rec = exchange_rank_tensor_all_to_all( + source_rec, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("prefix boundary GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, hidden_states, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state( - gdn, hidden_states, batch_size=plan.family_count - ), - ) - - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, + self.append_families(exchange.dest_family_indices, remote_conv, remote_rec) + dependency = _make_zero_autograd_dependency( + source_conv, source_rec, remote_conv, remote_rec ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("prefix tail GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out + return dependency if cp_dependency is None else dependency + cp_dependency + + def states_for_families( + self, + gdn: Any, + family_indices: Sequence[int], + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + if len(family_indices) == 0: + conv = _zero_conv_state(gdn, state_reference, batch_size=0) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=0) + return conv.requires_grad_(True), rec.requires_grad_(True) + return self._mixed_parent_states( + gdn, + tuple(int(index) for index in family_indices), + state_reference=state_reference, + batch_size=len(family_indices), + roots_allowed=False, ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) + def parent_states( + self, + gdn: Any, + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = _bucket_parent_indices_cpu(bucket) + batch_size = bucket.segment_count + if all(parent_index < 0 for parent_index in parent_indices_cpu): + return ( + _zero_conv_state(gdn, state_reference, batch_size=batch_size), + _zero_recurrent_state(gdn, state_reference, batch_size=batch_size), + ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, + return self._mixed_parent_states( + gdn, + parent_indices_cpu, + state_reference=state_reference, + batch_size=batch_size, ) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out + + def _mixed_parent_states( + self, + gdn: Any, + parent_indices_cpu: tuple[int, ...], + *, + state_reference: Tensor, + batch_size: int, + roots_allowed: bool = True, + ) -> tuple[Tensor, Tensor]: + sources_by_chunk: dict[int, list[tuple[int, int]]] = {} + missing_parents: list[int] = [] + for dest_row, parent_index in enumerate(parent_indices_cpu): + if parent_index < 0: + if roots_allowed: + continue + missing_parents.append(parent_index) + continue + source = ( + self._source_by_family[parent_index] + if parent_index < len(self._source_by_family) + else None + ) + if source is None: + missing_parents.append(parent_index) + continue + chunk_index, source_row = source + sources_by_chunk.setdefault(chunk_index, []).append((dest_row, source_row)) + if missing_parents: + raise RuntimeError( + "tree GDN append-only execution is missing parent state for " + f"families {tuple(missing_parents)}" + ) + + single_source_chunk = next(iter(sources_by_chunk.values())) + if len(sources_by_chunk) == 1 and len(single_source_chunk) == batch_size: + chunk_index, pairs = next(iter(sources_by_chunk.items())) + return ( + _select_state_rows(self._conv_chunks[chunk_index], pairs), + _select_state_rows(self._rec_chunks[chunk_index], pairs), + ) + + conv = _zero_conv_state(gdn, state_reference, batch_size=batch_size) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=batch_size) + for chunk_index, pairs in sources_by_chunk.items(): + dest_rows = _long_tensor( + (dest_row for dest_row, _ in pairs), + device=self._device, + ) + source_rows = _long_tensor( + (source_row for _, source_row in pairs), + device=self._device, + ) + conv = conv.index_copy( + 0, + dest_rows, + self._conv_chunks[chunk_index].index_select(0, source_rows), + ) + rec = rec.index_copy( + 0, + dest_rows, + self._rec_chunks[chunk_index].index_select(0, source_rows), + ) + return conv, rec + + +def _select_state_rows(chunk: Tensor, pairs: Sequence[tuple[int, int]]) -> Tensor: + source_rows = tuple(source_row for _, source_row in pairs) + if len(set(source_rows)) == 1: + return chunk.narrow(0, source_rows[0], 1).expand( + len(source_rows), + *tuple(chunk.shape[1:]), ) - return _project_gdn_output(gdn, recurrent_output, gate, plan) + first_row = source_rows[0] + if source_rows == tuple(range(first_row, first_row + len(source_rows))): + return chunk.narrow(0, first_row, len(source_rows)) + return chunk.index_select( + 0, + _long_tensor(source_rows, device=chunk.device), + ) + + +def _bucket_family_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + family_indices = bucket.family_indices_cpu + if family_indices is None: + family_indices = bucket.family_indices.detach().cpu() + return tuple(int(index) for index in family_indices.tolist()) + + +def _bucket_parent_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = bucket.parent_indices_cpu + if parent_indices_cpu is None: + parent_indices_cpu = parent_indices.detach().cpu() + return tuple(int(index) for index in parent_indices_cpu.tolist()) + + +def _long_tensor(values: Iterable[int], *, device: torch.device) -> Tensor: + return torch.tensor(tuple(values), dtype=torch.long, device=device) + + +def _bucket_has_parent_state(bucket: GdnSegmentBucketPlan) -> bool: + return any(parent_index >= 0 for parent_index in _bucket_parent_indices_cpu(bucket)) + + +def _bucket_has_uniform_lengths(bucket: GdnSegmentBucketPlan) -> bool: + lengths_cpu = bucket.lengths_cpu + if lengths_cpu is None: + lengths_cpu = bucket.lengths.detach().cpu() + return all(int(length) == int(bucket.length) for length in lengths_cpu.tolist()) def _run_cp_planned_prefixes_and_completions( @@ -679,385 +938,21 @@ def _run_cp_planned_prefixes_and_completions( if empty_gdn_rank else _empty_autograd_dependency(qkv) ) - qkv_with_remote_tail = qkv - beta_with_remote_tail = beta - recurrent_g_with_remote_tail = recurrent_g - if plan.remote_prefix_tail_exchange is not None: - remote_qkv, remote_beta, remote_g = _exchange_remote_prefix_tail_streams( - qkv, - beta, - recurrent_g, - plan=plan, - group=group, - ) - qkv_with_remote_tail = torch.cat([qkv, remote_qkv.unsqueeze(0)], dim=1) - beta_with_remote_tail = torch.cat([beta, remote_beta.unsqueeze(0)], dim=1) - recurrent_g_with_remote_tail = torch.cat( - [recurrent_g, remote_g.unsqueeze(0)], dim=1 - ) - cp_dependency = cp_dependency + _make_zero_autograd_dependency( - remote_qkv, remote_beta, remote_g - ) + if not plan.tree_segment_buckets_by_depth: + raise ValueError("CP shared-prefix GDN requires a tree execution plan") gate = gate.clone() recurrent_output = torch.zeros_like(gate) - prefix_family_chunks: list[Tensor] = [] - prefix_conv_chunks: list[Tensor] = [] - prefix_rec_chunks: list[Tensor] = [] - - for bucket in plan.chain_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("CP prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - cp_dependency = _make_autograd_dependency(prefix_out, prefix_conv, prefix_rec) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if ( - plan.prefix_tail_buckets - or plan.remote_prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets - or plan.remote_completion_with_prefix_tail_buckets - or plan.remote_prefix_tail_state_transfers - ): - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - remote_boundary_conv_table = boundary_conv_table - remote_boundary_rec_table = boundary_rec_table - if plan.remote_prefix_tail_state_transfers: - ( - remote_boundary_conv_table, - remote_boundary_rec_table, - remote_boundary_dependency, - ) = _exchange_parent_state_rows( - boundary_conv_table, - boundary_rec_table, - transfers=plan.remote_prefix_tail_state_transfers, - group=group, - ) - cp_dependency = cp_dependency + remote_boundary_dependency - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("local prefix tail GDN execution must return states") - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out - ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - for bucket in plan.remote_prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv_with_remote_tail, - beta_with_remote_tail, - recurrent_g_with_remote_tail, - bucket, - ) - tail_conv = remote_boundary_conv_table.index_select( - 0, bucket.family_indices - ) - tail_rec = remote_boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError( - "remote prefix tail GDN execution must return states" - ) - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - for bucket in plan.remote_completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, - beta, - recurrent_g, - bucket, - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - for bucket in plan.local_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if not prefix_conv_chunks and not plan.parent_state_exchange_family_indices: - projected, out_bias = _project_cp_gdn_output( - gdn, - recurrent_output, - gate, - plan, - group=group, - output_layout=output_layout, - ) - projected = _add_autograd_dependency(projected, cp_dependency) - return projected, out_bias - - prefix_conv_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - prefix_rec_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - parent_state_exchanged = False - if plan.chain_completion_buckets and plan.parent_state_exchange_family_indices: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - parent_state_exchanged = True - for bucket in plan.chain_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_conv = _scale_state_gradient(completion_conv, 1.0 / plan.cp_size) - completion_rec = _scale_state_gradient(completion_rec, 1.0 / plan.cp_size) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - cp_dependency = _make_autograd_dependency(completion_out) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - ready_completion_buckets = ( - plan.ready_local_completion_buckets - if plan.ready_local_completion_buckets or plan.remote_local_completion_buckets - else plan.local_completion_buckets + recurrent_output, cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=qkv, + group=group, + cp_dependency=cp_dependency, ) - for bucket in ready_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - if plan.parent_state_exchange_family_indices and not parent_state_exchanged: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - - for bucket in plan.remote_local_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - projected, out_bias = _project_cp_gdn_output( gdn, recurrent_output, @@ -1065,8 +960,8 @@ def _run_cp_planned_prefixes_and_completions( plan, group=group, output_layout=output_layout, + dependency=cp_dependency, ) - projected = _add_autograd_dependency(projected, cp_dependency) return projected, out_bias @@ -1659,12 +1554,6 @@ def _local_layout_token_count_for_hidden( return (real_count + _tp_world_size(projection) - 1) // _tp_world_size(projection) -def _attention_original_shape_from_plan( - hidden_states: Tensor, plan: GdnRankExecutionPlan -) -> tuple[int, int, int]: - return (int(plan.attention_token_count), 1, int(hidden_states.shape[-1])) - - def _restore_hidden_from_cp_flat( flat: Tensor, original_shape: tuple[int, int, int] ) -> Tensor: @@ -1922,6 +1811,7 @@ def _project_cp_gdn_output( *, group: Any, output_layout: Literal["attention", "gdn"], + dependency: Tensor | None = None, ) -> tuple[Tensor, Tensor | None]: batch_size, seq_len, _, _ = recurrent_output.shape token_uids = ( @@ -1933,6 +1823,8 @@ def _project_cp_gdn_output( norm_out = _apply_gated_rms_norm(gdn, recurrent_output, gate) norm_out = norm_out.reshape(batch_size, seq_len, _local_value_dim(gdn)) norm_out = norm_out.transpose(0, 1).contiguous() + if dependency is not None: + norm_out = _add_autograd_dependency(norm_out, dependency) if token_uids is not None: token_uids = _replicated_layout_token_uids(plan, "gdn", hidden_states=norm_out) _attach_trace_token_uids(norm_out, token_uids) @@ -2271,6 +2163,36 @@ def _local_value_dim(gdn: Any) -> int: return _local_value_heads(gdn) * int(gdn.value_head_dim) +def _prepare_dense_recurrent_inputs( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + *, + key_heads: int, + value_heads: int, + key_dim: int, + value_dim: int, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + key_channels = int(key_heads) * int(key_dim) + value_channels = int(value_heads) * int(value_dim) + query = qkv[..., :key_channels].reshape(*qkv.shape[:2], key_heads, key_dim) + key = qkv[..., key_channels : 2 * key_channels].reshape( + *qkv.shape[:2], + key_heads, + key_dim, + ) + value = qkv[..., 2 * key_channels : 2 * key_channels + value_channels].reshape( + *qkv.shape[:2], + value_heads, + value_dim, + ) + repeat = int(value_heads) // int(key_heads) + if repeat != 1: + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + return query, key, value, beta, recurrent_g + + def _scatter_bucket_recurrent_output( output: Tensor, bucket: GdnSegmentBucketPlan, bucket_output: Tensor ) -> Tensor: @@ -2289,269 +2211,6 @@ def _bucket_output_mask(bucket: GdnSegmentBucketPlan) -> Tensor: return bucket.real_mask if output_mask is None else output_mask -def _materialize_indexed_family_state_table( - *, - plan: GdnRankExecutionPlan, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - table = zero_state.detach() - if not state_chunks: - return table.requires_grad_(True) - values = torch.cat(state_chunks, dim=0) - family_indices = torch.cat(family_chunks, dim=0) - return table.index_copy(0, family_indices, values) - - -def _materialize_ordered_family_state_table( - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - if len(family_chunks) != len(state_chunks): - raise RuntimeError("family and state chunk counts must match") - table = zero_state.detach().requires_grad_(True) - for family_indices, states in zip(family_chunks, state_chunks, strict=True): - table = table.index_copy(0, family_indices, states) - return table - - -def _replace_indexed_family_states( - table: Tensor, - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], -) -> Tensor: - if not state_chunks: - return table - return table.index_copy( - 0, - torch.cat(family_chunks, dim=0), - torch.cat(state_chunks, dim=0), - ) - - -def _exchange_parent_state_rows( - conv_table: Tensor, - rec_table: Tensor, - *, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - if not transfers: - return conv_table, rec_table, _empty_autograd_dependency(conv_table) - conv_table, rec_table = _ParentStateExchange.apply( - conv_table, rec_table, transfers, group - ) - return conv_table, rec_table, _make_autograd_dependency(conv_table, rec_table) - - -def _exchange_remote_prefix_tail_streams( - qkv: Tensor, - beta: Tensor, - recurrent_g: Tensor, - *, - plan: GdnRankExecutionPlan, - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - from .layout import exchange_rank_tensor_all_to_all - - if plan.remote_prefix_tail_exchange is None: - return ( - qkv.new_empty((0, int(qkv.shape[-1]))), - beta.new_empty((0, int(beta.shape[-1]))), - recurrent_g.new_empty((0, int(recurrent_g.shape[-1]))), - ) - if plan.remote_prefix_tail_backward_exchange is None: - raise ValueError("remote prefix-tail exchange requires a backward plan") - qkv_flat = qkv.reshape(-1, int(qkv.shape[-1])) - beta_flat = beta.reshape(-1, int(beta.shape[-1])) - g_flat = recurrent_g.reshape(-1, int(recurrent_g.shape[-1])) - kwargs = { - "plan": plan.remote_prefix_tail_exchange, - "rank": plan.cp_rank, - "group": group, - "backward_plan": plan.remote_prefix_tail_backward_exchange, - } - return ( - exchange_rank_tensor_all_to_all(qkv_flat, **kwargs), - exchange_rank_tensor_all_to_all(beta_flat, **kwargs), - exchange_rank_tensor_all_to_all(g_flat, **kwargs), - ) - - -class _ParentStateExchange(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - conv_table: Tensor, - rec_table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, - ) -> tuple[Tensor, Tensor]: - ctx.group = group - ctx.transfers = transfers - ctx.save_for_backward(conv_table, rec_table) - return ( - _exchange_parent_state_tensor_forward( - conv_table, - transfers, - group=group, - ), - _exchange_parent_state_tensor_forward( - rec_table, - transfers, - group=group, - ), - ) - - @staticmethod - def backward( - ctx: Any, *grad_outputs: Tensor | None - ) -> tuple[Tensor | None, Tensor | None, None, None]: - grad_conv, grad_rec = grad_outputs - conv_ref, rec_ref = ctx.saved_tensors - return ( - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_conv, conv_ref), - ctx.transfers, - group=ctx.group, - ), - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_rec, rec_ref), - ctx.transfers, - group=ctx.group, - ), - None, - None, - ) - - -def _exchange_parent_state_tensor_forward( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - output = table.clone() - recvs = _exchange_parent_state_rows_all_to_all( - table, transfers, rank=rank, reverse=False, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=table.device) - output.index_copy_(0, index, rows) - return output - - -def _exchange_parent_state_tensor_backward( - grad_output: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - grad_input = grad_output.clone() - for transfer in transfers: - if transfer.dest_rank != rank: - continue - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_fill_(0, index, 0) - recvs = _exchange_parent_state_rows_all_to_all( - grad_output, transfers, rank=rank, reverse=True, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_add_(0, index, rows) - return grad_input - - -def _zero_if_none(grad: Tensor | None, reference: Tensor) -> Tensor: - if grad is None: - return reference.new_zeros(reference.shape) - return grad.contiguous() - - -def _exchange_parent_state_rows_all_to_all( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - rank: int, - reverse: bool, - group: Any, -) -> list[tuple[GdnParentStateTransferPlan, Tensor]]: - world_size = torch.distributed.get_world_size(group) # ty: ignore[possibly-missing-attribute] - send_counts = [0 for _ in range(world_size)] - recv_counts = [0 for _ in range(world_size)] - send_pieces: list[Tensor] = [] - for peer_rank in range(world_size): - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - row_count = len(transfer.family_indices) - if rank == send_rank and peer_rank == recv_rank: - index = _parent_state_index_tensor(transfer, device=table.device) - send_pieces.append(table.index_select(0, index).contiguous()) - send_counts[peer_rank] += row_count - if rank == recv_rank and peer_rank == send_rank: - recv_counts[peer_rank] += row_count - - trailing_shape = tuple(table.shape[1:]) - send_buffer = ( - torch.cat(send_pieces, dim=0) - if send_pieces - else table.new_empty((0, *trailing_shape)) - ) - recv_buffer = table.new_empty((sum(recv_counts), *trailing_shape)) - work = torch.distributed.all_to_all_single( # ty: ignore[possibly-missing-attribute] - recv_buffer, - send_buffer, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=group, - async_op=True, - ) - work.wait() - - recvs: list[tuple[GdnParentStateTransferPlan, Tensor]] = [] - offset = 0 - for peer_rank, count in enumerate(recv_counts): - peer_end = offset + count - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - if rank != recv_rank or peer_rank != send_rank: - continue - rows = len(transfer.family_indices) - recvs.append((transfer, recv_buffer[offset : offset + rows])) - offset += rows - if offset != peer_end: - raise RuntimeError( - "parent-state exchange unpack mismatch: " - f"rank={rank} peer={peer_rank} consumed={offset} expected={peer_end}" - ) - return recvs - - -def _parent_state_index_tensor( - transfer: GdnParentStateTransferPlan, - *, - device: torch.device, -) -> Tensor: - if ( - transfer.family_indices_tensor is not None - and transfer.family_indices_tensor.device == device - ): - return transfer.family_indices_tensor - return torch.tensor(transfer.family_indices, device=device, dtype=torch.long) - - def run_gdn_bucket( bucket: GdnSegmentBucketPlan, projected_streams: tuple[Tensor, Tensor, Tensor], @@ -2597,14 +2256,17 @@ def run_gdn_bucket( conv_output_final_state = output_final_state chain_conv_final: Tensor | None = None + chain_gradient_dependency: Tensor | None = None if recurrent_cp: - conv_initial, chain_conv_final = _chain_conv_initial_and_final( - qkv, - bucket.cu_seqlens_cpu, - bucket.lengths_by_rank_cpu, - conv_initial, - group=group, - output_final_state=output_final_state, + conv_initial, chain_conv_final, chain_gradient_dependency = ( + _chain_conv_initial_and_final( + qkv, + bucket.cu_seqlens_cpu, + bucket.lengths_by_rank_cpu, + conv_initial, + group=group, + output_final_state=output_final_state, + ) ) conv_output_final_state = False @@ -2618,15 +2280,31 @@ def run_gdn_bucket( if recurrent_cp: conv_final = chain_conv_final - query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( - qkv, - beta, - recurrent_g, - key_heads=_local_key_heads(gdn), - value_heads=_local_value_heads(gdn), - key_dim=int(gdn.key_head_dim), - value_dim=int(gdn.value_head_dim), - ) + dense_local_bucket = not recurrent_cp and _bucket_has_uniform_lengths(bucket) + if dense_local_bucket: + query, key, value, beta, recurrent_g = _prepare_dense_recurrent_inputs( + qkv.reshape(batch_size, int(bucket.length), int(qkv.shape[-1])), + beta.reshape(batch_size, int(bucket.length), int(beta.shape[-1])), + recurrent_g.reshape( + batch_size, + int(bucket.length), + int(recurrent_g.shape[-1]), + ), + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) + else: + query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( + qkv, + beta, + recurrent_g, + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) if gdn.use_qk_l2norm: query = _l2norm(query.contiguous()) key = _l2norm(key.contiguous()) @@ -2657,8 +2335,27 @@ def run_gdn_bucket( initial_state=recurrent_initial, output_final_state=output_final_state, use_qk_l2norm_in_kernel=False, - cu_seqlens=bucket.cu_seqlens, - ) + cu_seqlens=None if dense_local_bucket else bucket.cu_seqlens, + ) + if dense_local_bucket: + recurrent_out = recurrent_out.reshape( + 1, + token_count, + int(recurrent_out.shape[-2]), + int(recurrent_out.shape[-1]), + ) + if chain_gradient_dependency is not None: + recurrent_out = _add_autograd_dependency( + recurrent_out, + chain_gradient_dependency, + ) + if conv_final is not None: + conv_final = _add_autograd_dependency(conv_final, chain_gradient_dependency) + if recurrent_final is not None: + recurrent_final = _add_autograd_dependency( + recurrent_final, + chain_gradient_dependency, + ) return recurrent_out, conv_final, recurrent_final @@ -2670,15 +2367,22 @@ def _chain_conv_initial_and_final( *, group: Any, output_final_state: bool, -) -> tuple[Tensor, Tensor | None]: +) -> tuple[Tensor, Tensor | None, Tensor]: if group is None: raise ValueError("CP chain conv state requires a process group") if not dist.is_available() or not dist.is_initialized(): # ty: ignore[possibly-missing-attribute] raise RuntimeError("torch.distributed must be initialized for CP chain conv") - parent_initial = _AllReduceGradient.apply(parent_initial, group) + parent_initial, gradient_dependency = _AllReduceGradient.apply( + parent_initial, + group, + ) tail_width = int(parent_initial.shape[-1]) if tail_width <= 0: - return parent_initial, parent_initial if output_final_state else None + return ( + parent_initial, + parent_initial if output_final_state else None, + gradient_dependency, + ) if lengths_by_rank_cpu is None: raise ValueError("CP chain conv requires static all-rank bucket lengths") if cu_seqlens_cpu.device.type != "cpu" or lengths_by_rank_cpu.device.type != "cpu": @@ -2705,7 +2409,7 @@ def _chain_conv_initial_and_final( if output_final_state else None ) - return conv_initial, conv_final + return conv_initial, conv_final, gradient_dependency def _local_packed_conv_tail( @@ -2782,14 +2486,20 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: class _AllReduceGradient(torch.autograd.Function): @staticmethod - def forward(ctx: Any, tensor: Tensor, group: Any) -> Tensor: + def forward(ctx: Any, tensor: Tensor, group: Any) -> tuple[Tensor, Tensor]: ctx.group = group - return tensor + ctx.save_for_backward(tensor) + return tensor, tensor.new_zeros(()) @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: - (grad_output,) = grad_outputs - grad_input = grad_output.contiguous() + def backward(ctx: Any, *grad_outputs: Tensor | None) -> tuple[Tensor, None]: + grad_output, _grad_dependency = grad_outputs + (reference,) = ctx.saved_tensors + grad_input = ( + reference.new_zeros(reference.shape) + if grad_output is None + else grad_output.contiguous() + ) dist.all_reduce( # ty: ignore[possibly-missing-attribute] grad_input, op=dist.ReduceOp.SUM, # ty: ignore[possibly-missing-attribute] diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 4cea46b2a..62e7c8435 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,9 +1,14 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +import contextvars +from dataclasses import dataclass, replace +import functools +import importlib import json import math import os import re -from typing import Any, Literal, NamedTuple, cast +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps @@ -22,9 +27,7 @@ ) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.moe.experts import TEGroupedMLP -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_layer import TransformerLayer -from pydantic import BaseModel, ConfigDict import torch from .kernels.cute_grouped_lora_quack import ( @@ -42,6 +45,8 @@ ShardDomain = Literal["tp", "expert_tp"] GradSyncDomain = Literal["tp_default", "expert_tp"] GradSyncOp = Literal["none", "sum", "avg"] +LoraSlotKind = Literal["checkpoint", "lora"] +_F = TypeVar("_F", bound=Callable[..., Any]) TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" @@ -50,11 +55,114 @@ GRAD_SYNC_OP_AVG: GradSyncOp = "avg" -class LoRAParallelSpec(BaseModel): - # This spec only describes TP / expert-TP behavior. - # DP/CP vs expert-DP behavior is selected separately via `allreduce`. - model_config = ConfigDict(frozen=True) +@dataclass(frozen=True) +class LoRASlotRef: + kind: LoraSlotKind + name: str | None + +_CURRENT_LORA_SLOT: contextvars.ContextVar[LoRASlotRef | None] = contextvars.ContextVar( + "art_megatron_current_lora_slot", default=None +) + + +@contextmanager +def use_lora_slot(ref: LoRASlotRef | None) -> Iterator[None]: + token = _CURRENT_LORA_SLOT.set(ref) + try: + yield + finally: + _CURRENT_LORA_SLOT.reset(token) + + +def _with_captured_lora_slot(function: _F) -> _F: + context = _CURRENT_LORA_SLOT.get() + + @functools.wraps(function) + def wrapped(*args: Any, **kwargs: Any) -> Any: + token = _CURRENT_LORA_SLOT.set(context) + try: + return function(*args, **kwargs) + finally: + _CURRENT_LORA_SLOT.reset(token) + + return cast(_F, wrapped) + + +def _patch_function_once(module: Any, name: str, wrapper: Callable[[_F], _F]) -> None: + original = getattr(module, name, None) + if original is None or getattr(original, "_art_lora_slot_context_patch", False): + return + patched = wrapper(original) + setattr(patched, "_art_lora_slot_context_patch", True) + setattr(module, name, patched) + + +def install_lora_checkpoint_context_hooks() -> None: + """Preserve the selected dynamic LoRA slot across activation recompute.""" + + def wrap_checkpoint(original: _F, function_index: int) -> _F: + @functools.wraps(original) + def checkpoint(*args: Any, **kwargs: Any) -> Any: + if len(args) > function_index: + args = ( + *args[:function_index], + _with_captured_lora_slot(args[function_index]), + *args[function_index + 1 :], + ) + elif "function" in kwargs: + kwargs = { + **kwargs, + "function": _with_captured_lora_slot(kwargs["function"]), + } + elif "forward_func" in kwargs: + kwargs = { + **kwargs, + "forward_func": _with_captured_lora_slot(kwargs["forward_func"]), + } + else: + raise TypeError("checkpoint wrapper could not find callable argument") + return original(*args, **kwargs) + + return cast(_F, checkpoint) + + def patch(target: str, name: str, function_index: int) -> None: + try: + module_name, _, attr_path = target.partition(":") + target_obj = importlib.import_module(module_name) + for attr in attr_path.split(".") if attr_path else (): + target_obj = getattr(target_obj, attr, None) + if target_obj is None: + return + _patch_function_once( + target_obj, + name, + lambda original: wrap_checkpoint(original, function_index), + ) + except Exception: + pass + + for target, name, function_index in ( + ("torch.utils.checkpoint", "checkpoint", 0), + ("megatron.core.tensor_parallel", "checkpoint", 0), + ("megatron.core.tensor_parallel.random", "checkpoint", 0), + ( + "megatron.core.tensor_parallel.random:CheckpointWithoutOutput", + "checkpoint", + 1, + ), + ("megatron.core.transformer.transformer_block", "te_checkpoint", 0), + ("transformer_engine.pytorch.distributed", "checkpoint", 0), + ): + patch(target, name, function_index) + + +install_lora_checkpoint_context_hooks() + + +@dataclass(frozen=True) +class LoRAParallelSpec: + # This only describes TP / expert-TP; DP/CP vs expert-DP is selected by `allreduce`. shard_domain: ShardDomain = "tp" sharded: bool = False shard_dim: int | None = None @@ -72,10 +180,7 @@ class LoraShardMeta(NamedTuple): @property def numel(self) -> int: - total = 1 - for dim in self.shape: - total *= dim - return total + return math.prod(self.shape) class _LoraPublishTemplate(NamedTuple): @@ -123,14 +228,6 @@ def _get_shard_rank(domain: ShardDomain) -> int: return group.rank() -def _get_shard_group(domain: ShardDomain) -> Any | None: - if not _distributed_initialized(): - return None - if domain == "tp": - return ps.get_tensor_model_parallel_group() - return ps.get_expert_tensor_parallel_group(check_initialized=False) - - def _dtype_name(dtype: torch.dtype) -> str: return str(dtype).removeprefix("torch.") @@ -242,13 +339,8 @@ def _set_lora_parallel_metadata( setattr(param, "lora_tp_shard_dim", parallel_spec.shard_dim) setattr(param, "grad_sync_domain", parallel_spec.grad_sync_domain) setattr(param, "grad_sync_op", parallel_spec.grad_sync_op) - # Megatron DDP routing flag: - # - allreduce=True: sync with regular DP/CP replicas. - # - allreduce=False: sync with expert-DP replicas. - # TP / expert-TP replica handling is controlled by grad_sync_* metadata. setattr(param, "allreduce", allreduce) - # Megatron's native TP finalize path consumes this attr. setattr( param, "average_gradients_across_tp_domain", @@ -259,16 +351,12 @@ def _set_lora_parallel_metadata( ), ) - # Megatron optimizer and checkpoint logic rely on tensor model-parallel metadata - # to distinguish true shards from TP-duplicate params. if parallel_spec.sharded: shard_dim = parallel_spec.shard_dim if shard_dim is None: raise ValueError("LoRAParallelSpec.shard_dim must be set when sharded=True") setattr(param, "tensor_model_parallel", True) setattr(param, "partition_dim", _normalize_axis(shard_dim, param.ndim)) - # stride > 1 means the dim is split into blocks and each tp rank holds a shard of the block - # this might happen for fused e.g. gate_(up|proj), but loras are individual per module setattr(param, "partition_stride", 1) else: setattr(param, "tensor_model_parallel", False) @@ -307,6 +395,59 @@ def _exported_shard_dim(param: torch.nn.Parameter) -> int: return 1 - axis +def _copy_lora_param_metadata( + source: torch.nn.Parameter, + target: torch.nn.Parameter, +) -> None: + for name in ( + "lora_shard_domain", + "lora_tp_sharded", + "lora_tp_replicated", + "lora_tp_shard_dim", + "grad_sync_domain", + "grad_sync_op", + "allreduce", + "average_gradients_across_tp_domain", + "tensor_model_parallel", + "partition_dim", + "partition_stride", + "lora_tp_shard_strategy", + "lora_tp_component_sizes", + ): + if hasattr(source, name): + setattr(target, name, getattr(source, name)) + setattr(target, "_art_dynamic_lora_slot", True) + + +class LoRASlot(torch.nn.Module): + def __init__( + self, + *, + ref: LoRASlotRef, + a_t: torch.Tensor, + b_t: torch.Tensor, + alpha: float, + a_template: torch.nn.Parameter, + b_template: torch.nn.Parameter, + requires_grad: bool, + ) -> None: + super().__init__() + self.ref = ref + self.alpha = float(alpha) + self.A_T = torch.nn.Parameter(a_t.detach().clone(), requires_grad=requires_grad) + self.B_T = torch.nn.Parameter(b_t.detach().clone(), requires_grad=requires_grad) + _copy_lora_param_metadata(a_template, self.A_T) + _copy_lora_param_metadata(b_template, self.B_T) + + @property + def rank(self) -> int: + return int(self.A_T.shape[-1]) + + @property + def scale(self) -> float: + return self.alpha / self.rank + + class LoRA(torch.nn.Module): def __init__( self, @@ -327,7 +468,12 @@ def __init__( "adapter_model_prefix must contain the '{expert}' format placeholder if num_local_experts > 1" ) self.adapter_model_prefix = adapter_model_prefix + self.alpha = float(alpha) + self.in_features = int(in_features) + self.out_features = int(out_features) self.scale = alpha / rank + self._slot_modules = torch.nn.ModuleDict() + self._slot_keys: dict[LoRASlotRef, str] = {} self.A_T = torch.nn.Parameter( torch.zeros( num_local_experts, in_features, rank, dtype=dtype, device=device @@ -362,7 +508,11 @@ def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: world_size = _get_shard_world_size(domain) if world_size <= 1: return - group = _get_shard_group(domain) + group = ( + ps.get_tensor_model_parallel_group() + if domain == "tp" + else ps.get_expert_tensor_parallel_group(check_initialized=False) + ) if group is None: raise RuntimeError( f"{self.adapter_model_prefix}: missing process group for replicated parameter domain={domain}" @@ -395,43 +545,104 @@ def _expected_weight_keys(self, suffix: str) -> list[str]: ] return [f"{self.adapter_model_prefix}.{suffix}.weight"] + def load_lora_slot( + self, + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, + ) -> bool: + if ref.name is None: + raise ValueError("base-model slot refs do not own LoRA tensors") + weights = self._adapter_weights(adapter_model, require=False) + if weights is None: + return False + a_t = self._localized_weight(weights[0], into=self.A_T) + b_t = self._localized_weight(weights[1], into=self.B_T) + slot_key = self._slot_keys.get(ref) + if slot_key is None: + slot_key = f"slot_{len(self._slot_keys)}" + self._slot_keys[ref] = slot_key + elif self._has_live_slot_grads(ref): + raise RuntimeError( + f"Cannot overwrite live LoRA slot {ref.kind}:{ref.name} for " + f"{self.adapter_model_prefix}; clear grads/backward graph first." + ) + self._slot_modules[slot_key] = LoRASlot( + ref=ref, + a_t=a_t, + b_t=b_t, + alpha=alpha, + a_template=self.A_T, + b_template=self.B_T, + requires_grad=requires_grad, + ) + return True + + def lora_slot_params(self, ref: LoRASlotRef) -> list[torch.nn.Parameter]: + slot = self._slot(ref) + if slot is None: + return [] + return [slot.A_T, slot.B_T] + + def _slot(self, ref: LoRASlotRef) -> LoRASlot | None: + key = self._slot_keys.get(ref) + if key is None: + return None + return cast(LoRASlot, self._slot_modules[key]) + + def _has_live_slot_grads(self, ref: LoRASlotRef) -> bool: + slot = self._slot(ref) + return slot is not None and any( + param.grad is not None for param in (slot.A_T, slot.B_T) + ) + def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: - missing_keys = [ + weights = self._adapter_weights(adapter_model, require=True) + assert weights is not None + self._load_weight(weights[0], into=self.A_T) + self._load_weight(weights[1], into=self.B_T) + + def _adapter_weights( + self, + adapter_model: dict[str, torch.Tensor], + *, + require: bool, + ) -> tuple[torch.Tensor, torch.Tensor] | None: + all_keys = [ key for suffix in ("lora_A", "lora_B") for key in self._expected_weight_keys(suffix) - if key not in adapter_model ] - if missing_keys: + missing = [key for key in all_keys if key not in adapter_model] + if len(missing) == len(all_keys) and not require: + return None + if missing: + state = "Missing" if require else "Incomplete" raise KeyError( - f"Missing LoRA adapter keys for {self.adapter_model_prefix}: {sorted(missing_keys)}" + f"{state} LoRA adapter keys for {self.adapter_model_prefix}: " + f"{sorted(missing)}" ) - self.load_weights( - adapter_model, - suffix="lora_A", - into=self.A_T, - ) - self.load_weights( - adapter_model, - suffix="lora_B", - into=self.B_T, + return ( + self._adapter_weight(adapter_model, suffix="lora_A"), + self._adapter_weight(adapter_model, suffix="lora_B"), ) - def load_weights( + def _adapter_weight( self, adapter_model: dict[str, torch.Tensor], *, suffix: str, - into: torch.nn.Parameter, - ) -> None: + ) -> torch.Tensor: keys = self._expected_weight_keys(suffix) if self.num_local_experts > 1: - weight = torch.stack([adapter_model[key].T for key in keys]) - else: - weight = adapter_model[keys[0]].T - self.load_weight(weight, into=into) + return torch.stack([adapter_model[key].T for key in keys]) + return adapter_model[keys[0]].T - def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + def _localized_weight( + self, weight: torch.Tensor, *, into: torch.nn.Parameter + ) -> torch.Tensor: domain = into.lora_shard_domain # ty: ignore[unresolved-attribute] if into.lora_tp_sharded: # ty: ignore[unresolved-attribute] axis = into.lora_tp_shard_dim # ty: ignore[unresolved-attribute] @@ -470,11 +681,10 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None raise ValueError( f"{self.adapter_model_prefix}: unsupported shard strategy={strategy}" ) - elif tuple(weight.shape) != tuple(into.shape): - raise ValueError( - f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " - f"expected {tuple(into.shape)}" - ) + return weight.contiguous() + + def _load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + weight = self._localized_weight(weight, into=into) if tuple(weight.shape) != tuple(into.shape): raise ValueError( f"{self.adapter_model_prefix}: sharded load shape mismatch, got {tuple(weight.shape)} " @@ -575,9 +785,26 @@ def sharded_lora_grad_dict(self) -> dict[str, torch.Tensor]: grads[key] = local_grad.T return grads + def active_lora_tensors( + self, + ) -> tuple[torch.Tensor, torch.Tensor, float] | None: + ref = _CURRENT_LORA_SLOT.get() + if ref is None: + return self.A_T, self.B_T, self.scale + if ref.name is None: + return None + slot = self._slot(ref) + if slot is None: + return None + return slot.A_T, slot.B_T, slot.scale + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: + active = self.active_lora_tensors() + if active is None: + return x.new_zeros((*x.shape[:-1], self.out_features)) + a_t, b_t, scale = active if tokens_per_expert is not None: assert self.num_local_experts > 1, ( "tokens_per_expert is only supported if num_local_experts > 1" @@ -586,12 +813,10 @@ def forward( if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") if x.shape[0] == 0: - return x.new_zeros((x.shape[0], self.B_T.shape[-1])) - return quack_grouped_lora(x, self.A_T, self.B_T, bsz, scale=self.scale) - out = (x @ self.A_T) @ self.B_T - if self.scale == 1.0: - return out - return out * self.scale + return x.new_zeros((*x.shape[:-1], self.out_features)) + return quack_grouped_lora(x, a_t, b_t, bsz, scale=scale) + out = (x @ a_t) @ b_t + return out if scale == 1.0 else out * scale class LoRAPublishPlanner: @@ -667,52 +892,47 @@ def _metadata_for_template( template: _LoraPublishTemplate, adapter_model: dict[str, torch.Tensor], ) -> list[LoraShardMeta]: - if template.num_local_experts > 1: - return self._expert_metadata_for_template(template, adapter_model) - return self._dense_metadata_for_template(template, adapter_model) - - def _dense_metadata_for_template( - self, - template: _LoraPublishTemplate, - adapter_model: dict[str, torch.Tensor], - ) -> list[LoraShardMeta]: - tp_ranks = self._dense_tp_ranks() shard_ranks = range(template.shard_world_size) if template.sharded else (0,) + if template.num_local_experts <= 1: + tp_ranks = ( + _process_group_ranks(ps.get_tensor_model_parallel_group()) + if _distributed_initialized() + else (0,) + ) + owners = [ + ( + f"{template.adapter_model_prefix}.{template.suffix}", + tp_ranks[shard_rank], + shard_rank, + ) + for shard_rank in shard_ranks + ] + else: + ep_world_size = 1 + if _distributed_initialized(): + ep_world_size = ps.get_expert_model_parallel_world_size() + owners = [ + ( + f"{template.adapter_model_prefix.format(expert=expert)}.{template.suffix}", + self._expert_owner_rank(ep_rank, shard_rank), + shard_rank, + ) + for ep_rank in range(ep_world_size) + for local_expert in range(template.num_local_experts) + for expert in [ep_rank * template.num_local_experts + local_expert] + for shard_rank in shard_ranks + ] return [ self._make_metadata( template, - key=f"{template.adapter_model_prefix}.{template.suffix}", - owner_rank=tp_ranks[shard_rank], + key=key, + owner_rank=owner_rank, shard_rank=shard_rank, adapter_model=adapter_model, ) - for shard_rank in shard_ranks + for key, owner_rank, shard_rank in owners ] - def _expert_metadata_for_template( - self, - template: _LoraPublishTemplate, - adapter_model: dict[str, torch.Tensor], - ) -> list[LoraShardMeta]: - ep_world_size = self._expert_model_world_size() - shard_ranks = range(template.shard_world_size) if template.sharded else (0,) - metadata: list[LoraShardMeta] = [] - for ep_rank in range(ep_world_size): - for local_expert in range(template.num_local_experts): - expert = ep_rank * template.num_local_experts + local_expert - key = f"{template.adapter_model_prefix.format(expert=expert)}.{template.suffix}" - for shard_rank in shard_ranks: - metadata.append( - self._make_metadata( - template, - key=key, - owner_rank=self._expert_owner_rank(ep_rank, shard_rank), - shard_rank=shard_rank, - adapter_model=adapter_model, - ) - ) - return metadata - @staticmethod def _make_metadata( template: _LoraPublishTemplate, @@ -722,6 +942,18 @@ def _make_metadata( shard_rank: int, adapter_model: dict[str, torch.Tensor], ) -> LoraShardMeta: + manifest: dict[str, Any] = { + "sharded": template.sharded, + "shard_world_size": template.shard_world_size if template.sharded else 1, + "shard_rank": shard_rank if template.sharded else 0, + } + if template.sharded: + manifest["export_shard_dim"] = template.export_shard_dim + manifest["export_shard_strategy"] = ( + template.export_shard_strategy or "uniform" + ) + if template.component_sizes: + manifest["component_sizes"] = list(template.component_sizes) return LoraShardMeta( key=key, owner_rank=owner_rank, @@ -731,22 +963,10 @@ def _make_metadata( if key in adapter_model else template.dtype_name ), - manifest=_publish_manifest(template, shard_rank=shard_rank), + manifest=manifest, block=_block_for_key(key), ) - @staticmethod - def _dense_tp_ranks() -> tuple[int, ...]: - if not _distributed_initialized(): - return (0,) - return _process_group_ranks(ps.get_tensor_model_parallel_group()) - - @staticmethod - def _expert_model_world_size() -> int: - if not _distributed_initialized(): - return 1 - return ps.get_expert_model_parallel_world_size() - @staticmethod def _expert_owner_rank(ep_rank: int, shard_rank: int) -> int: if not _distributed_initialized(): @@ -793,24 +1013,6 @@ def _exported_param_shape(module: LoRA, param: torch.nn.Parameter) -> tuple[int, return tuple(int(dim) for dim in param.T.shape) -def _publish_manifest( - template: _LoraPublishTemplate, - *, - shard_rank: int, -) -> dict[str, Any]: - manifest: dict[str, Any] = { - "sharded": template.sharded, - "shard_world_size": template.shard_world_size if template.sharded else 1, - "shard_rank": shard_rank if template.sharded else 0, - } - if template.sharded: - manifest["export_shard_dim"] = template.export_shard_dim - manifest["export_shard_strategy"] = template.export_shard_strategy or "uniform" - if template.component_sizes: - manifest["component_sizes"] = list(template.component_sizes) - return manifest - - @torch.compiler.disable def _expert_grouped_lora_forward( lora: LoRA, @@ -834,15 +1036,110 @@ def _expert_grouped_lora_dual_forward( counts = torch.tensor(counts, dtype=torch.int64, device="cpu") if x.shape[0] == 0: return x.new_zeros((x.shape[0], module.linear_fc1.out_features)) + gate = module.gate_lora.active_lora_tensors() + up = module.up_lora.active_lora_tensors() + if gate is None or up is None: + return torch.cat( + [ + module.gate_lora(x, tokens_per_expert=counts), + module.up_lora(x, tokens_per_expert=counts), + ], + dim=-1, + ) + gate_a_t, gate_b_t, gate_scale = gate + up_a_t, up_b_t, up_scale = up return quack_grouped_lora_dual( x, - module.gate_lora.A_T, - module.gate_lora.B_T, - module.up_lora.A_T, - module.up_lora.B_T, + gate_a_t, + gate_b_t, + up_a_t, + up_b_t, counts, - scale_gate=module.gate_lora.scale, - scale_up=module.up_lora.scale, + scale_gate=gate_scale, + scale_up=up_scale, + ) + + +def _parallel_lora( + *, + adapter_model_prefix: str, + linear: Any, + out_features: int, + rank: int, + alpha: float, + layout: Literal["column", "row"], + shard_domain: ShardDomain = "tp", + grad_sync_domain: GradSyncDomain = TP_DEFAULT_GRAD_SYNC_DOMAIN, + allreduce: bool = True, + num_local_experts: int = 1, +) -> LoRA: + weight = getattr(linear, "weight0", None) + if weight is None: + weight = getattr(linear, "weight", None) + assert isinstance(weight, torch.Tensor) + row_layout = layout == "row" + a_parallel_spec = LoRAParallelSpec( + shard_domain=shard_domain, + sharded=row_layout, + shard_dim=-2 if row_layout else None, + grad_sync_domain=grad_sync_domain, + grad_sync_op=GRAD_SYNC_OP_NONE if row_layout else GRAD_SYNC_OP_SUM, + ) + b_parallel_spec = replace( + a_parallel_spec, + sharded=not row_layout, + shard_dim=None if row_layout else -1, + grad_sync_domain=grad_sync_domain, + grad_sync_op=GRAD_SYNC_OP_SUM if row_layout else GRAD_SYNC_OP_NONE, + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear.in_features, + out_features=out_features, + rank=rank, + alpha=alpha, + dtype=weight.dtype, + device=weight.device, + num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + allreduce=allreduce, + ) + + +def _parallel_lora_pair( + *, + adapter_model_prefix: str, + linear: Any, + out_features: int, + rank: int, + alpha: float, + layout: Literal["column", "row"], + suffixes: tuple[str, str], + num_local_experts: int = 1, +) -> tuple[LoRA, LoRA]: + expert_parallel = num_local_experts > 1 + return cast( + tuple[LoRA, LoRA], + tuple( + _parallel_lora( + adapter_model_prefix=f"{adapter_model_prefix}.{suffix}", + linear=linear, + out_features=out_features, + rank=rank, + alpha=alpha, + layout=layout, + shard_domain="expert_tp" if expert_parallel else "tp", + grad_sync_domain=( + EXPERT_TP_GRAD_SYNC_DOMAIN + if expert_parallel + else TP_DEFAULT_GRAD_SYNC_DOMAIN + ), + allreduce=not expert_parallel, + num_local_experts=num_local_experts, + ) + for suffix in suffixes + ), ) @@ -860,33 +1157,13 @@ def __init__( self.provider = provider self.linear_proj = linear_proj self.reduce_output = reduce_output - assert isinstance(linear_proj.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=True, - shard_dim=-2, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": False, - "shard_dim": None, - "grad_sync_op": GRAD_SYNC_OP_SUM, # sum replicated TP contributions - } - ) - self.lora = LoRA( + self.lora = _parallel_lora( adapter_model_prefix=adapter_model_prefix, - in_features=linear_proj.in_features, + linear=linear_proj, out_features=linear_proj.out_features, rank=rank, alpha=alpha, - dtype=linear_proj.weight.dtype, - device=linear_proj.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. - allreduce=True, + layout="row", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -957,64 +1234,29 @@ def __init__( self.provider.num_attention_heads // self.provider.num_query_groups ) self.hidden_size_per_attention_head = self.provider.kv_channels - self.q_proj_lora = self._build_qkv_lora( + self.q_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", - linear_qkv=linear_qkv, + linear=linear_qkv, + out_features=q_and_gate_out_features_per_rank, rank=rank, alpha=alpha, - out_features=q_and_gate_out_features_per_rank, + layout="column", ) - self.k_proj_lora = self._build_qkv_lora( + self.k_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.k_proj", - linear_qkv=linear_qkv, + linear=linear_qkv, + out_features=kv_out_features_per_rank, rank=rank, alpha=alpha, - out_features=kv_out_features_per_rank, + layout="column", ) - self.v_proj_lora = self._build_qkv_lora( + self.v_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.v_proj", - linear_qkv=linear_qkv, - rank=rank, - alpha=alpha, + linear=linear_qkv, out_features=kv_out_features_per_rank, - ) - - @staticmethod - def _build_qkv_lora( - *, - adapter_model_prefix: str, - linear_qkv: TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - out_features: int, - ) -> LoRA: - assert isinstance(linear_qkv.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, # sum replicated TP contributions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=linear_qkv.in_features, - out_features=out_features, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. - allreduce=True, + layout="column", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1080,13 +1322,13 @@ def __init__( z_out_features_per_partition = ( gated_delta_net.v_dim // ps.get_tensor_model_parallel_world_size() ) - assert isinstance(in_proj.weight, torch.Tensor) - self.qkv_lora = self._build_in_proj_lora( + self.qkv_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.in_proj_qkv", - in_proj=in_proj, + linear=in_proj, + out_features=qkv_out_features_per_partition, rank=rank, alpha=alpha, - out_features=qkv_out_features_per_partition, + layout="column", ) _set_lora_shard_strategy_metadata( self.qkv_lora.B_T, @@ -1097,49 +1339,13 @@ def __init__( gated_delta_net.v_dim, ), ) - self.z_lora = self._build_in_proj_lora( + self.z_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.in_proj_z", - in_proj=in_proj, - rank=rank, - alpha=alpha, + linear=in_proj, out_features=z_out_features_per_partition, - ) - - @staticmethod - def _build_in_proj_lora( - *, - adapter_model_prefix: str, - in_proj: TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - out_features: int, - ) -> LoRA: - assert isinstance(in_proj.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=in_proj.in_features, - out_features=out_features, rank=rank, alpha=alpha, - dtype=in_proj.weight.dtype, - device=in_proj.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=True, + layout="column", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1169,133 +1375,62 @@ def __init__( rank: int, alpha: float, num_local_experts: int, + fused_gate_up: bool = False, ) -> None: super().__init__() - assert linear_fc1 is not None self.linear_fc1 = linear_fc1 - self.gate_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=num_local_experts, - ) - self.up_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.up_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=num_local_experts, - ) - self.uses_direct_quack_grouped_lora_dual = True - - @staticmethod - def _build_fc1_lora( - *, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelGroupedLinear, - rank: int, - alpha: float, - num_local_experts: int, - ) -> LoRA: - assert linear_fc1 is not None - assert isinstance(linear_fc1.weight0, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=False, - shard_dim=None, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features // 2, - rank=rank, - alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Expert LoRA params use Megatron's expert-DP gradient buckets. - allreduce=False, - ) - - def forward( - self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = _expert_grouped_lora_dual_forward(self, x, tokens_per_expert) - return base_out + adapter_out, bias_out - - -class MLPExpertsLinearFC1FusedLoRA(torch.nn.Module): - def __init__( - self, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelGroupedLinear, - rank: int, - alpha: float, - num_local_experts: int, - ) -> None: - super().__init__() - assert linear_fc1 is not None - assert isinstance(linear_fc1.weight0, torch.Tensor) - self.linear_fc1 = linear_fc1 - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=False, - shard_dim=None, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - self.lora = LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_up_proj", - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features, - rank=rank, - alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=False, - ) - gate_out_features = linear_fc1.out_features // 2 - expert_tp_world_size = _get_shard_world_size("expert_tp") - _set_lora_shard_strategy_metadata( - self.lora.B_T, - strategy="componentwise", - component_sizes=( - gate_out_features * expert_tp_world_size, - gate_out_features * expert_tp_world_size, - ), - ) + self.fused_gate_up = bool(fused_gate_up) + if self.fused_gate_up: + self.lora = _parallel_lora( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_up_proj", + linear=linear_fc1, + out_features=linear_fc1.out_features, + rank=rank, + alpha=alpha, + layout="column", + shard_domain="expert_tp", + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + allreduce=False, + num_local_experts=num_local_experts, + ) + gate_out_features = linear_fc1.out_features // 2 + expert_tp_world_size = _get_shard_world_size("expert_tp") + _set_lora_shard_strategy_metadata( + self.lora.B_T, + strategy="componentwise", + component_sizes=( + gate_out_features * expert_tp_world_size, + gate_out_features * expert_tp_world_size, + ), + ) + else: + self.gate_lora, self.up_lora = _parallel_lora_pair( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}", + linear=linear_fc1, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + layout="column", + suffixes=("gate_proj", "up_proj"), + num_local_experts=num_local_experts, + ) def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = _expert_grouped_lora_forward( - self.lora, x, tokens_per_expert, self.linear_fc1.out_features + base_out, bias_out = cast( + Callable[ + [torch.Tensor, list[int] | torch.Tensor], + tuple[torch.Tensor, torch.Tensor | None], + ], + self.linear_fc1, + )(x, tokens_per_expert) + adapter_out = ( + _expert_grouped_lora_forward( + self.lora, x, tokens_per_expert, self.linear_fc1.out_features + ) + if self.fused_gate_up + else _expert_grouped_lora_dual_forward(self, x, tokens_per_expert) ) return base_out + adapter_out, bias_out @@ -1310,43 +1445,30 @@ def __init__( num_local_experts: int, ) -> None: super().__init__() - assert linear_fc2 is not None - assert isinstance(linear_fc2.weight0, torch.Tensor) self.linear_fc2 = linear_fc2 - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=True, - shard_dim=-2, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": False, - "shard_dim": None, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads - } - ) - self.lora = LoRA( + self.lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.down_proj", - in_features=linear_fc2.in_features, + linear=linear_fc2, out_features=linear_fc2.out_features, rank=rank, alpha=alpha, - dtype=linear_fc2.weight0.dtype, - device=linear_fc2.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Expert LoRA params use Megatron's expert-DP gradient buckets. + layout="row", + shard_domain="expert_tp", + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, allreduce=False, + num_local_experts=num_local_experts, ) def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc2(x, tokens_per_expert) + base_out, bias_out = cast( + Callable[ + [torch.Tensor, list[int] | torch.Tensor], + tuple[torch.Tensor, torch.Tensor | None], + ], + self.linear_fc2, + )(x, tokens_per_expert) adapter_out = _expert_grouped_lora_forward( self.lora, x, tokens_per_expert, self.linear_fc2.out_features ) @@ -1368,53 +1490,14 @@ def __init__( linear_fc1.return_layernorm_output = True linear_fc1.return_layernorm_output_gathered = True self.linear_fc1 = linear_fc1 - self.gate_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.gate_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - self.up_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.up_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - - @staticmethod - def _build_fc1_lora( - *, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - ) -> LoRA: - assert isinstance(linear_fc1.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - return LoRA( + self.gate_lora, self.up_lora = _parallel_lora_pair( adapter_model_prefix=adapter_model_prefix, - in_features=linear_fc1.in_features, + linear=linear_fc1, out_features=linear_fc1.out_features // 2, rank=rank, alpha=alpha, - dtype=linear_fc1.weight.dtype, - device=linear_fc1.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=True, + layout="column", + suffixes=("gate_proj", "up_proj"), ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1438,29 +1521,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: return base_out + adapter_out, bias_out -class SharedExpertsLinearFC2LoRA(torch.nn.Module): - def __init__( - self, - adapter_model_prefix: str, - linear_fc2: TERowParallelLinear, - rank: int, - alpha: float, - provider: GPTModelProvider, - ) -> None: - super().__init__() - self.row_parallel_lora = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.down_proj", - linear_proj=linear_fc2, - rank=rank, - alpha=alpha, - provider=provider, - reduce_output=not _linear_disables_tensor_parallel_comm(linear_fc2), - ) - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.row_parallel_lora(x) - - def _unwrap_attr( value: Any, attr_name: str, @@ -1568,8 +1628,14 @@ def wrap_grouped_moe_experts( target_modules: set[str], rank: int, alpha: int, + fused_gate_up: bool = False, ) -> None: - if _targets_include(target_modules, "gate_proj", "up_proj"): + wrap_fc1 = ( + _targets_include(target_modules, "experts") + if fused_gate_up + else _targets_include(target_modules, "gate_proj", "up_proj") + ) + if wrap_fc1: mlp_experts_linear_fc1 = _unwrap_attr( experts.linear_fc1, "linear_fc1", @@ -1581,58 +1647,27 @@ def wrap_grouped_moe_experts( rank=rank, alpha=alpha, num_local_experts=experts.num_local_experts, + fused_gate_up=fused_gate_up, ) - if _targets_include(target_modules, "down_proj"): - mlp_experts_linear_fc2 = _unwrap_attr( - experts.linear_fc2, - "linear_fc2", - TERowParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=rank, - alpha=alpha, - num_local_experts=experts.num_local_experts, - ) - - -def wrap_grouped_moe_experts_3d( - experts: TEGroupedMLP, - *, - adapter_model_prefix: str, - target_modules: set[str], - rank: int, - alpha: int, -) -> None: - if _targets_include(target_modules, "experts"): - mlp_experts_linear_fc1 = _unwrap_attr( - experts.linear_fc1, - "linear_fc1", - TEColumnParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc1 = MLPExpertsLinearFC1FusedLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=experts.num_local_experts, - ) - mlp_experts_linear_fc2 = _unwrap_attr( + wrap_fc2 = ( + wrap_fc1 if fused_gate_up else _targets_include(target_modules, "down_proj") + ) + if wrap_fc2: + linear_fc2 = _unwrap_attr( experts.linear_fc2, "linear_fc2", TERowParallelGroupedLinear, # type: ignore[arg-type] ) experts.linear_fc2 = MLPExpertsLinearFC2LoRA( adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, + linear_fc2=linear_fc2, rank=rank, alpha=alpha, num_local_experts=experts.num_local_experts, ) -def wrap_dense_mlp( +def wrap_split_mlp_lora( mlp: Any, *, adapter_model_prefix: str, @@ -1642,65 +1677,30 @@ def wrap_dense_mlp( alpha: int, ) -> None: if _targets_include(target_modules, "gate_proj", "up_proj"): - mlp_linear_fc1 = _unwrap_attr( + linear_fc1 = _unwrap_attr( mlp.linear_fc1, "linear_fc1", (TEColumnParallelLinear, TELayerNormColumnParallelLinear), ) mlp.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc1=mlp_linear_fc1, + adapter_model_prefix=adapter_model_prefix, + linear_fc1=linear_fc1, rank=rank, alpha=alpha, ) if _targets_include(target_modules, "down_proj"): - mlp_linear_fc2 = _unwrap_attr( + linear_fc2 = _unwrap_attr( mlp.linear_fc2, "linear_fc2", TERowParallelLinear, ) - mlp.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc2=mlp_linear_fc2, - rank=rank, - alpha=alpha, - provider=provider, - ) - - -def wrap_shared_experts_mlp( - shared_experts: SharedExpertMLP, - *, - adapter_model_prefix: str, - provider: GPTModelProvider, - target_modules: set[str], - rank: int, - alpha: int, -) -> None: - if _targets_include(target_modules, "gate_proj", "up_proj"): - shared_experts_linear_fc1 = _unwrap_attr( - shared_experts.linear_fc1, - "linear_fc1", - (TEColumnParallelLinear, TELayerNormColumnParallelLinear), - ) - shared_experts.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc1=shared_experts_linear_fc1, - rank=rank, - alpha=alpha, - ) - if _targets_include(target_modules, "down_proj"): - shared_experts_linear_fc2 = _unwrap_attr( - shared_experts.linear_fc2, - "linear_fc2", - TERowParallelLinear, - ) - shared_experts.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc2=shared_experts_linear_fc2, + mlp.linear_fc2 = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.down_proj", + linear_proj=linear_fc2, rank=rank, alpha=alpha, provider=provider, + reduce_output=not _linear_disables_tensor_parallel_comm(linear_fc2), ) @@ -1721,3 +1721,43 @@ def apply_lora_adapters( alpha=LORA_ALPHA, ) return list(model) + + +def load_lora_slot_into_model( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, +) -> int: + loaded = 0 + for chunk in model: + for module in chunk.modules(): + if isinstance(module, LoRA) and module.load_lora_slot( + ref, + adapter_model, + alpha=alpha, + requires_grad=requires_grad, + ): + loaded += 1 + if loaded == 0 and ref.name is not None: + raise RuntimeError(f"LoRA slot {ref.kind}:{ref.name} loaded no adapter sites") + return loaded + + +def iter_lora_slot_parameters( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, +) -> Iterator[torch.nn.Parameter]: + seen: set[int] = set() + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + for param in module.lora_slot_params(ref): + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + yield param diff --git a/src/art/megatron/model_support/handlers/default_dense.py b/src/art/megatron/model_support/handlers/default_dense.py index bd79332ae..d3f7d2416 100644 --- a/src/art/megatron/model_support/handlers/default_dense.py +++ b/src/art/megatron/model_support/handlers/default_dense.py @@ -137,7 +137,7 @@ def apply_lora_adapters( from art.megatron.lora import ( _adapter_model_prefix, - wrap_dense_mlp, + wrap_split_mlp_lora, wrap_standard_self_attention, ) @@ -146,18 +146,19 @@ def apply_lora_adapters( for module in chunk.modules(): if not isinstance(module, TransformerLayer): continue + adapter_model_prefix = _adapter_model_prefix(module) wrap_standard_self_attention( module.self_attention, - adapter_model_prefix=_adapter_model_prefix(module), + adapter_model_prefix=adapter_model_prefix, provider=provider, target_modules=target_set, rank=rank, alpha=alpha, ) _require_dense_mlp(module) - wrap_dense_mlp( + wrap_split_mlp_lora( module.mlp, - adapter_model_prefix=_adapter_model_prefix(module), + adapter_model_prefix=f"{adapter_model_prefix}.mlp", provider=provider, target_modules=target_set, rank=rank, @@ -168,32 +169,9 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.transformer.transformer_layer import TransformerLayer + from art.megatron.weights import adapter_export - from art.megatron.weights.adapter_export import ( - add_dense_mlp_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, - ) - - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - _require_dense_mlp(module) - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - add_dense_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - mlp=module.mlp, - ) - return adapter_weights_by_base + return adapter_export.build_transformer_layer_adapter_weights(model_chunks) def compile_workaround_config( self, @@ -236,7 +214,7 @@ def apply_lora_adapters( from art.megatron.lora import ( _adapter_model_prefix, wrap_grouped_moe_experts, - wrap_shared_experts_mlp, + wrap_split_mlp_lora, wrap_standard_self_attention, ) @@ -263,9 +241,9 @@ def apply_lora_adapters( ) shared_experts = getattr(module.mlp, "shared_experts", None) if shared_experts is not None: - wrap_shared_experts_mlp( + wrap_split_mlp_lora( shared_experts, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", provider=provider, target_modules=target_set, rank=rank, @@ -276,40 +254,13 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.transformer.transformer_layer import TransformerLayer + from art.megatron.weights import adapter_export - from art.megatron.weights.adapter_export import ( - add_grouped_moe_adapter_weights, - add_shared_experts_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, + return adapter_export.build_transformer_layer_adapter_weights( + model_chunks, + grouped_moe=True, ) - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - add_grouped_moe_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - experts=_require_moe_experts(module), - ) - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - add_shared_experts_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - shared_experts=shared_experts, - ) - return adapter_weights_by_base - def _require_dense_mlp(module: Any) -> None: if getattr(module.mlp, "experts", None) is not None: diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py index ad200499a..3d4ea98d8 100644 --- a/src/art/megatron/model_support/handlers/qwen3_5.py +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from copy import copy from functools import lru_cache import re @@ -50,6 +51,9 @@ r"^(?P.*\.mlp\.experts)\.(?P\d+)\." r"(?Pgate_proj|up_proj|down_proj)\.(?Plora_[AB])\.weight$" ) +_ART_MOE_MODULES = ("gate_up_proj", "down_proj") +_VLLM_EXPERT_MODULES = ("gate_proj", "up_proj", "down_proj") +_LORA_NAMES = ("lora_A", "lora_B") class Qwen35BaseHandler(DefaultDenseHandler): @@ -80,7 +84,7 @@ def to_vllm_lora_tensors( *, adapter_config: dict[str, Any], ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: - if _group_art_moe_tensors(tensors): + if _group_expert_lora_tensors(tensors, _ART_MOE_EXPERT_KEY_RE): raise TypeError("Dense Qwen3.5 handler received MoE LoRA tensors") transformed: dict[str, torch.Tensor] = {} for key, tensor in tensors.items(): @@ -313,44 +317,14 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.ssm.gated_delta_net import GatedDeltaNet - from megatron.core.transformer.attention import SelfAttention - from megatron.core.transformer.transformer_layer import TransformerLayer - - from art.megatron.lora import _is_language_transformer_layer_name - from art.megatron.weights.adapter_export import ( - add_gated_delta_net_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, - ) + from art.megatron.weights import adapter_export _ensure_bridge_qwen35_adapter_name_map() - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - if not _is_language_transformer_layer_name(module_name): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - if isinstance(module.self_attention, SelfAttention): - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - elif isinstance(module.self_attention, GatedDeltaNet): - add_gated_delta_net_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - self._add_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - module=module, - ) - return adapter_weights_by_base + return adapter_export.build_transformer_layer_adapter_weights( + model_chunks, + grouped_moe=self.is_moe, + language_layers_only=True, + ) def _wrap_mlp_lora( self, @@ -362,34 +336,18 @@ def _wrap_mlp_lora( rank: int, alpha: int, ) -> None: - from art.megatron.lora import wrap_dense_mlp + from art.megatron.lora import wrap_split_mlp_lora _require_dense_mlp(module) - wrap_dense_mlp( + wrap_split_mlp_lora( module.mlp, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp", provider=provider, target_modules=target_modules, rank=rank, alpha=alpha, ) - def _add_mlp_adapter_weights( - self, - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - module: Any, - ) -> None: - from art.megatron.weights.adapter_export import add_dense_mlp_adapter_weights - - _require_dense_mlp(module) - add_dense_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - mlp=module.mlp, - ) - def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: unwrapped = model while hasattr(unwrapped, "module"): @@ -483,54 +441,27 @@ def _wrap_mlp_lora( rank: int, alpha: int, ) -> None: - from art.megatron.lora import ( - wrap_grouped_moe_experts_3d, - wrap_shared_experts_mlp, - ) + from art.megatron.lora import wrap_grouped_moe_experts, wrap_split_mlp_lora - wrap_grouped_moe_experts_3d( + wrap_grouped_moe_experts( _require_moe_experts(module), adapter_model_prefix=adapter_model_prefix, target_modules=target_modules, rank=rank, alpha=alpha, + fused_gate_up=True, ) shared_experts = getattr(module.mlp, "shared_experts", None) if shared_experts is not None: - wrap_shared_experts_mlp( + wrap_split_mlp_lora( shared_experts, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", provider=provider, target_modules=target_modules, rank=rank, alpha=alpha, ) - def _add_mlp_adapter_weights( - self, - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - module: Any, - ) -> None: - from art.megatron.weights.adapter_export import ( - add_grouped_moe_adapter_weights, - add_shared_experts_adapter_weights, - ) - - add_grouped_moe_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - experts=_require_moe_experts(module), - ) - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - add_shared_experts_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - shared_experts=shared_experts, - ) - def compile_workaround_config( self, provider: Any, @@ -573,10 +504,6 @@ def _from_vllm_key(key: str) -> str: ) -def _is_lora_weight_key(key: str) -> bool: - return key.endswith((".lora_A.weight", ".lora_B.weight")) - - def _is_self_attn_q_proj_lora_b(key: str) -> bool: return key.endswith(".self_attn.q_proj.lora_B.weight") @@ -725,12 +652,16 @@ def _vllm_moe_config( return config -def _group_art_moe_tensors( +type _ExpertLoraGroups = dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] + + +def _group_expert_lora_tensors( tensors: dict[str, torch.Tensor], -) -> dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]]: - grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} + pattern: re.Pattern[str], +) -> _ExpertLoraGroups: + grouped: _ExpertLoraGroups = {} for key, tensor in tensors.items(): - match = _ART_MOE_EXPERT_KEY_RE.match(key) + match = pattern.match(key) if match is None: continue grouped.setdefault(match.group("prefix"), {}).setdefault( @@ -740,27 +671,57 @@ def _group_art_moe_tensors( return grouped +def _expert_lora_key(prefix: str, expert: int, module: str, lora_name: str) -> str: + return f"{prefix}.{expert}.{module}.{lora_name}.weight" + + +def _convert_remaining_lora_tensors( + transformed: dict[str, torch.Tensor], + tensors: dict[str, torch.Tensor], + *, + used_keys: set[str], + convert: Callable[ + [str, torch.Tensor], + tuple[str, torch.Tensor], + ], + reject_fused_moe: bool = False, +) -> None: + for key, tensor in tensors.items(): + if key in used_keys: + continue + if reject_fused_moe and _VLLM_MOE_KEY_RE.match(key) is not None: + raise RuntimeError( + "Mixed fused and per-expert Qwen3.5 vLLM MoE LoRA tensors" + ) + converted_key, converted = convert(key, tensor) + if converted_key in transformed: + raise RuntimeError( + f"Duplicate Qwen3.5 LoRA tensor after conversion: {converted_key}" + ) + transformed[converted_key] = converted + + def _to_vllm_lora_tensors( tensors: dict[str, torch.Tensor], *, adapter_config: dict[str, Any], ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: - grouped = _group_art_moe_tensors(tensors) + grouped = _group_expert_lora_tensors(tensors, _ART_MOE_EXPERT_KEY_RE) has_shared_experts = _has_shared_expert_lora_tensors(tensors) transformed: dict[str, torch.Tensor] = {} + convert = lambda key, tensor: _to_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) if not grouped: has_fused_experts = any(_VLLM_MOE_KEY_RE.match(key) for key in tensors) - for key, tensor in tensors.items(): - vllm_key, tensor = _to_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - if vllm_key in transformed: - raise RuntimeError( - f"Duplicate Qwen3.5 LoRA tensor after conversion: {vllm_key}" - ) - transformed[vllm_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=set(), + convert=convert, + ) return transformed, ( _vllm_moe_config( adapter_config, @@ -772,17 +733,21 @@ def _to_vllm_lora_tensors( used_keys: set[str] = set() for prefix, experts in grouped.items(): vllm_prefix = _to_vllm_key(prefix) - gate_up_a: list[torch.Tensor] = [] - gate_up_b: list[torch.Tensor] = [] - down_a: list[torch.Tensor] = [] - down_b: list[torch.Tensor] = [] + blocks = { + ("gate_up_proj", "lora_A"): [], + ("gate_up_proj", "lora_B"): [], + ("down_proj", "lora_A"): [], + ("down_proj", "lora_B"): [], + } for expert in sorted(experts): modules = experts[expert] try: - gate_up_a_tensor = modules["gate_up_proj"]["lora_A"] gate_up_b_tensor = modules["gate_up_proj"]["lora_B"] - d_a = modules["down_proj"]["lora_A"] - d_b = modules["down_proj"]["lora_B"] + expert_tensors = { + (module_name, lora_name): modules[module_name][lora_name] + for module_name in _ART_MOE_MODULES + for lora_name in _LORA_NAMES + } except KeyError as exc: raise RuntimeError( f"Incomplete Qwen3.5 MoE LoRA block for {prefix}.{expert}" @@ -792,34 +757,29 @@ def _to_vllm_lora_tensors( f"{prefix}.{expert}: gate/up lora_B rows " f"{gate_up_b_tensor.shape[0]} are not even" ) - gate_up_a.append(gate_up_a_tensor.contiguous()) - gate_up_b.append(gate_up_b_tensor.contiguous()) - down_a.append(d_a.contiguous()) - down_b.append(d_b.contiguous()) - for module_name in ("gate_up_proj", "down_proj"): - for lora_name in ("lora_A", "lora_B"): - used_keys.add(f"{prefix}.{expert}.{module_name}.{lora_name}.weight") + for slot, tensor in expert_tensors.items(): + blocks[slot].append(tensor.contiguous()) + used_keys.add(_expert_lora_key(prefix, expert, *slot)) transformed[f"{vllm_prefix}.base_layer.lora_A.weight"] = torch.cat( - gate_up_a, + blocks[("gate_up_proj", "lora_A")], dim=0, ).contiguous() transformed[f"{vllm_prefix}.base_layer.lora_B.weight"] = _pack_vllm_3d_lora_b( - gate_up_b + blocks[("gate_up_proj", "lora_B")] ) transformed[f"{vllm_prefix}.lora_A.weight"] = torch.cat( - down_a, + blocks[("down_proj", "lora_A")], dim=0, ).contiguous() - transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b(down_b) - for key, tensor in tensors.items(): - if key in used_keys: - continue - vllm_key, tensor = _to_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, + transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b( + blocks[("down_proj", "lora_B")] ) - transformed[vllm_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + ) return transformed, _vllm_moe_config( adapter_config, has_shared_experts=has_shared_experts, @@ -831,15 +791,12 @@ def _from_vllm_lora_tensors( *, adapter_config: dict[str, Any], ) -> dict[str, torch.Tensor]: - expert_grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} - for key, tensor in tensors.items(): - match = _VLLM_MOE_EXPERT_KEY_RE.match(key) - if match is None: - continue - expert_grouped.setdefault(match.group("prefix"), {}).setdefault( - int(match.group("expert")), - {}, - ).setdefault(match.group("module"), {})[match.group("lora")] = tensor + convert = lambda key, tensor: _from_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + expert_grouped = _group_expert_lora_tensors(tensors, _VLLM_MOE_EXPERT_KEY_RE) if expert_grouped: transformed: dict[str, torch.Tensor] = {} used_keys: set[str] = set() @@ -847,16 +804,17 @@ def _from_vllm_lora_tensors( art_prefix = _from_vllm_key(prefix) for expert, modules in experts.items(): try: - gate_a = modules["gate_proj"]["lora_A"] - gate_b = modules["gate_proj"]["lora_B"] - up_a = modules["up_proj"]["lora_A"] - up_b = modules["up_proj"]["lora_B"] - down_a = modules["down_proj"]["lora_A"] - down_b = modules["down_proj"]["lora_B"] + expert_tensors = { + (module_name, lora_name): modules[module_name][lora_name] + for module_name in _VLLM_EXPERT_MODULES + for lora_name in _LORA_NAMES + } except KeyError as exc: raise RuntimeError( f"Incomplete Qwen3.5 vLLM MoE LoRA block for {prefix}.{expert}" ) from exc + gate_a = expert_tensors[("gate_proj", "lora_A")] + up_a = expert_tensors[("up_proj", "lora_A")] if not torch.equal(gate_a, up_a): raise RuntimeError( "Qwen3.5 Megatron gate_up_proj requires gate/up " @@ -866,32 +824,29 @@ def _from_vllm_lora_tensors( _clone(gate_a) ) transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_B.weight"] = ( - torch.cat([gate_b, up_b], dim=0).contiguous() + torch.cat( + [ + expert_tensors[("gate_proj", "lora_B")], + expert_tensors[("up_proj", "lora_B")], + ], + dim=0, + ).contiguous() ) transformed[f"{art_prefix}.{expert}.down_proj.lora_A.weight"] = _clone( - down_a + expert_tensors[("down_proj", "lora_A")] ) transformed[f"{art_prefix}.{expert}.down_proj.lora_B.weight"] = _clone( - down_b - ) - for module_name in ("gate_proj", "up_proj", "down_proj"): - for lora_name in ("lora_A", "lora_B"): - used_keys.add( - f"{prefix}.{expert}.{module_name}.{lora_name}.weight" - ) - for key, tensor in tensors.items(): - if key in used_keys: - continue - if _VLLM_MOE_KEY_RE.match(key) is not None: - raise RuntimeError( - "Mixed fused and per-expert Qwen3.5 vLLM MoE LoRA tensors" + expert_tensors[("down_proj", "lora_B")] ) - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + for slot in expert_tensors: + used_keys.add(_expert_lora_key(prefix, expert, *slot)) + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + reject_fused_moe=True, + ) return transformed grouped: dict[str, dict[str, torch.Tensor]] = {} @@ -905,13 +860,12 @@ def _from_vllm_lora_tensors( grouped.setdefault(match.group("prefix"), {})[slot] = tensor if not grouped: transformed: dict[str, torch.Tensor] = {} - for key, tensor in tensors.items(): - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=set(), + convert=convert, + ) return transformed rank = int(adapter_config["r"]) @@ -970,15 +924,12 @@ def _from_vllm_lora_tensors( f"{prefix}.lora_B.weight", } ) - for key, tensor in tensors.items(): - if key in used_keys: - continue - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + ) return transformed diff --git a/src/art/megatron/model_support/spec.py b/src/art/megatron/model_support/spec.py index 15c6f8d96..92c1368a2 100644 --- a/src/art/megatron/model_support/spec.py +++ b/src/art/megatron/model_support/spec.py @@ -75,6 +75,7 @@ class ModelSupportSpec(BaseModel): class ModelSupportHandler(Protocol): key: str is_moe: bool + build_gdn_execution_spec: bool native_vllm_lora_status: NativeVllmLoraStatus def identity_lora_model_config(self, base_config: Any) -> Any: ... diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 884188a8d..f8cc0d311 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -241,10 +241,6 @@ def rollout_weights_mode(self) -> Literal["lora", "merged"]: def _vllm_base_url(self) -> str: return self._vllm_runtime.base_url - @property - def _vllm_host(self) -> str: - return self._vllm_runtime.host - @property def _vllm_port(self) -> int: return self._vllm_runtime.port diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index 6d3a5548c..3e5a1cb51 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -36,3 +36,7 @@ if [ -x "${HOME}/.local/bin/uv" ]; then uv_bin="${HOME}/.local/bin/uv" fi "${uv_bin}" sync --extra backend --extra megatron --frozen --active + +if [ "${INSTALL_VLLM_RUNTIME:-true}" = "true" ]; then + "${uv_bin}" sync --project vllm_runtime --frozen --no-dev +fi diff --git a/src/art/megatron/shared_prefix_packing.py b/src/art/megatron/shared_prefix_packing.py new file mode 100644 index 000000000..9e92f359c --- /dev/null +++ b/src/art/megatron/shared_prefix_packing.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class SharedPrefixPack: + tokens: torch.Tensor + group_ids: torch.Tensor + parent_ids: torch.Tensor + position_ids: torch.Tensor + positions_by_sequence: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _PrefixSegment: + sequence_indices: tuple[int, ...] + start: int + end: int + group_id: int + parent_id: int + + +def pack_shared_prefixes( + sequences: Iterable[torch.Tensor], + *, + max_depth: int, +) -> SharedPrefixPack: + """Pack token sequences by storing shared prefixes once. + + This is the small packing step that lets `TrainerRank.dp_rank_forward()` run one + model pass over a compact prefix tree instead of replaying the same prompt + tokens for every request. Think of each input sequence as a path through a + tree: when several paths start with the same tokens, this function writes + that shared segment once, then writes each branch after it. + + Args: + sequences: 1-D token tensors to pack. + max_depth: How many nested shared-prefix levels to emit. `0` disables + prefix sharing and writes each sequence as its own root segment. `1` + shares the first common segment in each branch; larger values allow + branches to contain shared sub-branches. + + Returns: + `tokens` is the compact model input, shaped `[1, packed_length]`. + `group_ids` and `parent_ids` describe the prefix tree to shared-prefix + attention. Positions in the same emitted segment share a group, and each + group points at the parent segment it continues from. Root groups point + to themselves. + `position_ids` keeps each token's original sequence position for + positional embeddings/rotary attention. + `positions_by_sequence` is the reverse index used after the model call + to unpack logits, logprobs, or hidden states back into one tensor per + original request. + + The implementation is a tiny radix-tree walk. It finds the longest prefix + shared by the active sequences, emits that segment once, then partitions the + remaining sequences by their next token while preserving first-seen order. + Single sequences, empty branches, and branches past `max_depth` are emitted + as ordinary unshared tails. + """ + if max_depth < 0: + raise ValueError("max_depth must be >= 0") + + tensors = tuple(_sequence_tensor(sequence) for sequence in sequences) + if not tensors: + return _empty_pack() + + device = tensors[0].device + rows = tuple(tensor.detach().cpu().tolist() for tensor in tensors) + segments = _prefix_segments(rows, max_depth=max_depth) + if not segments: + return _empty_pack(len(tensors), device=device) + + token_chunks: list[torch.Tensor] = [] + group_chunks: list[torch.Tensor] = [] + parent_chunks: list[torch.Tensor] = [] + position_chunks: list[torch.Tensor] = [] + positions_by_sequence: list[list[torch.Tensor]] = [[] for _ in tensors] + cursor = 0 + + for planned in segments: + segment = tensors[planned.sequence_indices[0]][planned.start : planned.end] + packed_positions = torch.arange(cursor, cursor + len(segment), device=device) + token_chunks.append(segment) + group_chunks.append(torch.full_like(segment, planned.group_id)) + parent_chunks.append(torch.full_like(segment, planned.parent_id)) + position_chunks.append(torch.arange(planned.start, planned.end, device=device)) + for sequence_index in planned.sequence_indices: + positions_by_sequence[sequence_index].append(packed_positions) + cursor += len(segment) + + return SharedPrefixPack( + tokens=torch.cat(token_chunks).unsqueeze(0), + group_ids=torch.cat(group_chunks).unsqueeze(0), + parent_ids=torch.cat(parent_chunks).unsqueeze(0), + position_ids=torch.cat(position_chunks).unsqueeze(0), + positions_by_sequence=tuple( + torch.cat(chunks) + if chunks + else torch.empty(0, dtype=torch.long, device=device) + for chunks in positions_by_sequence + ), + ) + + +def estimate_shared_prefix_packed_tokens( + sequences: Iterable[torch.Tensor], + *, + max_depth: int, +) -> int | None: + """Return the exact packed token count without building a packed batch. + + The estimator intentionally only handles CPU tensors. For CUDA tensors, many + tiny prefix probes would launch many tiny kernels, so callers should fall + back to full packing instead. + """ + if max_depth < 0: + raise ValueError("max_depth must be >= 0") + + rows: list[list[int]] = [] + for sequence in sequences: + tensor = _sequence_tensor(sequence) + if tensor.device.type != "cpu": + return None + rows.append(tensor.tolist()) + + return sum( + segment.end - segment.start + for segment in _prefix_segments(tuple(rows), max_depth=max_depth) + ) + + +def _local_position_pairs( + local_global_positions: torch.Tensor, + item_positions: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + flat = local_global_positions.reshape(-1).to(device=item_positions.device) + local_positions = torch.nonzero(flat >= 0, as_tuple=False).reshape(-1) + global_positions = flat.index_select(0, local_positions) + source_offsets, local_offsets = _matching_offsets(item_positions, global_positions) + return ( + local_positions.index_select(0, local_offsets).to("cpu"), + source_offsets.to("cpu"), + ) + + +def _prefix_segments( + rows: tuple[list[int], ...], + *, + max_depth: int, +) -> tuple[_PrefixSegment, ...]: + lengths = tuple(len(row) for row in rows) + segments: list[_PrefixSegment] = [] + next_group_id = 1 + + def emit( + indices: tuple[int, ...], + start: int, + end: int, + parent_group_id: int | None, + ) -> int: + nonlocal next_group_id + group_id = next_group_id + next_group_id += 1 + segments.append( + _PrefixSegment( + sequence_indices=indices, + start=start, + end=end, + group_id=group_id, + parent_id=group_id if parent_group_id is None else parent_group_id, + ) + ) + return group_id + + def shared_end(indices: tuple[int, ...], start: int) -> int: + end = min(lengths[index] for index in indices) + low = high = rows[indices[0]] + for index in indices[1:]: + row = rows[index] + if row < low: + low = row + elif row > high: + high = row + while start < end: + if low[start] != high[start]: + break + start += 1 + return start + + def branch_groups(indices: tuple[int, ...], start: int) -> list[tuple[int, ...]]: + groups: dict[int, list[int]] = {} + order: list[int] = [] + for index in indices: + token = rows[index][start] + if token not in groups: + groups[token] = [] + order.append(token) + groups[token].append(index) + return [tuple(groups[token]) for token in order] + + def walk( + indices: tuple[int, ...], + start: int, + parent_group_id: int | None, + depth: int, + ) -> None: + active = tuple(index for index in indices if lengths[index] > start) + if not active: + return + if ( + max_depth == 0 + or len(active) == 1 + or (parent_group_id is not None and depth >= max_depth) + ): + for index in active: + emit((index,), start, lengths[index], parent_group_id) + return + + end = shared_end(active, start) + if end > start: + walk(active, end, emit(active, start, end, parent_group_id), depth + 1) + return + for group in branch_groups(active, start): + walk(group, start, parent_group_id, depth) + + walk(tuple(range(len(rows))), 0, None, 0) + return tuple(segments) + + +def _empty_pack( + sequence_count: int = 0, + *, + device: torch.device | None = None, +) -> SharedPrefixPack: + flat = torch.empty(0, dtype=torch.long, device=device) + row = flat.unsqueeze(0) + return SharedPrefixPack( + tokens=row, + group_ids=row, + parent_ids=row, + position_ids=row, + positions_by_sequence=tuple(flat for _ in range(sequence_count)), + ) + + +def _sequence_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim != 1: + raise ValueError( + f"pack_shared_prefixes expects 1-D tensors, got {tuple(tensor.shape)}" + ) + return tensor.detach().to(dtype=torch.long).contiguous() + + +def _matching_offsets( + positions: torch.Tensor, + chunk_rows: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + if int(positions.numel()) == 0 or int(chunk_rows.numel()) == 0: + empty = torch.empty(0, dtype=torch.long, device=positions.device) + return empty, empty + sorted_rows, order = chunk_rows.sort() + indices = torch.searchsorted(sorted_rows, positions) + in_bounds = indices < int(sorted_rows.numel()) + source_offsets = torch.arange( + int(positions.numel()), + device=positions.device, + dtype=torch.long, + )[in_bounds] + found = indices[in_bounds] + keep = sorted_rows.index_select(0, found) == positions.index_select( + 0, + source_offsets, + ) + return source_offsets[keep], order.index_select(0, found[keep]) diff --git a/src/art/megatron/shared_prefix_state.py b/src/art/megatron/shared_prefix_state.py index 7bbda4624..f3c1565b1 100644 --- a/src/art/megatron/shared_prefix_state.py +++ b/src/art/megatron/shared_prefix_state.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dataclasses import replace import gc from typing import Any @@ -10,7 +11,10 @@ from torch import Tensor from torch.nn.attention.flex_attention import BlockMask -from art.megatron.context_parallel.block_mask import build_block_mask +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) from art.megatron.context_parallel.builder import build_shared_prefix_attention_spec from art.megatron.context_parallel.layout_index import TokenLayoutIndex from art.megatron.context_parallel.types import ( @@ -75,14 +79,11 @@ def create_shared_prefix_state( attention_value_head_dim=attention_value_head_dim, ), ) - cp_rank, cp_size, cp_group = _gdn_cp_rank_size_group() - gdn_execution_spec = _build_gdn_execution_spec_once( - group_ids_cpu, - parent_ids_cpu, - build=build_gdn_execution_spec, - cp_rank=cp_rank, - cp_size=cp_size, - cp_group=cp_group, + cp_rank, cp_size = _gdn_cp_rank_size() + gdn_execution_spec = ( + parse_gdn_shared_prefix_segments(group_ids_cpu, parent_ids_cpu) + if build_gdn_execution_spec + else None ) return SharedPrefixAttentionState( block_mask=block_mask, @@ -94,7 +95,6 @@ def create_shared_prefix_state( device=device, cp_rank=cp_rank, cp_size=cp_size, - cp_group=cp_group, attention_token_layout_index=attention_token_layout_index, ), ) @@ -118,53 +118,105 @@ def _build_sparse_shared_prefix_block_mask( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, ) - row_spec = batch_spec.rows[0] seq_len = int(group_ids_cpu.shape[1]) - slices = _full_row_slices_with_padding( - row_slices=row_spec.slices, - valid_tokens=int(row_spec.valid_tokens), - seq_len=seq_len, - ) - if not slices: + row_masks = [] + token_indices = torch.arange(seq_len, dtype=torch.int64) + for row_spec in batch_spec.rows: + row_index = int(row_spec.row_index) + slices = tuple(replace(slice_, row_index=0) for slice_ in row_spec.slices) + if int(row_spec.valid_tokens) < seq_len: + padding_range = TokenRange(start=int(row_spec.valid_tokens), end=seq_len) + slices = ( + *slices, + AttnSlice( + q_range=padding_range, + k_range=padding_range, + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + family_index=None, + ), + ) + if not slices: + row_masks.append( + _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) + ) + continue + row_masks.append( + build_block_mask_from_context( + FlexMaskSpec( + q_len=seq_len, + k_len=seq_len, + block_size=block_size, + slices=slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key=f"identity:{seq_len}", + ), + ), + context=prepare_block_mask_context( + group_ids=group_ids_cpu[row_index], + parent_ids=parent_ids_cpu[row_index], + ), + device=device, + ) + ) + if not row_masks: return _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) - return build_block_mask( - FlexMaskSpec( - q_len=seq_len, - k_len=seq_len, - block_size=block_size, - slices=slices, - exact_mask=ExactMaskMetadata( - q_token_indices=torch.arange(seq_len, dtype=torch.int64), - k_token_indices=torch.arange(seq_len, dtype=torch.int64), - cache_key=f"identity:{seq_len}", - ), - ), - group_ids=group_ids_cpu[0], - parent_ids=parent_ids_cpu[0], - device=device, + return _stack_row_block_masks( + row_masks, + seq_len=seq_len, + block_size=block_size, ) -def _full_row_slices_with_padding( +def _stack_optional_block_tensors( + masks: list[BlockMask], + name: str, +) -> Tensor | None: + tensors = [getattr(mask, name) for mask in masks] + if any(tensor is None for tensor in tensors): + return None + return torch.cat(tensors, dim=0) + + +def _stack_row_block_masks( + masks: list[BlockMask], *, - row_slices: tuple[AttnSlice, ...], - valid_tokens: int, seq_len: int, -) -> tuple[AttnSlice, ...]: - if valid_tokens >= seq_len: - return row_slices - padding_range = TokenRange(start=int(valid_tokens), end=int(seq_len)) - if padding_range.is_empty(): - return row_slices - return ( - *row_slices, - AttnSlice( - q_range=padding_range, - k_range=padding_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=0, - family_index=None, - ), + block_size: tuple[int, int], +) -> BlockMask: + if len(masks) == 1: + return masks[0] + row_mask_mods = tuple(mask.mask_mod for mask in masks) + + def mask_mod( + batch_idx: Tensor, + head_idx: Tensor, + query_idx: Tensor, + kv_idx: Tensor, + ) -> Tensor: + result = torch.zeros_like(query_idx, dtype=torch.bool) + for row_index, row_mask_mod in enumerate(row_mask_mods): + result = torch.where( + batch_idx == row_index, + row_mask_mod(batch_idx, head_idx, query_idx, kv_idx), + result, + ) + return result + + return BlockMask( + seq_lengths=(int(seq_len), int(seq_len)), + kv_num_blocks=torch.cat([mask.kv_num_blocks for mask in masks], dim=0), + kv_indices=torch.cat([mask.kv_indices for mask in masks], dim=0), + full_kv_num_blocks=_stack_optional_block_tensors(masks, "full_kv_num_blocks"), + full_kv_indices=_stack_optional_block_tensors(masks, "full_kv_indices"), + q_num_blocks=_stack_optional_block_tensors(masks, "q_num_blocks"), + q_indices=_stack_optional_block_tensors(masks, "q_indices"), + full_q_num_blocks=_stack_optional_block_tensors(masks, "full_q_num_blocks"), + full_q_indices=_stack_optional_block_tensors(masks, "full_q_indices"), + BLOCK_SIZE=block_size, + mask_mod=mask_mod, ) @@ -223,45 +275,17 @@ def _shared_prefix_block_size( ) -def _build_gdn_execution_spec_once( - group_ids: Tensor, - parent_ids: Tensor, - *, - build: bool, - cp_rank: int, - cp_size: int, - cp_group: Any | None, -) -> GdnPackedExecutionSpec | None: - if not build: - return None - if cp_size == 1: - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - if ( - not torch.distributed.is_available() or not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] - ): - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - - def _build_gdn_execution_plan_once( spec: GdnPackedExecutionSpec | None, *, device: torch.device, cp_rank: int, cp_size: int, - cp_group: Any | None, attention_token_layout_index: TokenLayoutIndex | None, ) -> GdnRankExecutionPlan | None: if spec is None: return None planner_device = torch.device("cpu") if device.type == "cuda" else device - del cp_group gc_was_enabled = gc.isenabled() if gc_was_enabled: gc.disable() @@ -279,7 +303,7 @@ def _build_gdn_execution_plan_once( return move_gdn_rank_execution_plan_to_device(plan, device) -def _gdn_cp_rank_size_group() -> tuple[int, int, Any | None]: +def _gdn_cp_rank_size() -> tuple[int, int]: try: from megatron.core import parallel_state as ps @@ -287,8 +311,7 @@ def _gdn_cp_rank_size_group() -> tuple[int, int, Any | None]: return ( int(ps.get_context_parallel_rank()), int(ps.get_context_parallel_world_size()), - ps.get_context_parallel_group(), ) except Exception: pass - return 0, 1, None + return 0, 1 diff --git a/src/art/megatron/shared_prefix_tree.py b/src/art/megatron/shared_prefix_tree.py new file mode 100644 index 000000000..63cdb0f07 --- /dev/null +++ b/src/art/megatron/shared_prefix_tree.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True, slots=True) +class SharedPrefixSegment: + group_id: int + parent_id: int + start: int + end: int + family_index: int + ancestors: tuple[int, ...] + + @property + def depth(self) -> int: + return len(self.ancestors) + + +@dataclass(frozen=True, slots=True) +class SharedPrefixRowTree: + row_index: int + valid_tokens: int + segments: tuple[SharedPrefixSegment, ...] + + +def parse_shared_prefix_tree( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + ignore_padding_group_id: int = -1, +) -> tuple[SharedPrefixRowTree, ...]: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 2: + raise RuntimeError( + "group_ids and parent_ids must be rank-2 packed tensors, got " + f"{group_ids.ndim}" + ) + return tuple( + parse_shared_prefix_row( + group_ids=group_ids[row_index], + parent_ids=parent_ids[row_index], + row_index=row_index, + ignore_padding_group_id=ignore_padding_group_id, + ) + for row_index in range(int(group_ids.shape[0])) + ) + + +def parse_shared_prefix_row( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + row_index: int = 0, + ignore_padding_group_id: int = -1, +) -> SharedPrefixRowTree: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 1: + raise RuntimeError( + f"group_ids and parent_ids must be rank-1 row tensors, got {group_ids.ndim}" + ) + + valid_tokens = _valid_length( + group_ids, + parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ) + if valid_tokens == 0: + return SharedPrefixRowTree(row_index=row_index, valid_tokens=0, segments=()) + + runs = _scan_runs(group_ids[:valid_tokens], parent_ids[:valid_tokens]) + first_segment_by_group: dict[int, SharedPrefixSegment] = {} + family_by_group: dict[int, int] = {} + ancestors_by_group: dict[int, tuple[int, ...]] = {} + segments: list[SharedPrefixSegment] = [] + next_family_index = 0 + + seen_groups: set[int] = set() + repeated_groups: dict[int, int] = {} + for _start, _end, group_id, _parent_id in runs: + if group_id in seen_groups and group_id != ignore_padding_group_id: + repeated_groups[group_id] = repeated_groups.get(group_id, 1) + 1 + seen_groups.add(group_id) + if repeated_groups: + raise RuntimeError( + "Shared-prefix metadata requires contiguous group runs per row, " + f"found repeats in row {row_index}: {repeated_groups}" + ) + + for start, end, group_id, parent_id in runs: + is_root = group_id == parent_id or ( + start == 0 and parent_id == ignore_padding_group_id + ) + if is_root: + family_index = next_family_index + next_family_index += 1 + ancestors: tuple[int, ...] = () + else: + parent_segment = first_segment_by_group.get(parent_id) + if parent_segment is None: + raise RuntimeError( + "Shared-prefix run points to a missing parent run: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + if int(parent_segment.end) > int(start): + raise RuntimeError( + "Shared-prefix parent run must end before its child starts: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + family_index = family_by_group[parent_id] + ancestors = (*ancestors_by_group[parent_id], parent_id) + + segment = SharedPrefixSegment( + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + family_index=family_index, + ancestors=ancestors, + ) + first_segment_by_group[group_id] = segment + family_by_group[group_id] = family_index + ancestors_by_group[group_id] = ancestors + segments.append(segment) + + return SharedPrefixRowTree( + row_index=row_index, + valid_tokens=valid_tokens, + segments=tuple(segments), + ) + + +def _valid_length( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + ignore_padding_group_id: int, +) -> int: + valid_mask = group_ids != ignore_padding_group_id + valid_count = int(valid_mask.sum().item()) + if valid_count == 0: + return 0 + if not bool(valid_mask[:valid_count].all().item()): + raise RuntimeError("Padding tokens must be a contiguous tail") + return _infer_terminal_padding_length( + group_ids[:valid_count], + parent_ids[:valid_count], + ) + + +def _infer_terminal_padding_length( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> int: + if group_row.numel() == 0: + return 0 + runs = _scan_runs(group_row, parent_row) + if len(runs) < 2: + return int(group_row.numel()) + last_start, _last_end, last_group_id, last_parent_id = runs[-1] + if last_parent_id >= 0: + return int(group_row.numel()) + terminal_pair = (last_group_id, last_parent_id) + if any( + (group_id, parent_id) == terminal_pair + for _start, _end, group_id, parent_id in runs[:-1] + ): + return last_start + return int(group_row.numel()) + + +def _scan_runs( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> list[tuple[int, int, int, int]]: + length = int(group_row.numel()) + if length == 0: + return [] + + group_changes = group_row[1:] != group_row[:-1] + parent_changes = parent_row[1:] != parent_row[:-1] + inconsistent_parent = torch.nonzero( + torch.logical_not(group_changes) & parent_changes, + as_tuple=False, + ).flatten() + if int(inconsistent_parent.numel()) > 0: + mismatch_index = int(inconsistent_parent[0].item()) + 1 + prior_boundaries = torch.nonzero( + group_changes[: mismatch_index - 1], + as_tuple=False, + ).flatten() + start = ( + 0 + if int(prior_boundaries.numel()) == 0 + else int(prior_boundaries[-1].item()) + 1 + ) + group_id = int(group_row[start].item()) + raise RuntimeError( + "Found one group run with inconsistent parent ids: " + f"group_id={group_id}, start={start}, end={mismatch_index}" + ) + + run_starts = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device=group_row.device), + torch.nonzero(group_changes, as_tuple=False).flatten() + 1, + ) + ) + run_ends = torch.cat( + ( + run_starts[1:], + torch.tensor([length], dtype=torch.int64, device=group_row.device), + ) + ) + starts = run_starts.to(device="cpu").tolist() + ends = run_ends.to(device="cpu").tolist() + group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() + parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() + return [ + (int(start), int(end), int(group_id), int(parent_id)) + for start, end, group_id, parent_id in zip( + starts, ends, group_ids, parent_ids, strict=True + ) + ] diff --git a/src/art/megatron/training/finalize_grads.py b/src/art/megatron/training/finalize_grads.py index cde0e7b06..e00cd8218 100644 --- a/src/art/megatron/training/finalize_grads.py +++ b/src/art/megatron/training/finalize_grads.py @@ -1,4 +1,3 @@ -from collections import defaultdict from collections.abc import Iterable from typing import Any, Literal, cast @@ -16,7 +15,6 @@ GRAD_SYNC_OP_NONE: GradSyncOp = "none" GRAD_SYNC_OP_SUM: GradSyncOp = "sum" GRAD_SYNC_OP_AVG: GradSyncOp = "avg" -VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN) VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_SUM, GRAD_SYNC_OP_AVG) @@ -28,6 +26,8 @@ def _iter_named_trainable_parameters( for name, param in model_chunk.named_parameters(): if not param.requires_grad: continue + if getattr(param, "_art_dynamic_lora_slot", False): + continue param_id = id(param) if param_id in seen: continue @@ -60,6 +60,48 @@ def _resolve_reduce_op(op: GradSyncOp) -> Any: raise RuntimeError(f"Unknown grad sync op: {op}") +def tensor_parallel_grad_sync( + param: torch.nn.Parameter, + *, + name: str, +) -> tuple[Any, Any] | None: + domain: GradSyncDomain = getattr( + param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN + ) + group = _resolve_domain_group(domain) + if group is None: + return None + op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) + if op not in VALID_SYNC_OPS: + raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") + if op == GRAD_SYNC_OP_NONE: + return None + return group, _resolve_reduce_op(op) + + +def coalesced_all_reduce( + grads: list[torch.Tensor], + *, + group: Any, + op: Any, +) -> None: + coalesced = _flatten_dense_tensors(grads) + reduced = ( + coalesced.float() + if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 + else coalesced + ) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced, + op=op, + group=group, + ) + if reduced is not coalesced: + reduced = reduced.to(dtype=coalesced.dtype) + for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): + grad.copy_(synced) + + def flush_param_grads_to_main_grads(model_chunks: Iterable[torch.nn.Module]) -> None: """Fallback for direct jobs when DDP post-hooks leave grads in param.grad. @@ -100,57 +142,30 @@ def finalize_model_grads_extended( ) buckets: dict[ - tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], - list[tuple[str, torch.Tensor]], - ] = defaultdict(list) + tuple[int, str, torch.dtype, torch.device], + tuple[Any, Any, list[torch.Tensor]], + ] = {} for name, param in _iter_named_trainable_parameters(model): - domain: GradSyncDomain = getattr( - param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN - ) - if _resolve_domain_group(domain) is None: - continue - - op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) - if op not in VALID_SYNC_OPS: - raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") - if op == GRAD_SYNC_OP_NONE: + sync = tensor_parallel_grad_sync(param, name=name) + if sync is None: continue if not hasattr(param, "main_grad"): raise RuntimeError( - f"{name}: expected main_grad for domain={domain} reduce_op={op}, but attribute is missing" + f"{name}: expected main_grad for tensor-parallel grad sync, but attribute is missing" ) grad = param.main_grad if grad is None: raise RuntimeError( - f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}" + f"{name}: expected non-None main_grad for tensor-parallel grad sync" ) local_grad = cast( # local part of dtensor torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad ) - buckets[(domain, op, local_grad.dtype, local_grad.device)].append( - (name, local_grad) - ) + group, reduce_op = sync + key = (id(group), str(reduce_op), local_grad.dtype, local_grad.device) + buckets.setdefault(key, (group, reduce_op, []))[2].append(local_grad) - for (domain, op, _dtype, _device), entries in buckets.items(): - group = _resolve_domain_group( - domain - ) # already checked if the domain is one we are handling - - grads = [grad for _name, grad in entries] - coalesced = _flatten_dense_tensors(grads) - reduced = ( - coalesced.float() - if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 - else coalesced - ) - torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] - reduced, - op=_resolve_reduce_op(op), - group=group, - ) - if reduced is not coalesced: - reduced = reduced.to(dtype=coalesced.dtype) - for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): - grad.copy_(synced) + for group, op, grads in buckets.values(): + coalesced_all_reduce(grads, group=group, op=op) diff --git a/src/art/megatron/weights/adapter_export.py b/src/art/megatron/weights/adapter_export.py index cce081188..76c545bda 100644 --- a/src/art/megatron/weights/adapter_export.py +++ b/src/art/megatron/weights/adapter_export.py @@ -1,3 +1,4 @@ +from collections.abc import Callable, Sequence import math from typing import Any @@ -9,31 +10,15 @@ from art.megatron.lora import ( GatedDeltaNetInProjLoRA, LoRA, - MLPExpertsLinearFC1FusedLoRA, MLPExpertsLinearFC1LoRA, MLPExpertsLinearFC2LoRA, SelfAttentionLinearProjLoRA, SelfAttentionLinearQKVLoRA, SharedExpertsLinearFC1LoRA, - SharedExpertsLinearFC2LoRA, ) from art.megatron.weights.param_name_canonicalization import canonical_art_param_name -def layer_base_prefix( - module: TransformerLayer, - *, - module_name: str | None = None, -) -> str: - if module_name is not None: - canonical_name = canonical_art_param_name(module_name) - if canonical_name.startswith( - ("decoder.layers.", "language_model.decoder.layers.") - ): - return canonical_name - return f"language_model.decoder.layers.{module.layer_number - 1}" - - def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: dim = int(lora.A_T.shape[-1]) alpha = float(lora.scale) * dim @@ -51,12 +36,6 @@ def _adapter_tensors( return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() -def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: - if adapter_key is None: - return f"{base_prefix}.adapter" - return f"{base_prefix}.adapter.{adapter_key}" - - def _adapter_weight( *, base_prefix: str, @@ -66,7 +45,8 @@ def _adapter_weight( linear_in: torch.Tensor, linear_out: torch.Tensor, ) -> AdapterWeight: - param_prefix = _adapter_param_prefix(base_prefix, adapter_key) + adapter_suffix = "" if adapter_key is None else f".{adapter_key}" + param_prefix = f"{base_prefix}.adapter{adapter_suffix}" return AdapterWeight( global_base_prefix=base_prefix, adapter_key=adapter_key, @@ -162,83 +142,174 @@ def _fused_pair_adapter_weight( ) -def add_standard_self_attention_adapter_weights( +def _set_adapter_weights( + out: dict[str, list[Any]], + base_prefix: str, + *weights: AdapterWeight, + weight_suffix: str = ".weight", +) -> None: + out[f"{base_prefix}{weight_suffix}"] = list(weights) + + +def _set_expert_adapter_weights( + out: dict[str, list[Any]], + base_prefix: str, + lora: LoRA, + build_weight: Callable[[int], AdapterWeight], +) -> None: + for local_expert_idx in range(lora.num_local_experts): + global_expert_idx = local_expert_idx + lora._expert_offset + _set_adapter_weights( + out, + base_prefix, + build_weight(local_expert_idx), + weight_suffix=f".weight{global_expert_idx}", + ) + + +def _set_lora_weights( + out: dict[str, list[Any]], + base_prefix: str, + *items: tuple[LoRA, str | None], +) -> None: + _set_adapter_weights( + out, + base_prefix, + *( + _simple_adapter_weight(base_prefix, lora, adapter_key=adapter_key) + for lora, adapter_key in items + ), + ) + + +def build_transformer_layer_adapter_weights( + model_chunks: Sequence[Any], + grouped_moe: bool = False, + language_layers_only: bool = False, +) -> dict[str, list[Any]]: + layer_filter = None + if language_layers_only: + from art.megatron.lora import ( + _is_language_transformer_layer_name as layer_filter, + ) + + add_mlp_adapter_weights = ( + _add_moe_mlp_adapter_weights_for_layer + if grouped_moe + else _add_dense_mlp_adapter_weights_for_layer + ) + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if layer_filter is not None and not layer_filter(module_name): + continue + canonical_name = canonical_art_param_name(module_name) + layer_prefix = ( + canonical_name + if canonical_name.startswith( + ("decoder.layers.", "language_model.decoder.layers.") + ) + else f"language_model.decoder.layers.{module.layer_number - 1}" + ) + add_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + add_mlp_adapter_weights(adapter_weights_by_base, layer_prefix, module) + return adapter_weights_by_base + + +def add_self_attention_adapter_weights( adapter_weights_by_base: dict[str, list[Any]], *, layer_prefix: str, self_attention: Any, ) -> None: - linear_proj = getattr(self_attention, "linear_proj", None) - if isinstance(linear_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.linear_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_proj.lora) - ] + for attr in ("linear_proj", "out_proj"): + linear_proj = getattr(self_attention, attr, None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.{attr}" + _set_lora_weights( + adapter_weights_by_base, + base_prefix, + (linear_proj.lora, None), + ) linear_qkv = getattr(self_attention, "linear_qkv", None) if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): base_prefix = f"{layer_prefix}.self_attention.linear_qkv" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _set_lora_weights( + adapter_weights_by_base, + base_prefix, + (linear_qkv.q_proj_lora, "adapter_q"), + (linear_qkv.k_proj_lora, "adapter_k"), + (linear_qkv.v_proj_lora, "adapter_v"), + ) + + in_proj = getattr(self_attention, "in_proj", None) + if isinstance(in_proj, GatedDeltaNetInProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.in_proj" + input_dim = int(in_proj.qkv_lora.A_T.shape[-2]) + output_dim = int(in_proj.num_value_heads_per_partition) + _set_adapter_weights( + adapter_weights_by_base, + base_prefix, _simple_adapter_weight( - base_prefix, - linear_qkv.q_proj_lora, - adapter_key="adapter_q", + base_prefix, in_proj.qkv_lora, adapter_key="adapter_qkv" ), _simple_adapter_weight( - base_prefix, - linear_qkv.k_proj_lora, - adapter_key="adapter_k", + base_prefix, in_proj.z_lora, adapter_key="adapter_z" ), - _simple_adapter_weight( - base_prefix, - linear_qkv.v_proj_lora, - adapter_key="adapter_v", + *( + _zero_adapter_weight( + base_prefix=base_prefix, + adapter_key=adapter_key, + input_dim=input_dim, + output_dim=output_dim, + like=in_proj.qkv_lora.B_T, + ) + for adapter_key in ("adapter_b", "adapter_a") ), - ] + ) -def add_gated_delta_net_adapter_weights( +def _add_dense_mlp_adapter_weights_for_layer( adapter_weights_by_base: dict[str, list[Any]], - *, layer_prefix: str, - self_attention: Any, + module: Any, ) -> None: - out_proj = getattr(self_attention, "out_proj", None) - if isinstance(out_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.out_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, out_proj.lora) - ] + from art.megatron.model_support.handlers.default_dense import _require_dense_mlp + + _require_dense_mlp(module) + add_split_mlp_adapter_weights( + adapter_weights_by_base, + f"{layer_prefix}.mlp", + module.mlp, + ) - in_proj = getattr(self_attention, "in_proj", None) - if isinstance(in_proj, GatedDeltaNetInProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.in_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - in_proj.qkv_lora, - adapter_key="adapter_qkv", - ), - _simple_adapter_weight( - base_prefix, - in_proj.z_lora, - adapter_key="adapter_z", - ), - _zero_adapter_weight( - base_prefix=base_prefix, - adapter_key="adapter_b", - input_dim=int(in_proj.qkv_lora.A_T.shape[-2]), - output_dim=int(in_proj.num_value_heads_per_partition), - like=in_proj.qkv_lora.B_T, - ), - _zero_adapter_weight( - base_prefix=base_prefix, - adapter_key="adapter_a", - input_dim=int(in_proj.qkv_lora.A_T.shape[-2]), - output_dim=int(in_proj.num_value_heads_per_partition), - like=in_proj.qkv_lora.B_T, - ), - ] + +def _add_moe_mlp_adapter_weights_for_layer( + adapter_weights_by_base: dict[str, list[Any]], + layer_prefix: str, + module: Any, +) -> None: + from art.megatron.model_support.handlers.default_dense import _require_moe_experts + + add_grouped_moe_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + experts=_require_moe_experts(module), + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + add_split_mlp_adapter_weights( + adapter_weights_by_base, + f"{layer_prefix}.mlp.shared_experts", + shared_experts, + ) def add_grouped_moe_adapter_weights( @@ -248,100 +319,66 @@ def add_grouped_moe_adapter_weights( experts: Any, ) -> None: linear_fc1 = getattr(experts, "linear_fc1", None) - if isinstance(linear_fc1, MLPExpertsLinearFC1FusedLoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" - for local_expert_idx in range(linear_fc1.lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc1.lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.lora, - expert_idx=local_expert_idx, - ) - ] - elif isinstance(linear_fc1, MLPExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" - for local_expert_idx in range(linear_fc1.gate_lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc1.gate_lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _fused_pair_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - linear_fc1.up_lora, - first_expert_idx=local_expert_idx, - second_expert_idx=local_expert_idx, - ) - ] + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" + if isinstance(linear_fc1, MLPExpertsLinearFC1LoRA): + if linear_fc1.fused_gate_up: + lora = linear_fc1.lora + build_weight = lambda local_expert_idx: _simple_adapter_weight( + base_prefix, + linear_fc1.lora, + expert_idx=local_expert_idx, + ) + else: + lora = linear_fc1.gate_lora + build_weight = lambda local_expert_idx: _fused_pair_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + linear_fc1.up_lora, + first_expert_idx=local_expert_idx, + second_expert_idx=local_expert_idx, + ) + _set_expert_adapter_weights( + adapter_weights_by_base, + base_prefix, + lora, + build_weight, + ) linear_fc2 = getattr(experts, "linear_fc2", None) if isinstance(linear_fc2, MLPExpertsLinearFC2LoRA): base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" - for local_expert_idx in range(linear_fc2.lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc2.lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc2.lora, - expert_idx=local_expert_idx, - ) - ] + _set_expert_adapter_weights( + adapter_weights_by_base, + base_prefix, + linear_fc2.lora, + lambda local_expert_idx: _simple_adapter_weight( + base_prefix, + linear_fc2.lora, + expert_idx=local_expert_idx, + ), + ) -def add_dense_mlp_adapter_weights( +def add_split_mlp_adapter_weights( adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, + base_prefix: str, mlp: Any, ) -> None: linear_fc1 = getattr(mlp, "linear_fc1", None) if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, - linear_fc1.up_lora, - adapter_key="adapter_up", - ), - ] + fc1_prefix = f"{base_prefix}.linear_fc1" + _set_lora_weights( + adapter_weights_by_base, + fc1_prefix, + (linear_fc1.gate_lora, "adapter_gate"), + (linear_fc1.up_lora, "adapter_up"), + ) linear_fc2 = getattr(mlp, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) - ] - - -def add_shared_experts_adapter_weights( - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - shared_experts: Any, -) -> None: - linear_fc1 = getattr(shared_experts, "linear_fc1", None) - if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, - linear_fc1.up_lora, - adapter_key="adapter_up", - ), - ] - - linear_fc2 = getattr(shared_experts, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) - ] + if isinstance(linear_fc2, SelfAttentionLinearProjLoRA): + fc2_prefix = f"{base_prefix}.linear_fc2" + _set_adapter_weights( + adapter_weights_by_base, + fc2_prefix, + _simple_adapter_weight(fc2_prefix, linear_fc2.lora), + ) diff --git a/src/art/megatron/weights/lora_publish.py b/src/art/megatron/weights/lora_publish.py index f4fd02a0a..930b45c58 100644 --- a/src/art/megatron/weights/lora_publish.py +++ b/src/art/megatron/weights/lora_publish.py @@ -1,16 +1,22 @@ from collections.abc import Iterable, Sequence -import re from typing import Any, NamedTuple import torch -from art.megatron.lora import LoRAPublishPlanner, LoraShardMeta +from art.megatron.lora import ( + LoRA, + LoRAPublishPlanner, + LoraShardMeta, + _block_for_key, + _dtype_name, +) +from art.megatron.lora import ( + _distributed_initialized as _distributed_ready, +) from art.megatron.model_support.lora_disk import save_vllm_lora_tensors from art.megatron.model_support.spec import ExpertPackedLoraGroup, ExpertPackedLoraSlot from art.megatron.training.model_chunks import ModelChunks -_LAYER_BLOCK_RE = re.compile(r"^(?P.*\.layers\.\d+)\.") - class PackedExpertShardMeta(NamedTuple): key: str @@ -58,35 +64,17 @@ def finish(self) -> None: self._events.clear() -def iter_lora_modules(model_chunks: ModelChunks) -> Iterable[Any]: +def iter_lora_modules(model_chunks: ModelChunks) -> Iterable[LoRA]: for chunk in model_chunks: for module in chunk.modules(): - yield module - - -def _dtype_name(dtype: torch.dtype) -> str: - return str(dtype).removeprefix("torch.") + if isinstance(module, LoRA): + yield module def _dtype_from_name(name: str) -> torch.dtype: - dtype = getattr(torch, name, None) - if not isinstance(dtype, torch.dtype): - raise RuntimeError(f"Unsupported LoRA tensor dtype={name!r}") - return dtype - - -def _block_for_key(key: str) -> str: - match = _LAYER_BLOCK_RE.match(key) - if match is not None: - return match.group("block") - return "__global__" - - -def _expert_prefix_projection(adapter_model_prefix: str) -> tuple[str, str] | None: - group_prefix, separator, projection = adapter_model_prefix.partition(".{expert}.") - if not separator: - return None - return group_prefix, projection + if isinstance(dtype := getattr(torch, name, None), torch.dtype): + return dtype + raise RuntimeError(f"Unsupported LoRA tensor dtype={name!r}") def _packed_expert_slot( @@ -94,10 +82,9 @@ def _packed_expert_slot( suffix: str, groups: Sequence[ExpertPackedLoraGroup], ) -> tuple[str, ExpertPackedLoraSlot] | None: - parts = _expert_prefix_projection(adapter_model_prefix) - if parts is None: + group_prefix, separator, projection = adapter_model_prefix.partition(".{expert}.") + if not separator: return None - group_prefix, projection = parts lora_name = suffix.removesuffix(".weight") for group in groups: if not group_prefix.endswith(group.art_group_suffix): @@ -109,23 +96,15 @@ def _packed_expert_slot( def _uses_packed_expert_publish( - module: Any, + module: LoRA, groups: Sequence[ExpertPackedLoraGroup], ) -> bool: - if int(getattr(module, "num_local_experts", 1)) <= 1: - return False - if not hasattr(module, "_lora_params"): + if module.num_local_experts <= 1: return False - adapter_model_prefix = getattr(module, "adapter_model_prefix", "") - if not isinstance(adapter_model_prefix, str): - return False - lora_suffixes = [ - suffix - for suffix, _param in module._lora_params() # type: ignore[attr-defined] - ] - return bool(lora_suffixes) and all( - _packed_expert_slot(adapter_model_prefix, suffix, groups) is not None - for suffix in lora_suffixes + params = tuple(module._lora_params()) + return bool(params) and all( + _packed_expert_slot(module.adapter_model_prefix, suffix, groups) is not None + for suffix, _param in params ) @@ -141,15 +120,12 @@ def collect_local_lora_entries( for module in iter_lora_modules(model_chunks): if _uses_packed_expert_publish(module, packed_expert_groups): continue - if hasattr(module, "sharded_lora_state_dict"): - module_state: dict[str, torch.Tensor] = module.sharded_lora_state_dict() # type: ignore[attr-defined] - for key, value in module_state.items(): - target_dtype = ( - adapter_model[key].dtype if key in adapter_model else value.dtype - ) - local_tensors[key] = value.to(target_dtype).contiguous() - if hasattr(module, "sharded_lora_manifest"): - local_manifest.update(module.sharded_lora_manifest()) # type: ignore[attr-defined] + for key, value in module.sharded_lora_state_dict().items(): + target_dtype = ( + adapter_model[key].dtype if key in adapter_model else value.dtype + ) + local_tensors[key] = value.to(target_dtype).contiguous() + local_manifest.update(module.sharded_lora_manifest()) if set(local_tensors) != set(local_manifest): raise RuntimeError( @@ -171,18 +147,6 @@ def collect_local_lora_entries( return local_tensors, metadata -def _target_dtype_for_lora_param( - module: Any, - adapter_model: dict[str, torch.Tensor], - suffix: str, - fallback: torch.dtype, -) -> torch.dtype: - keys = module._expected_weight_keys(suffix.removesuffix(".weight")) # type: ignore[attr-defined] - return ( - adapter_model[keys[0]].dtype if keys and keys[0] in adapter_model else fallback - ) - - def collect_local_packed_expert_entries( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], @@ -195,25 +159,24 @@ def collect_local_packed_expert_entries( for module in iter_lora_modules(model_chunks): if not _uses_packed_expert_publish(module, packed_expert_groups): continue - adapter_model_prefix = module.adapter_model_prefix # type: ignore[attr-defined] - expert_start = int(module._expert_offset) # type: ignore[attr-defined] - expert_count = int(module.num_local_experts) # type: ignore[attr-defined] - for suffix, param in module._lora_params(): # type: ignore[attr-defined] + expert_start = int(module._expert_offset) + expert_count = int(module.num_local_experts) + for suffix, param in module._lora_params(): slot_match = _packed_expert_slot( - adapter_model_prefix, + module.adapter_model_prefix, suffix, packed_expert_groups, ) - if slot_match is None or not module._should_export_parameter(param): # type: ignore[attr-defined] + if slot_match is None or not module._should_export_parameter(param): continue group_prefix, slot = slot_match key = f"{group_prefix}.{slot.output_suffix}" tensor = param.data.transpose(1, 2).contiguous() - target_dtype = _target_dtype_for_lora_param( - module, - adapter_model, - suffix, - tensor.dtype, + source_keys = module._expected_weight_keys(suffix.removesuffix(".weight")) + target_dtype = ( + adapter_model[source_keys[0]].dtype + if source_keys and source_keys[0] in adapter_model + else tensor.dtype ) tensor = tensor.to(target_dtype).contiguous() if key in local_tensors: @@ -225,7 +188,7 @@ def collect_local_packed_expert_entries( owner_rank=owner_rank, shape=tuple(int(dim) for dim in tensor.shape), dtype_name=_dtype_name(tensor.dtype), - manifest=module._manifest_for_param(param), # type: ignore[attr-defined] + manifest=module._manifest_for_param(param), expert_start=expert_start, expert_count=expert_count, pack_layout=slot.pack_layout, @@ -252,7 +215,12 @@ def _global_packed_expert_metadata( continue group_prefix, slot = slot_match shard_ranks = range(template.shard_world_size) if template.sharded else (0,) - for ep_rank in range(planner._expert_model_world_size()): + ep_world_size = 1 + if _distributed_ready(): + from megatron.core import parallel_state as ps + + ep_world_size = ps.get_expert_model_parallel_world_size() + for ep_rank in range(ep_world_size): expert_start = ep_rank * template.num_local_experts expert_key = ( f"{template.adapter_model_prefix.format(expert=expert_start)}." @@ -350,70 +318,66 @@ def _merge_sharded_tensor( return torch.cat(tuple(ordered_shards), dim=axis).contiguous() -def merge_sharded_adapter_entries( - entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]], -) -> dict[str, torch.Tensor]: - adapter_model: dict[str, torch.Tensor] = {} - for key, key_entries in entries_by_key.items(): - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - for manifest_entry, _tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for key={key}") - - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated key={key} expected 1 shard, got {len(key_entries)}" - ) - adapter_model[key] = key_entries[0][1] - continue - - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError(f"Duplicate shard_rank={shard_rank} for key={key}") - shard_rank_to_tensor[shard_rank] = shard_tensor +def _merge_manifest_entries( + key: str, + key_entries: Sequence[tuple[dict[str, Any], torch.Tensor]], + *, + manifest: dict[str, Any] | None = None, +) -> torch.Tensor: + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + for entry_manifest, _tensor in key_entries: + if bool(entry_manifest["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(entry_manifest["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor) != expected_shard_ranks: + if not sharded: + if len(key_entries) != 1: raise RuntimeError( - f"Shard rank coverage mismatch for key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" ) + return key_entries[0][1] - ordered_shards = [ - shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) - ] - adapter_model[key] = _merge_sharded_tensor( - key, - ordered_shards=ordered_shards, - manifest=first_manifest, + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for entry_manifest, shard_tensor in key_entries: + shard_rank = int(entry_manifest["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError(f"Duplicate shard_rank={shard_rank} for key={key}") + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" ) - return adapter_model + return _merge_sharded_tensor( + key, + ordered_shards=[ + shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) + ], + manifest=first_manifest if manifest is None else manifest, + ) -def _distributed_ready() -> bool: - is_initialized = getattr(torch.distributed, "is_initialized", None) - return ( - torch.distributed.is_available() - and callable(is_initialized) - and bool(is_initialized()) - ) +def merge_sharded_adapter_entries( + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]], +) -> dict[str, torch.Tensor]: + return { + key: _merge_manifest_entries(key, key_entries) + for key, key_entries in entries_by_key.items() + } def _rank_and_device() -> tuple[int, torch.device]: - if _distributed_ready(): - rank = torch.distributed.get_rank() # type: ignore[possibly-missing-attribute] - else: - rank = 0 - if torch.cuda.is_available(): - return rank, torch.device("cuda", torch.cuda.current_device()) - return rank, torch.device("cpu") + return ( + torch.distributed.get_rank() if _distributed_ready() else 0, # type: ignore[possibly-missing-attribute] + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu"), + ) def _metadata_by_owner_dtype( @@ -514,45 +478,10 @@ def _merge_packed_expert_block( key: str, key_entries: list[tuple[dict[str, Any], torch.Tensor]], ) -> torch.Tensor: - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated packed key={key} expected 1 shard, got {len(key_entries)}" - ) - return key_entries[0][1] - - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for packed key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for packed key={key}") - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError( - f"Duplicate shard_rank={shard_rank} for packed key={key}" - ) - shard_rank_to_tensor[shard_rank] = shard_tensor - - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor) != expected_shard_ranks: - raise RuntimeError( - f"Shard rank coverage mismatch for packed key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" - ) - - manifest = dict(first_manifest) - manifest["export_shard_dim"] = int(manifest["export_shard_dim"]) + 1 - return _merge_sharded_tensor( - key, - ordered_shards=[ - shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) - ], - manifest=manifest, - ) + manifest = dict(key_entries[0][0]) + if bool(manifest["sharded"]): + manifest["export_shard_dim"] = int(manifest["export_shard_dim"]) + 1 + return _merge_manifest_entries(key, key_entries, manifest=manifest) def _pack_merged_expert_blocks( diff --git a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py index 3d3d51d4c..5b6e39390 100644 --- a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py +++ b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py @@ -60,9 +60,7 @@ def test_shared_prefix_attention_matches_flattened_grad_accumulation() -> None: tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) q, k, v = _attention_inputs(group_ids.shape, seed=20260425) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) @@ -82,40 +80,19 @@ def test_shared_prefix_attention_matches_flattened_grad_accumulation() -> None: ref_out = torch.zeros_like(packed_out) ref_loss = q_ref.new_zeros(()) - for family in spec.families: - prefix = family.prefix - prefix_grad_used = False - for completion in family.completions: - indices = torch.tensor( - [ - *range(prefix.start, prefix.end), - *range(completion.start, completion.end), - ], - device=q.device, - dtype=torch.long, - ) - row = family.row_index - q_slice = q_ref[row : row + 1].index_select(2, indices) - k_slice = k_ref[row : row + 1].index_select(2, indices) - v_slice = v_ref[row : row + 1].index_select(2, indices) - flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) - - ref_out[row, :, completion.start : completion.end] = flat_out[ - 0, :, prefix.length : - ] - flat_grad = torch.zeros_like(flat_out) - flat_grad[0, :, prefix.length :] = output_grad[ - row, :, completion.start : completion.end - ] - if not prefix_grad_used: - ref_out[row, :, prefix.start : prefix.end] = flat_out[ - 0, :, : prefix.length - ] - flat_grad[0, :, : prefix.length] = output_grad[ - row, :, prefix.start : prefix.end - ] - prefix_grad_used = True - ref_loss = ref_loss + (flat_out * flat_grad).sum() + for segment_index, segment in enumerate(spec.tree_segments): + indices, output_slice = _segment_context_positions(spec, segment_index) + index_tensor = torch.tensor(indices, device=q.device, dtype=torch.long) + row = segment.row_index + q_slice = q_ref[row : row + 1].index_select(2, index_tensor) + k_slice = k_ref[row : row + 1].index_select(2, index_tensor) + v_slice = v_ref[row : row + 1].index_select(2, index_tensor) + flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) + + ref_out[row, :, segment.start : segment.end] = flat_out[0, :, output_slice] + flat_grad = torch.zeros_like(flat_out) + flat_grad[0, :, output_slice] = output_grad[row, :, segment.start : segment.end] + ref_loss = ref_loss + (flat_out * flat_grad).sum() ref_loss.backward() real_mask = _real_token_mask(spec, q.shape, device=q.device) @@ -142,9 +119,7 @@ def test_physical_causal_attention_leaks_across_siblings() -> None: tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) q, k, v = _attention_inputs(group_ids.shape, seed=20260427) attention_state = create_shared_prefix_state(group_ids, parent_ids) packed_out = FlexAttentionWrapper()( @@ -225,11 +200,27 @@ def _completion_token_mask( spec: Any, shape: torch.Size, *, device: torch.device ) -> torch.Tensor: mask = torch.zeros(shape, device=device, dtype=torch.bool) - for family in spec.families: - for completion in family.completions: - mask[ - family.row_index, - :, - completion.start : completion.end, - ] = True + for index, segment in enumerate(spec.tree_segments): + if spec.tree_parent_indices[index] >= 0: + mask[segment.row_index, :, segment.start : segment.end] = True return mask + + +def _segment_context_positions( + spec: Any, segment_index: int +) -> tuple[list[int], slice]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + path.reverse() + positions = [ + position + for index in path + for position in range( + spec.tree_segments[index].start, spec.tree_segments[index].end + ) + ] + segment_length = spec.tree_segments[segment_index].length + return positions, slice(len(positions) - segment_length, len(positions)) diff --git a/tests/integration/megatron/gdn_shared_prefix/distributed_init.py b/tests/integration/megatron/gdn_shared_prefix/distributed_init.py new file mode 100644 index 000000000..b9b4075c8 --- /dev/null +++ b/tests/integration/megatron/gdn_shared_prefix/distributed_init.py @@ -0,0 +1,7 @@ +from pathlib import Path + + +def file_init_method(tmp_path: Path, name: str) -> str: + path = tmp_path / f"{name}.dist" + path.unlink(missing_ok=True) + return f"file://{path}" diff --git a/tests/integration/megatron/gdn_shared_prefix/layout_reference.py b/tests/integration/megatron/gdn_shared_prefix/layout_reference.py index 7369eaef7..8cff82405 100644 --- a/tests/integration/megatron/gdn_shared_prefix/layout_reference.py +++ b/tests/integration/megatron/gdn_shared_prefix/layout_reference.py @@ -19,7 +19,7 @@ class TestGdnCpLayoutPlan(BaseModel): - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) batch_size: int = Field(ge=1) sequence_length: int = Field(ge=1) @@ -39,9 +39,7 @@ def build_test_gdn_cp_layout_plan( gdn_token_ranges_by_rank: Sequence[Sequence[tuple[int, int, int]]] | None = None, device: torch.device | str | None = None, ) -> TestGdnCpLayoutPlan: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) gdn_ranges = ( _normalize_rank_ranges(gdn_token_ranges_by_rank, cp_size=cp_size) if gdn_token_ranges_by_rank is not None @@ -89,7 +87,7 @@ def _build_full_exchange_plan( ) for transfer in local_plan.transfers: transfers.setdefault((transfer.source_rank, transfer.dest_rank), transfer) - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=len(source_layout.token_counts_by_rank), source_token_counts_by_rank=source_layout.token_counts_by_rank, dest_token_counts_by_rank=tuple( @@ -133,7 +131,7 @@ def _split_gdn_token_ranges_by_rank( _segment_token_start(segment, spec.sequence_length), _segment_token_start(segment, spec.sequence_length) + segment.length, ) - for segment in spec.segments() + for segment in spec.tree_segments ), cp_size=cp_size, ) diff --git a/tests/integration/megatron/gdn_shared_prefix/oracles.py b/tests/integration/megatron/gdn_shared_prefix/oracles.py index 3d3f9ae12..019ec74e7 100644 --- a/tests/integration/megatron/gdn_shared_prefix/oracles.py +++ b/tests/integration/megatron/gdn_shared_prefix/oracles.py @@ -7,6 +7,8 @@ from torch import Tensor import torch.nn.functional as F +from art.megatron.gdn.gdn_shared_prefix import GdnPackedExecutionSpec, GdnSegmentSpec + from .metrics import ( mean_abs_pct, parameter_grad_mean_abs_pct_with_name, @@ -107,27 +109,27 @@ def run_toy_packed( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_out, prefix_conv, prefix_rec = module.forward_segment( - prefix_hidden, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), + conv_states: list[Tensor] = [] + rec_states: list[Tensor] = [] + for segment_index, segment in enumerate(spec.tree_segments): + row = segment.row_index + parent_index = spec.tree_parent_indices[segment_index] + if parent_index < 0: + conv_initial = module.zero_conv_state(hidden) + rec_initial = module.zero_recurrent_state(hidden) + else: + conv_initial = conv_states[parent_index] + rec_initial = rec_states[parent_index] + segment_out, conv_final, rec_final = module.forward_segment( + hidden[row, segment.start : segment.end], + conv_initial=conv_initial, + recurrent_initial=rec_initial, ) - output[row, family.prefix.start : family.prefix.end] = prefix_out - for completion in family.completions: - suffix_hidden = hidden[row, completion.start : completion.end] - suffix_out, _, _ = module.forward_segment( - suffix_hidden, - conv_initial=prefix_conv, - recurrent_initial=prefix_rec, - ) - output[row, completion.start : completion.end] = suffix_out + output[row, segment.start : segment.end] = segment_out + conv_states.append(conv_final) + rec_states.append(rec_final) return output @@ -138,30 +140,36 @@ def run_toy_flattened_reference( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden[row, completion.start : completion.end] - flattened = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _ = module.forward_segment( - flattened, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), - ) - if child_index == 0: - output[row, family.prefix.start : family.prefix.end] = flat_out[ - :prefix_len - ] - output[row, completion.start : completion.end] = flat_out[prefix_len:] + for segment_index, segment in enumerate(spec.tree_segments): + path = _segment_path(spec, segment_index) + flattened = torch.cat( + [hidden[node.row_index, node.start : node.end] for node in path], + dim=0, + ) + flat_out, _, _ = module.forward_segment( + flattened, + conv_initial=module.zero_conv_state(hidden), + recurrent_initial=module.zero_recurrent_state(hidden), + ) + segment_len = segment.length + output[segment.row_index, segment.start : segment.end] = flat_out[-segment_len:] return output +def _segment_path( + spec: GdnPackedExecutionSpec, + segment_index: int, +) -> tuple[GdnSegmentSpec, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def run_toy_physical_stream( module: ToyStatefulGdn, hidden: Tensor, diff --git a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py index 45a41ff58..fa1b00d05 100644 --- a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py +++ b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py @@ -137,19 +137,23 @@ def summarize_case( conv_width: int, cp_sizes: tuple[int, ...] = (2, 4, 8), ) -> GdnCaseSummary: - spec = parse_gdn_shared_prefix_segments( - tensors["group_ids"], tensors["parent_ids"], min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(tensors["group_ids"], tensors["parent_ids"]) suffix_lengths = [ - segment.length for family in spec.families for segment in family.completions + segment.length + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] >= 0 ] boundary = _boundary_flags(spec, cp_sizes) return GdnCaseSummary( name=case.name, total_tokens=spec.real_token_count, family_count=spec.family_count, - completion_count=spec.completion_count, - max_segment_length=spec.max_segment_length, + completion_count=sum( + 1 for parent_index in spec.tree_parent_indices if parent_index >= 0 + ), + max_segment_length=max( + (segment.length for segment in spec.tree_segments), default=0 + ), suffix_shorter_than_conv=any(length < conv_width for length in suffix_lengths), suffix_equal_to_conv=any(length == conv_width for length in suffix_lengths), suffix_longer_than_conv=any(length > conv_width for length in suffix_lengths), @@ -227,19 +231,49 @@ def _boundary_flags( boundaries = {shard * rank for rank in range(1, cp_size)} if shard * (cp_size - 1) >= spec.real_token_count: flags["empty_trailing_rank"] = True - for family in spec.families: - family_start = _segment_real_start(family.prefix, spec, real_index) - family_end = _segment_real_end(family.completions[-1], spec, real_index) + for root in _root_segments(spec): + descendants = _descendant_segments(spec, root.family_index) + family_segments = (root, *descendants) + family_start = min( + _segment_real_start(segment, spec, real_index) + for segment in family_segments + ) + family_end = max( + _segment_real_end(segment, spec, real_index) + for segment in family_segments + ) if family_start in boundaries or family_end in boundaries: flags["family_boundary_at_partition"] = True - if _crosses_boundary(family.prefix, spec, real_index, boundaries): + if _crosses_boundary(root, spec, real_index, boundaries): flags["cp_boundary_prefix"] = True - for completion in family.completions: + for completion in descendants: if _crosses_boundary(completion, spec, real_index, boundaries): flags["cp_boundary_suffix"] = True return flags +def _root_segments(spec: GdnPackedExecutionSpec) -> tuple[Any, ...]: + return tuple( + segment + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] < 0 + ) + + +def _descendant_segments( + spec: GdnPackedExecutionSpec, root_index: int +) -> tuple[Any, ...]: + descendants = [] + for index, segment in enumerate(spec.tree_segments): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + descendants.append(segment) + break + parent = spec.tree_parent_indices[parent] + return tuple(descendants) + + def _segment_real_start( segment: Any, spec: GdnPackedExecutionSpec, real_index: dict[int, int] ) -> int: diff --git a/tests/integration/megatron/gdn_shared_prefix/parser_import.py b/tests/integration/megatron/gdn_shared_prefix/parser_import.py index ce184d96e..3a473ebf3 100644 --- a/tests/integration/megatron/gdn_shared_prefix/parser_import.py +++ b/tests/integration/megatron/gdn_shared_prefix/parser_import.py @@ -24,6 +24,5 @@ def _load_parser_module() -> ModuleType: _MODULE = _load_parser_module() GdnPackedExecutionSpec: Any = _MODULE.GdnPackedExecutionSpec -build_gdn_cp_segment_schedule: Any = _MODULE.build_gdn_cp_segment_schedule build_gdn_rank_execution_plan: Any = _MODULE.build_gdn_rank_execution_plan parse_gdn_shared_prefix_segments: Any = _MODULE.parse_gdn_shared_prefix_segments diff --git a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py index e69fef22b..38fb01889 100644 --- a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py +++ b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Any, Literal, NamedTuple from pydantic import BaseModel, ConfigDict import torch @@ -61,6 +61,57 @@ class GdnChainBoundaryDebug(BaseModel): ] +class _TreeFamily(NamedTuple): + row_index: int + family_index: int + prefix: Any + completions: tuple[Any, ...] + segment_indices: tuple[int, ...] + parent_indices: tuple[int, ...] + + @property + def token_count(self) -> int: + return self.prefix.length + sum(segment.length for segment in self.completions) + + +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(path)) + + +def _tree_families(spec: Any) -> tuple[_TreeFamily, ...]: + families = [] + for root_index, root in enumerate(spec.tree_segments): + if spec.tree_parent_indices[root_index] >= 0: + continue + segment_indices = [root_index] + for index in range(root_index + 1, len(spec.tree_segments)): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + segment_indices.append(index) + break + parent = spec.tree_parent_indices[parent] + segments = tuple(spec.tree_segments[index] for index in segment_indices) + families.append( + _TreeFamily( + row_index=root.row_index, + family_index=root_index, + prefix=root, + completions=segments[1:], + segment_indices=tuple(segment_indices), + parent_indices=tuple( + spec.tree_parent_indices[index] for index in segment_indices + ), + ) + ) + return tuple(families) + + def compare_real_gdn_cp1_to_flattened( *, packed_gdn: Any, @@ -296,35 +347,34 @@ def run_real_gdn_flattened_reference( parent_ids: Tensor, execution_spec: Any | None = None, ) -> Tensor: - spec = execution_spec or parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = execution_spec or parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden_states[ - family.prefix.start : family.prefix.end, row : row + 1, : - ] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden_states[ - completion.start : completion.end, row : row + 1, : - ] - flat_hidden = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _, _ = _run_gdn_segment( - gdn, - flat_hidden, - conv_initial=_zero_conv_state(gdn, hidden_states, row), - recurrent_initial=_zero_recurrent_state(gdn, hidden_states, row), - output_final_state=False, - ) - if child_index == 0: - output[family.prefix.start : family.prefix.end, row : row + 1, :] = ( - flat_out[:prefix_len] - ) - output[completion.start : completion.end, row : row + 1, :] = flat_out[ - prefix_len: - ] + for segment_index, segment in enumerate(spec.tree_segments): + flat_hidden = torch.cat( + [ + hidden_states[ + node.start : node.end, + node.row_index : node.row_index + 1, + :, + ] + for node in _segment_path(spec, segment_index) + ], + dim=0, + ) + flat_out, _, _, _ = _run_gdn_segment( + gdn, + flat_hidden, + conv_initial=_zero_conv_state(gdn, hidden_states, segment.row_index), + recurrent_initial=_zero_recurrent_state( + gdn, hidden_states, segment.row_index + ), + output_final_state=False, + ) + output[ + segment.start : segment.end, + segment.row_index : segment.row_index + 1, + :, + ] = flat_out[-segment.length :] return output @@ -359,9 +409,7 @@ def run_real_gdn_local_fork_reference( cp_size: int, attention_token_layout_index: TokenLayoutIndex | None = None, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) gdn_token_indices_by_rank = _split_gdn_families_by_rank(spec, cp_size=cp_size) gdn_token_ranges_by_rank = _rank_ranges_from_tokens_by_rank( gdn_token_indices_by_rank @@ -414,12 +462,12 @@ def _split_gdn_families_by_rank( raise ValueError(f"cp_size must be >= 1, got {cp_size}") ranks: list[list[int]] = [[] for _ in range(cp_size)] loads = [0] * cp_size - for family in spec.families: + for family in _tree_families(spec): rank = min(range(cp_size), key=lambda index: (loads[index], index)) family_tokens = tuple( token for segment in (family.prefix, *family.completions) - for token in segment.linear_indices(spec.sequence_length) + for token in _segment_linear_indices(segment, spec.sequence_length) ) ranks[rank].extend(family_tokens) loads[rank] += len(family_tokens) @@ -471,6 +519,11 @@ def _simulate_all_to_all_single( return tuple(outputs) +def _segment_linear_indices(segment: Any, sequence_length: int) -> range: + base = int(segment.row_index) * int(sequence_length) + return range(base + int(segment.start), base + int(segment.end)) + + def _transfer_positions(tensor: Tensor | None, *, count: int) -> tuple[int, ...]: if tensor is None: return tuple(range(count)) @@ -523,11 +576,9 @@ def run_real_gdn_suffix_only_chain_reference( mutation: GdnChainMutation | None = None, boundary_debug: list[GdnChainBoundaryDebug] | None = None, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): row = family.row_index zero_conv = _zero_conv_state(gdn, hidden_states, batch_size=1) zero_rec = _zero_recurrent_state(gdn, hidden_states, batch_size=1) @@ -575,11 +626,9 @@ def run_real_gdn_chunk_native_reference( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): _scatter_family_output( output, family, @@ -597,13 +646,11 @@ def run_real_gdn_mixed_cp_reference( cp_size: int, local_fork_max_tokens: int, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) local_count = 0 chain_count = 0 - for family in spec.families: + for family in _tree_families(spec): if family.token_count <= local_fork_max_tokens: local_count += 1 _scatter_family_output( @@ -753,14 +800,21 @@ def _family_group_tensors( ) -> tuple[Tensor, Tensor]: group_ids = [] parent_ids = [] - prefix_group_id = 0 - group_ids.extend([prefix_group_id] * family.prefix.length) - parent_ids.extend([prefix_group_id] * family.prefix.length) - next_group_id = 1 - for completion in family.completions: - group_ids.extend([next_group_id] * completion.length) - parent_ids.extend([prefix_group_id] * completion.length) - next_group_id += 1 + local_group_by_global: dict[int, int] = {} + for local_group_id, (segment, global_index, parent_index) in enumerate( + zip( + (family.prefix, *family.completions), + family.segment_indices, + family.parent_indices, + strict=True, + ) + ): + local_group_by_global[global_index] = local_group_id + local_parent_id = ( + local_group_id if parent_index < 0 else local_group_by_global[parent_index] + ) + group_ids.extend([local_group_id] * segment.length) + parent_ids.extend([local_parent_id] * segment.length) return ( torch.tensor([group_ids], device=device, dtype=torch.long), torch.tensor([parent_ids], device=device, dtype=torch.long), @@ -883,12 +937,12 @@ def _local_fork_group_tensors( ) parent_ids = torch.full_like(group_ids, -1) next_group_id = 0 - for family in spec.families: + for family in _tree_families(spec): family_segments = (family.prefix, *family.completions) family_tokens = tuple( token_index for segment in family_segments - for token_index in segment.linear_indices(spec.sequence_length) + for token_index in _segment_linear_indices(segment, spec.sequence_length) ) token_is_local = tuple( token_index in local_position for token_index in family_tokens @@ -898,19 +952,23 @@ def _local_fork_group_tensors( if not all(token_is_local): raise ValueError("local-fork execution requires whole prompt families") - prefix_group_id = next_group_id - next_group_id += 1 - for token_index in family.prefix.linear_indices(spec.sequence_length): - position = local_position[token_index] - group_ids[position] = prefix_group_id - parent_ids[position] = prefix_group_id - for completion in family.completions: - child_group_id = next_group_id + group_by_segment_index: dict[int, int] = {} + for segment, global_index, parent_index in zip( + family_segments, + family.segment_indices, + family.parent_indices, + strict=True, + ): + group_id = next_group_id next_group_id += 1 - for token_index in completion.linear_indices(spec.sequence_length): + group_by_segment_index[global_index] = group_id + parent_group_id = ( + group_id if parent_index < 0 else group_by_segment_index[parent_index] + ) + for token_index in _segment_linear_indices(segment, spec.sequence_length): position = local_position[token_index] - group_ids[position] = child_group_id - parent_ids[position] = prefix_group_id + group_ids[position] = group_id + parent_ids[position] = parent_group_id if torch.any(group_ids == -1): raise RuntimeError("local-fork metadata left unassigned token rows") return group_ids.unsqueeze(0), parent_ids.unsqueeze(0) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py b/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py index bcf3a0cfb..6f5eefc17 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import Any, cast import pytest @@ -20,6 +19,7 @@ chunk_gated_delta_rule_native_cp, ) +from .distributed_init import file_init_method # noqa: E402 from .metrics import GDN_CORRECTNESS_DTYPE, assert_mean_abs_pct # noqa: E402 _CP_SIZES = ( @@ -43,10 +43,10 @@ def test_native_fla_cp_recurrent_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_fla_recurrent_cp{cp_size}") mp.spawn( _native_fla_cp_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -62,10 +62,10 @@ def test_native_fla_cp_recurrent_matches_single_rank( def test_native_fla_cp_recurrent_varlen_multichain_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_fla_varlen_cp{cp_size}") mp.spawn( _native_fla_cp_varlen_multichain_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -119,13 +119,13 @@ def test_native_fla_summary_affine_debug_matches_final_state() -> None: def _native_fla_cp_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -201,13 +201,13 @@ def _native_fla_cp_worker( def _native_fla_cp_varlen_multichain_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -519,9 +519,3 @@ def _cat_varlen_slices( def _assert_grad_close(left: torch.Tensor, right_grad: torch.Tensor, name: str) -> None: assert left.grad is not None, name assert_mean_abs_pct(right_grad, left.grad, name) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 2151b41e1..ac14e8df8 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -2,7 +2,6 @@ from collections.abc import Callable from pathlib import Path -import socket from typing import Any import pytest @@ -21,6 +20,7 @@ parse_gdn_shared_prefix_segments, ) from art.megatron.gdn.operator import run_gdn_layer # noqa: E402 +from art.megatron.shared_prefix_packing import pack_shared_prefixes # noqa: E402 from .cases import ( # noqa: E402 GdnFamilyShape, @@ -29,6 +29,7 @@ default_phase0_cases, ) from .distributed_grad import all_reduce_parameter_grads_coalesced # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import ( # noqa: E402 GDN_CORRECTNESS_DTYPE, REAL_GDN_GRAD_MEAN_ABS_PCT_THRESHOLD, @@ -50,10 +51,10 @@ def test_gdn_cp_packed_matches_cp1_oracle_all_edge_cases( cp_size: int, tmp_path: Path ) -> None: _skip_without_gpus(cp_size) - port = _find_free_port() + init_method = file_init_method(tmp_path, f"cp1_oracle_cp{cp_size}") mp.spawn( _cp1_oracle_worker, - args=(cp_size, port, str(tmp_path), False), + args=(cp_size, init_method, str(tmp_path), False), nprocs=cp_size, join=True, ) @@ -66,10 +67,10 @@ def test_gdn_cp_packed_sibling_order_matches_cp1_oracle( cp_size: int, tmp_path: Path ) -> None: _skip_without_gpus(cp_size) - port = _find_free_port() + init_method = file_init_method(tmp_path, f"cp1_oracle_sibling_cp{cp_size}") mp.spawn( _cp1_oracle_worker, - args=(cp_size, port, str(tmp_path), True), + args=(cp_size, init_method, str(tmp_path), True), nprocs=cp_size, join=True, ) @@ -77,17 +78,45 @@ def test_gdn_cp_packed_sibling_order_matches_cp1_oracle( assert (tmp_path / f"cp1_oracle_sibling_rank_{rank}.ok").read_text() == "ok\n" +@pytest.mark.parametrize("cp_size", (2, 4)) +def test_gdn_cp_tree_chain_matches_cp1_oracle(cp_size: int, tmp_path: Path) -> None: + _skip_without_gpus(cp_size) + init_method = file_init_method(tmp_path, f"tree_chain_cp{cp_size}") + mp.spawn( + _tree_chain_oracle_worker, + args=(cp_size, init_method, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_chain_rank_{rank}.ok").read_text() == "ok\n" + + +def test_gdn_cp_tree_fuzz_matches_cp1_oracle(tmp_path: Path) -> None: + cp_size = 4 + _skip_without_gpus(cp_size) + init_method = file_init_method(tmp_path, "tree_fuzz_cp4") + mp.spawn( + _tree_fuzz_oracle_worker, + args=(cp_size, init_method, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_fuzz_rank_{rank}.ok").read_text() == "ok\n" + + def _cp1_oracle_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, sibling_only: bool, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -126,6 +155,86 @@ def _cp1_oracle_worker( destroy_process_group() +def _tree_chain_oracle_worker( + rank: int, + cp_size: int, + init_method: str, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + _assert_tree_pack_matches_cp1( + "tree_chain", + ref_gdn, + cp_gdn, + _tree_chain_pack(), + rank=rank, + cp_size=cp_size, + seed=9090, + planner_config=_tree_chain_planner_config(), + require_chain=True, + ) + Path(output_dir, f"tree_chain_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _tree_fuzz_oracle_worker( + rank: int, + cp_size: int, + init_method: str, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + for case_index, (name, pack) in enumerate(_tree_fuzz_packs()): + _assert_tree_pack_matches_cp1( + name, + ref_gdn, + cp_gdn, + pack, + rank=rank, + cp_size=cp_size, + seed=9190 + case_index, + planner_config=_tree_fuzz_planner_config(), + require_chain=False, + ) + torch.distributed.barrier() + Path(output_dir, f"tree_fuzz_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + def _assert_case_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -141,9 +250,7 @@ def _assert_case_matches_cp1( tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -212,6 +319,81 @@ def _assert_case_matches_cp1( ) +def _assert_tree_pack_matches_cp1( + name: str, + ref_gdn: torch.nn.Module, + cp_gdn: torch.nn.Module, + pack: Any, + *, + rank: int, + cp_size: int, + seed: int, + planner_config: GdnPlannerConfig, + require_chain: bool, +) -> None: + zero_parameter_grads(ref_gdn) + zero_parameter_grads(cp_gdn) + group_ids = pack.group_ids.cuda() + parent_ids = pack.parent_ids.cuda() + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) + plan = build_gdn_rank_execution_plan( + spec, + device=group_ids.device, + cp_rank=rank, + cp_size=cp_size, + planner_config=planner_config, + ) + if require_chain: + assert any(plan.tree_chain_buckets_by_depth) + hidden, output_grad = _tree_hidden_and_grad(spec.real_token_count, seed=seed) + ref_hidden = hidden.clone().detach().requires_grad_(True) + ref_out, _ = run_gdn_layer( + ref_gdn, + ref_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + ) + ref_loss = (ref_out * output_grad).sum() + ref_loss.backward() + + flat_hidden = hidden.transpose(0, 1).reshape(-1, hidden.shape[-1]) + flat_grad = output_grad.transpose(0, 1).reshape(-1, output_grad.shape[-1]) + local_index = torch.tensor( + plan.attention_token_indices, device=hidden.device, dtype=torch.long + ) + local_hidden = ( + flat_hidden.index_select(0, local_index) + .unsqueeze(1) + .contiguous() + .detach() + .requires_grad_(True) + ) + local_output_grad = flat_grad.index_select(0, local_index).unsqueeze(1).contiguous() + cp_out, _ = run_gdn_layer( + cp_gdn, + local_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=spec, + execution_plan=plan, + cp_group=torch.distributed.group.WORLD, + ) + cp_loss = (cp_out * local_output_grad).sum() + cp_loss.backward() + _assert_cp_matches_reference( + name, + ref_gdn, + cp_gdn, + ref_hidden, + ref_out, + ref_loss.detach(), + local_hidden, + cp_out, + cp_loss.detach(), + local_index, + ) + + def _assert_sibling_order_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -233,9 +415,7 @@ def _assert_sibling_order_matches_cp1( swapped_parent_ids[0, 5:9] = 0 swapped_group_ids[0, 9:12] = 2 swapped_parent_ids[0, 9:12] = 0 - spec = parse_gdn_shared_prefix_segments( - swapped_group_ids, swapped_parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(swapped_group_ids, swapped_parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -377,6 +557,126 @@ def _hidden_and_grad( return hidden, grad +def _tree_hidden_and_grad( + sequence_length: int, *, seed: int +) -> tuple[torch.Tensor, torch.Tensor]: + generator = torch.Generator(device="cuda").manual_seed(seed) + hidden = torch.randn( + sequence_length, + 1, + 64, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + grad = torch.randn( + hidden.shape, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + torch.distributed.broadcast(hidden, src=0) + torch.distributed.broadcast(grad, src=0) + return hidden, grad + + +def _tree_chain_pack(): + long_root = torch.arange(11, 267) + short_root = torch.arange(1001, 1097) + long_mid = torch.arange(2001, 2641) + other_mid = torch.arange(3001, 3065) + return pack_shared_prefixes( + ( + torch.cat((long_root, torch.tensor([301]))), + torch.cat((long_root, torch.tensor([302]))), + torch.cat((short_root, long_mid, torch.tensor([401]))), + torch.cat((short_root, long_mid, torch.tensor([402]))), + torch.cat((short_root, other_mid, torch.tensor([403]))), + ), + max_depth=2, + ) + + +def _tree_chain_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=16, + cp_chain_min_total_tokens=128, + cp_chain_min_prefix_only_tokens=128, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_packs() -> tuple[tuple[str, Any], ...]: + return ( + ( + "tree_fuzz_duplicates", + pack_shared_prefixes(_duplicate_tree_sequences(), max_depth=4), + ), + ( + "tree_fuzz_ragged_depth4", + pack_shared_prefixes(_random_tree_sequences(13, max_depth=4), max_depth=4), + ), + ( + "tree_fuzz_mixed_tiny_long", + pack_shared_prefixes(_random_tree_sequences(29, max_depth=5), max_depth=5), + ), + ) + + +def _duplicate_tree_sequences() -> tuple[torch.Tensor, ...]: + root = torch.arange(11, 331) + mid_a = torch.arange(1001, 1261) + mid_b = torch.arange(2001, 2065) + leaf_a = torch.arange(3001, 3013) + leaf_b = torch.arange(4001, 4017) + first = torch.cat((root, mid_a, leaf_a)) + second = torch.cat((root, mid_a, leaf_b)) + third = torch.cat((root, mid_b, torch.tensor([91, 92, 93]))) + return (first, first, second, third, third) + + +def _random_tree_sequences(seed: int, *, max_depth: int) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + 997 + return out + + def segment_length(depth: int) -> int: + choices = (1, 3, 17, 64, 129, 257, 384 if depth == 0 else 96) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + here = torch.cat((prefix, tokens(segment_length(depth)))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 17)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) + + def _packed_correctness_cases() -> tuple[GdnPhase0Case, ...]: return ( *default_phase0_cases(conv_width=2), @@ -450,9 +750,3 @@ def _swap_siblings(tensor: torch.Tensor) -> torch.Tensor: def _skip_without_gpus(cp_size: int) -> None: if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: pytest.skip(f"Need {cp_size} CUDA devices for CP{cp_size} packed GDN.") - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py index e0d2e831f..6ef5a8890 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import Any, cast import pytest @@ -24,6 +23,7 @@ from art.preprocessing.pack import PackedTensors # noqa: E402 from .cases import default_phase0_cases # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .packed_layout import build_phase0_packed_tensors # noqa: E402 @@ -31,10 +31,10 @@ def test_gdn_cp_training_batch_carries_prebuilt_rank_plan(tmp_path: Path) -> Non cp_size = 2 if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: pytest.skip(f"requires {cp_size} CUDA devices") - port = _find_free_port() + init_method = file_init_method(tmp_path, "gdn_cp_train_prepare") mp.spawn( _worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -42,11 +42,11 @@ def test_gdn_cp_training_batch_carries_prebuilt_rank_plan(tmp_path: Path) -> Non assert (tmp_path / f"rank_{rank}.ok").read_text() == "ok\n" -def _worker(rank: int, cp_size: int, port: int, output_dir: str) -> None: +def _worker(rank: int, cp_size: int, init_method: str, output_dir: str) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -101,12 +101,6 @@ def _worker(rank: int, cp_size: int, port: int, output_dir: str) -> None: destroy_process_group() -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) - - def test_main_loss_matches_shifted_dispatched_loss_inputs() -> None: packed = cast( Any, diff --git a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py index 19f33970c..b8e61537d 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from contextlib import ExitStack, contextmanager -import socket +from pathlib import Path +import tempfile from typing import Any import pytest @@ -31,6 +32,7 @@ _apply_test_flex_inner_fp32_patch, ) from .cases import default_phase0_cases +from .distributed_init import file_init_method from .metrics import ( GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_THRESHOLD, @@ -96,66 +98,66 @@ def test_qwen35_full_model_cp1_matches_flattened_grad_accumulation() -> None: flat_loss_sum: torch.Tensor | None = None logits_mean_abs_pct = 0.0 - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_assistant_mask = torch.cat( - [ - torch.zeros( - (1, prefix.length), dtype=torch.bool, device=device - ), - assistant_mask[ - row : row + 1, completion.start : completion.end - ], - ], - dim=1, - ) - ref_group_ids = torch.zeros_like(ref_tokens) - ref_parent_ids = torch.zeros_like(ref_tokens) - ref_logits, ref_loss = _run_model_loss( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=ref_group_ids, - parent_ids=ref_parent_ids, - assistant_mask=ref_assistant_mask, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [ + tokens[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_pos = torch.cat( + [ + input_pos[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_assistant_mask = torch.cat( + [ + torch.zeros( + (1, completion_offset), + dtype=torch.bool, + device=device, + ), + assistant_mask[row : row + 1, completion.start : completion.end], + ], + dim=1, + ) + ref_group_ids = torch.zeros_like(ref_tokens) + ref_parent_ids = torch.zeros_like(ref_tokens) + ref_logits, ref_loss = _run_model_loss( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=ref_group_ids, + parent_ids=ref_parent_ids, + assistant_mask=ref_assistant_mask, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), + ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -214,70 +216,64 @@ def _assert_logits_vjp_equivalence( flat_loss_sum: torch.Tensor | None = None logits_mean_abs_pct = 0.0 - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_logits = _run_model_logits( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=torch.zeros_like(ref_tokens), - parent_ids=torch.zeros_like(ref_tokens), - ) - ref_output_grad = torch.zeros_like(ref_logits) - ref_output_mask = torch.zeros( - ref_logits.shape[:2], - device=ref_logits.device, - dtype=torch.bool, - ) - if completion.length > 1: - ref_output_grad[ - :, prefix.length : prefix.length + completion.length - 1 - ] = output_grad[row : row + 1, completion.start : completion.end - 1] - ref_output_mask[ - :, prefix.length : prefix.length + completion.length - 1 - ] = True - ref_loss = stable_output_mse_loss( - ref_logits, - ref_output_grad, - mask=ref_output_mask.unsqueeze(-1), - denominator=loss_denominator, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [tokens[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_pos = torch.cat( + [input_pos[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_logits = _run_model_logits( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=torch.zeros_like(ref_tokens), + parent_ids=torch.zeros_like(ref_tokens), + ) + ref_output_grad = torch.zeros_like(ref_logits) + ref_output_mask = torch.zeros( + ref_logits.shape[:2], + device=ref_logits.device, + dtype=torch.bool, + ) + if completion.length > 1: + ref_output_grad[ + :, completion_offset : completion_offset + completion.length - 1 + ] = output_grad[row : row + 1, completion.start : completion.end - 1] + ref_output_mask[ + :, completion_offset : completion_offset + completion.length - 1 + ] = True + ref_loss = stable_output_mse_loss( + ref_logits, + ref_output_grad, + mask=ref_output_mask.unsqueeze(-1), + denominator=loss_denominator, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -359,6 +355,15 @@ def _run_model_logits( return logits +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def _make_matching_models() -> tuple[torch.nn.Module, torch.nn.Module]: model_parallel_cuda_manual_seed(1234) packed = _make_model() @@ -424,28 +429,23 @@ def _single_rank_model_parallel() -> Iterator[None]: if is_initialized(): pytest.skip("torch.distributed is already initialized in this process.") torch.cuda.set_device(0) - init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{_find_free_port()}", - rank=0, - world_size=1, - ) - try: - ps.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=1, + with tempfile.TemporaryDirectory(prefix="art_dist_") as tmp: + init_process_group( + backend="nccl", + init_method=file_init_method(Path(tmp), "qwen35_full_model_cp1"), + rank=0, + world_size=1, ) - yield - finally: - if getattr(ps, "model_parallel_is_initialized", lambda: False)(): - ps.destroy_model_parallel() - if is_initialized(): - destroy_process_group() - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + if is_initialized(): + destroy_process_group() diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py index de6933582..f026b90c9 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from contextlib import contextmanager -import socket +from pathlib import Path +import tempfile import pytest @@ -18,13 +19,13 @@ from megatron.core.ssm.gated_delta_net import GatedDeltaNet from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from torch.distributed import ( - DistNetworkError, destroy_process_group, init_process_group, is_initialized, ) from .cases import default_phase0_cases +from .distributed_init import file_init_method from .metrics import ( GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_MISMATCH_THRESHOLD, @@ -232,44 +233,23 @@ def _single_rank_model_parallel() -> Iterator[None]: if is_initialized(): pytest.skip("torch.distributed is already initialized in this process.") torch.cuda.set_device(0) - _init_single_rank_process_group() - try: - ps.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=1, + with tempfile.TemporaryDirectory(prefix="art_dist_") as tmp: + init_process_group( + backend="nccl", + init_method=file_init_method(Path(tmp), "single_rank"), + rank=0, + world_size=1, ) - yield - finally: - if getattr(ps, "model_parallel_is_initialized", lambda: False)(): - ps.destroy_model_parallel() - if is_initialized(): - destroy_process_group() - - -def _init_single_rank_process_group() -> None: - last_error: DistNetworkError | None = None - for _ in range(16): try: - init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{_find_free_port()}", - rank=0, - world_size=1, + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, ) - return - except DistNetworkError as error: - if "EADDRINUSE" not in str(error): - raise - last_error = error + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() if is_initialized(): destroy_process_group() - if last_error is not None: - raise last_error - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py index e0d164c56..7a173ae8f 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import cast import pytest @@ -37,6 +36,7 @@ ) from .cases import GdnFamilyShape, GdnPackedRowShape, GdnPhase0Case # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import ( # noqa: E402 GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_THRESHOLD, @@ -70,10 +70,10 @@ def test_real_qwen35_gdn_native_fla_cp_prepared_varlen_batch_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_gdn_prepared_cp{cp_size}") mp.spawn( _native_gdn_cp_prepared_varlen_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -89,10 +89,10 @@ def test_real_qwen35_gdn_native_fla_cp_prepared_varlen_batch_matches_single_rank def test_real_qwen35_gdn_native_cp_packed_layer_matches_cp1( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_gdn_packed_cp{cp_size}") mp.spawn( _native_gdn_cp_packed_layer_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -103,13 +103,13 @@ def test_real_qwen35_gdn_native_cp_packed_layer_matches_cp1( def _native_gdn_cp_packed_layer_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -127,9 +127,7 @@ def _native_gdn_cp_packed_layer_worker( tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -139,17 +137,9 @@ def _native_gdn_cp_packed_layer_worker( cp_chain_min_tokens_per_rank=16, cp_chain_min_total_tokens=128, cp_chain_min_prefix_only_tokens=128, - # This test is the native chain correctness guard, so force the - # planner onto chain prefix and completion buckets. - planner_chain_bucket_ms=0.0, - planner_chain_token_ms=0.0, - planner_local_bucket_ms=1.0, - planner_local_token_ms=1.0, - cp_chain_min_score_delta_ms=0.0, ), ) - assert plan.chain_prefix_buckets - assert plan.chain_completion_buckets + assert any(plan.tree_chain_buckets_by_depth) hidden, output_grad = _packed_hidden_and_grad(case, cp_size) ref_hidden = hidden.clone().detach().requires_grad_(True) ref_out, _ = run_gdn_layer( @@ -215,13 +205,13 @@ def _native_gdn_cp_packed_layer_worker( def _native_gdn_cp_prepared_varlen_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -599,9 +589,3 @@ def _all_reduce_parameter_grads(module: torch.nn.Module) -> None: main_grad = getattr(parameter, "main_grad", None) if main_grad is not None: torch.distributed.all_reduce(main_grad) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py index c4bd99abc..62e217daf 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket import pytest @@ -26,6 +25,7 @@ from art.megatron.model_support.handlers import QWEN3_5_MOE_HANDLER # noqa: E402 from .cases import GdnPhase0Case, default_phase0_cases # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import GDN_CORRECTNESS_DTYPE, assert_real_gdn_metrics # noqa: E402 from .packed_layout import build_phase0_packed_tensors # noqa: E402 from .real_gdn_oracle import ( # noqa: E402 @@ -68,10 +68,10 @@ def test_real_qwen35_gdn_lora_gradients_match_flattened() -> None: reason="At least two CUDA devices are required for TP2 GDN coverage.", ) def test_real_qwen35_gdn_tp2_gradients_match_flattened(tmp_path: Path) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, "real_gdn_tp2_lora") mp.spawn( _tp2_worker, - args=(port, str(tmp_path)), + args=(init_method, str(tmp_path)), nprocs=2, join=True, ) @@ -79,11 +79,11 @@ def test_real_qwen35_gdn_tp2_gradients_match_flattened(tmp_path: Path) -> None: assert (tmp_path / f"rank_{rank}.ok").read_text() == "ok\n" -def _tp2_worker(rank: int, port: int, output_dir: str) -> None: +def _tp2_worker(rank: int, init_method: str, output_dir: str) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=2, ) @@ -229,9 +229,3 @@ def _gdn_lora_grad_names(gdn: torch.nn.Module) -> tuple[str, ...]: and parameter.grad is not None and bool(parameter.grad.abs().max().item() > 0) ) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index b14cd2a4c..7bb3e1b94 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -1,12 +1,17 @@ import json +import os from pathlib import Path +import shutil import subprocess import sys from typing import Any, cast +import pytest from safetensors.torch import load_file, save_file import torch +pytest.importorskip("megatron.bridge.models.gpt_provider") + from art.megatron import lora as lora_module from art.megatron.lora import LoRA, LoRAParallelSpec, LoRAPublishPlanner from art.megatron.model_support.handlers import ( @@ -29,6 +34,66 @@ REPO_ROOT = Path(__file__).parents[4] VLLM_PYTHON = REPO_ROOT / "vllm_runtime/.venv/bin/python" +_VLLM_RUNTIME_UNAVAILABLE_REASON: str | None | object = object() + + +def _vllm_python_cmd() -> list[str]: + override = os.environ.get("ART_TEST_VLLM_PYTHON") + if override: + return [override] + if VLLM_PYTHON.exists(): + return [str(VLLM_PYTHON)] + uv = shutil.which("uv") + if uv is None: + raise RuntimeError( + f"{VLLM_PYTHON} does not exist and uv is not available to run " + "the locked vLLM runtime project" + ) + return [ + uv, + "run", + "--project", + str(REPO_ROOT / "vllm_runtime"), + "--frozen", + "--no-dev", + "python", + ] + + +def _vllm_runtime_unavailable_reason() -> str | None: + global _VLLM_RUNTIME_UNAVAILABLE_REASON + if isinstance(_VLLM_RUNTIME_UNAVAILABLE_REASON, str): + return _VLLM_RUNTIME_UNAVAILABLE_REASON + if _VLLM_RUNTIME_UNAVAILABLE_REASON is None: + return None + try: + subprocess.run( + [ + *_vllm_python_cmd(), + "-c", + "import vllm; from vllm.lora.lora_model import LoRAModel", + ], + check=True, + text=True, + capture_output=True, + timeout=120, + ) + except Exception as exc: + _VLLM_RUNTIME_UNAVAILABLE_REASON = ( + "Stock vLLM loader runtime is unavailable. Run " + "`uv sync --project vllm_runtime --frozen --no-dev`, or set " + "`ART_TEST_VLLM_PYTHON` to a Python environment with vLLM installed. " + f"Original error: {exc}" + ) + return _VLLM_RUNTIME_UNAVAILABLE_REASON + _VLLM_RUNTIME_UNAVAILABLE_REASON = None + return None + + +def test_stock_vllm_loader_runtime_is_available() -> None: + reason = _vllm_runtime_unavailable_reason() + if reason is not None: + pytest.fail(reason) def _config(base_model: str, rank: int = 2, alpha: int = 4) -> dict: @@ -116,6 +181,8 @@ def _assert_stock_vllm_loads( expected_modules: set[str], mapper: str = "none", ) -> list[str]: + if reason := _vllm_runtime_unavailable_reason(): + pytest.skip(reason) script = r""" import json import sys @@ -142,7 +209,7 @@ def _assert_stock_vllm_loads( """ result = subprocess.run( [ - str(VLLM_PYTHON), + *_vllm_python_cmd(), "-c", script, str(path), diff --git a/tests/integration/megatron/model_support/forward_trace.py b/tests/integration/megatron/model_support/forward_trace.py index 289b8b7a6..30731cbdd 100644 --- a/tests/integration/megatron/model_support/forward_trace.py +++ b/tests/integration/megatron/model_support/forward_trace.py @@ -1118,6 +1118,7 @@ def _canonicalize_row_aligned_value( def _canonicalize_call_row_token_order(cls, call: dict[str, Any]) -> None: """Canonicalizes all row-aligned call tensors to global token order.""" cls._align_exact_zero_padding_row_token_uids(call) + cls._drop_exact_zero_padding_rows(call) row_token_uids = call.get("row_token_uids") if not isinstance(row_token_uids, torch.Tensor) or row_token_uids.ndim != 1: return @@ -1138,6 +1139,38 @@ def _canonicalize_call_row_token_order(cls, call: dict[str, Any]) -> None: ) call["row_token_uids"] = row_token_uids.index_select(0, order).contiguous() + @classmethod + def _drop_exact_zero_padding_rows(cls, call: dict[str, Any]) -> None: + """Removes traced sequence-padding rows before comparing compact CP traces.""" + row_token_uids = call.get("row_token_uids") + tensor = call.get("primary_output") + if ( + not isinstance(row_token_uids, torch.Tensor) + or row_token_uids.ndim != 1 + or not isinstance(tensor, torch.Tensor) + or tensor.ndim == 0 + or int(tensor.shape[0]) != int(row_token_uids.numel()) + ): + return + row_count = int(row_token_uids.numel()) + padding_rows = row_token_uids < 0 + if row_count == 0 or not bool(padding_rows.any().item()): + return + flat = tensor.detach().reshape(row_count, -1) + if not bool((flat[padding_rows] == 0).all().item()): + return + valid_rows = torch.nonzero(~padding_rows, as_tuple=False).reshape(-1) + original_call = dict(call) + for key, value in original_call.items(): + if key == "row_token_uids": + continue + call[key] = cls._slice_row_aligned_value( + value, + row_indices=valid_rows, + total_rows=row_count, + ) + call["row_token_uids"] = row_token_uids.index_select(0, valid_rows).contiguous() + @staticmethod def _align_exact_zero_padding_row_token_uids(call: dict[str, Any]) -> None: """Moves padding UID markers onto exact-zero sequence-parallel pad rows.""" diff --git a/tests/integration/megatron/model_support/oracle_worker.py b/tests/integration/megatron/model_support/oracle_worker.py index 86ad2fb0e..50141edde 100644 --- a/tests/integration/megatron/model_support/oracle_worker.py +++ b/tests/integration/megatron/model_support/oracle_worker.py @@ -1117,12 +1117,16 @@ def _reference_forward( def _reference_fc1_forward(self: Any, x: torch.Tensor, tokens_per_expert: Any): base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = torch.cat( - ( - self.gate_lora(x, tokens_per_expert), - self.up_lora(x, tokens_per_expert), - ), - dim=1, + adapter_out = ( + self.lora(x, tokens_per_expert) + if self.fused_gate_up + else torch.cat( + ( + self.gate_lora(x, tokens_per_expert), + self.up_lora(x, tokens_per_expert), + ), + dim=1, + ) ) return base_out + adapter_out, bias_out diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py index 5a45bc03a..043736553 100644 --- a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -314,6 +314,37 @@ def test_forward_trace_canonicalizes_row_outputs_by_token_uid() -> None: ) +def test_forward_trace_drops_exact_zero_padding_rows() -> None: + trace: dict[str, list[dict[str, Any]]] = { + "chunk0.module.decoder.layers.0.self_attention.out_proj": [ + { + "primary_output": torch.tensor( + [[0.0, 0.0], [30.0, 31.0], [10.0, 11.0], [20.0, 21.0]] + ), + "output": { + "hidden": torch.tensor( + [[0.0, 0.0], [3.0, 3.1], [1.0, 1.1], [2.0, 2.1]] + ) + }, + "row_token_uids": torch.tensor([-1, 3, 1, 2]), + } + ] + } + + ForwardTraceCapture.canonicalize_trace(trace) + + call = trace["chunk0.module.decoder.layers.0.self_attention.out_proj"][0] + assert torch.equal(call["row_token_uids"], torch.tensor([1, 2, 3])) + assert torch.equal( + call["primary_output"], + torch.tensor([[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]), + ) + assert torch.equal( + call["output"]["hidden"], + torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]), + ) + + def test_forward_trace_expands_attention_output_uids_for_out_norm_heads() -> None: trace: dict[str, list[dict[str, Any]]] = { "chunk0.module.decoder.layers.0.self_attention": [ diff --git a/tests/unit/test_shared_prefix_attention_builder.py b/tests/unit/test_shared_prefix_attention_builder.py new file mode 100644 index 000000000..34992bac9 --- /dev/null +++ b/tests/unit/test_shared_prefix_attention_builder.py @@ -0,0 +1,586 @@ +from __future__ import annotations + +import pytest +import torch +from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask + +pytest.importorskip("megatron.core.packed_seq_params") + +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) +from art.megatron.context_parallel.builder import ( + build_dense_reference_mask, + build_shared_prefix_attention_spec, +) +from art.megatron.context_parallel.runtime import get_or_build_runtime_plan +from art.megatron.context_parallel.types import ( + AttnMaskKind, + AttnSlice, + ContextParallelConfig, + ExactMaskMetadata, + FlexMaskSpec, + ParallelTopology, + TokenRange, +) +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.shared_prefix_state import create_shared_prefix_state + + +def build_block_mask( + spec: FlexMaskSpec, + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + device: torch.device, +) -> BlockMask | None: + return build_block_mask_from_context( + spec, + context=prepare_block_mask_context( + group_ids=group_ids, + parent_ids=parent_ids, + ), + device=device, + ) + + +def test_shared_prefix_attention_spec_supports_branching_completions() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.int().tolist() == [ + [1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 1, 0], + [1, 1, 1, 0, 0, 1, 1], + ] + + +def test_shared_prefix_attention_spec_matches_tree_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.equal(_reference_tree_mask(group_ids[0], parent_ids[0])) + + +def test_shared_prefix_can_build_context_parallel_layout() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + + plan = get_or_build_runtime_plan( + spec, + topology=ParallelTopology(cp=2), + config=ContextParallelConfig(planner_chunk_size=2, planner_max_search_steps=1), + original_seq_len=int(group_ids.numel()), + ) + + assert sum(plan[rank].local_valid_lengths[0] for rank in range(2)) == int( + group_ids.numel() + ) + + +def test_sparse_block_mask_exact_predicate_matches_dense_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="depth-two", + ), + ), + group_ids=group_ids[0], + parent_ids=parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(row.valid_tokens)[:, None] + k_indices = torch.arange(row.valid_tokens)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(build_dense_reference_mask(row_spec=row)) + + +@pytest.mark.parametrize( + ("name", "pack"), + ( + ( + "no-sharing", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5]), + torch.tensor([6, 7, 8, 9]), + ), + max_depth=0, + ), + ), + ( + "depth-one", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 2, 6]), + ), + max_depth=1, + ), + ), + ( + "depth-three", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + ), + ), +) +def test_sparse_block_mask_matches_torch_block_metadata( + name: str, + pack: SharedPrefixPack, +) -> None: + del name + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="torch-parity", + ), + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + _assert_matches_torch_block_mask(block_mask) + + +def test_sparse_block_mask_prunes_exact_blocks_rejected_by_group_tree() -> None: + group_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + parent_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=4, + k_len=4, + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=4), + k_range=TokenRange(start=0, end=4), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=torch.tensor([4, 5, 6, 7], dtype=torch.long), + k_token_indices=torch.tensor([0, 1, 2, 3], dtype=torch.long), + cache_key="all-false-cross-family", + ), + ), + group_ids=group_ids, + parent_ids=parent_ids, + device=torch.device("cpu"), + ) + + assert block_mask is not None + assert int(block_mask.kv_num_blocks.sum().item()) == 0 + assert int(block_mask.full_kv_num_blocks.sum().item()) == 0 + _assert_matches_torch_block_mask(block_mask) + + +def test_shared_prefix_state_builds_batched_block_mask() -> None: + group_ids = torch.tensor( + [ + [1, 1, 2, 2, -1], + [10, 11, 11, -1, -1], + ], + dtype=torch.long, + ) + parent_ids = torch.tensor( + [ + [1, 1, 1, 1, -1], + [10, 10, 10, -1, -1], + ], + dtype=torch.long, + ) + + state = create_shared_prefix_state( + group_ids=group_ids, + parent_ids=parent_ids, + target_device=torch.device("cpu"), + ) + seq_len = int(group_ids.shape[1]) + batch_idx = torch.arange(2)[:, None, None].expand(2, seq_len, seq_len) + query_idx = torch.arange(seq_len)[None, :, None].expand(2, seq_len, seq_len) + kv_idx = torch.arange(seq_len)[None, None, :].expand(2, seq_len, seq_len) + actual = state.block_mask.mask_mod( + batch_idx, + torch.zeros_like(batch_idx), + query_idx, + kv_idx, + ) + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + assert int(state.block_mask.kv_num_blocks.shape[0]) == 2 + for row_index, row_spec in enumerate(spec.rows): + valid_tokens = int(row_spec.valid_tokens) + assert actual[ + row_index, + :valid_tokens, + :valid_tokens, + ].equal(build_dense_reference_mask(row_spec=row_spec)) + _assert_matches_torch_block_mask(state.block_mask, batch_size=2) + + +def test_context_parallel_stage_masks_match_dense_nested_tree() -> None: + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + require_remote_stage=True, + ) + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + torch.tensor([7, 8]), + torch.tensor([9, 10, 11, 12]), + ), + max_depth=3, + ), + require_remote_stage=False, + ) + + +def _assert_context_parallel_stage_masks_match_dense( + pack: SharedPrefixPack, + *, + require_remote_stage: bool, +) -> None: + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + dense = build_dense_reference_mask(row_spec=row) + topology = ParallelTopology(cp=2) + config = ContextParallelConfig( + block_size=2, + planner_chunk_size=2, + planner_max_search_steps=1, + planner_remote_stage_token_floor=1, + planner_remote_stage_pair_floor=1, + ) + plan = get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + original_seq_len=int(pack.tokens.numel()), + ) + + checked_stages = 0 + checked_remote_stages = 0 + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + block_mask = build_block_mask( + FlexMaskSpec( + q_len=stage.q_len, + k_len=stage.k_len, + block_size=(2, 2), + slices=stage.slices, + exact_mask=stage.mask_metadata, + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + assert block_mask is not None + q_offsets = torch.arange(stage.q_len)[:, None] + k_offsets = torch.arange(stage.k_len)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + q_tokens = stage.mask_metadata.q_token_indices + k_tokens = stage.mask_metadata.k_token_indices + expected = ( + dense[q_tokens.clamp_min(0)[:, None], k_tokens.clamp_min(0)[None, :]] + & (q_tokens[:, None] >= 0) + & (k_tokens[None, :] >= 0) + ) + + assert actual.equal(expected) + assert _effective_block_mask(block_mask).equal(expected) + _assert_matches_torch_block_mask(block_mask) + checked_stages += 1 + checked_remote_stages += int(not stage.is_local_stage) + + assert checked_stages + if require_remote_stage: + assert checked_remote_stages + + +def _effective_block_mask(block_mask: BlockMask) -> torch.Tensor: + q_len, k_len = block_mask.seq_lengths + q_block, k_block = block_mask.BLOCK_SIZE + effective = torch.zeros((q_len, k_len), dtype=torch.bool) + _fill_full_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + _fill_partial_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + return effective + + +def _fill_full_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + if block_mask.full_kv_num_blocks is None or block_mask.full_kv_indices is None: + return + for q_block_index in range(int(block_mask.full_kv_num_blocks.shape[-1])): + q_slice = slice(q_block_index * q_block, (q_block_index + 1) * q_block) + block_count = int(block_mask.full_kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.full_kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_slice = slice( + int(k_block_index) * k_block, + (int(k_block_index) + 1) * k_block, + ) + effective[q_slice, k_slice] = True + + +def _fill_partial_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + for q_block_index in range(int(block_mask.kv_num_blocks.shape[-1])): + q_offsets = torch.arange( + q_block_index * q_block, + min((q_block_index + 1) * q_block, effective.shape[0]), + )[:, None] + block_count = int(block_mask.kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_offsets = torch.arange( + int(k_block_index) * k_block, + min((int(k_block_index) + 1) * k_block, effective.shape[1]), + )[None, :] + effective[q_offsets, k_offsets] |= block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + + +def test_sparse_block_mask_supports_non_monotonic_remote_k_indices() -> None: + q_token_indices = torch.tensor([4, 5, 6, 7], dtype=torch.long) + k_token_indices = torch.tensor([0, 1, 6, 2, 3, 4], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=int(q_token_indices.numel()), + k_len=int(k_token_indices.numel()), + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=int(q_token_indices.numel())), + k_range=TokenRange(start=0, end=int(k_token_indices.numel())), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=q_token_indices, + k_token_indices=k_token_indices, + cache_key="non-monotonic-k", + ), + ), + group_ids=torch.ones(8, dtype=torch.long), + parent_ids=torch.ones(8, dtype=torch.long), + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(q_token_indices.numel())[:, None] + k_indices = torch.arange(k_token_indices.numel())[None, :] + + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(q_token_indices[:, None] >= k_token_indices[None, :]) + _assert_matches_torch_block_mask(block_mask) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + batch_size: int = 1, +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + block_mask.mask_mod, + B=batch_size, + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + assert _effective_block_mask(block_mask).equal(_effective_block_mask(reference)) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + assert _block_entries(block_mask, counts_name, indices_name) == _block_entries( + reference, + counts_name, + indices_name, + ) + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _branching_prefix_inputs() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.tensor([[1, 1, 1, 2, 3, 4, 4]], dtype=torch.long), + torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.long), + ) + + +def _reference_tree_mask( + group_ids: torch.Tensor, parent_ids: torch.Tensor +) -> torch.Tensor: + group_list = [int(value) for value in group_ids.tolist()] + parent_by_group: dict[int, int | None] = {} + for group_id, parent_id in zip(group_list, parent_ids.tolist(), strict=True): + group_id = int(group_id) + parent_id = int(parent_id) + if group_id not in parent_by_group: + parent_by_group[group_id] = None if parent_id == group_id else parent_id + + ancestors_by_group = { + group_id: _ancestors(group_id, parent_by_group) for group_id in parent_by_group + } + dense = torch.zeros((len(group_list), len(group_list)), dtype=torch.bool) + for q_pos, q_group in enumerate(group_list): + allowed_groups = ancestors_by_group[q_group] | {q_group} + for k_pos, k_group in enumerate(group_list): + dense[q_pos, k_pos] = k_pos <= q_pos and k_group in allowed_groups + return dense + + +def _ancestors( + group_id: int, + parent_by_group: dict[int, int | None], +) -> set[int]: + ancestors: set[int] = set() + cursor = parent_by_group[group_id] + while cursor is not None and cursor not in ancestors: + ancestors.add(cursor) + cursor = parent_by_group.get(cursor) + return ancestors diff --git a/tests/unit/test_shared_prefix_grad_parity.py b/tests/unit/test_shared_prefix_grad_parity.py new file mode 100644 index 000000000..5b812782b --- /dev/null +++ b/tests/unit/test_shared_prefix_grad_parity.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from copy import deepcopy + +import pytest +import torch +from torch import nn +import torch.nn.functional as F + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes + + +class _ToyCausalLM(nn.Module): + def __init__(self) -> None: + super().__init__() + self.token_embedding = nn.Embedding(32, 8, dtype=torch.float64) + self.position_embedding = nn.Embedding(8, 8, dtype=torch.float64) + self.mix = nn.Linear(8, 8, bias=False, dtype=torch.float64) + self.output = nn.Linear(8, 32, bias=False, dtype=torch.float64) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + states = self.token_embedding(input_ids) + self.position_embedding(position_ids) + context = causal_mask.to(states.dtype) @ states + return self.output(torch.tanh(self.mix(context))) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("multi_target", (False, True)) +def test_shared_prefix_ce_parameter_grads_match_independent_sequences( + *, + max_depth: int, + multi_target: bool, +) -> None: + input_ids = _input_ids() + target_ids = tuple( + _targets(tokens, multi_target=multi_target) for tokens in input_ids + ) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + + assert int(pack.tokens.numel()) < sum(len(row) for row in input_ids) + + torch.manual_seed(20260518) + naive_model = _ToyCausalLM() + packed_model = deepcopy(naive_model) + + naive_loss = torch.stack( + [ + _sequence_ce_loss(naive_model, tokens, labels) + for tokens, labels in zip(input_ids, target_ids, strict=True) + ] + ).sum() + packed_loss = _packed_ce_loss(packed_model, pack, target_ids) + + torch.testing.assert_close(packed_loss, naive_loss, rtol=1e-12, atol=1e-12) + naive_loss.backward() + packed_loss.backward() + + for (name, naive_param), packed_param in zip( + naive_model.named_parameters(), + packed_model.parameters(), + strict=True, + ): + assert naive_param.grad is not None, name + assert packed_param.grad is not None, name + torch.testing.assert_close( + packed_param.grad, + naive_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +def test_same_layout_mutation_preserves_forward_outputs(max_depth: int) -> None: + pack = pack_shared_prefixes(_input_ids(), max_depth=max_depth) + torch.manual_seed(20260518) + model = _ToyCausalLM() + logits = _packed_logits(model, pack) + + for positions in pack.positions_by_sequence: + mutated_logits = _packed_logits(model, _mutated_pack(pack, keep=positions)) + torch.testing.assert_close( + mutated_logits.index_select(0, positions), + logits.index_select(0, positions), + rtol=0.0, + atol=0.0, + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("sequence_index", (0, 2, 4)) +def test_same_layout_mutation_preserves_target_loss_grads( + max_depth: int, + sequence_index: int, +) -> None: + input_ids = _input_ids() + target_ids = tuple(_targets(tokens, multi_target=True) for tokens in input_ids) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + mutated = _mutated_pack(pack, keep=pack.positions_by_sequence[sequence_index]) + + torch.manual_seed(20260518) + base_model = _ToyCausalLM() + mutated_model = deepcopy(base_model) + + base_loss = _packed_sequence_ce_loss(base_model, pack, target_ids, sequence_index) + mutated_loss = _packed_sequence_ce_loss( + mutated_model, + mutated, + target_ids, + sequence_index, + ) + + torch.testing.assert_close(mutated_loss, base_loss, rtol=0.0, atol=0.0) + base_loss.backward() + mutated_loss.backward() + _assert_matching_grads(mutated_model, base_model) + + +def _input_ids() -> tuple[torch.Tensor, ...]: + return ( + torch.tensor([1, 2, 3, 4, 5]), + torch.tensor([1, 2, 3, 4, 6]), + torch.tensor([1, 2, 3, 7]), + torch.tensor([1, 2, 8]), + torch.tensor([9, 10, 11]), + ) + + +def _targets(tokens: torch.Tensor, *, multi_target: bool) -> torch.Tensor: + labels = (tokens * 3 + 5) % 31 + if not multi_target: + return labels + alternate = (tokens * 5 + 7) % 31 + stacked = torch.stack((labels, alternate), dim=1) + if int(stacked.numel()) > 2: + stacked[1, 1] = -100 + return stacked + + +def _sequence_ce_loss( + model: _ToyCausalLM, + input_ids: torch.Tensor, + target_ids: torch.Tensor, +) -> torch.Tensor: + seq_len = int(input_ids.numel()) + logits = model( + input_ids, + torch.arange(seq_len), + torch.ones((seq_len, seq_len), dtype=torch.bool).tril(), + ) + return _target_ce_loss(logits, target_ids) + + +def _packed_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], +) -> torch.Tensor: + logits = _packed_logits(model, pack) + losses = [ + _target_ce_loss(logits.index_select(0, positions), labels) + for positions, labels in zip( + pack.positions_by_sequence, + target_ids, + strict=True, + ) + ] + return torch.stack(losses).sum() + + +def _packed_sequence_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], + sequence_index: int, +) -> torch.Tensor: + return _target_ce_loss( + _packed_logits(model, pack).index_select( + 0, + pack.positions_by_sequence[sequence_index], + ), + target_ids[sequence_index], + ) + + +def _packed_logits(model: _ToyCausalLM, pack: SharedPrefixPack) -> torch.Tensor: + return model( + pack.tokens.reshape(-1), + pack.position_ids.reshape(-1), + _shared_prefix_causal_mask(pack), + ) + + +def _target_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if labels.ndim == 1: + return F.cross_entropy(logits, labels, ignore_index=-100, reduction="sum") + expanded = logits.unsqueeze(1).expand(-1, int(labels.shape[1]), -1) + return F.cross_entropy( + expanded.reshape(-1, int(logits.shape[-1])), + labels.reshape(-1), + ignore_index=-100, + reduction="sum", + ) + + +def _mutated_pack(pack: SharedPrefixPack, *, keep: torch.Tensor) -> SharedPrefixPack: + tokens = pack.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool) + mutate[keep] = False + replacement = torch.arange(int(tokens.shape[1]), dtype=tokens.dtype) + 17 + tokens[0, mutate] = replacement[mutate] % 31 + return SharedPrefixPack( + tokens=tokens, + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + position_ids=pack.position_ids, + positions_by_sequence=pack.positions_by_sequence, + ) + + +def _assert_matching_grads(actual_model: nn.Module, expected_model: nn.Module) -> None: + for (name, expected_param), actual_param in zip( + expected_model.named_parameters(), + actual_model.parameters(), + strict=True, + ): + assert expected_param.grad is not None, name + assert actual_param.grad is not None, name + torch.testing.assert_close( + actual_param.grad, + expected_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +def _shared_prefix_causal_mask(pack: SharedPrefixPack) -> torch.Tensor: + group_ids = pack.group_ids.reshape(-1).tolist() + parent_ids = pack.parent_ids.reshape(-1).tolist() + position_ids = pack.position_ids.reshape(-1).tolist() + parent_by_group: dict[int, int] = {} + for group_id, parent_id in zip(group_ids, parent_ids, strict=True): + previous = parent_by_group.setdefault(group_id, parent_id) + assert previous == parent_id + + ancestors = { + group_id: _ancestor_groups(group_id, parent_by_group) + for group_id in parent_by_group + } + mask = torch.zeros((len(group_ids), len(group_ids)), dtype=torch.bool) + for query_index, query_group in enumerate(group_ids): + query_ancestors = ancestors[query_group] + query_position = position_ids[query_index] + for key_index, key_group in enumerate(group_ids): + if ( + key_group in query_ancestors + and position_ids[key_index] <= query_position + ): + mask[query_index, key_index] = True + return mask + + +def _ancestor_groups(group_id: int, parent_by_group: dict[int, int]) -> set[int]: + ancestors = {group_id} + parent_id = parent_by_group[group_id] + while parent_id != group_id: + if parent_id in ancestors: + raise AssertionError("shared-prefix group parents contain a cycle") + ancestors.add(parent_id) + group_id = parent_id + parent_id = parent_by_group[group_id] + return ancestors diff --git a/tests/unit/test_shared_prefix_packing.py b/tests/unit/test_shared_prefix_packing.py new file mode 100644 index 000000000..c32fb21a9 --- /dev/null +++ b/tests/unit/test_shared_prefix_packing.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import ( + _local_position_pairs, + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) + + +def test_pack_shared_prefixes_support_depth_one() -> None: + inputs = ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 5]), + torch.tensor([9]), + ) + + pack = pack_shared_prefixes(inputs, max_depth=1) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 9]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3, 4]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 1, 1, 4]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 2, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 4], + [5], + ] + + +def test_pack_shared_prefixes_support_zero_depth_without_sharing() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + assert pack.tokens.tolist() == [[1, 2, 1, 3, 4]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.parent_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.position_ids.tolist() == [[0, 1, 0, 1, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1], + [2, 3], + [4], + ] + + +def test_pack_shared_prefixes_support_deeper_trees() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6, 7]), + ), + max_depth=2, + ) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 6, 7]] + assert pack.group_ids.tolist() == [[1, 2, 2, 3, 4, 5, 5]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 2, 2, 1, 1]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 3, 1, 2]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 2, 4], + [0, 5, 6], + ] + + +def test_packing_preserves_first_seen_branch_order() -> None: + pack = pack_shared_prefixes( + (torch.tensor([9]), torch.tensor([1])), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[9, 1]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0], + [1], + ] + + +def test_packing_handles_empty_sequences() -> None: + pack = pack_shared_prefixes( + (torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[]] + assert pack.group_ids.tolist() == [[]] + assert pack.parent_ids.tolist() == [[]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [[], []] + + +def test_packed_token_estimator_matches_real_packing() -> None: + cases = [ + (torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4]), torch.tensor([5])), + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 2, 6, 7]), + torch.tensor([1, 8]), + ), + ( + torch.tensor([9, 1, 2]), + torch.tensor([9, 1, 3]), + torch.tensor([9, 4, 5]), + torch.tensor([6, 7]), + torch.tensor([], dtype=torch.long), + ), + ] + + for inputs in cases: + for depth in range(5): + pack = pack_shared_prefixes(inputs, max_depth=depth) + + assert estimate_shared_prefix_packed_tokens(inputs, max_depth=depth) == int( + pack.tokens.numel() + ) + + +def test_packed_token_estimator_matches_randomized_packing() -> None: + generator = torch.Generator().manual_seed(123) + inputs = [] + for family in range(5): + prefix = torch.randint(1, 100, (4,), generator=generator) + for branch in range(4): + middle = torch.tensor([family, branch]) + suffix = torch.randint(1, 100, (3,), generator=generator) + inputs.append(torch.cat((prefix, middle, suffix))) + + for depth in range(5): + pack = pack_shared_prefixes(inputs, max_depth=depth) + + assert estimate_shared_prefix_packed_tokens(inputs, max_depth=depth) == int( + pack.tokens.numel() + ) + + +def test_packing_rejects_non_1d_sequences() -> None: + with pytest.raises(ValueError, match="expects 1-D tensors"): + pack_shared_prefixes((torch.tensor([[1, 2], [3, 4]]),), max_depth=1) + + +def test_local_position_pairs_preserve_requested_order_without_dense_match() -> None: + local_global_positions = torch.tensor([[2, -1, 0, 4, 1]]) + item_positions = torch.tensor([0, 1, 2, 3, 4]) + + local_positions, source_positions = _local_position_pairs( + local_global_positions, + item_positions, + ) + + assert local_positions.tolist() == [2, 4, 0, 3] + assert source_positions.tolist() == [0, 1, 2, 4] diff --git a/tests/unit/test_shared_prefix_tree.py b/tests/unit/test_shared_prefix_tree.py new file mode 100644 index 000000000..ce95c4fe1 --- /dev/null +++ b/tests/unit/test_shared_prefix_tree.py @@ -0,0 +1,502 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import pack_shared_prefixes +from art.megatron.shared_prefix_tree import parse_shared_prefix_row + + +def test_parse_shared_prefix_row_tracks_ancestors_and_depth() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ) + + tree = parse_shared_prefix_row( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + + assert tree.valid_tokens == int(pack.tokens.numel()) + assert max(segment.depth for segment in tree.segments) == 3 + assert [(segment.group_id, segment.ancestors) for segment in tree.segments] == [ + (1, ()), + (2, (1,)), + (3, (1, 2)), + (4, (1, 2, 3)), + (5, (1, 2, 3)), + (6, (1, 2)), + (7, (1,)), + ] + + +def test_parse_shared_prefix_row_rejects_missing_parent() -> None: + with pytest.raises(RuntimeError, match="missing parent"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2]), + parent_ids=torch.tensor([1, 3]), + ) + + +def test_parse_shared_prefix_row_rejects_non_contiguous_group() -> None: + with pytest.raises(RuntimeError, match="contiguous group runs"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2, 1]), + parent_ids=torch.tensor([1, 1, 1]), + ) + + +def test_gdn_tree_parser_accepts_nested_tree() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, 0, 1, 1, 0) + assert spec.tree_depths == (0, 1, 2, 2, 1) + assert [ + sum(bucket.segment_count for bucket in buckets) + for buckets in plan.tree_segment_buckets_by_depth + ] == [1, 2, 2] + + +def test_gdn_tree_parser_accepts_zero_depth_roots() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, -1, -1) + assert spec.tree_depths == (0, 0, 0) + assert [bucket.segment_count for bucket in plan.tree_segment_buckets_by_depth[0]] + assert not hasattr(plan, "local_prefix_buckets") + assert not hasattr(plan, "chain_completion_buckets") + assert not hasattr(plan, "prefix_boundary_buckets") + assert all( + not bucket.needs_final_state for bucket in plan.tree_segment_buckets_by_depth[0] + ) + + +def test_gdn_tree_planner_splits_leaf_and_internal_final_state_buckets() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 7]), + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 5, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan( + spec, + device="cpu", + planner_config=GdnPlannerConfig(max_padding_ratio=4.0), + ) + tree_has_children = _tree_has_children(spec) + + depth_one_buckets = plan.tree_segment_buckets_by_depth[1] + assert any(bucket.needs_final_state for bucket in depth_one_buckets) + assert any(not bucket.needs_final_state for bucket in depth_one_buckets) + for bucket in depth_one_buckets: + expected = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert expected == {bucket.needs_final_state} + + +def test_gdn_tree_cp_plan_chains_long_nodes() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 321) + mid = torch.arange(1001, 1321) + other = torch.arange(2001, 2321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, other, torch.tensor([13]))), + ), + max_depth=3, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = _chain_every_legal_segment_config() + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert any(plans[0].tree_chain_buckets_by_depth[0]) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + _assert_remote_parent_state_transfers_cover(spec, plans) + for plan in plans: + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + assert bucket.lengths_by_rank_cpu is not None + assert tuple(bucket.lengths_by_rank_cpu.shape)[0] == 4 + assert bucket.parent_indices is not None + + +def test_gdn_tree_cp_plan_exchanges_remote_parent_states() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 17) + mid = torch.arange(1001, 1321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, torch.tensor([99]))), + ), + max_depth=2, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=_chain_every_legal_segment_config(), + ) + for rank in range(4) + ) + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + assert _remote_parent_state_transfer_count(plans) > 0 + _assert_remote_parent_state_transfers_cover(spec, plans) + + +def test_gdn_tree_cp_randomized_plans_cover_each_token_once() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = _chain_every_legal_segment_config() + for seed in range(8): + pack = pack_shared_prefixes( + _random_tree_sequences(seed), + max_depth=4, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for plan in plans: + for depth_buckets in ( + *plan.tree_segment_buckets_by_depth, + *plan.tree_chain_buckets_by_depth, + ): + for bucket in depth_buckets: + assert bucket.parent_indices is not None + assert int(bucket.real_token_count) > 0 + + +def test_gdn_tree_cp_randomized_plans_pass_health_checks() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + for seed in range(16): + pack = pack_shared_prefixes( + _random_tree_sequences(seed + 100, max_depth=5), + max_depth=5, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + _assert_tree_plan_health( + spec, plans, max_padding_ratio=config.max_padding_ratio + ) + + +def _chain_every_legal_segment_config(): + from art.megatron.gdn.gdn_shared_prefix import GdnPlannerConfig + + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=1, + cp_chain_min_prefix_only_tokens=1, + max_padding_ratio=4.0, + ) + + +def _covered_token_indices(plans) -> set[int]: + return { + token + for plan in plans + for start, end, _position in plan.gdn_token_ranges + for token in range(start, end) + } + + +def _local_owner_by_family(plans) -> dict[int, int]: + owner_by_family = {} + for rank, plan in enumerate(plans): + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + for family_index in bucket.family_indices.tolist(): + previous = owner_by_family.setdefault(int(family_index), rank) + assert previous == rank + return owner_by_family + + +def _assert_remote_parent_state_transfers_cover(spec, plans) -> None: + owner_by_family = _local_owner_by_family(plans) + for family_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or parent_index not in owner_by_family: + continue + source_rank = owner_by_family[parent_index] + dest_rank = owner_by_family[family_index] + if source_rank == dest_rank: + continue + depth = spec.tree_depths[family_index] + source_exchange = plans[source_rank].tree_state_exchanges_by_depth[depth] + dest_exchange = plans[dest_rank].tree_state_exchanges_by_depth[depth] + assert source_exchange is not None + assert dest_exchange is not None + assert parent_index in source_exchange.source_family_indices + assert parent_index in dest_exchange.dest_family_indices + matching = [ + transfer + for transfer in dest_exchange.exchange.transfers + if transfer.source_rank == source_rank and transfer.dest_rank == dest_rank + ] + assert matching + + +def _remote_parent_state_transfer_count(plans) -> int: + return sum( + exchange.exchange.cross_rank_token_count + for plan in plans + for exchange in plan.tree_state_exchanges_by_depth + if exchange is not None + ) // len(plans) + + +def _tree_has_children(spec) -> list[bool]: + has_children = [False] * spec.family_count + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + has_children[parent_index] = True + return has_children + + +def _assert_tree_plan_health(spec, plans, *, max_padding_ratio: float) -> None: + tree_has_children = _tree_has_children(spec) + token_counts = [0] * int(spec.real_token_count) + for plan in plans: + range_tokens = sum( + end - start for start, end, _position in plan.gdn_token_ranges + ) + assert range_tokens == int(plan.gdn_token_count) + assert len(plan.attention_token_indices) == int(plan.attention_token_count) + + bucket_tokens = 0 + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert bucket_state_flags == {bucket.needs_final_state} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + if bucket.needs_final_state: + assert any(bucket_state_flags) + else: + assert bucket_state_flags == {False} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + assert bucket_tokens == int(plan.gdn_token_count) + + for start, end, _position in plan.gdn_token_ranges: + for token_index in range(start, end): + token_counts[token_index] += 1 + + _assert_remote_parent_state_transfers_cover(spec, plans) + assert token_counts == [1] * int(spec.real_token_count) + rank_tokens = [int(plan.gdn_token_count) for plan in plans] + assert max(rank_tokens) - min(rank_tokens) <= max(256, spec.real_token_count // 3) + + +def _random_tree_sequences( + seed: int, *, max_depth: int = 4 +) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + return out + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + segment_length = [1, 3, 17, 64, 129, 257][randint(0, 5)] + here = torch.cat((prefix, tokens(segment_length))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 9)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) diff --git a/typings/wandb/__init__.pyi b/typings/wandb/__init__.pyi new file mode 100644 index 000000000..09d1c8d16 --- /dev/null +++ b/typings/wandb/__init__.pyi @@ -0,0 +1,38 @@ +from typing import Any + +class Settings: + def __init__(self, **kwargs: Any) -> None: ... + +class Artifact: + aliases: list[str] + metadata: dict[str, Any] + def __init__(self, name: str, type: str, **kwargs: Any) -> None: ... + def add_dir(self, local_path: str, **kwargs: Any) -> None: ... + def add_file(self, local_path: str, **kwargs: Any) -> None: ... + def download(self, **kwargs: Any) -> str: ... + def save(self) -> None: ... + def wait(self) -> Artifact: ... + +class Run: + entity: str + project: str + name: str + config: Any + _is_finished: bool + def finish(self, *args: Any, **kwargs: Any) -> None: ... + def define_metric(self, *args: Any, **kwargs: Any) -> None: ... + def log(self, *args: Any, **kwargs: Any) -> None: ... + def log_artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... + +class Api: + default_entity: str + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... + def artifacts(self, *args: Any, **kwargs: Any) -> list[Artifact]: ... + def run(self, *args: Any, **kwargs: Any) -> Run: ... + +def init(*args: Any, **kwargs: Any) -> Run: ... +def login(*args: Any, **kwargs: Any) -> Any: ... + +class errors: + class CommError(Exception): ... diff --git a/typings/wandb/sdk/__init__.pyi b/typings/wandb/sdk/__init__.pyi new file mode 100644 index 000000000..1ce9ecf99 --- /dev/null +++ b/typings/wandb/sdk/__init__.pyi @@ -0,0 +1,5 @@ +from typing import Any + +__all__: list[str] + +def __getattr__(name: str) -> Any: ... diff --git a/typings/wandb/sdk/wandb_run.pyi b/typings/wandb/sdk/wandb_run.pyi new file mode 100644 index 000000000..416ae101e --- /dev/null +++ b/typings/wandb/sdk/wandb_run.pyi @@ -0,0 +1,3 @@ +from wandb import Run + +__all__ = ["Run"] From 4afa197ab667da057ca11127dcfdd6de29705616 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 18:57:49 -0600 Subject: [PATCH 02/33] fix sparse block mask metadata parity --- .../megatron/context_parallel/block_mask.py | 151 +++++++++++++++++- .../test_shared_prefix_attention_builder.py | 56 ++++++- 2 files changed, 200 insertions(+), 7 deletions(-) diff --git a/src/art/megatron/context_parallel/block_mask.py b/src/art/megatron/context_parallel/block_mask.py index 60508c48f..a4b0bb15f 100644 --- a/src/art/megatron/context_parallel/block_mask.py +++ b/src/art/megatron/context_parallel/block_mask.py @@ -306,6 +306,112 @@ def k_intervals(k_idx: int) -> tuple[tuple[int, int, int], ...]: full_blocks[q_idx, k_idx] = bool(is_full) +def _refine_sliding_interval_blocks( + *, + partial_blocks: np.ndarray, + full_blocks: np.ndarray, + q_abs: np.ndarray, + k_abs: np.ndarray, + q_enter: np.ndarray, + k_enter: np.ndarray, + k_exit: np.ndarray, + q_pos: np.ndarray, + k_pos: np.ndarray, + q_block: int, + k_block: int, + sliding_window: int, +) -> None: + candidates = partial_blocks | full_blocks + if not bool(candidates.any()): + return + + q_abs_blocks = _block_matrix( + q_abs, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ABS, + ) + q_enter_blocks = _block_matrix( + q_enter, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ENTER, + ) + q_pos_blocks = _block_matrix( + q_pos, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_POS, + ) + k_abs_blocks = _block_matrix( + k_abs, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ABS, + ) + k_enter_blocks = _block_matrix( + k_enter, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ENTER, + ) + k_exit_blocks = _block_matrix( + k_exit, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_EXIT, + ) + k_pos_blocks = _block_matrix( + k_pos, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_POS, + ) + + q_valid = ( + (q_abs_blocks >= 0) & (q_enter_blocks >= 0) & (q_pos_blocks != _INVALID_POS) + ) + k_valid = ( + (k_abs_blocks >= 0) + & (k_enter_blocks >= 0) + & (k_exit_blocks > k_enter_blocks) + & (k_pos_blocks != _INVALID_POS) + ) + + q_indices, k_indices = np.nonzero(candidates) + partial_blocks[q_indices, k_indices] = False + full_blocks[q_indices, k_indices] = False + for q_idx, k_idx in zip(q_indices, k_indices, strict=True): + q_valid_row = q_valid[q_idx] + k_valid_row = k_valid[k_idx] + if not bool(q_valid_row.any()) or not bool(k_valid_row.any()): + continue + + q_abs_row = q_abs_blocks[q_idx][:, None] + q_enter_row = q_enter_blocks[q_idx][:, None] + q_pos_row = q_pos_blocks[q_idx][:, None] + k_abs_row = k_abs_blocks[k_idx][None, :] + k_enter_row = k_enter_blocks[k_idx][None, :] + k_exit_row = k_exit_blocks[k_idx][None, :] + k_pos_row = k_pos_blocks[k_idx][None, :] + delta = q_pos_row - k_pos_row + allowed = ( + q_valid_row[:, None] + & k_valid_row[None, :] + & (q_abs_row >= k_abs_row) + & (k_enter_row <= q_enter_row) + & (q_enter_row < k_exit_row) + & (delta >= 0) + & (delta < int(sliding_window)) + ) + if not bool(allowed.any()): + continue + if bool(q_valid_row.all()) and bool(k_valid_row.all()) and bool(allowed.all()): + full_blocks[q_idx, k_idx] = True + else: + partial_blocks[q_idx, k_idx] = True + + def _is_strictly_increasing(values: np.ndarray) -> bool: return int(values.size) <= 1 or bool(np.all(values[1:] > values[:-1])) @@ -373,6 +479,9 @@ def visit(group_id: int) -> None: for segment in row_tree.segments: enter_by_token[segment.start : segment.end] = enter_by_group[segment.group_id] exit_by_token[segment.start : segment.end] = exit_by_group[segment.group_id] + if int(row_tree.valid_tokens) < int(length): + enter_by_token[int(row_tree.valid_tokens) :] = next_enter + exit_by_token[int(row_tree.valid_tokens) :] = next_enter + 1 return enter_by_token, exit_by_token @@ -497,6 +606,27 @@ def _build_sparse_block_mask( k_overlap_end == k_block_end_raw ) covers_block = q_is_full[:, None] & k_is_full[None, :] + if sliding_window is None: + window_has_any = None + window_is_full = None + else: + assert q_pos is not None and k_pos is not None + q_pos_min, q_pos_max = _block_min_max( + q_pos, + q_overlap_start, + q_overlap_end, + ) + k_pos_min, k_pos_max = _block_min_max( + k_pos, + k_overlap_start, + k_overlap_end, + ) + window_has_any = (q_pos_max[:, None] >= k_pos_min[None, :]) & ( + q_pos_min[:, None] - k_pos_max[None, :] < int(sliding_window) + ) + window_is_full = (q_pos_min[:, None] >= k_pos_max[None, :]) & ( + q_pos_max[:, None] - k_pos_min[None, :] < int(sliding_window) + ) if slice_.mask_kind == AttnMaskKind.FULL: has_any = np.ones( (int(q_block_indices.size), int(k_block_indices.size)), dtype=bool @@ -515,6 +645,10 @@ def _build_sparse_block_mask( ) has_any = q_max[:, None] >= k_min[None, :] is_full = covers_block & (q_min[:, None] >= k_max[None, :]) + if sliding_window is not None: + assert window_has_any is not None and window_is_full is not None + has_any &= window_has_any + is_full &= window_is_full q_slice = slice(int(q_block_indices[0]), int(q_block_indices[-1]) + 1) k_slice = slice(int(k_block_indices[0]), int(k_block_indices[-1]) + 1) @@ -541,8 +675,21 @@ def _build_sparse_block_mask( partial_blocks = (partial_blocks & ~needs_refine) | refined_partial full_blocks = (full_blocks & ~needs_refine) | refined_full if sliding_window is not None: - partial_blocks |= full_blocks - full_blocks = np.zeros_like(full_blocks) + assert q_pos is not None and k_pos is not None + _refine_sliding_interval_blocks( + partial_blocks=partial_blocks, + full_blocks=full_blocks, + q_abs=q_abs, + k_abs=k_abs, + q_enter=q_enter, + k_enter=k_enter, + k_exit=k_exit, + q_pos=q_pos, + k_pos=k_pos, + q_block=q_block, + k_block=k_block, + sliding_window=int(sliding_window), + ) kv_num_blocks, kv_indices = _dense_blocks_to_ordered( partial_blocks, device=device, diff --git a/tests/unit/test_shared_prefix_attention_builder.py b/tests/unit/test_shared_prefix_attention_builder.py index e4cd05aa2..e44ab01e0 100644 --- a/tests/unit/test_shared_prefix_attention_builder.py +++ b/tests/unit/test_shared_prefix_attention_builder.py @@ -34,6 +34,8 @@ def build_block_mask( *, group_ids: torch.Tensor, parent_ids: torch.Tensor, + input_pos: torch.Tensor | None = None, + sliding_window: int | None = None, device: torch.device, ) -> BlockMask | None: return build_block_mask_from_context( @@ -41,7 +43,9 @@ def build_block_mask( context=prepare_block_mask_context( group_ids=group_ids, parent_ids=parent_ids, + input_pos=input_pos, ), + sliding_window=sliding_window, device=device, ) @@ -278,14 +282,56 @@ def test_shared_prefix_state_builds_batched_block_mask() -> None: assert int(state.block_mask.kv_num_blocks.shape[0]) == 2 for row_index, row_spec in enumerate(spec.rows): valid_tokens = int(row_spec.valid_tokens) - assert actual[ - row_index, - :valid_tokens, - :valid_tokens, - ].equal(build_dense_reference_mask(row_spec=row_spec)) + expected = torch.zeros((seq_len, seq_len), dtype=torch.bool) + expected[:valid_tokens, :valid_tokens] = build_dense_reference_mask( + row_spec=row_spec + ) + if valid_tokens < seq_len: + padding_len = seq_len - valid_tokens + expected[valid_tokens:, valid_tokens:] = torch.tril( + torch.ones((padding_len, padding_len), dtype=torch.bool) + ) + assert actual[row_index].equal(expected) _assert_matches_torch_block_mask(state.block_mask, batch_size=2) +def test_sparse_block_mask_matches_torch_for_sliding_window_metadata() -> None: + seq_len = 128 + group_ids = torch.ones(seq_len, dtype=torch.long) + parent_ids = torch.ones(seq_len, dtype=torch.long) + input_pos = torch.arange(seq_len, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=seq_len, + k_len=seq_len, + block_size=(8, 8), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=seq_len), + k_range=TokenRange(start=0, end=seq_len), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=torch.arange(seq_len, dtype=torch.long), + k_token_indices=torch.arange(seq_len, dtype=torch.long), + cache_key="sliding-window", + ), + ), + group_ids=group_ids, + parent_ids=parent_ids, + input_pos=input_pos, + sliding_window=8, + device=torch.device("cpu"), + ) + + assert block_mask is not None + _assert_matches_torch_block_mask(block_mask) + full_causal_blocks = (seq_len // 8) * (seq_len // 8 + 1) // 2 + assert int(block_mask.kv_num_blocks.sum().item()) < full_causal_blocks + + def test_context_parallel_stage_masks_match_dense_nested_tree() -> None: _assert_context_parallel_stage_masks_match_dense( pack_shared_prefixes( From 48bdd1519ef6abc51a9cf6ebdfbbec5b12be8322 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 19:20:38 -0600 Subject: [PATCH 03/33] use native wandb sdk types --- pyproject.toml | 1 - src/art/model.py | 40 +++++++++-------- src/art/serverless/backend.py | 33 ++++++-------- .../log_constant_metrics_wandb.py | 5 +-- src/art/utils/deployment/wandb.py | 11 +++-- src/art/utils/record_provenance.py | 12 ++--- src/art/utils/wandb_sdk.py | 45 +++++++++++++++++++ tests/integration/test_provenance.py | 4 +- tests/integration/test_push_and_fork.py | 4 +- tests/unit/test_metric_routing.py | 43 +++++++++--------- typings/wandb/__init__.pyi | 38 ---------------- typings/wandb/sdk/__init__.pyi | 5 --- typings/wandb/sdk/wandb_run.pyi | 3 -- 13 files changed, 119 insertions(+), 125 deletions(-) create mode 100644 src/art/utils/wandb_sdk.py delete mode 100644 typings/wandb/__init__.pyi delete mode 100644 typings/wandb/sdk/__init__.pyi delete mode 100644 typings/wandb/sdk/wandb_run.pyi diff --git a/pyproject.toml b/pyproject.toml index f94de980b..6e6261e24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -249,7 +249,6 @@ requires-dist = [ [tool.ty.environment] python-version = "3.12" -extra-paths = ["typings"] [tool.ty.rules] # Ignore unused-ignore-comment warnings because they vary depending on whether diff --git a/src/art/model.py b/src/art/model.py index 597aa30a5..c92f14f36 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -29,6 +29,7 @@ from .preprocessing.vllm_tokens import attach_vllm_token_metadata_to_choice from .trajectories import Trajectory, TrajectoryGroup from .types import SFTMetricLoggingConfig, TrainSFTConfig +from .utils import wandb_sdk from .utils.trajectory_logging import write_trajectory_groups_parquet if TYPE_CHECKING: @@ -617,28 +618,31 @@ def _sync_wandb_config( def _get_wandb_run(self) -> Optional["Run"]: """Get or create the wandb run for this model.""" - import wandb - if "WANDB_API_KEY" not in os.environ: return None if self._wandb_run is None or self._wandb_run._is_finished: - run = wandb.init( - project=self.project, - name=self.name, - id=self.name, - config=self._wandb_config or None, - resume="allow", - reinit="create_new", - settings=wandb.Settings( - x_stats_open_metrics_endpoints={ - "vllm": "http://localhost:8000/metrics", - }, - x_stats_open_metrics_filters=( - "vllm.vllm:num_requests_waiting", - "vllm.vllm:num_requests_running", + try: + run = wandb_sdk.init( + project=self.project, + name=self.name, + id=self.name, + config=self._wandb_config or None, + resume="allow", + reinit="create_new", + settings=wandb_sdk.settings( + x_stats_open_metrics_endpoints={ + "vllm": "http://localhost:8000/metrics", + }, + x_stats_open_metrics_filters=( + "vllm.vllm:num_requests_waiting", + "vllm.vllm:num_requests_running", + ), ), - ), - ) + ) + except ModuleNotFoundError as e: + if e.name == "wandb" or (e.name or "").startswith("wandb."): + return None + raise self._wandb_run = run object.__setattr__( self, diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index f6a797a87..67659c90f 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -28,15 +28,16 @@ TrainConfig, TrainSFTConfig, ) +from ..utils import wandb_sdk from ..utils.record_provenance import record_provenance if TYPE_CHECKING: - import wandb + from wandb.sdk.artifacts.artifact import Artifact from ..model import Model, TrainableModel -def _extract_step_from_wandb_artifact(artifact: "wandb.Artifact") -> int | None: +def _extract_step_from_wandb_artifact(artifact: "Artifact") -> int | None: """Extract step number from a W&B artifact's aliases.""" for alias in artifact.aliases: if alias.startswith("step"): @@ -541,15 +542,13 @@ async def _train_sft( import tempfile import uuid - import wandb - from ..utils.sft import resolve_sft_batch_size assert model.id is not None, "Model ID is required" # Get the user's default entity from W&B if not set if model.entity is None: - api = wandb.Api(api_key=self._client.api_key) + api = wandb_sdk.api(api_key=self._client.api_key) model.entity = api.default_entity # Generate unique artifact name to avoid race conditions in distributed systems @@ -592,17 +591,17 @@ async def _train_sft( # Upload the file to W&B as a dataset artifact # Use the model's canonical run_id from database, or fall back to model name - run = wandb.init( + run = wandb_sdk.init( name=model.name, id=model.run_id or model.name, # Use stored run_id to match the canonical wandb run entity=model.entity, project=model.project, resume="allow", # Resume if this run already exists - settings=wandb.Settings(api_key=self._client.api_key), + settings=wandb_sdk.settings(api_key=self._client.api_key), ) try: - artifact = wandb.Artifact( + artifact = wandb_sdk.artifact( artifact_name, type="dataset", metadata={ @@ -735,12 +734,10 @@ async def _experimental_pull_model_checkpoint( import os import tempfile - import wandb - assert model.id is not None, "Model ID is required" # If entity is not set, use the user's default entity from W&B - api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] + api = wandb_sdk.api(api_key=self._client.api_key) if model.entity is None: model.entity = api.default_entity if verbose: @@ -905,8 +902,6 @@ async def _experimental_fork_checkpoint( import os import tempfile - import wandb - from_project = from_project or model.project if from_s3_bucket is not None: @@ -962,7 +957,7 @@ async def _experimental_fork_checkpoint( selected_step = target_step else: # Pull from W&B artifacts - api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] + api = wandb_sdk.api(api_key=self._client.api_key) from_entity = model.entity or api.default_entity # Iterate all artifact versions to find the best step. @@ -1012,17 +1007,17 @@ async def _experimental_fork_checkpoint( if verbose: print(f"Uploading forked checkpoint as W&B artifact for {model.name}...") - wandb.login(key=self._client.api_key) # ty:ignore[possibly-missing-attribute] - run = wandb.init( + wandb_sdk.login(key=self._client.api_key) + run = wandb_sdk.init( project=model.project, entity=model.entity, job_type="checkpoint-fork", name=f"fork-{from_model}-to-{model.name}", - settings=wandb.Settings(silent=True), + settings=wandb_sdk.settings(silent=True), ) assert run is not None - dest_artifact = wandb.Artifact(name=model.name, type="lora") + dest_artifact = wandb_sdk.artifact(name=model.name, type="lora") dest_artifact.add_dir(checkpoint_dir) aliases = ["latest"] if selected_step is not None: @@ -1031,7 +1026,7 @@ async def _experimental_fork_checkpoint( run.finish() # Copy provenance from the source model's W&B run to the destination model - api = wandb.Api(api_key=self._client.api_key) # ty:ignore[possibly-missing-attribute] + api = wandb_sdk.api(api_key=self._client.api_key) try: source_run = api.run(f"{model.entity}/{from_project}/{from_model}") source_provenance = source_run.config.get("wandb.provenance") diff --git a/src/art/utils/benchmarking/log_constant_metrics_wandb.py b/src/art/utils/benchmarking/log_constant_metrics_wandb.py index ada248105..6da9b07a8 100644 --- a/src/art/utils/benchmarking/log_constant_metrics_wandb.py +++ b/src/art/utils/benchmarking/log_constant_metrics_wandb.py @@ -1,8 +1,7 @@ """Utilities for logging constant baseline metrics to Weights & Biases.""" -import wandb - import art +from art.utils import wandb_sdk async def log_constant_metrics_wandb( @@ -30,7 +29,7 @@ async def log_constant_metrics_wandb( Example: `{"train": {"loss": 0.5}, "val": {"loss": 0.4, "accuracy": 0.8}}` """ - run = wandb.init( + run = wandb_sdk.init( project=model.project, name=logged_run_name if logged_run_name else model.name, reinit="create_new", diff --git a/src/art/utils/deployment/wandb.py b/src/art/utils/deployment/wandb.py index 9ddf778e8..8add1dd4f 100644 --- a/src/art/utils/deployment/wandb.py +++ b/src/art/utils/deployment/wandb.py @@ -5,6 +5,7 @@ from art.errors import UnsupportedBaseModelDeploymentError +from .. import wandb_sdk from .common import DeploymentConfig if TYPE_CHECKING: @@ -52,8 +53,6 @@ def deploy_wandb( Returns: The model name for inference: wandb-artifact:///{entity}/{project}/{name}:step{step} """ - import wandb - if model.base_model not in WANDB_SUPPORTED_BASE_MODELS: raise UnsupportedBaseModelDeploymentError( message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}" @@ -64,23 +63,23 @@ def deploy_wandb( # Get the user's default entity from W&B if not set if model.entity is None: - api = wandb.Api() + api = wandb_sdk.api() model.entity = api.default_entity if verbose: print(f"Uploading checkpoint from {checkpoint_path} to W&B...") - run = wandb.init( + run = wandb_sdk.init( name=model.name + " (deployment)", entity=model.entity, project=model.project, - settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]), + settings=wandb_sdk.settings(api_key=os.environ["WANDB_API_KEY"]), ) try: metadata: dict[str, object] = {"wandb.base_model": model.base_model} if config is not None: metadata["wandb.provenance"] = config.provenance - artifact = wandb.Artifact( + artifact = wandb_sdk.artifact( model.name, type="lora", metadata=metadata, diff --git a/src/art/utils/record_provenance.py b/src/art/utils/record_provenance.py index 84a8bce7f..9b202be35 100644 --- a/src/art/utils/record_provenance.py +++ b/src/art/utils/record_provenance.py @@ -3,18 +3,18 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - import wandb + from wandb.sdk.wandb_run import Run +from . import wandb_sdk -def record_provenance(run: wandb.Run, provenance: str) -> None: - """Record provenance on the latest artifact version's metadata.""" - import wandb as wandb_module - api = wandb_module.Api() +def record_provenance(run: Run, provenance: str) -> None: + """Record provenance on the latest artifact version's metadata.""" + api = wandb_sdk.api() artifact_path = f"{run.entity}/{run.project}/{run.name}:latest" try: artifact = api.artifact(artifact_path, type="lora") - except wandb_module.errors.CommError: + except wandb_sdk.comm_error_type(): return # No artifact exists yet existing = artifact.metadata.get("wandb.provenance") diff --git a/src/art/utils/wandb_sdk.py b/src/art/utils/wandb_sdk.py new file mode 100644 index 000000000..0c015c42a --- /dev/null +++ b/src/art/utils/wandb_sdk.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from wandb.apis.public import Api + from wandb.sdk.artifacts.artifact import Artifact + from wandb.sdk.wandb_run import Run + from wandb.sdk.wandb_settings import Settings + + +def api(*args: Any, **kwargs: Any) -> Api: + from wandb.apis.public import Api + + return Api(*args, **kwargs) + + +def artifact(*args: Any, **kwargs: Any) -> Artifact: + from wandb.sdk.artifacts.artifact import Artifact + + return Artifact(*args, **kwargs) + + +def init(*args: Any, **kwargs: Any) -> Run: + from wandb.sdk.wandb_init import init as wandb_init + + return wandb_init(*args, **kwargs) + + +def login(*args: Any, **kwargs: Any) -> Any: + from wandb.sdk.wandb_login import login as wandb_login + + return wandb_login(*args, **kwargs) + + +def settings(*args: Any, **kwargs: Any) -> Settings: + from wandb.sdk.wandb_settings import Settings + + return Settings(*args, **kwargs) + + +def comm_error_type() -> type[Exception]: + from wandb.errors import CommError + + return CommError diff --git a/tests/integration/test_provenance.py b/tests/integration/test_provenance.py index 187fcad88..459e2297d 100644 --- a/tests/integration/test_provenance.py +++ b/tests/integration/test_provenance.py @@ -4,10 +4,10 @@ from datetime import datetime from dotenv import load_dotenv -import wandb import art from art.serverless.backend import ServerlessBackend +from art.utils import wandb_sdk load_dotenv() @@ -41,7 +41,7 @@ def get_latest_artifact_provenance( entity: str, project: str, name: str ) -> list[str] | None: """Fetch provenance from the latest W&B artifact's metadata.""" - api = wandb.Api() + api = wandb_sdk.api() artifact = api.artifact(f"{entity}/{project}/{name}:latest", type="lora") return artifact.metadata.get("wandb.provenance") diff --git a/tests/integration/test_push_and_fork.py b/tests/integration/test_push_and_fork.py index 1f590f92d..401e281d8 100644 --- a/tests/integration/test_push_and_fork.py +++ b/tests/integration/test_push_and_fork.py @@ -154,9 +154,9 @@ async def test_fork_checkpoint_from_wandb(): # Verify the forked checkpoint matches model A's checkpoint. # Pull both via W&B directly (the fork uploaded the artifact # with a step{N} alias matching the source step). - import wandb + from art.utils import wandb_sdk - api = wandb.Api(api_key=backend._client.api_key) # ty:ignore[possibly-missing-attribute] + api = wandb_sdk.api(api_key=backend._client.api_key) with tempfile.TemporaryDirectory() as tmpdir: dir_a = os.path.join(tmpdir, "a") dir_b = os.path.join(tmpdir, "b") diff --git a/tests/unit/test_metric_routing.py b/tests/unit/test_metric_routing.py index 529cdf14a..6be608d4f 100644 --- a/tests/unit/test_metric_routing.py +++ b/tests/unit/test_metric_routing.py @@ -1,7 +1,6 @@ import json import os from pathlib import Path -import types from unittest.mock import MagicMock, patch import pytest @@ -46,13 +45,13 @@ def test_get_wandb_run_registers_taxonomy_sections(self, tmp_path: Path) -> None fake_run = MagicMock() fake_run._is_finished = False - fake_wandb = types.SimpleNamespace() - fake_wandb.init = MagicMock(return_value=fake_run) - fake_wandb.define_metric = MagicMock() - fake_wandb.Settings = lambda **kwargs: kwargs + fake_init = MagicMock(return_value=fake_run) with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): - with patch.dict("sys.modules", {"wandb": fake_wandb}): + with ( + patch("art.model.wandb_sdk.init", fake_init), + patch("art.model.wandb_sdk.settings", lambda **kwargs: kwargs), + ): model = Model( name="test-model", project="test-project", @@ -88,13 +87,13 @@ def test_log_metrics_defines_nested_cost_keys_with_training_step( fake_run._is_finished = False fake_run.config = MagicMock() - fake_wandb = types.SimpleNamespace() - fake_wandb.init = MagicMock(return_value=fake_run) - fake_wandb.define_metric = MagicMock() - fake_wandb.Settings = lambda **kwargs: kwargs + fake_init = MagicMock(return_value=fake_run) with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): - with patch.dict("sys.modules", {"wandb": fake_wandb}): + with ( + patch("art.model.wandb_sdk.init", fake_init), + patch("art.model.wandb_sdk.settings", lambda **kwargs: kwargs), + ): model = Model( name="test-model", project="test-project", @@ -134,10 +133,7 @@ def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None: fake_run._is_finished = False fake_run.config = MagicMock() - fake_wandb = types.SimpleNamespace() - fake_wandb.init = MagicMock(return_value=fake_run) - fake_wandb.define_metric = MagicMock() - fake_wandb.Settings = lambda **kwargs: kwargs + fake_init = MagicMock(return_value=fake_run) payload = { "experiment": {"learning_rate": 1e-5, "batch_size": 4}, @@ -145,7 +141,10 @@ def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None: } with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): - with patch.dict("sys.modules", {"wandb": fake_wandb}): + with ( + patch("art.model.wandb_sdk.init", fake_init), + patch("art.model.wandb_sdk.settings", lambda **kwargs: kwargs), + ): model = Model( name="test-model", project="test-project", @@ -155,7 +154,7 @@ def test_update_wandb_config_seeds_wandb_init(self, tmp_path: Path) -> None: run = model._get_wandb_run() assert run is fake_run - init_kwargs = fake_wandb.init.call_args.kwargs + init_kwargs = fake_init.call_args.kwargs assert init_kwargs["config"] == payload assert "allow_val_change" not in init_kwargs fake_run.config.update.assert_called_once_with(payload) @@ -165,13 +164,13 @@ def test_update_wandb_config_updates_active_run(self, tmp_path: Path) -> None: fake_run._is_finished = False fake_run.config = MagicMock() - fake_wandb = types.SimpleNamespace() - fake_wandb.init = MagicMock(return_value=fake_run) - fake_wandb.define_metric = MagicMock() - fake_wandb.Settings = lambda **kwargs: kwargs + fake_init = MagicMock(return_value=fake_run) with patch.dict(os.environ, {"WANDB_API_KEY": "test-key"}, clear=False): - with patch.dict("sys.modules", {"wandb": fake_wandb}): + with ( + patch("art.model.wandb_sdk.init", fake_init), + patch("art.model.wandb_sdk.settings", lambda **kwargs: kwargs), + ): model = Model( name="test-model", project="test-project", diff --git a/typings/wandb/__init__.pyi b/typings/wandb/__init__.pyi deleted file mode 100644 index 09d1c8d16..000000000 --- a/typings/wandb/__init__.pyi +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Any - -class Settings: - def __init__(self, **kwargs: Any) -> None: ... - -class Artifact: - aliases: list[str] - metadata: dict[str, Any] - def __init__(self, name: str, type: str, **kwargs: Any) -> None: ... - def add_dir(self, local_path: str, **kwargs: Any) -> None: ... - def add_file(self, local_path: str, **kwargs: Any) -> None: ... - def download(self, **kwargs: Any) -> str: ... - def save(self) -> None: ... - def wait(self) -> Artifact: ... - -class Run: - entity: str - project: str - name: str - config: Any - _is_finished: bool - def finish(self, *args: Any, **kwargs: Any) -> None: ... - def define_metric(self, *args: Any, **kwargs: Any) -> None: ... - def log(self, *args: Any, **kwargs: Any) -> None: ... - def log_artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... - -class Api: - default_entity: str - def __init__(self, *args: Any, **kwargs: Any) -> None: ... - def artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... - def artifacts(self, *args: Any, **kwargs: Any) -> list[Artifact]: ... - def run(self, *args: Any, **kwargs: Any) -> Run: ... - -def init(*args: Any, **kwargs: Any) -> Run: ... -def login(*args: Any, **kwargs: Any) -> Any: ... - -class errors: - class CommError(Exception): ... diff --git a/typings/wandb/sdk/__init__.pyi b/typings/wandb/sdk/__init__.pyi deleted file mode 100644 index 1ce9ecf99..000000000 --- a/typings/wandb/sdk/__init__.pyi +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Any - -__all__: list[str] - -def __getattr__(name: str) -> Any: ... diff --git a/typings/wandb/sdk/wandb_run.pyi b/typings/wandb/sdk/wandb_run.pyi deleted file mode 100644 index 416ae101e..000000000 --- a/typings/wandb/sdk/wandb_run.pyi +++ /dev/null @@ -1,3 +0,0 @@ -from wandb import Run - -__all__ = ["Run"] From 46a9d61047a6488715ce9aed5478242bfd79c455 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Thu, 25 Jun 2026 19:39:41 -0600 Subject: [PATCH 04/33] fix wandb boundary test mock --- .../unit/test_serverless_pipeline_trainer_compat.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_serverless_pipeline_trainer_compat.py b/tests/unit/test_serverless_pipeline_trainer_compat.py index 25dd1a415..66a63f848 100644 --- a/tests/unit/test_serverless_pipeline_trainer_compat.py +++ b/tests/unit/test_serverless_pipeline_trainer_compat.py @@ -1,4 +1,3 @@ -import sys from types import SimpleNamespace from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -250,12 +249,6 @@ def log_artifact(self, artifact): def finish(self) -> None: pass - fake_wandb = SimpleNamespace( - Artifact=FakeArtifact, - init=MagicMock(return_value=FakeRun()), - Settings=lambda **kwargs: kwargs, - ) - trajectory = Trajectory( messages_and_choices=[ {"role": "user", "content": "prompt"}, @@ -264,7 +257,11 @@ def finish(self) -> None: ) with patch.object(model, "_get_wandb_run", return_value=None): - with patch.dict(sys.modules, {"wandb": fake_wandb}): + with ( + patch("art.serverless.backend.wandb_sdk.artifact", FakeArtifact), + patch("art.serverless.backend.wandb_sdk.init", return_value=FakeRun()), + patch("art.serverless.backend.wandb_sdk.settings", lambda **kwargs: kwargs), + ): with patch("art.serverless.backend.asyncio.sleep", no_sleep): async for _ in backend._train_sft( model, From 6c05e09165cf3dfc25632221566c5384b5b5399a Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 17:29:10 -0600 Subject: [PATCH 05/33] test: cover tree gdn trainability --- .../test_gdn_cp_packed_correctness.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index ac14e8df8..93dcaa287 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -106,6 +106,20 @@ def test_gdn_cp_tree_fuzz_matches_cp1_oracle(tmp_path: Path) -> None: assert (tmp_path / f"tree_fuzz_rank_{rank}.ok").read_text() == "ok\n" +@pytest.mark.parametrize("cp_size", (2, 4)) +def test_gdn_cp_tree_trainability_loss_decreases(cp_size: int, tmp_path: Path) -> None: + _skip_without_gpus(cp_size) + init_method = file_init_method(tmp_path, f"tree_trainability_cp{cp_size}") + mp.spawn( + _tree_trainability_worker, + args=(cp_size, init_method, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_trainability_rank_{rank}.ok").read_text() == "ok\n" + + def _cp1_oracle_worker( rank: int, cp_size: int, @@ -235,6 +249,84 @@ def _tree_fuzz_oracle_worker( destroy_process_group() +def _tree_trainability_worker( + rank: int, + cp_size: int, + init_method: str, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + _, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size, params_dtype=torch.float32) + pack = _tree_chain_pack() + group_ids = pack.group_ids.cuda() + parent_ids = pack.parent_ids.cuda() + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) + plan = build_gdn_rank_execution_plan( + spec, + device=group_ids.device, + cp_rank=rank, + cp_size=cp_size, + planner_config=_tree_chain_planner_config(), + ) + assert any(plan.tree_chain_buckets_by_depth) + hidden = _tree_trainability_hidden(spec.real_token_count, cp_size=cp_size) + flat_hidden = hidden.transpose(0, 1).reshape(-1, hidden.shape[-1]) + local_index = torch.tensor( + plan.attention_token_indices, device=hidden.device, dtype=torch.long + ) + local_hidden = flat_hidden.index_select(0, local_index).unsqueeze(1) + optimizer = torch.optim.SGD(cp_gdn.parameters(), lr=5e-3) + + initial_loss = _tree_training_loss( + cp_gdn, + local_hidden, + group_ids, + parent_ids, + spec, + plan, + ).detach() + for _ in range(3): + optimizer.zero_grad(set_to_none=True) + loss = _tree_training_loss( + cp_gdn, + local_hidden, + group_ids, + parent_ids, + spec, + plan, + ) + loss.backward() + all_reduce_parameter_grads_coalesced(cp_gdn) + optimizer.step() + final_loss = _tree_training_loss( + cp_gdn, + local_hidden, + group_ids, + parent_ids, + spec, + plan, + ).detach() + assert final_loss < initial_loss, (initial_loss.item(), final_loss.item()) + Path(output_dir, f"tree_trainability_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + def _assert_case_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -529,6 +621,30 @@ def _assert_cp_matches_reference( torch.cuda.synchronize() +def _tree_training_loss( + gdn: torch.nn.Module, + local_hidden: torch.Tensor, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + spec: Any, + plan: Any, +) -> torch.Tensor: + out, _ = run_gdn_layer( + gdn, + local_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=spec, + execution_plan=plan, + cp_group=torch.distributed.group.WORLD, + ) + local_sum = out.float().square().sum() + denom = torch.tensor(out.numel(), device=out.device, dtype=torch.float32) + torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(denom, op=torch.distributed.ReduceOp.SUM) + return local_sum / denom.clamp_min(1.0) + + class _TensorGradView: def __init__(self, grad: torch.Tensor) -> None: self.grad = grad @@ -580,6 +696,20 @@ def _tree_hidden_and_grad( return hidden, grad +def _tree_trainability_hidden(sequence_length: int, *, cp_size: int) -> torch.Tensor: + generator = torch.Generator(device="cuda").manual_seed(6262026 + cp_size) + hidden = torch.randn( + sequence_length, + 1, + 64, + device="cuda", + dtype=torch.float32, + generator=generator, + ) + torch.distributed.broadcast(hidden, src=0) + return hidden + + def _tree_chain_pack(): long_root = torch.arange(11, 267) short_root = torch.arange(1001, 1097) From 498bcd822a58cf209853caac081e7e454c6547f4 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 17:42:50 -0600 Subject: [PATCH 06/33] test: keep tree trainability smoke small --- .../test_gdn_cp_packed_correctness.py | 78 +++++++++++++------ 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 93dcaa287..91290171d 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -270,7 +270,7 @@ def _tree_trainability_worker( expert_model_parallel_size=1, ) _, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size, params_dtype=torch.float32) - pack = _tree_chain_pack() + pack = _tree_trainability_pack() group_ids = pack.group_ids.cuda() parent_ids = pack.parent_ids.cuda() spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) @@ -290,35 +290,34 @@ def _tree_trainability_worker( local_hidden = flat_hidden.index_select(0, local_index).unsqueeze(1) optimizer = torch.optim.SGD(cp_gdn.parameters(), lr=5e-3) - initial_loss = _tree_training_loss( + initial_loss = _tree_training_loss_value( cp_gdn, local_hidden, group_ids, parent_ids, spec, plan, - ).detach() - for _ in range(3): - optimizer.zero_grad(set_to_none=True) - loss = _tree_training_loss( - cp_gdn, - local_hidden, - group_ids, - parent_ids, - spec, - plan, - ) - loss.backward() - all_reduce_parameter_grads_coalesced(cp_gdn) - optimizer.step() - final_loss = _tree_training_loss( + ) + optimizer.zero_grad(set_to_none=True) + loss = _tree_training_local_loss( + cp_gdn, + local_hidden, + group_ids, + parent_ids, + spec, + plan, + ) + loss.backward() + all_reduce_parameter_grads_coalesced(cp_gdn) + optimizer.step() + final_loss = _tree_training_loss_value( cp_gdn, local_hidden, group_ids, parent_ids, spec, plan, - ).detach() + ) assert final_loss < initial_loss, (initial_loss.item(), final_loss.item()) Path(output_dir, f"tree_trainability_rank_{rank}.ok").write_text("ok\n") finally: @@ -621,7 +620,28 @@ def _assert_cp_matches_reference( torch.cuda.synchronize() -def _tree_training_loss( +def _tree_training_local_loss( + gdn: torch.nn.Module, + local_hidden: torch.Tensor, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + spec: Any, + plan: Any, +) -> torch.Tensor: + out, _ = run_gdn_layer( + gdn, + local_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=spec, + execution_plan=plan, + cp_group=torch.distributed.group.WORLD, + ) + return out.float().square().mean() + + +@torch.no_grad() +def _tree_training_loss_value( gdn: torch.nn.Module, local_hidden: torch.Tensor, group_ids: torch.Tensor, @@ -639,10 +659,10 @@ def _tree_training_loss( cp_group=torch.distributed.group.WORLD, ) local_sum = out.float().square().sum() - denom = torch.tensor(out.numel(), device=out.device, dtype=torch.float32) + denominator = torch.tensor(out.numel(), device=out.device, dtype=torch.float32) torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(denom, op=torch.distributed.ReduceOp.SUM) - return local_sum / denom.clamp_min(1.0) + torch.distributed.all_reduce(denominator, op=torch.distributed.ReduceOp.SUM) + return local_sum / denominator.clamp_min(1.0) class _TensorGradView: @@ -710,6 +730,20 @@ def _tree_trainability_hidden(sequence_length: int, *, cp_size: int) -> torch.Te return hidden +def _tree_trainability_pack(): + root = torch.arange(11, 107) + mid = torch.arange(1001, 1161) + sibling = torch.arange(2001, 2081) + return pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([301, 302]))), + torch.cat((root, mid, torch.tensor([303, 304, 305]))), + torch.cat((root, sibling, torch.tensor([401]))), + ), + max_depth=2, + ) + + def _tree_chain_pack(): long_root = torch.arange(11, 267) short_root = torch.arange(1001, 1097) From 6f5c379e2c87b8eaf0089a3ae54f8e774bd954c8 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 17:50:50 -0600 Subject: [PATCH 07/33] test: relax tree trainability planner shape --- .../megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 91290171d..6229b7960 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -281,7 +281,6 @@ def _tree_trainability_worker( cp_size=cp_size, planner_config=_tree_chain_planner_config(), ) - assert any(plan.tree_chain_buckets_by_depth) hidden = _tree_trainability_hidden(spec.real_token_count, cp_size=cp_size) flat_hidden = hidden.transpose(0, 1).reshape(-1, hidden.shape[-1]) local_index = torch.tensor( From d5b90880556e0fb153fda693d05c839a14c3e427 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 18:04:18 -0600 Subject: [PATCH 08/33] test: use production dtype for tree trainability --- .../gdn_shared_prefix/test_gdn_cp_packed_correctness.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 6229b7960..6613aa433 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -269,7 +269,10 @@ def _tree_trainability_worker( context_parallel_size=cp_size, expert_model_parallel_size=1, ) - _, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size, params_dtype=torch.float32) + _, cp_gdn = _make_matching_gdn_pair( + cp_size=cp_size, + params_dtype=GDN_CORRECTNESS_DTYPE, + ) pack = _tree_trainability_pack() group_ids = pack.group_ids.cuda() parent_ids = pack.parent_ids.cuda() @@ -287,7 +290,7 @@ def _tree_trainability_worker( plan.attention_token_indices, device=hidden.device, dtype=torch.long ) local_hidden = flat_hidden.index_select(0, local_index).unsqueeze(1) - optimizer = torch.optim.SGD(cp_gdn.parameters(), lr=5e-3) + optimizer = torch.optim.SGD(cp_gdn.parameters(), lr=5e-2) initial_loss = _tree_training_loss_value( cp_gdn, @@ -722,7 +725,7 @@ def _tree_trainability_hidden(sequence_length: int, *, cp_size: int) -> torch.Te 1, 64, device="cuda", - dtype=torch.float32, + dtype=GDN_CORRECTNESS_DTYPE, generator=generator, ) torch.distributed.broadcast(hidden, src=0) From 4c2b3df8cb627d790710384e6893d5b0c3ced909 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 18:06:43 -0600 Subject: [PATCH 09/33] test: check tree trainability by parameter update --- .../test_gdn_cp_packed_correctness.py | 66 ++++++++----------- 1 file changed, 26 insertions(+), 40 deletions(-) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 6613aa433..297205c73 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -107,7 +107,9 @@ def test_gdn_cp_tree_fuzz_matches_cp1_oracle(tmp_path: Path) -> None: @pytest.mark.parametrize("cp_size", (2, 4)) -def test_gdn_cp_tree_trainability_loss_decreases(cp_size: int, tmp_path: Path) -> None: +def test_gdn_cp_tree_trainability_updates_parameters( + cp_size: int, tmp_path: Path +) -> None: _skip_without_gpus(cp_size) init_method = file_init_method(tmp_path, f"tree_trainability_cp{cp_size}") mp.spawn( @@ -292,14 +294,6 @@ def _tree_trainability_worker( local_hidden = flat_hidden.index_select(0, local_index).unsqueeze(1) optimizer = torch.optim.SGD(cp_gdn.parameters(), lr=5e-2) - initial_loss = _tree_training_loss_value( - cp_gdn, - local_hidden, - group_ids, - parent_ids, - spec, - plan, - ) optimizer.zero_grad(set_to_none=True) loss = _tree_training_local_loss( cp_gdn, @@ -311,16 +305,23 @@ def _tree_trainability_worker( ) loss.backward() all_reduce_parameter_grads_coalesced(cp_gdn) + grad_norm = _parameter_l1_norm( + parameter.grad + for parameter in cp_gdn.parameters() + if parameter.grad is not None + ) + assert torch.isfinite(grad_norm) + assert float(grad_norm.item()) > 0.0 + before_step = [ + parameter.detach().float().clone() for parameter in cp_gdn.parameters() + ] optimizer.step() - final_loss = _tree_training_loss_value( - cp_gdn, - local_hidden, - group_ids, - parent_ids, - spec, - plan, + update_norm = _parameter_l1_norm( + parameter.detach().float() - before + for parameter, before in zip(cp_gdn.parameters(), before_step, strict=True) ) - assert final_loss < initial_loss, (initial_loss.item(), final_loss.item()) + assert torch.isfinite(update_norm) + assert float(update_norm.item()) > 0.0 Path(output_dir, f"tree_trainability_rank_{rank}.ok").write_text("ok\n") finally: if getattr(ps, "model_parallel_is_initialized", lambda: False)(): @@ -642,29 +643,14 @@ def _tree_training_local_loss( return out.float().square().mean() -@torch.no_grad() -def _tree_training_loss_value( - gdn: torch.nn.Module, - local_hidden: torch.Tensor, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - spec: Any, - plan: Any, -) -> torch.Tensor: - out, _ = run_gdn_layer( - gdn, - local_hidden, - group_ids=group_ids, - parent_ids=parent_ids, - execution_spec=spec, - execution_plan=plan, - cp_group=torch.distributed.group.WORLD, - ) - local_sum = out.float().square().sum() - denominator = torch.tensor(out.numel(), device=out.device, dtype=torch.float32) - torch.distributed.all_reduce(local_sum, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(denominator, op=torch.distributed.ReduceOp.SUM) - return local_sum / denominator.clamp_min(1.0) +def _parameter_l1_norm(tensors: Any) -> torch.Tensor: + total: torch.Tensor | None = None + for tensor in tensors: + contribution = tensor.detach().float().abs().sum() + total = contribution if total is None else total + contribution + if total is None: + return torch.tensor(0.0, device="cuda") + return total class _TensorGradView: From e944f811bcf06d4958370693ad745e8817602e1e Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 19:33:59 -0600 Subject: [PATCH 10/33] fix: allow partial lora target coverage --- src/art/megatron/lora.py | 6 ++- .../megatron/lora/test_lora_module_loading.py | 43 +++++++++++++++++++ .../model_support/chat_template_rollout.py | 36 ++++++++++------ 3 files changed, 71 insertions(+), 14 deletions(-) create mode 100644 tests/integration/megatron/lora/test_lora_module_loading.py diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 2b3962e38..2ce4144e9 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -599,8 +599,10 @@ def _has_live_slot_grads(self, ref: LoRASlotRef) -> bool: ) def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: - weights = self._adapter_weights(adapter_model, require=True) - assert weights is not None + weights = self._adapter_weights(adapter_model, require=False) + if weights is None: + self.reset_lora_parameters() + return self._load_weight(weights[0], into=self.A_T) self._load_weight(weights[1], into=self.B_T) diff --git a/tests/integration/megatron/lora/test_lora_module_loading.py b/tests/integration/megatron/lora/test_lora_module_loading.py new file mode 100644 index 000000000..2c0c1f1d4 --- /dev/null +++ b/tests/integration/megatron/lora/test_lora_module_loading.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from art.megatron.lora import LoRA + + +def test_load_lora_treats_absent_site_as_identity() -> None: + module = LoRA( + "base_model.model.foo", + in_features=3, + out_features=5, + rank=2, + alpha=32, + dtype=torch.float32, + device=torch.device("cpu"), + ) + adapter = { + "base_model.model.foo.lora_A.weight": torch.ones(2, 3), + "base_model.model.foo.lora_B.weight": torch.ones(5, 2), + } + x = torch.ones(4, 3) + + module.load_lora(adapter) + assert module(x).abs().sum() > 0 + + module.load_lora({}) + assert torch.count_nonzero(module.B_T) == 0 + assert torch.allclose(module(x), torch.zeros(4, 5)) + + +def test_load_lora_rejects_partially_present_site() -> None: + module = LoRA( + "base_model.model.foo", + in_features=3, + out_features=5, + rank=2, + alpha=32, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + with pytest.raises(KeyError, match="Incomplete LoRA adapter keys"): + module.load_lora({"base_model.model.foo.lora_A.weight": torch.ones(2, 3)}) diff --git a/tests/integration/megatron/model_support/chat_template_rollout.py b/tests/integration/megatron/model_support/chat_template_rollout.py index 4067eedbf..138e1dc0e 100644 --- a/tests/integration/megatron/model_support/chat_template_rollout.py +++ b/tests/integration/megatron/model_support/chat_template_rollout.py @@ -13,6 +13,7 @@ _messages_for_chat_template, tokenize_trajectory, tokenize_trajectory_groups, + tokenize_vllm_trajectory_histories, ) from art.trajectories import History from tests.support.chat_template_conformance_cases import ( @@ -124,27 +125,31 @@ def run_chat_template_rollout(base_model: str) -> ChatTemplateRolloutReport: ) ) - non_final_tool_call_base = tokenize_trajectory( + non_final_tool_call_base_results = tokenize_vllm_trajectory_histories( tokenizer=tokenizer, - image_processor=None, - history=_history(inputs.non_final_tool_call_base), + histories=[_history(inputs.non_final_tool_call_base)], advantage=1.0, allow_training_without_logprobs=False, trajectory=inputs.non_final_tool_call_base, ) - non_final_tool_call_mutated = tokenize_trajectory( + non_final_tool_call_mutated_results = tokenize_vllm_trajectory_histories( tokenizer=tokenizer, - image_processor=None, - history=_history(inputs.non_final_tool_call_mutated), + histories=[_history(inputs.non_final_tool_call_mutated)], advantage=1.0, allow_training_without_logprobs=False, trajectory=inputs.non_final_tool_call_mutated, ) - if non_final_tool_call_base is None or non_final_tool_call_mutated is None: + if not non_final_tool_call_base_results or not non_final_tool_call_mutated_results: raise RuntimeError("tool-call tokenization produced no trainable tokens") + non_final_tool_call_base = non_final_tool_call_base_results[-1] + non_final_tool_call_mutated = non_final_tool_call_mutated_results[-1] if ( - len(non_final_tool_call_base.choice_offsets) < 2 - or len(non_final_tool_call_mutated.choice_offsets) < 2 + sum(len(result.choice_offsets) for result in non_final_tool_call_base_results) + < 2 + or sum( + len(result.choice_offsets) for result in non_final_tool_call_mutated_results + ) + < 2 ): raise RuntimeError("expected non-final tool call and final assistant answer") non_final_tool_call_prefix_changed = _assistant_prefix_tokens( @@ -157,10 +162,17 @@ def run_chat_template_rollout(base_model: str) -> ChatTemplateRolloutReport: scenarios.append( ChatTemplateScenarioReport( name="rl_non_final_tool_call_prefill_mutation", - entrypoint="tokenize_trajectory", + entrypoint="tokenize_vllm_trajectory_histories", passed=non_final_tool_call_prefix_changed - and int(sum(non_final_tool_call_base.assistant_mask)) > 0, - assistant_token_count=int(sum(non_final_tool_call_base.assistant_mask)), + and sum( + int(sum(result.assistant_mask)) + for result in non_final_tool_call_base_results + ) + > 0, + assistant_token_count=sum( + int(sum(result.assistant_mask)) + for result in non_final_tool_call_base_results + ), mutation_changed_prompt=non_final_tool_call_prefix_changed, ) ) From 1c79aceb5e4203b52ed7b2bb0b911e001599bd58 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 20:02:20 -0600 Subject: [PATCH 11/33] fix: avoid qwen35 dense te compile crash --- src/art/megatron/compile_workarounds.py | 4 ++++ src/art/megatron/model_support/handlers/qwen3_5.py | 12 ++++++++++++ .../megatron/model_support/test_compile_flags.py | 13 ++++++++++++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/art/megatron/compile_workarounds.py b/src/art/megatron/compile_workarounds.py index a6d9d916f..c12fd6b47 100644 --- a/src/art/megatron/compile_workarounds.py +++ b/src/art/megatron/compile_workarounds.py @@ -211,6 +211,10 @@ def _sync_dealloc_fake( _install_self_attn_linear_proj_reduce_scatter_workaround() if "weighted_bias_swiglu_no_inner_forward_cast" in flags: _install_weighted_bias_swiglu_no_inner_forward_cast_workaround() + if "te_layernorm_column_parallel_linear" in flags: + te_ext.TELayerNormColumnParallelLinear.forward = _disable( + te_ext.TELayerNormColumnParallelLinear.forward + ) deepep_flags = {"deepep_permute_restore", "deepep_dispatch_combine"} & flags if deepep_flags: diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py index 3d4ea98d8..392b73cba 100644 --- a/src/art/megatron/model_support/handlers/qwen3_5.py +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -37,6 +37,7 @@ "weighted_bias_swiglu_no_inner_forward_cast", ) _QWEN35_MOE_UNCONDITIONAL_COMPILE_WORKAROUND_FLAGS: tuple[str, ...] = () +_QWEN35_DENSE_COMPILE_WORKAROUND_FLAGS = ("te_layernorm_column_parallel_linear",) _ART_LAYER_PREFIX = "base_model.model.model.layers." _VLLM_LAYER_PREFIX = "base_model.model.model.language_model.layers." _ART_MOE_EXPERT_KEY_RE = re.compile( @@ -360,6 +361,17 @@ def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: class Qwen35DenseHandler(Qwen35BaseHandler): key = "qwen3_5_dense" + def compile_workaround_config( + self, + provider: Any, + ) -> CompileWorkaroundConfig: + return CompileWorkaroundConfig( + flags=_compile_workaround_flags_for_provider( + provider, + _QWEN35_DENSE_COMPILE_WORKAROUND_FLAGS, + ), + ) + class Qwen35MoeHandler(Qwen35BaseHandler): key = "qwen3_5_moe" diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index 15654fc09..f0e1e1645 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -1,4 +1,7 @@ -from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER +from art.megatron.model_support.handlers.qwen3_5 import ( + QWEN3_5_DENSE_HANDLER, + QWEN3_5_MOE_HANDLER, +) from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER _QWEN3_MOE_COMPILE_FLAGS = ( @@ -17,6 +20,14 @@ "te_triton_permute_with_mask_map", "weighted_bias_swiglu_no_inner_forward_cast", ) +_QWEN35_DENSE_COMPILE_FLAGS = ("te_layernorm_column_parallel_linear",) + + +def test_qwen35_dense_compile_workarounds_cover_te_layernorm_linear() -> None: + provider = type("Provider", (), {"context_parallel_size": 1})() + config = QWEN3_5_DENSE_HANDLER.compile_workaround_config(provider) + assert config.flags == _QWEN35_DENSE_COMPILE_FLAGS + assert config.unconditional_flags == () def test_qwen3_moe_compile_workarounds_cover_deepep_permute_restore() -> None: From b039a282773b40ab41c66e05c4e404e6c2d8524b Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 20:05:33 -0600 Subject: [PATCH 12/33] fix: disable te layernorm linear compile path --- src/art/megatron/compile_workarounds.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/art/megatron/compile_workarounds.py b/src/art/megatron/compile_workarounds.py index c12fd6b47..83536f98b 100644 --- a/src/art/megatron/compile_workarounds.py +++ b/src/art/megatron/compile_workarounds.py @@ -212,9 +212,12 @@ def _sync_dealloc_fake( if "weighted_bias_swiglu_no_inner_forward_cast" in flags: _install_weighted_bias_swiglu_no_inner_forward_cast_workaround() if "te_layernorm_column_parallel_linear" in flags: + from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear + te_ext.TELayerNormColumnParallelLinear.forward = _disable( te_ext.TELayerNormColumnParallelLinear.forward ) + LayerNormLinear.forward = _disable(LayerNormLinear.forward) deepep_flags = {"deepep_permute_restore", "deepep_dispatch_combine"} & flags if deepep_flags: From 0527d761b0991c1cb0db9ec3ab590a85dc1da2b9 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 20:09:07 -0600 Subject: [PATCH 13/33] fix: graph break qwen35 dense lora fc1 --- src/art/megatron/compile_workarounds.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/art/megatron/compile_workarounds.py b/src/art/megatron/compile_workarounds.py index 83536f98b..458935206 100644 --- a/src/art/megatron/compile_workarounds.py +++ b/src/art/megatron/compile_workarounds.py @@ -214,6 +214,11 @@ def _sync_dealloc_fake( if "te_layernorm_column_parallel_linear" in flags: from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear + from art.megatron import lora as art_lora + + art_lora.SharedExpertsLinearFC1LoRA.forward = _disable( + art_lora.SharedExpertsLinearFC1LoRA.forward + ) te_ext.TELayerNormColumnParallelLinear.forward = _disable( te_ext.TELayerNormColumnParallelLinear.forward ) From 2acd4cbd8224fbba0d20b1b69039edf33836f794 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 20:16:20 -0600 Subject: [PATCH 14/33] fix: handle empty lora fc1 slices --- src/art/megatron/compile_workarounds.py | 13 ------ src/art/megatron/lora.py | 9 ++++ .../model_support/handlers/qwen3_5.py | 12 ----- .../model_support/test_compile_flags.py | 15 +------ .../megatron/model_support/test_workflow.py | 44 +++++++++++++++++++ .../megatron/model_support/workflow.py | 11 ++--- 6 files changed, 60 insertions(+), 44 deletions(-) diff --git a/src/art/megatron/compile_workarounds.py b/src/art/megatron/compile_workarounds.py index 458935206..acfbec203 100644 --- a/src/art/megatron/compile_workarounds.py +++ b/src/art/megatron/compile_workarounds.py @@ -211,19 +211,6 @@ def _sync_dealloc_fake( _install_self_attn_linear_proj_reduce_scatter_workaround() if "weighted_bias_swiglu_no_inner_forward_cast" in flags: _install_weighted_bias_swiglu_no_inner_forward_cast_workaround() - if "te_layernorm_column_parallel_linear" in flags: - from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear - - from art.megatron import lora as art_lora - - art_lora.SharedExpertsLinearFC1LoRA.forward = _disable( - art_lora.SharedExpertsLinearFC1LoRA.forward - ) - te_ext.TELayerNormColumnParallelLinear.forward = _disable( - te_ext.TELayerNormColumnParallelLinear.forward - ) - LayerNormLinear.forward = _disable(LayerNormLinear.forward) - deepep_flags = {"deepep_permute_restore", "deepep_dispatch_combine"} & flags if deepep_flags: deepep_manager = _require_attr(token_dispatcher, "_DeepepManager") diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 2ce4144e9..61633f35f 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1503,6 +1503,15 @@ def __init__( ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + if int(x.numel()) == 0: + zero = x.sum() * 0.0 + weight = getattr(self.linear_fc1, "weight", None) + if isinstance(weight, torch.Tensor): + zero = zero + weight.to(dtype=x.dtype).sum() * 0.0 + for lora in (self.gate_lora, self.up_lora): + zero = zero + lora.A_T.to(dtype=x.dtype).sum() * 0.0 + zero = zero + lora.B_T.to(dtype=x.dtype).sum() * 0.0 + return zero.expand(*x.shape[:-1], self.linear_fc1.out_features).clone(), None base_output, bias_out = self.linear_fc1(x) if isinstance(base_output, tuple): base_out, lora_input = base_output diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py index 392b73cba..3d4ea98d8 100644 --- a/src/art/megatron/model_support/handlers/qwen3_5.py +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -37,7 +37,6 @@ "weighted_bias_swiglu_no_inner_forward_cast", ) _QWEN35_MOE_UNCONDITIONAL_COMPILE_WORKAROUND_FLAGS: tuple[str, ...] = () -_QWEN35_DENSE_COMPILE_WORKAROUND_FLAGS = ("te_layernorm_column_parallel_linear",) _ART_LAYER_PREFIX = "base_model.model.model.layers." _VLLM_LAYER_PREFIX = "base_model.model.model.language_model.layers." _ART_MOE_EXPERT_KEY_RE = re.compile( @@ -361,17 +360,6 @@ def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: class Qwen35DenseHandler(Qwen35BaseHandler): key = "qwen3_5_dense" - def compile_workaround_config( - self, - provider: Any, - ) -> CompileWorkaroundConfig: - return CompileWorkaroundConfig( - flags=_compile_workaround_flags_for_provider( - provider, - _QWEN35_DENSE_COMPILE_WORKAROUND_FLAGS, - ), - ) - class Qwen35MoeHandler(Qwen35BaseHandler): key = "qwen3_5_moe" diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index f0e1e1645..8dadcd82c 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -1,7 +1,4 @@ -from art.megatron.model_support.handlers.qwen3_5 import ( - QWEN3_5_DENSE_HANDLER, - QWEN3_5_MOE_HANDLER, -) +from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER _QWEN3_MOE_COMPILE_FLAGS = ( @@ -20,16 +17,6 @@ "te_triton_permute_with_mask_map", "weighted_bias_swiglu_no_inner_forward_cast", ) -_QWEN35_DENSE_COMPILE_FLAGS = ("te_layernorm_column_parallel_linear",) - - -def test_qwen35_dense_compile_workarounds_cover_te_layernorm_linear() -> None: - provider = type("Provider", (), {"context_parallel_size": 1})() - config = QWEN3_5_DENSE_HANDLER.compile_workaround_config(provider) - assert config.flags == _QWEN35_DENSE_COMPILE_FLAGS - assert config.unconditional_flags == () - - def test_qwen3_moe_compile_workarounds_cover_deepep_permute_restore() -> None: provider = type("Provider", (), {"context_parallel_size": 1})() config = QWEN3_MOE_HANDLER.compile_workaround_config(provider) diff --git a/tests/integration/megatron/model_support/test_workflow.py b/tests/integration/megatron/model_support/test_workflow.py index 0dc4f2113..30f16774d 100644 --- a/tests/integration/megatron/model_support/test_workflow.py +++ b/tests/integration/megatron/model_support/test_workflow.py @@ -10,6 +10,7 @@ from .validation_spec import ValidationReport, ValidationStageResult from .workflow import ( + KEEP_TOPOLOGY_ARTIFACTS_ENV, MANDATORY_VALIDATION_STAGES, NATIVE_VLLM_LORA_STAGE, SKIP_SENSITIVITY_ENV, @@ -370,6 +371,49 @@ def test_build_validation_report_populates_architecture_stage( assert native_vllm_lora_stage.artifact_dir == "/tmp/native-vllm-lora" +def test_build_validation_report_preserves_traces_when_sensitivity_runs( + monkeypatch, +) -> None: + seen_keep_env: list[str | None] = [] + + monkeypatch.delenv(KEEP_TOPOLOGY_ARTIFACTS_ENV, raising=False) + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow.inspect_architecture", + lambda base_model: ArchitectureReport( + base_model=base_model, + model_key="qwen3_5_moe", + handler_key="qwen3_5_moe", + layer_families=[LayerFamilyInstance(key="standard_attention", count=1)], + recommended_min_layers=1, + ), + ) + + def _run_stage_in_subprocess( + *, + stage_name, + base_model, + architecture, + allow_unvalidated_arch=False, + ) -> ValidationStageResult: + del base_model, architecture, allow_unvalidated_arch + if stage_name == "correctness_sensitivity": + seen_keep_env.append(os.environ.get(KEEP_TOPOLOGY_ARTIFACTS_ENV)) + return ValidationStageResult(name=stage_name, passed=True, metrics={}) + + monkeypatch.setattr( + "tests.integration.megatron.model_support.workflow._run_stage_in_subprocess", + _run_stage_in_subprocess, + ) + + build_validation_report( + base_model="Qwen/Qwen3.5-35B-A3B", + include_sensitivity=True, + ) + + assert seen_keep_env == ["1"] + assert os.environ.get(KEEP_TOPOLOGY_ARTIFACTS_ENV) is None + + def test_build_validation_report_captures_hf_parity_failure(monkeypatch) -> None: monkeypatch.setattr( "tests.integration.megatron.model_support.workflow.inspect_architecture", diff --git a/tests/integration/megatron/model_support/workflow.py b/tests/integration/megatron/model_support/workflow.py index ab59371a0..9e1d7e1a2 100644 --- a/tests/integration/megatron/model_support/workflow.py +++ b/tests/integration/megatron/model_support/workflow.py @@ -37,6 +37,7 @@ LIVE_TRAINING_LOG_PATH = LOCAL_LOG_DIR / "live_training.log" ORACLE_LIVE_TRAINING_LOG_ENV = "ART_ORACLE_LIVE_TRAINING_LOG" SKIP_SENSITIVITY_ENV = "ART_MODEL_SUPPORT_SKIP_SENSITIVITY" +KEEP_TOPOLOGY_ARTIFACTS_ENV = "ART_ORACLE_KEEP_TOPOLOGY_ARTIFACTS" WORKFLOW_ARTIFACT_SUITE_NAME = "Megatron model-support validation workflow" MANDATORY_VALIDATION_STAGES = ( @@ -805,11 +806,11 @@ def build_validation_report( YES_NO_TRAINABILITY_STAGE: run_yes_no_trainability_stage, NATIVE_VLLM_LORA_STAGE: run_native_vllm_lora_stage, } - env = ( - {SKIP_SENSITIVITY_ENV: "0" if include_sensitivity else "1"} - if include_sensitivity is not None - else {} - ) + env = {} + if include_sensitivity is not None: + env[SKIP_SENSITIVITY_ENV] = "0" if include_sensitivity else "1" + if include_sensitivity: + env[KEEP_TOPOLOGY_ARTIFACTS_ENV] = "1" skip_stages = skip_stages or set() architecture: ArchitectureReport | None = None context = _temporary_env(**env) if env else nullcontext() From efe01bae5eebca8670bb430e7a59899028d5371d Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 20:56:25 -0600 Subject: [PATCH 15/33] fix: normalize gemma4 shared expert lora keys --- .../megatron/model_support/handlers/gemma4.py | 7 ++-- .../megatron/lora/test_lora_disk_codecs.py | 37 +++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index 123580113..d6f2de6ef 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -1108,9 +1108,10 @@ def _wrap_gemma4_attention_output_lora( def _to_vllm_key(key: str) -> str: - key = key.replace(".mlp.shared_expert.", ".mlp.").replace( - ".mlp.experts", - ".moe.experts", + key = ( + key.replace(".mlp.shared_experts.", ".mlp.") + .replace(".mlp.shared_expert.", ".mlp.") + .replace(".mlp.experts", ".moe.experts") ) return _HF_TEXT_EXPERT_KEY_RE.sub(r"\g.moe.experts", key) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index 7bb3e1b94..67006ae0d 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -19,6 +19,7 @@ QWEN3_5_MOE_HANDLER, QWEN3_MOE_HANDLER, ) +from art.megatron.model_support.handlers.gemma4 import GEMMA4_MOE_HANDLER from art.megatron.model_support.lora_disk import ( load_lora_tensors_for_megatron, normalize_lora_checkpoint_to_vllm, @@ -709,6 +710,42 @@ def test_qwen35_vllm_config_preserves_shared_expert_targets_when_present(): _assert_tensors_equal(roundtrip, original) +def test_gemma4_shared_experts_plural_keys_map_to_vllm_dense_mlp(): + art_prefix = "base_model.model.model.layers.0" + hidden_size = 2816 + original = { + f"{art_prefix}.mlp.shared_experts.gate_proj.lora_A.weight": torch.ones( + 2, + hidden_size, + ), + f"{art_prefix}.mlp.shared_experts.gate_proj.lora_B.weight": torch.ones(4, 2), + f"{art_prefix}.mlp.shared_experts.up_proj.lora_A.weight": torch.ones( + 2, + hidden_size, + ), + f"{art_prefix}.mlp.shared_experts.up_proj.lora_B.weight": torch.ones(4, 2), + f"{art_prefix}.mlp.shared_experts.down_proj.lora_A.weight": torch.ones(2, 4), + f"{art_prefix}.mlp.shared_experts.down_proj.lora_B.weight": torch.ones( + hidden_size, + 2, + ), + } + vllm_tensors, _ = GEMMA4_MOE_HANDLER.to_vllm_lora_tensors( + original, + adapter_config=_config("google/gemma-4-26B-A4B-it"), + ) + + assert set(vllm_tensors) == { + f"{art_prefix}.mlp.gate_proj.lora_A.weight", + f"{art_prefix}.mlp.gate_proj.lora_B.weight", + f"{art_prefix}.mlp.up_proj.lora_A.weight", + f"{art_prefix}.mlp.up_proj.lora_B.weight", + f"{art_prefix}.mlp.down_proj.lora_A.weight", + f"{art_prefix}.mlp.down_proj.lora_B.weight", + } + assert not any("shared_expert" in key for key in vllm_tensors) + + def test_qwen35_target_parameter_identity_normalizes_to_fused_vllm_layout( tmp_path: Path, ) -> None: From 32b14a132512884f8b049ea8bc86d813eaa5d364 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:13:11 -0600 Subject: [PATCH 16/33] style: apply megatron formatting --- src/art/megatron/lora.py | 4 +++- .../integration/megatron/model_support/test_compile_flags.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 61633f35f..5c6b06534 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1511,7 +1511,9 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: for lora in (self.gate_lora, self.up_lora): zero = zero + lora.A_T.to(dtype=x.dtype).sum() * 0.0 zero = zero + lora.B_T.to(dtype=x.dtype).sum() * 0.0 - return zero.expand(*x.shape[:-1], self.linear_fc1.out_features).clone(), None + return zero.expand( + *x.shape[:-1], self.linear_fc1.out_features + ).clone(), None base_output, bias_out = self.linear_fc1(x) if isinstance(base_output, tuple): base_out, lora_input = base_output diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index 8dadcd82c..15654fc09 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -17,6 +17,8 @@ "te_triton_permute_with_mask_map", "weighted_bias_swiglu_no_inner_forward_cast", ) + + def test_qwen3_moe_compile_workarounds_cover_deepep_permute_restore() -> None: provider = type("Provider", (), {"context_parallel_size": 1})() config = QWEN3_MOE_HANDLER.compile_workaround_config(provider) From 61e4d2b46f30026c4fce587e50af4d5f1586bdb4 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:21:06 -0600 Subject: [PATCH 17/33] test: cover fused fc1 lora sensitivity --- .../megatron/model_support/oracle_worker.py | 3 ++- .../test_oracle_harness_invariants.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/integration/megatron/model_support/oracle_worker.py b/tests/integration/megatron/model_support/oracle_worker.py index 2a34e985f..945f3dfff 100644 --- a/tests/integration/megatron/model_support/oracle_worker.py +++ b/tests/integration/megatron/model_support/oracle_worker.py @@ -769,7 +769,8 @@ def _matches_grad_sync_skip_mutation( return ".self_attention.linear_proj.lora.B_T" in param_name if mutation == "bwd_skip_sync_fc1_a": return ( - ".mlp.experts.linear_fc1.gate_lora.A_T" in param_name + ".mlp.experts.linear_fc1.lora.A_T" in param_name + or ".mlp.experts.linear_fc1.gate_lora.A_T" in param_name or ".mlp.experts.linear_fc1.up_lora.A_T" in param_name or ".mlp.linear_fc1.gate_lora.A_T" in param_name or ".mlp.linear_fc1.up_lora.A_T" in param_name diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py index 5c3e57503..8623045f9 100644 --- a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -30,6 +30,7 @@ selected_sensitivity_mutations_for_objective, sensitivity_topology_for_mutation, ) +from .oracle_worker import _matches_grad_sync_skip_mutation def _metric_row( @@ -98,6 +99,25 @@ def _expert_trace_call( } +def test_fc1_grad_sync_sensitivity_matches_split_and_fused_lora_names() -> None: + assert _matches_grad_sync_skip_mutation( + "chunk0.module.decoder.layers.0.mlp.experts.linear_fc1.lora.A_T", + "bwd_skip_sync_fc1_a", + ) + assert _matches_grad_sync_skip_mutation( + "chunk0.module.decoder.layers.0.mlp.experts.linear_fc1.gate_lora.A_T", + "bwd_skip_sync_fc1_a", + ) + assert _matches_grad_sync_skip_mutation( + "chunk0.module.decoder.layers.0.mlp.experts.linear_fc1.up_lora.A_T", + "bwd_skip_sync_fc1_a", + ) + assert not _matches_grad_sync_skip_mutation( + "chunk0.module.decoder.layers.0.mlp.experts.linear_fc2.lora.A_T", + "bwd_skip_sync_fc1_a", + ) + + def test_metric_threshold_rule_can_require_strictly_positive_values() -> None: rule = MetricThresholdRule(minimums={"candidate_abs_scale": 0.0}) From 6089a2a68a309877d7283f4f5e5c664d44f38222 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:23:58 -0600 Subject: [PATCH 18/33] test: normalize chat template tool calls --- tests/support/chat_template_conformance_cases.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/support/chat_template_conformance_cases.py b/tests/support/chat_template_conformance_cases.py index 5912c5784..960bc6599 100644 --- a/tests/support/chat_template_conformance_cases.py +++ b/tests/support/chat_template_conformance_cases.py @@ -7,8 +7,11 @@ from pydantic import BaseModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from art.preprocessing.tokenize import _apply_chat_template_token_ids -from art.trajectories import History, Trajectory, TrajectoryGroup, get_messages +from art.preprocessing.tokenize import ( + _apply_chat_template_token_ids, + _messages_for_chat_template, +) +from art.trajectories import History, Trajectory, TrajectoryGroup from art.types import MessagesAndChoices, Tools @@ -121,7 +124,7 @@ def _rendered_ids( ) -> list[int]: return _apply_chat_template_token_ids( tokenizer, - cast(list[dict[str, Any]], get_messages(messages_and_choices)), + _messages_for_chat_template(tokenizer, messages_and_choices), tools=tools, tokenize=True, add_generation_prompt=False, From d0d16816fc43c1aee50d5e60632012aa063494cd Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:27:52 -0600 Subject: [PATCH 19/33] test: strengthen flash lse sensitivity topology --- .../integration/megatron/model_support/oracle_harness.py | 5 +++++ .../model_support/test_oracle_harness_invariants.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/tests/integration/megatron/model_support/oracle_harness.py b/tests/integration/megatron/model_support/oracle_harness.py index d5bf7c581..273ed2abb 100644 --- a/tests/integration/megatron/model_support/oracle_harness.py +++ b/tests/integration/megatron/model_support/oracle_harness.py @@ -245,6 +245,9 @@ def world_size(self) -> int: mutation: CP_ATTENTION_SENSITIVITY_TOPOLOGY for mutation in CP_ATTENTION_SENSITIVITY_MUTATIONS } +SENSITIVITY_TOPOLOGY_BY_MUTATION["attn_skip_flash_lse_normalize"] = Topology( + tp=1, ep=2, etp=1, dp=1, cp=4, sp=False +) SENSITIVITY_TOPOLOGY_BY_MUTATION["bwd_skip_sync_fc1_a"] = Topology( tp=2, ep=1, etp=2, dp=1, sp=True ) @@ -701,6 +704,8 @@ def sensitivity_topology_for_mutation( }: return DENSE_DP_SENSITIVITY_TOPOLOGY if mutation in CP_ATTENTION_SENSITIVITY_MUTATIONS: + if mutation == "attn_skip_flash_lse_normalize": + return Topology(tp=1, ep=1, etp=1, dp=1, cp=4, sp=False) return DENSE_CP_ATTENTION_SENSITIVITY_TOPOLOGY return DENSE_SENSITIVITY_TOPOLOGY return SENSITIVITY_TOPOLOGY_BY_MUTATION[mutation] diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py index 8623045f9..83d6790e9 100644 --- a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -899,6 +899,14 @@ def test_dense_sensitivity_keeps_dp_and_cp_attention_cases() -> None: ) == DENSE_CP_ATTENTION_SENSITIVITY_TOPOLOGY ) + assert sensitivity_topology_for_mutation( + "attn_skip_flash_lse_normalize", + is_moe=False, + ) == Topology(tp=1, ep=1, etp=1, dp=1, cp=4, sp=False) + assert sensitivity_topology_for_mutation( + "attn_skip_flash_lse_normalize", + is_moe=True, + ) == Topology(tp=1, ep=2, etp=1, dp=1, cp=4, sp=False) def test_case_config_base_model_can_be_overridden_by_env( From 5474538e6e12ae90297ed98730aa1b7d27dd8016 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:38:28 -0600 Subject: [PATCH 20/33] fix: preserve gdn tree segment order --- src/art/megatron/gdn/gdn_shared_prefix.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index f4bc02ba7..4eb6698a8 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -1221,14 +1221,11 @@ def _batch_segments_by_padded_work( ) -> tuple[tuple[GdnSegmentSpec, ...], ...]: if not segments: return () - ordered = sorted( - segments, key=lambda segment: (segment.length, segment.family_index) - ) batches: list[list[GdnSegmentSpec]] = [] current: list[GdnSegmentSpec] = [] current_tokens = 0 current_max = 0 - for segment in ordered: + for segment in segments: next_count = len(current) + 1 next_tokens = current_tokens + segment.length next_max = max(current_max, segment.length) From 947033714badd8fa35708ac1eecaa579a25b0be1 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:41:10 -0600 Subject: [PATCH 21/33] Revert "fix: preserve gdn tree segment order" This reverts commit 5474538e6e12ae90297ed98730aa1b7d27dd8016. --- src/art/megatron/gdn/gdn_shared_prefix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index 4eb6698a8..f4bc02ba7 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -1221,11 +1221,14 @@ def _batch_segments_by_padded_work( ) -> tuple[tuple[GdnSegmentSpec, ...], ...]: if not segments: return () + ordered = sorted( + segments, key=lambda segment: (segment.length, segment.family_index) + ) batches: list[list[GdnSegmentSpec]] = [] current: list[GdnSegmentSpec] = [] current_tokens = 0 current_max = 0 - for segment in segments: + for segment in ordered: next_count = len(current) + 1 next_tokens = current_tokens + segment.length next_max = max(current_max, segment.length) From eccdefb13351fb5be25b3769f8529f4183388768 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:45:04 -0600 Subject: [PATCH 22/33] fix: chunk-align local tree gdn forks --- src/art/megatron/gdn/gdn_shared_prefix.py | 249 +++++++++++++++++++++- 1 file changed, 241 insertions(+), 8 deletions(-) diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index f4bc02ba7..ec8334a26 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -92,6 +92,20 @@ class GdnStateExchangePlan: reverse_exchange: Any +@dataclass(frozen=True) +class _ExplicitBucketColumn: + row_index: int + family_index: int + parent_index: int + positions: tuple[int, ...] + output_mask: tuple[bool, ...] + needs_final_state: bool + + @property + def length(self) -> int: + return len(self.positions) + + @dataclass(frozen=True) class GdnPlannerConfig: """Tunable cost coefficients for one packed-row GDN execution plan.""" @@ -347,18 +361,26 @@ def assign_tree(root_index: int) -> None: cross_rank_token_count=cross_rank_token_count, ) local_token_ranges = gdn_ranges_by_rank_by_source[cp_rank] - tree_segment_buckets_by_depth = tuple( - _build_tree_bucket_plans( - tuple(segments_by_rank_depth[cp_rank][depth]), - spec.tree_parent_indices, + if cp_size == 1: + tree_segment_buckets_by_depth = _build_chunk_aligned_cp1_tree_buckets( + spec, tuple(tree_has_children), - local_token_ranges=None if cp_size == 1 else local_token_ranges, - sequence_length=spec.sequence_length, device=device, planner_config=planner_config, ) - for depth in range(depth_count) - ) + else: + tree_segment_buckets_by_depth = tuple( + _build_tree_bucket_plans( + tuple(segments_by_rank_depth[cp_rank][depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges=local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ) + for depth in range(depth_count) + ) tree_chain_buckets_by_depth = ( tuple( _build_tree_bucket_plans( @@ -1074,6 +1096,217 @@ def _build_tree_bucket_plans( ) +def _build_chunk_aligned_cp1_tree_buckets( + spec: GdnPackedExecutionSpec, + tree_has_children: tuple[bool, ...], + *, + device: torch.device | str, + planner_config: GdnPlannerConfig, +) -> tuple[tuple[GdnSegmentBucketPlan, ...], ...]: + depth_count = max(spec.tree_depths, default=0) + 1 + children_by_node: list[list[int]] = [[] for _ in spec.tree_segments] + for node_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index >= 0: + children_by_node[parent_index].append(node_index) + + regular_by_depth: list[list[GdnSegmentSpec]] = [[] for _ in range(depth_count)] + boundary_by_depth: list[list[GdnSegmentSpec]] = [[] for _ in range(depth_count)] + child_columns_by_depth: list[list[_ExplicitBucketColumn]] = [ + [] for _ in range(depth_count) + ] + + for node_index, segment in enumerate(spec.tree_segments): + depth = spec.tree_depths[node_index] + parent_index = spec.tree_parent_indices[node_index] + if parent_index >= 0: + continue + if not tree_has_children[node_index]: + regular_by_depth[depth].append(segment) + continue + boundary_end = _prefix_chunk_boundary_end(segment) + if boundary_end > segment.start: + boundary_by_depth[depth].append( + replace(segment, start=segment.start, end=boundary_end) + ) + + for parent_index, child_indices in enumerate(children_by_node): + if not child_indices: + continue + parent = spec.tree_segments[parent_index] + parent_depth = spec.tree_depths[parent_index] + child_depth = min(parent_depth + 1, depth_count - 1) + parent_tree_parent = spec.tree_parent_indices[parent_index] + if parent_tree_parent < 0: + boundary_end = _prefix_chunk_boundary_end(parent) + tail_positions = tuple(range(boundary_end, parent.end)) + explicit_parent = parent.family_index if boundary_end > parent.start else -1 + else: + tail_positions = () + explicit_parent = parent.family_index + for child_offset, child_index in enumerate(child_indices): + child = spec.tree_segments[child_index] + child_positions = tail_positions + tuple(range(child.start, child.end)) + child_columns_by_depth[child_depth].append( + _ExplicitBucketColumn( + row_index=child.row_index, + family_index=child.family_index, + parent_index=explicit_parent, + positions=child_positions, + output_mask=( + ((child_offset == 0),) * len(tail_positions) + + (True,) * child.length + ), + needs_final_state=tree_has_children[child.family_index], + ) + ) + + return tuple( + ( + *_build_tree_bucket_plans( + tuple(boundary_by_depth[depth]), + spec.tree_parent_indices, + tree_has_children, + local_token_ranges=None, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ), + *_build_tree_bucket_plans( + tuple(regular_by_depth[depth]), + spec.tree_parent_indices, + tree_has_children, + local_token_ranges=None, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + ), + *_build_explicit_bucket_plans( + tuple(child_columns_by_depth[depth]), + device=device, + planner_config=planner_config, + ), + ) + for depth in range(depth_count) + ) + + +def _build_explicit_bucket_plans( + columns: tuple[_ExplicitBucketColumn, ...], + *, + device: torch.device | str, + planner_config: GdnPlannerConfig, +) -> tuple[GdnSegmentBucketPlan, ...]: + stateful = tuple(column for column in columns if column.needs_final_state) + stateless = tuple(column for column in columns if not column.needs_final_state) + return ( + *( + _build_explicit_bucket_plan(batch, needs_final_state=True, device=device) + for batch in _batch_explicit_columns( + stateful, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ), + *( + _build_explicit_bucket_plan(batch, needs_final_state=False, device=device) + for batch in _batch_explicit_columns( + stateless, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ), + ) + + +def _batch_explicit_columns( + columns: tuple[_ExplicitBucketColumn, ...], + *, + max_padding_ratio: float, + max_segments_per_batch: int, +) -> tuple[tuple[_ExplicitBucketColumn, ...], ...]: + if not columns: + return () + batches: list[list[_ExplicitBucketColumn]] = [] + current: list[_ExplicitBucketColumn] = [] + current_tokens = 0 + current_max = 0 + for column in columns: + next_count = len(current) + 1 + next_tokens = current_tokens + column.length + next_max = max(current_max, column.length) + padded = next_max * next_count + can_extend = not current or ( + next_count <= max_segments_per_batch + and padded <= max_padding_ratio * next_tokens + ) + if not can_extend: + batches.append(current) + current = [] + current_tokens = 0 + current_max = 0 + current.append(column) + current_tokens += column.length + current_max = max(current_max, column.length) + if current: + batches.append(current) + return tuple(tuple(batch) for batch in batches) + + +def _build_explicit_bucket_plan( + columns: tuple[_ExplicitBucketColumn, ...], + *, + needs_final_state: bool, + device: torch.device | str, +) -> GdnSegmentBucketPlan: + lengths_cpu = torch.tensor([column.length for column in columns], dtype=torch.long) + max_length = int(lengths_cpu.max().item()) + row_indices_cpu = torch.zeros(max_length, len(columns), dtype=torch.long) + position_indices_cpu = torch.zeros(max_length, len(columns), dtype=torch.long) + output_mask_cpu = torch.zeros(max_length, len(columns), dtype=torch.bool) + for column_index, column in enumerate(columns): + length = column.length + row_indices_cpu[:length, column_index] = column.row_index + position_indices_cpu[:length, column_index] = torch.tensor( + column.positions, dtype=torch.long + ) + output_mask_cpu[:length, column_index] = torch.tensor( + column.output_mask, dtype=torch.bool + ) + plan = _build_bucket_plan( + tuple( + GdnSegmentSpec( + row_index=column.row_index, + family_index=column.family_index, + group_id=column.family_index, + parent_id=column.parent_index, + start=0, + end=column.length, + kind="completion", + ) + for column in columns + ), + lengths_cpu=lengths_cpu, + row_indices_cpu=row_indices_cpu, + position_indices_cpu=position_indices_cpu, + device=device, + ) + parent_indices_cpu = torch.tensor( + [column.parent_index for column in columns], dtype=torch.long + ) + return replace( + plan, + parent_indices=_move_planner_tensor(parent_indices_cpu, device), + parent_indices_cpu=parent_indices_cpu, + output_mask=_move_planner_tensor(output_mask_cpu, device), + needs_final_state=needs_final_state, + ) + + +def _prefix_chunk_boundary_end(segment: GdnSegmentSpec) -> int: + aligned_length = (segment.length // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE + return segment.start + aligned_length + + def _bucket_with_tree_parent_indices( plan: GdnSegmentBucketPlan, segments: tuple[GdnSegmentSpec, ...], From 61dba04bd2e4e0f728d3ef0148525b7804d6ecfd Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 21:51:11 -0600 Subject: [PATCH 23/33] test: cover chunk-aligned tree gdn planning --- tests/unit/test_shared_prefix_tree.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_shared_prefix_tree.py b/tests/unit/test_shared_prefix_tree.py index ce95c4fe1..4191eef77 100644 --- a/tests/unit/test_shared_prefix_tree.py +++ b/tests/unit/test_shared_prefix_tree.py @@ -80,7 +80,18 @@ def test_gdn_tree_parser_accepts_nested_tree() -> None: assert [ sum(bucket.segment_count for bucket in buckets) for buckets in plan.tree_segment_buckets_by_depth - ] == [1, 2, 2] + ] == [0, 2, 2] + first_child_bucket, second_child_bucket = plan.tree_segment_buckets_by_depth[1] + assert first_child_bucket.parent_indices is not None + assert first_child_bucket.parent_indices.tolist() == [-1] + assert first_child_bucket.position_indices.tolist() == [[0], [1], [2]] + assert first_child_bucket.output_mask is not None + assert first_child_bucket.output_mask.tolist() == [[True], [True], [True]] + assert second_child_bucket.parent_indices is not None + assert second_child_bucket.parent_indices.tolist() == [-1] + assert second_child_bucket.position_indices.tolist() == [[0], [5]] + assert second_child_bucket.output_mask is not None + assert second_child_bucket.output_mask.tolist() == [[False], [True]] def test_gdn_tree_parser_accepts_zero_depth_roots() -> None: From e7497a4ca1590e312593609dc09094d4e156eba1 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 22:23:59 -0600 Subject: [PATCH 24/33] test: run flash lse sensitivity under flash backend --- .../megatron/model_support/oracle_harness.py | 103 +++++++++++++----- 1 file changed, 77 insertions(+), 26 deletions(-) diff --git a/tests/integration/megatron/model_support/oracle_harness.py b/tests/integration/megatron/model_support/oracle_harness.py index 273ed2abb..3b7dfade6 100644 --- a/tests/integration/megatron/model_support/oracle_harness.py +++ b/tests/integration/megatron/model_support/oracle_harness.py @@ -2193,12 +2193,6 @@ def run_sensitivity_suite( reports: list[VariantReport] = [] ran_any_variants = False for objective in selected_oracle_objectives(): - runner = VariantRunner( - objective=objective, - case_config=case_config, - oracle_flex_backend=oracle_flex_backend, - variant_flex_backend=variant_flex_backend, - ) objective_mutations = selected_sensitivity_mutations_for_objective( objective, mutations, @@ -2206,29 +2200,86 @@ def run_sensitivity_suite( ) if not objective_mutations: continue - variants = [] - for mutation in objective_mutations: - topology = sensitivity_topology_for_mutation( - mutation, - is_moe=case_config.is_moe, - ) - if max_world_size is not None and topology.world_size() > max_world_size: + for flex_backend, flex_mutations in ( + ( + None, + [ + mutation + for mutation in objective_mutations + if mutation != "attn_skip_flash_lse_normalize" + ], + ), + ( + "FLASH", + [ + mutation + for mutation in objective_mutations + if mutation == "attn_skip_flash_lse_normalize" + ], + ), + ): + if not flex_mutations: continue - variants.append( - VariantSpec( - name=f"{objective}_sensitivity_{mutation}", - objective=objective, - topology=topology, - mutation=mutation, - expected_signal="fail", - pass_fn_by_phase=phase_pass, - flex_backend=variant_flex_backend, + oracle_slug = ( + None + if flex_backend is None + else oracle_output_slug( + objective, + oracle_topology(is_moe=case_config.is_moe), + "flash", ) ) - if not variants: - continue - ran_any_variants = True - reports.extend(runner.run_suite(variants)) + runner = VariantRunner( + objective=objective, + case_config=case_config, + oracle_flex_backend=( + oracle_flex_backend if flex_backend is None else flex_backend + ), + variant_flex_backend=( + variant_flex_backend if flex_backend is None else flex_backend + ), + oracle_slug_override=oracle_slug, + ) + variants = [] + for mutation in flex_mutations: + topology = sensitivity_topology_for_mutation( + mutation, + is_moe=case_config.is_moe, + ) + if ( + max_world_size is not None + and topology.world_size() > max_world_size + ): + continue + variants.append( + VariantSpec( + name=f"{objective}_sensitivity_{mutation}", + objective=objective, + topology=topology, + output_slug=( + None + if flex_backend is None + else oracle_output_slug( + objective, + topology, + f"{mutation}_flash", + ) + ), + reference_slug=oracle_slug, + mutation=mutation, + expected_signal="fail", + pass_fn_by_phase=phase_pass, + flex_backend=( + variant_flex_backend + if flex_backend is None + else flex_backend + ), + ) + ) + if not variants: + continue + ran_any_variants = True + reports.extend(runner.run_suite(variants)) if ran_any_variants: return reports requested = ", ".join(mutations) From 60539206de38c11bd0ad2794e4bf876616909139 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 22:26:28 -0600 Subject: [PATCH 25/33] test: use bf16 for flash lse sensitivity --- .../integration/megatron/model_support/oracle_harness.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/integration/megatron/model_support/oracle_harness.py b/tests/integration/megatron/model_support/oracle_harness.py index 3b7dfade6..9665f5c4e 100644 --- a/tests/integration/megatron/model_support/oracle_harness.py +++ b/tests/integration/megatron/model_support/oracle_harness.py @@ -2229,9 +2229,14 @@ def run_sensitivity_suite( "flash", ) ) + runner_case_config = ( + case_config + if flex_backend is None or case_config.precision == "bf16" + else case_config.model_copy(update={"precision": "bf16"}) + ) runner = VariantRunner( objective=objective, - case_config=case_config, + case_config=runner_case_config, oracle_flex_backend=( oracle_flex_backend if flex_backend is None else flex_backend ), @@ -2244,7 +2249,7 @@ def run_sensitivity_suite( for mutation in flex_mutations: topology = sensitivity_topology_for_mutation( mutation, - is_moe=case_config.is_moe, + is_moe=runner_case_config.is_moe, ) if ( max_world_size is not None From c72552d730ecd143a33c5013283021cc2d0c1e10 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 22:47:02 -0600 Subject: [PATCH 26/33] test: stabilize qwen35 length trainability --- tests/integration/megatron/trainability/test_config.py | 6 ++++++ .../trainability/test_live_length_trainability.py | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/integration/megatron/trainability/test_config.py b/tests/integration/megatron/trainability/test_config.py index b265c2b3d..f41f2f6de 100644 --- a/tests/integration/megatron/trainability/test_config.py +++ b/tests/integration/megatron/trainability/test_config.py @@ -7,6 +7,7 @@ import art +from .test_live_length_trainability import _default_learning_rate from .yes_no_trainability import ( _build_internal_config, _build_variant, @@ -153,6 +154,11 @@ def test_qwen3_5_defaults_to_shared_lora_rollout() -> None: assert "inference_gpu_ids" not in config +def test_qwen3_5_length_trainability_uses_stable_learning_rate() -> None: + assert _default_learning_rate("Qwen/Qwen3.5-35B-A3B") == 5e-5 + assert _default_learning_rate("Qwen/Qwen3-30B-A3B-Instruct-2507") == 1e-4 + + def test_validated_dense_model_uses_dense_shared_topology( monkeypatch, ) -> None: diff --git a/tests/integration/megatron/trainability/test_live_length_trainability.py b/tests/integration/megatron/trainability/test_live_length_trainability.py index 32a8270bf..0a250fb21 100644 --- a/tests/integration/megatron/trainability/test_live_length_trainability.py +++ b/tests/integration/megatron/trainability/test_live_length_trainability.py @@ -32,6 +32,8 @@ torch = pytest.importorskip("torch") DEFAULT_BASE_MODEL = "Qwen/Qwen3.5-35B-A3B" +DEFAULT_LENGTH_LEARNING_RATE = 1e-4 +LARGE_MOE_LENGTH_LEARNING_RATE = 5e-5 LIVE_ENV = "ART_RUN_LIVE_LENGTH_TRAINABILITY" TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" @@ -189,6 +191,12 @@ def _target_tokens() -> int: return _get_env_int("ART_MODEL_SUPPORT_LENGTH_TARGET_TOKENS", 10) +def _default_learning_rate(base_model: str) -> float: + if base_model == DEFAULT_BASE_MODEL: + return LARGE_MOE_LENGTH_LEARNING_RATE + return DEFAULT_LENGTH_LEARNING_RATE + + def _use_default_moe_dedicated_placement(variant: Any, *, base_model: str) -> None: if not model_uses_expert_parallel(base_model, allow_unvalidated_arch=True): return @@ -578,7 +586,7 @@ async def rollout_fn( max_steps_off_policy=max_steps_off_policy, learning_rate=_get_env_float( "ART_MODEL_SUPPORT_LENGTH_LEARNING_RATE", - 1e-4, + _default_learning_rate(base_model), ), loss_fn="cispo", normalize_advantages=normalize_advantages, From c104a2c5c1824e7872e58c1b86a0578ee57e89eb Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 22:57:45 -0600 Subject: [PATCH 27/33] test: tune qwen35 length learning rate --- tests/integration/megatron/trainability/test_config.py | 2 +- .../megatron/trainability/test_live_length_trainability.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/megatron/trainability/test_config.py b/tests/integration/megatron/trainability/test_config.py index f41f2f6de..a055461e1 100644 --- a/tests/integration/megatron/trainability/test_config.py +++ b/tests/integration/megatron/trainability/test_config.py @@ -155,7 +155,7 @@ def test_qwen3_5_defaults_to_shared_lora_rollout() -> None: def test_qwen3_5_length_trainability_uses_stable_learning_rate() -> None: - assert _default_learning_rate("Qwen/Qwen3.5-35B-A3B") == 5e-5 + assert _default_learning_rate("Qwen/Qwen3.5-35B-A3B") == 7e-5 assert _default_learning_rate("Qwen/Qwen3-30B-A3B-Instruct-2507") == 1e-4 diff --git a/tests/integration/megatron/trainability/test_live_length_trainability.py b/tests/integration/megatron/trainability/test_live_length_trainability.py index 0a250fb21..ace4a19c1 100644 --- a/tests/integration/megatron/trainability/test_live_length_trainability.py +++ b/tests/integration/megatron/trainability/test_live_length_trainability.py @@ -33,7 +33,7 @@ DEFAULT_BASE_MODEL = "Qwen/Qwen3.5-35B-A3B" DEFAULT_LENGTH_LEARNING_RATE = 1e-4 -LARGE_MOE_LENGTH_LEARNING_RATE = 5e-5 +LARGE_MOE_LENGTH_LEARNING_RATE = 7e-5 LIVE_ENV = "ART_RUN_LIVE_LENGTH_TRAINABILITY" TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" From 4ebb865d30386bccbda6c3b55c85fc42d65d8e1f Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 26 Jun 2026 23:48:02 -0600 Subject: [PATCH 28/33] test: lower train-inf vllm memory reservation --- .../megatron/train_inf_mismatch/output_parity.py | 4 ++++ .../test_output_parity_invariants.py | 11 +++++++++++ .../megatron/train_inf_mismatch/workflow_stage.py | 1 + 3 files changed, 16 insertions(+) diff --git a/tests/integration/megatron/train_inf_mismatch/output_parity.py b/tests/integration/megatron/train_inf_mismatch/output_parity.py index 20ef99b4d..92f6e7af8 100644 --- a/tests/integration/megatron/train_inf_mismatch/output_parity.py +++ b/tests/integration/megatron/train_inf_mismatch/output_parity.py @@ -344,6 +344,10 @@ def config_from_env() -> TrainInfOutputParityConfig: config.topology = config.topology.model_copy(update={"ep": 1, "etp": 1}) if raw_targets := os.environ.get("ART_TRAIN_INF_MISMATCH_LORA_TARGET_MODULES"): config.lora_target_modules = _parse_str_list(raw_targets) + if raw_vllm_memory := os.environ.get( + "ART_TRAIN_INF_MISMATCH_VLLM_GPU_MEMORY_UTILIZATION" + ): + config.engine_args["gpu_memory_utilization"] = float(raw_vllm_memory) return config diff --git a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py index 39c8ede0a..9f53f6c6b 100644 --- a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py +++ b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py @@ -260,6 +260,16 @@ def test_config_from_env_accepts_lora_target_module_override( assert config.lora_target_modules == ["experts", "in_proj_qkv", "in_proj_z"] +def test_config_from_env_accepts_vllm_memory_utilization_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("ART_TRAIN_INF_MISMATCH_VLLM_GPU_MEMORY_UTILIZATION", "0.5") + + config = config_from_env() + + assert config.engine_args["gpu_memory_utilization"] == 0.5 + + def test_default_rollout_modes_follow_model_support_native_lora_status() -> None: assert TrainInfOutputParityConfig( base_model="Qwen/Qwen3.5-35B-A3B" @@ -317,3 +327,4 @@ def fake_run(*args, **kwargs): assert captured_env["ART_RUN_TRAIN_INF_MISMATCH_LIVE"] == "1" assert captured_env["ART_TRAIN_INF_MISMATCH_ALLOW_UNVALIDATED_ARCH"] == "1" assert captured_env["ART_REAL_PATH_MAX_COMPLETION_TOKENS"] == "16" + assert captured_env["ART_TRAIN_INF_MISMATCH_VLLM_GPU_MEMORY_UTILIZATION"] == "0.50" diff --git a/tests/integration/megatron/train_inf_mismatch/workflow_stage.py b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py index ae0a7cef2..12e449887 100644 --- a/tests/integration/megatron/train_inf_mismatch/workflow_stage.py +++ b/tests/integration/megatron/train_inf_mismatch/workflow_stage.py @@ -76,6 +76,7 @@ def run_train_inf_mismatch( "1" if allow_unvalidated_arch else "0" ) env["ART_REAL_PATH_MAX_COMPLETION_TOKENS"] = "16" + env.setdefault("ART_TRAIN_INF_MISMATCH_VLLM_GPU_MEMORY_UTILIZATION", "0.50") existing_pythonpath = env.get("PYTHONPATH") tests_dir = str(REPO_ROOT / "tests") env["PYTHONPATH"] = ( From c39a3793b4ab671e21d8855e7624195815a5e387 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 00:21:42 -0600 Subject: [PATCH 29/33] fix: generalize gemma4 flex attention workaround --- .../megatron/model_support/handlers/gemma4.py | 39 ++++++++----------- .../model_support/test_compile_flags.py | 32 +++++++++++++++ 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index d6f2de6ef..bfeacdd88 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -46,11 +46,6 @@ "flex_token_dispatch_combine", "te_triton_permute_with_mask_map", ) -_GEMMA4_TRITON_NUM_STAGES_2_SIGNATURES = { - # google/gemma-4-31B-it: Triton flex attention raises "No valid triton - # configs" for global attention head_dim=512 with backend-only options. - ("dense", 60, 5376, 32, 256, 512, 4), -} _ART_MOE_EXPERT_KEY_RE = re.compile( r"^(?P.*\.mlp\.experts)\.(?P\d+)\." r"(?Pgate_up_proj|down_proj)\.(?Plora_[AB])\.weight$" @@ -349,6 +344,12 @@ def compile_workaround_config( disable_compile=False, ) + def flex_attention_compile_crash_config( + self, + provider: Any, + ) -> FlexAttentionCompileCrashConfig: + return _gemma4_flex_attention_compile_crash_config(provider) + GEMMA4_MOE_HANDLER = Gemma4MoeHandler() @@ -537,14 +538,7 @@ def flex_attention_compile_crash_config( self, provider: Any, ) -> FlexAttentionCompileCrashConfig: - if ( - _gemma4_compile_crash_signature(provider) - in _GEMMA4_TRITON_NUM_STAGES_2_SIGNATURES - ): - return FlexAttentionCompileCrashConfig( - triton_num_stages_2_head_dims=(int(provider.global_head_dim),) - ) - return FlexAttentionCompileCrashConfig() + return _gemma4_flex_attention_compile_crash_config(provider) def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: return _gemma4_forward_kwargs(model, **kwargs) @@ -975,16 +969,15 @@ def _gemma4_attention_pattern(provider: Any) -> tuple[int, int]: return (int(pattern[0]), int(pattern[1])) -def _gemma4_compile_crash_signature(provider: Any) -> tuple[Any, ...]: - return ( - "moe" if int(getattr(provider, "num_moe_experts", 0) or 0) > 0 else "dense", - int(provider.num_layers), - int(provider.hidden_size), - int(provider.num_attention_heads), - int(provider.kv_channels), - int(getattr(provider, "global_head_dim", 0) or 0), - int(getattr(provider, "num_global_key_value_heads", 0) or 0), - ) +def _gemma4_flex_attention_compile_crash_config( + provider: Any, +) -> FlexAttentionCompileCrashConfig: + global_head_dim = int(getattr(provider, "global_head_dim", 0) or 0) + if global_head_dim > 256: + return FlexAttentionCompileCrashConfig( + triton_num_stages_2_head_dims=(global_head_dim,) + ) + return FlexAttentionCompileCrashConfig() def _is_gemma4_global_layer(layer_number: int, provider: Any) -> bool: diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index 15654fc09..70641d6c7 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -1,3 +1,7 @@ +from art.megatron.model_support.handlers.gemma4 import ( + GEMMA4_DENSE_HANDLER, + GEMMA4_MOE_HANDLER, +) from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER @@ -31,3 +35,31 @@ def test_qwen35_moe_compile_workarounds_cover_deepep_permute_restore() -> None: config = QWEN3_5_MOE_HANDLER.compile_workaround_config(provider) assert config.flags == _QWEN35_MOE_COMPILE_FLAGS assert config.unconditional_flags == () + + +def test_gemma4_wide_global_attention_uses_lower_triton_stage_count() -> None: + provider = type("Provider", (), {"global_head_dim": 512})() + + assert GEMMA4_DENSE_HANDLER.flex_attention_compile_crash_config( + provider + ).triton_num_stages_2_head_dims == (512,) + assert GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( + provider + ).triton_num_stages_2_head_dims == (512,) + + +def test_gemma4_standard_global_attention_keeps_default_triton_stage_count() -> None: + provider = type("Provider", (), {"global_head_dim": 256})() + + assert ( + GEMMA4_DENSE_HANDLER.flex_attention_compile_crash_config( + provider + ).triton_num_stages_2_head_dims + == () + ) + assert ( + GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( + provider + ).triton_num_stages_2_head_dims + == () + ) From c8cfaf172ba399701186f8726fdc49717e473b2c Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 01:00:43 -0600 Subject: [PATCH 30/33] fix: remove gemma4 shared expert lora rescale --- .../megatron/model_support/handlers/gemma4.py | 134 +----------------- 1 file changed, 5 insertions(+), 129 deletions(-) diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index bfeacdd88..e13650f48 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -62,10 +62,6 @@ r"(?P\.mlp)\.(?Pgate_proj|up_proj|down_proj)\." r"(?Plora_[AB])\.weight$" ) -_SHARED_EXPERT_FC1_LORA_A_KEY_RE = re.compile( - r"^.*\.layers\.(?P\d+)\.mlp\.(?:shared_expert\.)?" - r"(?:gate_proj|up_proj)\.lora_A\.weight$" -) _SELF_ATTN_V_LORA_KEY_RE = re.compile( r"^(?P.*\.layers\.(?P\d+)\.self_attn\.)v_proj\." r"(?Plora_[AB]\.weight)$" @@ -1164,101 +1160,6 @@ def _gemma4_k_eq_v_layers(adapter_config: dict[str, Any]) -> set[int]: } -def _gemma4_hf_file(base_model_name_or_path: str, filename: str) -> Path: - base_path = Path(base_model_name_or_path) - if base_path.exists(): - return base_path / filename - from huggingface_hub import hf_hub_download - - return Path( - hf_hub_download( - base_model_name_or_path, - filename, - local_files_only=True, - ) - ) - - -@lru_cache(maxsize=8) -def _gemma4_shared_expert_prenorm_corrections( - base_model_name_or_path: str, -) -> tuple[torch.Tensor, ...]: - from safetensors import safe_open - - index = json.loads( - _gemma4_hf_file( - base_model_name_or_path, - "model.safetensors.index.json", - ).read_text(encoding="utf-8") - ) - weight_map = dict(index["weight_map"]) - text_config = _gemma4_text_config_dict(base_model_name_or_path) - num_layers = int(text_config["num_hidden_layers"]) - norm_keys_by_file: dict[str, list[tuple[int, str, str]]] = {} - - for layer in range(num_layers): - for suffix in ( - "pre_feedforward_layernorm", - "pre_feedforward_layernorm_2", - ): - candidates = ( - f"model.language_model.layers.{layer}.{suffix}.weight", - f"model.layers.{layer}.{suffix}.weight", - ) - key = next(candidate for candidate in candidates if candidate in weight_map) - norm_keys_by_file.setdefault(weight_map[key], []).append( - (layer, suffix, key) - ) - norm_weights: dict[tuple[int, str], torch.Tensor] = {} - for filename, entries in norm_keys_by_file.items(): - with safe_open( - _gemma4_hf_file(base_model_name_or_path, filename), - framework="pt", - device="cpu", - ) as handle: - for layer, suffix, key in entries: - norm_weights[(layer, suffix)] = handle.get_tensor(key).float() - - return tuple( - norm_weights[(layer, "pre_feedforward_layernorm")] - / norm_weights[(layer, "pre_feedforward_layernorm_2")] - for layer in range(num_layers) - ) - - -def _shared_expert_fc1_prenorm_correction( - *, - adapter_config: dict[str, Any], - layer: int, - device: torch.device, -) -> torch.Tensor: - # Megatron Bridge folds pffl/pffl2 into shared-expert FC1 base weights because - # MCore feeds pffl2-normalized activations while HF/vLLM feeds pffl-normalized - # activations. LoRA-A needs the same basis change at the HF/vLLM boundary. - return _gemma4_shared_expert_prenorm_corrections( - str(adapter_config["base_model_name_or_path"]) - )[layer].to(device=device) - - -def _rescale_shared_expert_fc1_lora_a( - key: str, - tensor: torch.Tensor, - *, - adapter_config: dict[str, Any], - to_vllm: bool, -) -> torch.Tensor: - match = _SHARED_EXPERT_FC1_LORA_A_KEY_RE.match(key) - if match is None: - return tensor - correction = _shared_expert_fc1_prenorm_correction( - adapter_config=adapter_config, - layer=int(match.group("layer")), - device=tensor.device, - ) - factor = correction.reciprocal() if to_vllm else correction - return (tensor.float() * factor.unsqueeze(0)).to(tensor.dtype).contiguous() - - def _drop_gemma4_k_eq_v_v_lora_tensors( tensors: dict[str, torch.Tensor], *, @@ -1328,12 +1229,7 @@ def _to_vllm_lora_tensors( grouped = _group_art_moe_tensors(tensors) if not grouped: transformed = { - vllm_key: _rescale_shared_expert_fc1_lora_a( - vllm_key, - tensor, - adapter_config=adapter_config, - to_vllm=True, - ) + vllm_key: tensor for key, tensor in tensors.items() for vllm_key in (_to_vllm_key(key),) } @@ -1396,12 +1292,7 @@ def _to_vllm_lora_tensors( raise RuntimeError( f"Duplicate Gemma 4 LoRA tensor after conversion: {vllm_key}" ) - transformed[vllm_key] = _rescale_shared_expert_fc1_lora_a( - vllm_key, - tensor, - adapter_config=adapter_config, - to_vllm=True, - ) + transformed[vllm_key] = tensor transformed = _add_gemma4_k_eq_v_v_lora_tensors( transformed, adapter_config=adapter_config, @@ -1445,12 +1336,7 @@ def _from_vllm_lora_tensors( if not grouped: return _drop_gemma4_k_eq_v_v_lora_tensors( { - art_key: _rescale_shared_expert_fc1_lora_a( - art_key, - tensor, - adapter_config=adapter_config, - to_vllm=False, - ) + art_key: tensor for key, tensor in tensors.items() for art_key in (_from_vllm_key(key),) }, @@ -1517,12 +1403,7 @@ def _from_vllm_lora_tensors( raise RuntimeError( f"Duplicate Gemma 4 LoRA tensor after conversion: {art_key}" ) - transformed[art_key] = _rescale_shared_expert_fc1_lora_a( - art_key, - tensor, - adapter_config=adapter_config, - to_vllm=False, - ) + transformed[art_key] = tensor return _drop_gemma4_k_eq_v_v_lora_tensors( transformed, adapter_config=adapter_config, @@ -1579,12 +1460,7 @@ def _from_vllm_per_expert_lora_tensors( "Mixed fused and per-expert Gemma 4 vLLM MoE LoRA tensors" ) art_key = _from_vllm_key(key) - transformed[art_key] = _rescale_shared_expert_fc1_lora_a( - art_key, - tensor, - adapter_config=adapter_config, - to_vllm=False, - ) + transformed[art_key] = tensor return transformed From 31a7954a522d8f70975c6da09b160b0ad0a1a2e7 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 01:21:10 -0600 Subject: [PATCH 31/33] fix: align gemma4 shared expert lora conversion --- .../megatron/model_support/handlers/gemma4.py | 120 +++++++++++++++++- .../megatron/lora/test_lora_disk_codecs.py | 48 ++++++- 2 files changed, 159 insertions(+), 9 deletions(-) diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index e13650f48..b9e7a795a 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -62,6 +62,10 @@ r"(?P\.mlp)\.(?Pgate_proj|up_proj|down_proj)\." r"(?Plora_[AB])\.weight$" ) +_SHARED_EXPERT_FC1_LORA_A_KEY_RE = re.compile( + r"^.*\.layers\.(?P\d+)\.mlp\.(?:shared_experts\.)?" + r"(?:gate_proj|up_proj)\.lora_A\.weight$" +) _SELF_ATTN_V_LORA_KEY_RE = re.compile( r"^(?P.*\.layers\.(?P\d+)\.self_attn\.)v_proj\." r"(?Plora_[AB]\.weight)$" @@ -1108,7 +1112,7 @@ def _to_vllm_key(key: str) -> str: def _from_vllm_key(key: str) -> str: key = key.replace(".moe.experts", ".mlp.experts") return _DENSE_MLP_LORA_KEY_RE.sub( - r"\g.shared_expert.\g.\g.weight", + r"\g.shared_experts.\g.\g.weight", key, ) @@ -1160,6 +1164,85 @@ def _gemma4_k_eq_v_layers(adapter_config: dict[str, Any]) -> set[int]: } +def _gemma4_hf_file(base_model_name_or_path: str, filename: str) -> Path: + base_path = Path(base_model_name_or_path) + if base_path.exists(): + return base_path / filename + from huggingface_hub import hf_hub_download + + return Path( + hf_hub_download( + base_model_name_or_path, + filename, + local_files_only=True, + ) + ) + + +@lru_cache(maxsize=8) +def _gemma4_shared_expert_prenorm_corrections( + base_model_name_or_path: str, +) -> tuple[torch.Tensor, ...]: + from safetensors import safe_open + + index = json.loads( + _gemma4_hf_file( + base_model_name_or_path, + "model.safetensors.index.json", + ).read_text(encoding="utf-8") + ) + weight_map = dict(index["weight_map"]) + text_config = _gemma4_text_config_dict(base_model_name_or_path) + norm_keys_by_file: dict[str, list[tuple[int, str, str]]] = {} + for layer in range(int(text_config["num_hidden_layers"])): + for suffix in ( + "pre_feedforward_layernorm", + "pre_feedforward_layernorm_2", + ): + candidates = ( + f"model.language_model.layers.{layer}.{suffix}.weight", + f"model.layers.{layer}.{suffix}.weight", + ) + key = next(candidate for candidate in candidates if candidate in weight_map) + norm_keys_by_file.setdefault(weight_map[key], []).append( + (layer, suffix, key) + ) + + norm_weights: dict[tuple[int, str], torch.Tensor] = {} + for filename, entries in norm_keys_by_file.items(): + with safe_open( + _gemma4_hf_file(base_model_name_or_path, filename), + framework="pt", + device="cpu", + ) as handle: + for layer, suffix, key in entries: + norm_weights[(layer, suffix)] = handle.get_tensor(key).float() + + return tuple( + norm_weights[(layer, "pre_feedforward_layernorm")] + / norm_weights[(layer, "pre_feedforward_layernorm_2")] + for layer in range(int(text_config["num_hidden_layers"])) + ) + + +def _rescale_shared_expert_fc1_lora_a( + key: str, + tensor: torch.Tensor, + *, + adapter_config: dict[str, Any], + to_vllm: bool, +) -> torch.Tensor: + match = _SHARED_EXPERT_FC1_LORA_A_KEY_RE.match(key) + if match is None: + return tensor + corrections = _gemma4_shared_expert_prenorm_corrections( + str(adapter_config["base_model_name_or_path"]) + ) + correction = corrections[int(match.group("layer"))].to(device=tensor.device) + factor = correction.reciprocal() if to_vllm else correction + return (tensor.float() * factor.unsqueeze(0)).to(tensor.dtype).contiguous() + + def _drop_gemma4_k_eq_v_v_lora_tensors( tensors: dict[str, torch.Tensor], *, @@ -1229,7 +1312,12 @@ def _to_vllm_lora_tensors( grouped = _group_art_moe_tensors(tensors) if not grouped: transformed = { - vllm_key: tensor + vllm_key: _rescale_shared_expert_fc1_lora_a( + vllm_key, + tensor, + adapter_config=adapter_config, + to_vllm=True, + ) for key, tensor in tensors.items() for vllm_key in (_to_vllm_key(key),) } @@ -1292,7 +1380,12 @@ def _to_vllm_lora_tensors( raise RuntimeError( f"Duplicate Gemma 4 LoRA tensor after conversion: {vllm_key}" ) - transformed[vllm_key] = tensor + transformed[vllm_key] = _rescale_shared_expert_fc1_lora_a( + vllm_key, + tensor, + adapter_config=adapter_config, + to_vllm=True, + ) transformed = _add_gemma4_k_eq_v_v_lora_tensors( transformed, adapter_config=adapter_config, @@ -1336,7 +1429,12 @@ def _from_vllm_lora_tensors( if not grouped: return _drop_gemma4_k_eq_v_v_lora_tensors( { - art_key: tensor + art_key: _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) for key, tensor in tensors.items() for art_key in (_from_vllm_key(key),) }, @@ -1403,7 +1501,12 @@ def _from_vllm_lora_tensors( raise RuntimeError( f"Duplicate Gemma 4 LoRA tensor after conversion: {art_key}" ) - transformed[art_key] = tensor + transformed[art_key] = _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) return _drop_gemma4_k_eq_v_v_lora_tensors( transformed, adapter_config=adapter_config, @@ -1460,7 +1563,12 @@ def _from_vllm_per_expert_lora_tensors( "Mixed fused and per-expert Gemma 4 vLLM MoE LoRA tensors" ) art_key = _from_vllm_key(key) - transformed[art_key] = tensor + transformed[art_key] = _rescale_shared_expert_fc1_lora_a( + art_key, + tensor, + adapter_config=adapter_config, + to_vllm=False, + ) return transformed diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index 67006ae0d..8b8664d9e 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -710,9 +710,41 @@ def test_qwen35_vllm_config_preserves_shared_expert_targets_when_present(): _assert_tensors_equal(roundtrip, original) -def test_gemma4_shared_experts_plural_keys_map_to_vllm_dense_mlp(): +def test_gemma4_shared_experts_plural_keys_map_to_vllm_dense_mlp(tmp_path: Path): art_prefix = "base_model.model.model.layers.0" - hidden_size = 2816 + hidden_size = 3 + model_dir = tmp_path / "gemma4" + model_dir.mkdir() + (model_dir / "config.json").write_text( + json.dumps({"num_hidden_layers": 1}), + encoding="utf-8", + ) + save_file( + { + "model.layers.0.pre_feedforward_layernorm.weight": torch.tensor( + [2.0, 4.0, 8.0] + ), + "model.layers.0.pre_feedforward_layernorm_2.weight": torch.tensor( + [1.0, 2.0, 4.0] + ), + }, + model_dir / "model-00001-of-00001.safetensors", + ) + (model_dir / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "model.layers.0.pre_feedforward_layernorm.weight": ( + "model-00001-of-00001.safetensors" + ), + "model.layers.0.pre_feedforward_layernorm_2.weight": ( + "model-00001-of-00001.safetensors" + ), + } + } + ), + encoding="utf-8", + ) original = { f"{art_prefix}.mlp.shared_experts.gate_proj.lora_A.weight": torch.ones( 2, @@ -730,9 +762,10 @@ def test_gemma4_shared_experts_plural_keys_map_to_vllm_dense_mlp(): 2, ), } + adapter_config = _config(str(model_dir)) vllm_tensors, _ = GEMMA4_MOE_HANDLER.to_vllm_lora_tensors( original, - adapter_config=_config("google/gemma-4-26B-A4B-it"), + adapter_config=adapter_config, ) assert set(vllm_tensors) == { @@ -744,6 +777,15 @@ def test_gemma4_shared_experts_plural_keys_map_to_vllm_dense_mlp(): f"{art_prefix}.mlp.down_proj.lora_B.weight", } assert not any("shared_expert" in key for key in vllm_tensors) + assert torch.equal( + vllm_tensors[f"{art_prefix}.mlp.gate_proj.lora_A.weight"], + torch.full((2, hidden_size), 0.5), + ) + roundtrip = GEMMA4_MOE_HANDLER.from_vllm_lora_tensors( + vllm_tensors, + adapter_config=adapter_config, + ) + _assert_tensors_equal(roundtrip, original) def test_qwen35_target_parameter_identity_normalizes_to_fused_vllm_layout( From e8d17f24bd783b5302bdb173957ab74227d10aaa Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 02:52:52 -0600 Subject: [PATCH 32/33] test: expose gdn prefill backend for train-inf validation --- .../megatron/train_inf_mismatch/output_parity.py | 4 ++++ .../test_output_parity_invariants.py | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/tests/integration/megatron/train_inf_mismatch/output_parity.py b/tests/integration/megatron/train_inf_mismatch/output_parity.py index 92f6e7af8..dcebc9896 100644 --- a/tests/integration/megatron/train_inf_mismatch/output_parity.py +++ b/tests/integration/megatron/train_inf_mismatch/output_parity.py @@ -348,6 +348,10 @@ def config_from_env() -> TrainInfOutputParityConfig: "ART_TRAIN_INF_MISMATCH_VLLM_GPU_MEMORY_UTILIZATION" ): config.engine_args["gpu_memory_utilization"] = float(raw_vllm_memory) + if raw_gdn_backend := os.environ.get("ART_TRAIN_INF_MISMATCH_GDN_PREFILL_BACKEND"): + config.engine_args.setdefault("additional_config", {})[ + "gdn_prefill_backend" + ] = raw_gdn_backend return config diff --git a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py index 9f53f6c6b..f891a19c9 100644 --- a/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py +++ b/tests/integration/megatron/train_inf_mismatch/test_output_parity_invariants.py @@ -270,6 +270,16 @@ def test_config_from_env_accepts_vllm_memory_utilization_override( assert config.engine_args["gpu_memory_utilization"] == 0.5 +def test_config_from_env_accepts_gdn_prefill_backend_override( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("ART_TRAIN_INF_MISMATCH_GDN_PREFILL_BACKEND", "triton") + + config = config_from_env() + + assert config.engine_args["additional_config"] == {"gdn_prefill_backend": "triton"} + + def test_default_rollout_modes_follow_model_support_native_lora_status() -> None: assert TrainInfOutputParityConfig( base_model="Qwen/Qwen3.5-35B-A3B" From 83266585f1b278fe2a36908e8c0765ca6ae61d25 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Sat, 27 Jun 2026 04:32:26 -0600 Subject: [PATCH 33/33] test: stabilize model support workflow probes --- .../megatron/model_support/test_workflow.py | 14 +- .../megatron/model_support/workflow.py | 23 +-- .../megatron/trainability/test_config.py | 136 +++++++++++++++++- .../test_live_length_trainability.py | 19 ++- .../test_live_yes_no_trainability.py | 12 +- .../trainability/yes_no_trainability.py | 44 +++++- 6 files changed, 219 insertions(+), 29 deletions(-) diff --git a/tests/integration/megatron/model_support/test_workflow.py b/tests/integration/megatron/model_support/test_workflow.py index 30f16774d..6640b36fd 100644 --- a/tests/integration/megatron/model_support/test_workflow.py +++ b/tests/integration/megatron/model_support/test_workflow.py @@ -10,6 +10,7 @@ from .validation_spec import ValidationReport, ValidationStageResult from .workflow import ( + INCLUDE_FLASH_SENSITIVITY_ENV, KEEP_TOPOLOGY_ARTIFACTS_ENV, MANDATORY_VALIDATION_STAGES, NATIVE_VLLM_LORA_STAGE, @@ -34,6 +35,7 @@ @pytest.fixture(autouse=True) def _stub_pinned_git_state(monkeypatch) -> None: + monkeypatch.delenv(INCLUDE_FLASH_SENSITIVITY_ENV, raising=False) monkeypatch.setattr( "tests.integration.megatron.model_support.workflow.pinned_git_state", lambda suite_name: SimpleNamespace( @@ -698,6 +700,9 @@ def test_run_correctness_sensitivity_stage_runs_dense_models(monkeypatch) -> Non assert result.metrics["correctness_variant_count"] == 1 assert result.metrics["correctness_excluded_topologies"] == [] assert result.metrics["sensitivity_mutations"] == ["skip_finalize"] + assert result.metrics["default_excluded_sensitivity_mutations"] == [ + "attn_skip_flash_lse_normalize" + ] assert case_configs[0].is_moe is False @@ -721,7 +726,10 @@ def test_run_yes_no_trainability_stage(monkeypatch) -> None: "saturated_step": 2, }, ) - ) + ), + yes_no_trainability_passed=lambda report: ( + report.final_eval_reward >= report.reward_threshold + ), ), ) @@ -1043,6 +1051,9 @@ def test_run_correctness_sensitivity_stage_summarizes_reports(monkeypatch) -> No assert stage.metrics["is_moe"] is True assert stage.metrics["objectives"] == ["sft"] assert stage.metrics["sensitivity_mutations"] == ["skip_finalize"] + assert stage.metrics["default_excluded_sensitivity_mutations"] == [ + "attn_skip_flash_lse_normalize" + ] assert stage.metrics["available_gpu_count"] == 2 assert stage.metrics["required_gpu_count"] == 1 assert stage.metrics["correctness_variant_count"] == 1 @@ -1109,6 +1120,7 @@ def test_run_correctness_sensitivity_stage_can_skip_sensitivity_only( assert stage.metrics["required_gpu_count"] == 1 assert stage.metrics["correctness_variant_count"] == 1 assert stage.metrics["sensitivity_mutations"] == [] + assert stage.metrics["default_excluded_sensitivity_mutations"] == [] assert stage.metrics["sensitivity_skipped"] is True assert stage.metrics["sensitivity_skip_reason"] == f"{SKIP_SENSITIVITY_ENV}=1" assert stage.metrics["sensitivity_variant_count"] == 0 diff --git a/tests/integration/megatron/model_support/workflow.py b/tests/integration/megatron/model_support/workflow.py index 9e1d7e1a2..3029278bc 100644 --- a/tests/integration/megatron/model_support/workflow.py +++ b/tests/integration/megatron/model_support/workflow.py @@ -37,8 +37,10 @@ LIVE_TRAINING_LOG_PATH = LOCAL_LOG_DIR / "live_training.log" ORACLE_LIVE_TRAINING_LOG_ENV = "ART_ORACLE_LIVE_TRAINING_LOG" SKIP_SENSITIVITY_ENV = "ART_MODEL_SUPPORT_SKIP_SENSITIVITY" +INCLUDE_FLASH_SENSITIVITY_ENV = "ART_MODEL_SUPPORT_INCLUDE_FLASH_SENSITIVITY" KEEP_TOPOLOGY_ARTIFACTS_ENV = "ART_ORACLE_KEEP_TOPOLOGY_ARTIFACTS" WORKFLOW_ARTIFACT_SUITE_NAME = "Megatron model-support validation workflow" +FLASH_SENSITIVITY_MUTATION = "attn_skip_flash_lse_normalize" MANDATORY_VALIDATION_STAGES = ( "dependency_resolution", @@ -490,6 +492,7 @@ def run_correctness_sensitivity_stage( if topology.world_size() > max_world_size ] mutations: list[str] = [] + default_excluded_sensitivity_mutations: list[str] = [] excluded_sensitivity_mutations: list[str] = [] if not skip_sensitivity: for objective in objectives: @@ -510,10 +513,16 @@ def run_correctness_sensitivity_stage( ).world_size() > max_world_size ] + if not _truthy_env(INCLUDE_FLASH_SENSITIVITY_ENV): + default_excluded_sensitivity_mutations.append(FLASH_SENSITIVITY_MUTATION) mutations = [ mutation for mutation in mutations - if mutation not in excluded_sensitivity_mutations + if mutation + not in { + *excluded_sensitivity_mutations, + *default_excluded_sensitivity_mutations, + } ] LIVE_TRAINING_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) LIVE_TRAINING_LOG_PATH.write_text("", encoding="utf-8") @@ -560,6 +569,9 @@ def run_correctness_sensitivity_stage( "objectives": objectives, "sensitivity_mutations": mutations, "excluded_sensitivity_mutations": excluded_sensitivity_mutations, + "default_excluded_sensitivity_mutations": ( + default_excluded_sensitivity_mutations + ), "available_gpu_count": available_gpu_count, "max_world_size": max_world_size, "required_gpu_count": oracle_world_size, @@ -669,14 +681,7 @@ def run_yes_no_trainability_stage( base_model=base_model, allow_unvalidated_arch=allow_unvalidated_arch, ) - passed = ( - report.saturated_step is not None - and report.saturated_step > 0 - and report.initial_eval_reward < report.reward_threshold - and report.final_eval_reward is not None - and report.final_eval_reward >= report.reward_threshold - and report.final_eval_reward > report.initial_eval_reward - ) + passed = yes_no_trainability.yes_no_trainability_passed(report) return ValidationStageResult( name=YES_NO_TRAINABILITY_STAGE, passed=passed, diff --git a/tests/integration/megatron/trainability/test_config.py b/tests/integration/megatron/trainability/test_config.py index a055461e1..31a9b4619 100644 --- a/tests/integration/megatron/trainability/test_config.py +++ b/tests/integration/megatron/trainability/test_config.py @@ -7,8 +7,15 @@ import art -from .test_live_length_trainability import _default_learning_rate +from .test_live_length_trainability import ( + LengthSampleReport, + LengthTrainabilityReport, + _default_learning_rate, + length_trainability_passed, +) from .yes_no_trainability import ( + TrainabilityStepReport, + YesNoTrainabilityReport, _build_internal_config, _build_variant, _default_variant_name, @@ -18,6 +25,7 @@ _variant_max_steps, _variant_packed_sequence_length, _variant_rollouts_per_prompt, + yes_no_trainability_passed, ) @@ -154,11 +162,137 @@ def test_qwen3_5_defaults_to_shared_lora_rollout() -> None: assert "inference_gpu_ids" not in config +def test_dense_yes_no_default_uses_dedicated_placement(monkeypatch) -> None: + monkeypatch.delenv("ART_MODEL_SUPPORT_YES_NO_VARIANT", raising=False) + + assert _default_variant_name("Qwen/Qwen3-32B") == "megatron_dedicated" + + +def test_yes_no_default_variant_env_override(monkeypatch) -> None: + monkeypatch.setenv("ART_MODEL_SUPPORT_YES_NO_VARIANT", "megatron_shared") + + assert _default_variant_name("Qwen/Qwen3-32B") == "megatron_shared" + + +def test_yes_no_trainability_passes_initially_saturated_stable_report() -> None: + report = YesNoTrainabilityReport( + variant="megatron_shared", + backend_name="megatron", + placement_mode="shared", + base_model="google/gemma-4-31B-it", + output_dir="/tmp/report", + trainer_gpu_ids=[0, 1], + inference_gpu_ids=[0, 1], + rollout_weights_mode="lora", + reward_threshold=0.9, + max_steps=4, + prompt_count=8, + eval_prompt_count=8, + rollouts_per_prompt=4, + latest_step=1, + initial_eval_reward=0.9375, + final_eval_reward=0.9375, + saturated_step=1, + step0_name="model@0", + latest_name="model@1", + steps=[ + TrainabilityStepReport( + step=1, + eval_reward=0.9375, + train_reward=0.875, + train_metrics={"grad_norm": 54.0}, + ) + ], + ) + + assert yes_no_trainability_passed(report) is True + + def test_qwen3_5_length_trainability_uses_stable_learning_rate() -> None: assert _default_learning_rate("Qwen/Qwen3.5-35B-A3B") == 7e-5 assert _default_learning_rate("Qwen/Qwen3-30B-A3B-Instruct-2507") == 1e-4 +def test_length_trainability_accepts_near_baseline_learning_signal() -> None: + report = LengthTrainabilityReport( + base_model="google/gemma-4-31B-it", + max_steps=10, + max_steps_off_policy=0, + latest_step=3, + variant_name="megatron_dedicated", + trainer_gpu_ids=[0], + inference_gpu_ids=[1], + training_topology={"tp": 1, "cp": 1, "ep": 1, "etp": 1, "dp": 1, "sp": False}, + rollout_weights_mode="lora", + rollouts_per_prompt=4, + normalize_advantages=True, + summary_log_path="/tmp/length_trainability.log", + latest_summary_log_path="/tmp/latest_length_trainability.log", + initial_train_abs_error=3.875, + best_train_abs_error=0.5, + success_step=3, + final_train_reward=-0.05, + final_train_abs_error=0.5, + model_ids_after=["length@0", "length@3"], + samples=[ + LengthSampleReport( + split="train", + step=0, + scenario_index=0, + target_step=0, + target_tokens=10, + max_tokens=142, + prompt_word_count=300, + generated_tokens=14, + abs_error=4, + reward=-0.4, + text="a short answer", + ), + LengthSampleReport( + split="train", + step=0, + scenario_index=1, + target_step=0, + target_tokens=10, + max_tokens=142, + prompt_word_count=300, + generated_tokens=6, + abs_error=4, + reward=-0.4, + text="brief", + ), + LengthSampleReport( + split="train", + step=3, + scenario_index=2, + target_step=3, + target_tokens=10, + max_tokens=142, + prompt_word_count=300, + generated_tokens=10, + abs_error=0, + reward=0.0, + text="a target length answer", + ), + LengthSampleReport( + split="train", + step=3, + scenario_index=3, + target_step=3, + target_tokens=10, + max_tokens=142, + prompt_word_count=300, + generated_tokens=11, + abs_error=1, + reward=-0.1, + text="a slightly long answer", + ), + ], + ) + + assert length_trainability_passed(report) is True + + def test_validated_dense_model_uses_dense_shared_topology( monkeypatch, ) -> None: diff --git a/tests/integration/megatron/trainability/test_live_length_trainability.py b/tests/integration/megatron/trainability/test_live_length_trainability.py index ace4a19c1..fb03a0620 100644 --- a/tests/integration/megatron/trainability/test_live_length_trainability.py +++ b/tests/integration/megatron/trainability/test_live_length_trainability.py @@ -617,7 +617,8 @@ async def rollout_fn( success_step = next( ( step - for step, abs_error in train_abs_error_by_step.items() + for step in sorted(train_abs_error_by_step) + for abs_error in (train_abs_error_by_step[step],) if abs_error <= SUCCESS_ABS_ERROR_MAX ), None, @@ -682,11 +683,19 @@ def length_trainability_passed(report: LengthTrainabilityReport) -> bool: step: [sample.reward for sample in train_samples if sample.step == step] for step in {sample.step for sample in train_samples} } + started_far_enough = ( + report.initial_train_abs_error is not None + and report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN + ) + learned_after_near_baseline = ( + report.initial_train_abs_error is not None + and report.success_step is not None + and report.success_step > 0 + ) return ( bool(train_samples) and report.latest_step <= report.max_steps - and report.initial_train_abs_error is not None - and report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN + and (started_far_enough or learned_after_near_baseline) and report.best_train_abs_error is not None and report.best_train_abs_error <= SUCCESS_ABS_ERROR_MAX and report.success_step is not None @@ -709,7 +718,9 @@ def assert_length_trainability_passed(report: LengthTrainabilityReport) -> None: assert train_samples assert report.latest_step <= report.max_steps assert report.initial_train_abs_error is not None - assert report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN + assert report.initial_train_abs_error >= INITIAL_ABS_ERROR_MIN or ( + report.success_step is not None and report.success_step > 0 + ) assert report.best_train_abs_error is not None assert report.best_train_abs_error <= SUCCESS_ABS_ERROR_MAX assert report.success_step is not None diff --git a/tests/integration/megatron/trainability/test_live_yes_no_trainability.py b/tests/integration/megatron/trainability/test_live_yes_no_trainability.py index 119d3b74a..a12353752 100644 --- a/tests/integration/megatron/trainability/test_live_yes_no_trainability.py +++ b/tests/integration/megatron/trainability/test_live_yes_no_trainability.py @@ -4,7 +4,10 @@ import pytest -from .yes_no_trainability import run_yes_no_trainability_async +from .yes_no_trainability import ( + run_yes_no_trainability_async, + yes_no_trainability_passed, +) torch = pytest.importorskip("torch") @@ -29,12 +32,7 @@ def _unsloth_base_model() -> str: def _assert_passed(report) -> None: - assert report.saturated_step is not None - assert report.saturated_step > 0 - assert report.initial_eval_reward < report.reward_threshold - assert report.final_eval_reward is not None - assert report.final_eval_reward >= report.reward_threshold - assert report.final_eval_reward > report.initial_eval_reward + assert yes_no_trainability_passed(report) assert report.latest_step > 0 assert report.step0_name in report.model_ids_before assert report.latest_name in report.model_ids_after diff --git a/tests/integration/megatron/trainability/yes_no_trainability.py b/tests/integration/megatron/trainability/yes_no_trainability.py index a3f7918ae..29d544dcc 100644 --- a/tests/integration/megatron/trainability/yes_no_trainability.py +++ b/tests/integration/megatron/trainability/yes_no_trainability.py @@ -30,6 +30,7 @@ _TRAINER_GPU_IDS_ENV = "ART_MODEL_SUPPORT_TRAINER_GPU_IDS" _INFERENCE_GPU_IDS_ENV = "ART_MODEL_SUPPORT_INFERENCE_GPU_IDS" _SHARED_GPU_IDS_ENV = "ART_MODEL_SUPPORT_SHARED_GPU_IDS" +_VARIANT_ENV = "ART_MODEL_SUPPORT_YES_NO_VARIANT" _TRAINABILITY_ROOT = ( Path(__file__).resolve().parents[4] / ".local" / "model_support_validation" ) @@ -420,13 +421,22 @@ def _default_variant_name( *, allow_unvalidated_arch: bool = False, ) -> _VARIANT_NAME: - if ( - _rollout_weights_mode( - base_model, - allow_unvalidated_arch=allow_unvalidated_arch, - ) - == "merged" - ): + if override := os.environ.get(_VARIANT_ENV, "").strip(): + if override not in {"megatron_shared", "megatron_dedicated"}: + raise ValueError( + f"Unsupported {_VARIANT_ENV}={override!r}. " + "Expected 'megatron_shared' or 'megatron_dedicated'." + ) + return cast(_VARIANT_NAME, override) + is_moe = model_uses_expert_parallel( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + rollout_weights_mode = _rollout_weights_mode( + base_model, + allow_unvalidated_arch=allow_unvalidated_arch, + ) + if rollout_weights_mode == "merged" or not is_moe: return "megatron_dedicated" return "megatron_shared" @@ -824,6 +834,26 @@ def run_yes_no_trainability( ) +def yes_no_trainability_passed(report: YesNoTrainabilityReport) -> bool: + learned_from_below_threshold = ( + report.saturated_step is not None + and report.saturated_step > 0 + and report.initial_eval_reward < report.reward_threshold + and report.final_eval_reward is not None + and report.final_eval_reward >= report.reward_threshold + and report.final_eval_reward > report.initial_eval_reward + ) + already_saturated_and_stable = ( + report.initial_eval_reward >= report.reward_threshold + and report.latest_step > 0 + and report.final_eval_reward is not None + and report.final_eval_reward >= report.reward_threshold + and bool(report.steps) + and any(step.train_metrics.get("grad_norm", 0.0) > 0.0 for step in report.steps) + ) + return learned_from_below_threshold or already_saturated_and_stable + + def run_megatron_dedicated_yes_no_trainability( base_model: str, *,