Port sharded linear primitives + DistributedGroup from upstream #371#3
Open
anupsv wants to merge 1 commit into
Open
Port sharded linear primitives + DistributedGroup from upstream #371#3anupsv wants to merge 1 commit into
anupsv wants to merge 1 commit into
Conversation
…plore#371 Vendors two files from ml-explore#371 ("Add distributed communication framework for multi-device tensor parallelism") to enable tensor-parallel inference on the Layr-Labs fork ahead of the upstream merge: Source/MLX/Distributed.swift - DistributedGroup class wrapping mlx_distributed_group_t with rank, size, split, send/recv, recvLike, allSum, allGather, allMax, allMin, sumScatter - DistributedBackend enum (.any, .mpi, .ring, .jaccl) - DistributedError type with structured failure modes - Throw-style API on top of the existing C mlx_distributed_* symbols exposed via Cmlx (depends on the earlier #2 PR that enables the jaccl backend and surfaces the distributed-group headers) Source/MLXNN/Distributed.swift - AllToShardedLinear / ShardedToAllLinear (column-parallel and row-parallel linear layers) - QuantizedAllToShardedLinear / QuantizedShardedToAllLinear (4-bit variants — critical since most production MLX checkpoints are 4-bit quantized) - shardLinear / shardInPlace utilities for slicing existing weights into rank-local shards - averageGradients helper with batched allReduce, cast-on-wire communicationType for bandwidth reduction, and mixed-dtype fallback - sumGradients identity-forward / allSum-backward VJP helper Both files come from PicoMLX/mlx-swift @ ba68dfa unmodified (upstream PR head at port time). No behavior change for existing MLX, MLXNN, MLXRandom, MLXFast, MLXOptimizers, MLXFFT, MLXLinalg consumers — the new files are pure additions and the targets auto-discover sources by directory. Stacked on top of #2 (Cmlx product + jaccl backend); merge that first.
This was referenced May 21, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Vendors two files from ml-explore/mlx-swift#371 ("Add distributed communication framework for multi-device tensor parallelism") to enable tensor-parallel inference on this fork ahead of the upstream merge.
Stacked on #2 (Cmlx product + jaccl backend) — base is
feat/cmlx-jaccl-distributed, notmain. Merge #2 first; then this PR can rebase ontomain.What's in this PR
Source/MLX/Distributed.swiftDistributedGroupclass wrappingmlx_distributed_group_twithrank,size,split,send/recv,recvLike,allSum,allGather,allMax,allMin,sumScatter.DistributedBackendenum (.any,.mpi,.ring,.jaccl).DistributedErrorwith structured failure modes. Throw-style API on top of the Cmlx_distributed_*symbols exposed viaCmlx(depends on #2).Source/MLXNN/Distributed.swiftAllToShardedLinear(column-parallel — output sharded, input replicated, no collective in forward).ShardedToAllLinear(row-parallel — input sharded,allSumin forward).QuantizedAllToShardedLinear/QuantizedShardedToAllLinear(4-bit variants — critical since most production MLX checkpoints are 4-bit).shardLinear/shardInPlacefor slicing existing weights into rank-local shards.averageGradientswith batched all-reduce + cast-on-wire bandwidth reduction.sumGradientsidentity-forward / allSum-backward VJP helper.Net: 2 files, +1400 lines, additive. Both files come from PicoMLX/mlx-swift @ ba68dfa unmodified (upstream PR head at port time). No behavior change for existing
MLX,MLXNN,MLXRandom,MLXFast,MLXOptimizers,MLXFFT,MLXLinalgconsumers — theMLXandMLXNNtargets auto-discover sources by directory.Why we need it before upstream lands
LlamaModelTPwill be added in a follow-up onLayr-Labs/mlx-swift-lm, and the dispatcher (TP default, PP fallback) lands inLayr-Labs/d-inference.Test plan
swift buildsucceeds on macOS 14+ (Apple Silicon) with the Cmlx + jaccl backend from Expose Cmlx product + enable jaccl distributed backend #2mlx_distributed_*C APIRisks worth naming
mlx-c/mlx/c/distributed*.cppand the surfaced headers, this branch won't link. Hence the stack ordering.Sendableannotations.DistributedGroupis declared@unchecked Sendablebecause the underlying C handle is thread-safe per jaccl/ring documentation but the Swift compiler can't prove it.Quantized*variants. 4-bit packed weights needin_dimdivisible by8 * world_size = 16for row-parallel layers. All standard Llama / Mistral / Qwen hidden dims (4096, 8192) clear this.Upstream-bug awareness (from d-inference#193)
Same caveats as #2 —
ml-explore/mlx#3149(varying-shape p2p hang),#3467(RTR GID regression),#3442(backend="any" picks ring). The TP forward path uses fixed-shapeallSumso #3149's failure mode shouldn't apply, but worth verifying empirically.