Add LlamaModelTP: tensor-parallel variant of LlamaModel#25
Open
anupsv wants to merge 3 commits into
Open
Conversation
LlamaModelTP mirrors LlamaModel but uses sharded linear layers from mlx-swift's AllToShardedLinear / ShardedToAllLinear so two (or more) ranks can split each layer's compute and weight memory: - Q/K/V projections + gate/up are AllToShardedLinear (column-parallel): each rank holds outDim/world_size rows; output is sharded across heads. No collective op in forward. - O projection + down are ShardedToAllLinear (row-parallel): each rank holds inDim/world_size cols; allSum in forward to combine partial contributions across ranks. - Embedding, layernorms, and the (optional, non-tied) LM head stay replicated — LM head sharding is a future optimization. Per-rank: attentionHeads/world_size local Q heads and kvHeads/world_size local KV heads. The KV cache is sized to the local KV head count so each rank only stores its shard. Weight loading: LlamaModelTP.sanitize(weights:) slices column-parallel weights along axis 0 (output dim) and row-parallel weights along axis 1 (input dim) into each rank's shard before module assignment. Singleton group (world=1) is a pass-through. Quantized weights (.scales/.biases) are passed through unmodified — quantized TP support is a follow-up that needs packed-uint32-aware slicing. The exposed shardWeightIfNeeded static helper centralizes the key-pattern → axis decision so tests can validate the sharding logic without spinning up a real distributed group. Tests/MLXLMTests/LlamaTPTests.swift covers: - Parameter shape parity between LlamaModel and LlamaModelTP(world=1) - Forward-pass output shape with singleton group - Numerical equivalence: with identical weights, LlamaModelTP(world=1) matches LlamaModel logits to within 1e-3 (float-accumulation noise) - Static shardWeightIfNeeded behavior for column-parallel, row-parallel, embedding, and layernorm key patterns Stacked on feat/llama-pipeline-parallel (PR #24) for callPartial; that remains the fallback path for non-Llama models. Requires Cmlx + the sharded linear primitives from Layr-Labs/mlx-swift#3.
5 tasks
Critical: stock MLX Llama checkpoints are 4-bit quantized; the original LlamaModelTP only worked on fp16/bf16 weights because AllToShardedLinear and QuantizedAllToShardedLinear aren't in an inheritance hierarchy in mlx-swift — you can't swap one for the other via dynamic dispatch the way MLXNN.quantize() swaps Linear for QuantizedLinear. This commit adds a parallel quantized class hierarchy: - LlamaAttentionTPQ, LlamaMLPTPQ, LlamaTransformerBlockTPQ, LlamaModelInnerTPQ, LlamaModelTPQ - Uses QuantizedAllToShardedLinear / QuantizedShardedToAllLinear for the Q/K/V/gate/up and O/down projections respectively - LM head stays replicated + quantized (QuantizedLinear) — same memory trade-off as the fp16 path, just with quantized weights - sanitize() / shardQuantizedWeightIfNeeded handle .weight (uint32 packed shape [outDim, inDim/8]) plus .scales / .biases (fp16 shape [outDim, inDim/groupSize]) along the right axes. Row-parallel biases pass through unsliced (added once after the allSum). - A `makeLlamaTP(args:quantization:group:)` factory picks the right variant from a BaseConfiguration.Quantization value. Also tightens shardWeightIfNeeded on the fp16 LlamaModelTP: - "norm" substring check replaced with explicit isLayerNormKey() that matches input_layernorm / post_attention_layernorm / model.norm only — future TP variants with q_norm / k_norm per-head norms (Qwen3, Gemma) won't accidentally skip sharding for those. - Projection name matches use path-segment boundaries via isProjectionKey (e.g. ".q_proj.weight" not "q_proj.weight") so a "q_proj_extra" never matches. - Loading a quantized blob into LlamaModelTP now precondition-fails with a clear message pointing at LlamaModelTPQ — silent corruption replaced with a loud routing error. - Row-parallel biases are now also passed through unsliced in the fp16 path (was an open TODO comment before). High: the previous numerical-equivalence test on a singleton group (world=1) didn't exercise any actual sharding — sharded layers degenerate to regular Linear. To actually validate the sharding math without needing a multi-process distributed backend, this commit adds three matrix-level tests that: - Slice weights with the SAME shardWeightIfNeeded function the production sanitize path uses - Manually run the per-rank matmuls and (for row-parallel) simulate the allSum by summing partial outputs - Compare against the unsharded reference to within 1e-4 Covered patterns: column-parallel (concat), row-parallel (allSum), column-then-row pipeline (the typical attention / MLP shape). Real multi-rank DistributedGroup tests still need spawned subprocesses (jaccl or ring backend) — the 2-Mac Thunderbolt 5 smoke test in d-inference remains the final validation. But the in-process math checks catch axis confusion and slicing bugs that would have shipped silently before. Plus 5 unit tests for the quantized weight slicing helper and a small test that q_norm / k_norm aren't tripped by the layernorm gate. Net: 18 tests pass, up from 8. No regression to existing LlamaModel / LlamaModelInner consumers; LlamaModelTP behavior unchanged for fp16 checkpoints (only the diagnostic precondition is new). LlamaModelTPQ is additive.
anupsv
added a commit
to Layr-Labs/d-inference
that referenced
this pull request
May 21, 2026
Picks up the quantized TP variant + matrix-level sharding tests on Layr-Labs/mlx-swift-lm#25 (commit 56a5ca6). This makes makeLlamaTP() available for routing the dispatcher to the right TP variant when a 4-bit / 8-bit quantized checkpoint is loaded.
LoRA: LLMModel protocol requires LoRAModel conformance, but TP+LoRA isn't actually supported. MLXLLM's LoRA training code walks loraLayers looking for Linear sub-modules to wrap with adapters; our sharded variants (AllToShardedLinear, QuantizedAllToShardedLinear, etc.) are NOT Linear subclasses so adapter insertion either no-ops silently or hits surprising shape errors at train time. Both LlamaModelTP and LlamaModelTPQ now return `[]` from loraLayers with a docstring pointing at LlamaModel (single-rank) for LoRA workflows. Sharded LoRA is its own research problem. Test: testLlamaTPRejectsNonDivisibleHeads was misleading — it actually asserted XCTAssertNoThrow because 7 % 1 == 0 on a singleton group. Renamed to testLlamaTPAcceptsAnyHeadCountOnSingletonGroup so the name matches what's tested. Real world>=2 rejection paths are exercised end-to-end by the 2-Mac smoke test in d-inference, not here.
anupsv
added a commit
to Layr-Labs/d-inference
that referenced
this pull request
May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#25 commit 735c28a — drops the misleading LoRA conformance on LlamaModelTP/Q (returns [] with a docstring pointing at LlamaModel for LoRA workflows) and renames the testLlamaTPRejectsNonDivisibleHeads test to match what it actually exercises.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
LlamaModelTP, a tensor-parallel variant ofLlamaModelthat two (or more) ranks can use to split each layer's compute and weight memory. On a 2-Mac Thunderbolt 5 cluster, TP roughly halves single-stream decode latency vs the existing PP path because both Macs run all layers in parallel rather than taking turns.Stack: Stacked on #24 (Llama callPartial for pipeline-parallel inference, kept as fallback). Requires the sharded-linear primitives from
Layr-Labs/mlx-swift#3(which in turn stacks onLayr-Labs/mlx-swift#2).What's in this PR
Libraries/MLXLLM/Models/LlamaTP.swift(new)LlamaModelTP,LlamaModelInnerTP,LlamaTransformerBlockTP,LlamaAttentionTP,LlamaMLPTP. Q/K/V/gate/up →AllToShardedLinear(column-parallel). O/down →ShardedToAllLinear(row-parallel, allSum in forward). Embedding + layernorms + LM head stay replicated. Per-rank local head counts derived fromattentionHeads / group.sizeandkvHeads / group.size.sanitize(weights:)slices loaded weights into the rank's shard (column-parallel along axis 0, row-parallel along axis 1).LoRAModelconformance for parity withLlamaModel.Tests/MLXLMTests/LlamaTPTests.swift(new)shardWeightIfNeededbehavior for column-parallel, row-parallel, embedding, layernorm key patterns. All pass on Apple Silicon with mlx.metallib present.Net: 2 files, +515 lines, additive. Existing
LlamaModelis untouched; consumers that don't need TP keep using it.Behavior on singleton group (
world == 1)AllToShardedLinearwithworld=1is shape-equivalent to a regularLinear.ShardedToAllLinearwithworld=1does anallSumthat's a no-op.sanitize(weights:)early-returns forworld == 1. SoLlamaModelTP(config, group: singleton)with the same weights asLlamaModel(config)produces logits equal to the unsharded reference within float-accumulation tolerance — this is the equivalence-test baseline.What's NOT in this PR (deferred)
sanitize. Adding quantized TP requires packed-uint32-aware slicing ofweight(each uint32 packs 8 4-bit weights, so axis-1 slicing on row-parallel layers needsinDim % (8 * world) == 0). Stock Llama dims satisfy this; left as follow-up so this PR stays reviewable.allGatheris a future optimization once memory becomes the bottleneck.callPartial) until they get their own*TPvariants. Pattern is identical — copy this file, swap the model-specific module structure.Test plan
swift buildsucceeds via the d-inference workspace (uses local sibling submodules)swift test --filter "LlamaTPTests"passes all 8 tests on macOS 14+ (Apple Silicon, requires mlx.metallib in the test bundle)LlamaModel/LlamaModelInnerconsumers — files are pure additionsworld=2distributed group via jaccl) — lives in the d-inference cluster smoke test of the third PR