Skip to content

Port sharded linear primitives + DistributedGroup from upstream #371#3

Open
anupsv wants to merge 1 commit into
feat/cmlx-jaccl-distributedfrom
feat/sharded-linear-primitives
Open

Port sharded linear primitives + DistributedGroup from upstream #371#3
anupsv wants to merge 1 commit into
feat/cmlx-jaccl-distributedfrom
feat/sharded-linear-primitives

Conversation

@anupsv
Copy link
Copy Markdown

@anupsv anupsv commented May 21, 2026

Summary

Vendors two files from ml-explore/mlx-swift#371 ("Add distributed communication framework for multi-device tensor parallelism") to enable tensor-parallel inference on this fork ahead of the upstream merge.

Stacked on #2 (Cmlx product + jaccl backend) — base is feat/cmlx-jaccl-distributed, not main. Merge #2 first; then this PR can rebase onto main.

What's in this PR

File Purpose
Source/MLX/Distributed.swift DistributedGroup class wrapping mlx_distributed_group_t with rank, size, split, send/recv, recvLike, allSum, allGather, allMax, allMin, sumScatter. DistributedBackend enum (.any, .mpi, .ring, .jaccl). DistributedError with structured failure modes. Throw-style API on top of the C mlx_distributed_* symbols exposed via Cmlx (depends on #2).
Source/MLXNN/Distributed.swift AllToShardedLinear (column-parallel — output sharded, input replicated, no collective in forward). ShardedToAllLinear (row-parallel — input sharded, allSum in forward). QuantizedAllToShardedLinear / QuantizedShardedToAllLinear (4-bit variants — critical since most production MLX checkpoints are 4-bit). shardLinear / shardInPlace for slicing existing weights into rank-local shards. averageGradients with batched all-reduce + cast-on-wire bandwidth reduction. sumGradients identity-forward / allSum-backward VJP helper.

Net: 2 files, +1400 lines, additive. Both files come from PicoMLX/mlx-swift @ ba68dfa unmodified (upstream PR head at port time). No behavior change for existing MLX, MLXNN, MLXRandom, MLXFast, MLXOptimizers, MLXFFT, MLXLinalg consumers — the MLX and MLXNN targets auto-discover sources by directory.

Why we need it before upstream lands

  • d-inference's cluster work is moving from pipeline-parallel inference (one rank runs layers 0..N/2, sends activations) to tensor-parallel inference (both ranks run all layers in parallel, exchanging allreduces per layer) — TP is roughly 2× lower decode latency on a 2-Mac Thunderbolt 5 cluster.
  • The sharded linear primitives are the building blocks: TP-aware LlamaModelTP will be added in a follow-up on Layr-Labs/mlx-swift-lm, and the dispatcher (TP default, PP fallback) lands in Layr-Labs/d-inference.
  • Upstream Add distributed communication framework for multi-device tensor parallelism ml-explore/mlx-swift#371 has been open since 2026-03-15 with no clear merge ETA. We track it via d-inference#193; when it lands, we rebase the fork onto upstream and retire any local divergence.

Test plan

  • swift build succeeds on macOS 14+ (Apple Silicon) with the Cmlx + jaccl backend from Expose Cmlx product + enable jaccl distributed backend #2
  • No symbol/link errors against the existing mlx_distributed_* C API
  • No regression for existing MLX/MLXNN consumers — files are pure additions, targets auto-discover sources
  • Numerical equivalence via the LlamaModelTP follow-up PR on mlx-swift-lm (the test harness for sharded layers lives there, not here)
  • Two-Mac Thunderbolt 5 smoke test in d-inference once all three PRs land

Risks worth naming

  • Linker dep on Cmlx symbols. Without #2's un-exclusion of mlx-c/mlx/c/distributed*.cpp and the surfaced headers, this branch won't link. Hence the stack ordering.
  • Sendable annotations. DistributedGroup is declared @unchecked Sendable because the underlying C handle is thread-safe per jaccl/ring documentation but the Swift compiler can't prove it.
  • Quantized* variants. 4-bit packed weights need in_dim divisible by 8 * world_size = 16 for row-parallel layers. All standard Llama / Mistral / Qwen hidden dims (4096, 8192) clear this.

Upstream-bug awareness (from d-inference#193)

Same caveats as #2ml-explore/mlx#3149 (varying-shape p2p hang), #3467 (RTR GID regression), #3442 (backend="any" picks ring). The TP forward path uses fixed-shape allSum so #3149's failure mode shouldn't apply, but worth verifying empirically.

…plore#371

Vendors two files from ml-explore#371 ("Add distributed
communication framework for multi-device tensor parallelism") to
enable tensor-parallel inference on the Layr-Labs fork ahead of the
upstream merge:

Source/MLX/Distributed.swift
- DistributedGroup class wrapping mlx_distributed_group_t with rank,
  size, split, send/recv, recvLike, allSum, allGather, allMax, allMin,
  sumScatter
- DistributedBackend enum (.any, .mpi, .ring, .jaccl)
- DistributedError type with structured failure modes
- Throw-style API on top of the existing C mlx_distributed_* symbols
  exposed via Cmlx (depends on the earlier #2 PR that enables the
  jaccl backend and surfaces the distributed-group headers)

Source/MLXNN/Distributed.swift
- AllToShardedLinear / ShardedToAllLinear (column-parallel and
  row-parallel linear layers)
- QuantizedAllToShardedLinear / QuantizedShardedToAllLinear (4-bit
  variants — critical since most production MLX checkpoints are
  4-bit quantized)
- shardLinear / shardInPlace utilities for slicing existing weights
  into rank-local shards
- averageGradients helper with batched allReduce, cast-on-wire
  communicationType for bandwidth reduction, and mixed-dtype fallback
- sumGradients identity-forward / allSum-backward VJP helper

Both files come from PicoMLX/mlx-swift @ ba68dfa unmodified (upstream
PR head at port time).

No behavior change for existing MLX, MLXNN, MLXRandom, MLXFast,
MLXOptimizers, MLXFFT, MLXLinalg consumers — the new files are pure
additions and the targets auto-discover sources by directory.

Stacked on top of #2 (Cmlx product + jaccl backend); merge that first.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant