Skip to content

Add LlamaModelTP: tensor-parallel variant of LlamaModel#25

Open
anupsv wants to merge 3 commits into
feat/llama-pipeline-parallelfrom
feat/llama-tensor-parallel
Open

Add LlamaModelTP: tensor-parallel variant of LlamaModel#25
anupsv wants to merge 3 commits into
feat/llama-pipeline-parallelfrom
feat/llama-tensor-parallel

Conversation

@anupsv
Copy link
Copy Markdown

@anupsv anupsv commented May 21, 2026

Summary

Adds LlamaModelTP, a tensor-parallel variant of LlamaModel that two (or more) ranks can use to split each layer's compute and weight memory. On a 2-Mac Thunderbolt 5 cluster, TP roughly halves single-stream decode latency vs the existing PP path because both Macs run all layers in parallel rather than taking turns.

Stack: Stacked on #24 (Llama callPartial for pipeline-parallel inference, kept as fallback). Requires the sharded-linear primitives from Layr-Labs/mlx-swift#3 (which in turn stacks on Layr-Labs/mlx-swift#2).

What's in this PR

File Change
Libraries/MLXLLM/Models/LlamaTP.swift (new) LlamaModelTP, LlamaModelInnerTP, LlamaTransformerBlockTP, LlamaAttentionTP, LlamaMLPTP. Q/K/V/gate/up → AllToShardedLinear (column-parallel). O/down → ShardedToAllLinear (row-parallel, allSum in forward). Embedding + layernorms + LM head stay replicated. Per-rank local head counts derived from attentionHeads / group.size and kvHeads / group.size. sanitize(weights:) slices loaded weights into the rank's shard (column-parallel along axis 0, row-parallel along axis 1). LoRAModel conformance for parity with LlamaModel.
Tests/MLXLMTests/LlamaTPTests.swift (new) 8 tests: parameter-shape parity (TP world=1 vs LlamaModel), forward output shape, numerical equivalence (identical weights → identical logits to within 1e-3), divisibility validation, and static shardWeightIfNeeded behavior for column-parallel, row-parallel, embedding, layernorm key patterns. All pass on Apple Silicon with mlx.metallib present.

Net: 2 files, +515 lines, additive. Existing LlamaModel is untouched; consumers that don't need TP keep using it.

Behavior on singleton group (world == 1)

AllToShardedLinear with world=1 is shape-equivalent to a regular Linear. ShardedToAllLinear with world=1 does an allSum that's a no-op. sanitize(weights:) early-returns for world == 1. So LlamaModelTP(config, group: singleton) with the same weights as LlamaModel(config) produces logits equal to the unsharded reference within float-accumulation tolerance — this is the equivalence-test baseline.

What's NOT in this PR (deferred)

  • Quantized TP. 4-bit packed weights (.scales / .biases) are currently passed through unmodified in sanitize. Adding quantized TP requires packed-uint32-aware slicing of weight (each uint32 packs 8 4-bit weights, so axis-1 slicing on row-parallel layers needs inDim % (8 * world) == 0). Stock Llama dims satisfy this; left as follow-up so this PR stays reviewable.
  • Multi-process tests. The numerical equivalence test runs on a singleton group (in-process). Real two-Mac validation lives in the d-inference cluster smoke test that lands in the third PR of this stack.
  • LM head sharding. Currently replicated. Column-parallel + allGather is a future optimization once memory becomes the bottleneck.
  • Other model families. Mistral, Qwen, Gemma, etc. will fall back to the existing PP path (callPartial) until they get their own *TP variants. Pattern is identical — copy this file, swap the model-specific module structure.

Test plan

  • swift build succeeds via the d-inference workspace (uses local sibling submodules)
  • swift test --filter "LlamaTPTests" passes all 8 tests on macOS 14+ (Apple Silicon, requires mlx.metallib in the test bundle)
  • No regression for existing LlamaModel / LlamaModelInner consumers — files are pure additions
  • Two-Mac Thunderbolt 5 numerical equivalence (real world=2 distributed group via jaccl) — lives in the d-inference cluster smoke test of the third PR

LlamaModelTP mirrors LlamaModel but uses sharded linear layers from
mlx-swift's AllToShardedLinear / ShardedToAllLinear so two (or more)
ranks can split each layer's compute and weight memory:

- Q/K/V projections + gate/up are AllToShardedLinear (column-parallel):
  each rank holds outDim/world_size rows; output is sharded across heads.
  No collective op in forward.
- O projection + down are ShardedToAllLinear (row-parallel): each rank
  holds inDim/world_size cols; allSum in forward to combine partial
  contributions across ranks.
- Embedding, layernorms, and the (optional, non-tied) LM head stay
  replicated — LM head sharding is a future optimization.

Per-rank: attentionHeads/world_size local Q heads and kvHeads/world_size
local KV heads. The KV cache is sized to the local KV head count so
each rank only stores its shard.

Weight loading: LlamaModelTP.sanitize(weights:) slices column-parallel
weights along axis 0 (output dim) and row-parallel weights along axis 1
(input dim) into each rank's shard before module assignment. Singleton
group (world=1) is a pass-through. Quantized weights (.scales/.biases)
are passed through unmodified — quantized TP support is a follow-up
that needs packed-uint32-aware slicing.

The exposed shardWeightIfNeeded static helper centralizes the
key-pattern → axis decision so tests can validate the sharding logic
without spinning up a real distributed group.

Tests/MLXLMTests/LlamaTPTests.swift covers:
- Parameter shape parity between LlamaModel and LlamaModelTP(world=1)
- Forward-pass output shape with singleton group
- Numerical equivalence: with identical weights, LlamaModelTP(world=1)
  matches LlamaModel logits to within 1e-3 (float-accumulation noise)
- Static shardWeightIfNeeded behavior for column-parallel, row-parallel,
  embedding, and layernorm key patterns

Stacked on feat/llama-pipeline-parallel (PR #24) for callPartial; that
remains the fallback path for non-Llama models. Requires Cmlx + the
sharded linear primitives from Layr-Labs/mlx-swift#3.
Critical: stock MLX Llama checkpoints are 4-bit quantized; the original
LlamaModelTP only worked on fp16/bf16 weights because AllToShardedLinear
and QuantizedAllToShardedLinear aren't in an inheritance hierarchy in
mlx-swift — you can't swap one for the other via dynamic dispatch the
way MLXNN.quantize() swaps Linear for QuantizedLinear.

This commit adds a parallel quantized class hierarchy:
- LlamaAttentionTPQ, LlamaMLPTPQ, LlamaTransformerBlockTPQ,
  LlamaModelInnerTPQ, LlamaModelTPQ
- Uses QuantizedAllToShardedLinear / QuantizedShardedToAllLinear for the
  Q/K/V/gate/up and O/down projections respectively
- LM head stays replicated + quantized (QuantizedLinear) — same memory
  trade-off as the fp16 path, just with quantized weights
- sanitize() / shardQuantizedWeightIfNeeded handle .weight (uint32
  packed shape [outDim, inDim/8]) plus .scales / .biases (fp16 shape
  [outDim, inDim/groupSize]) along the right axes. Row-parallel biases
  pass through unsliced (added once after the allSum).
- A `makeLlamaTP(args:quantization:group:)` factory picks the right
  variant from a BaseConfiguration.Quantization value.

Also tightens shardWeightIfNeeded on the fp16 LlamaModelTP:
- "norm" substring check replaced with explicit isLayerNormKey() that
  matches input_layernorm / post_attention_layernorm / model.norm only —
  future TP variants with q_norm / k_norm per-head norms (Qwen3, Gemma)
  won't accidentally skip sharding for those.
- Projection name matches use path-segment boundaries via isProjectionKey
  (e.g. ".q_proj.weight" not "q_proj.weight") so a "q_proj_extra" never
  matches.
- Loading a quantized blob into LlamaModelTP now precondition-fails with
  a clear message pointing at LlamaModelTPQ — silent corruption replaced
  with a loud routing error.
- Row-parallel biases are now also passed through unsliced in the fp16
  path (was an open TODO comment before).

High: the previous numerical-equivalence test on a singleton group
(world=1) didn't exercise any actual sharding — sharded layers degenerate
to regular Linear. To actually validate the sharding math without
needing a multi-process distributed backend, this commit adds three
matrix-level tests that:
- Slice weights with the SAME shardWeightIfNeeded function the production
  sanitize path uses
- Manually run the per-rank matmuls and (for row-parallel) simulate the
  allSum by summing partial outputs
- Compare against the unsharded reference to within 1e-4

Covered patterns: column-parallel (concat), row-parallel (allSum),
column-then-row pipeline (the typical attention / MLP shape).

Real multi-rank DistributedGroup tests still need spawned subprocesses
(jaccl or ring backend) — the 2-Mac Thunderbolt 5 smoke test in
d-inference remains the final validation. But the in-process math
checks catch axis confusion and slicing bugs that would have shipped
silently before.

Plus 5 unit tests for the quantized weight slicing helper and a small
test that q_norm / k_norm aren't tripped by the layernorm gate.

Net: 18 tests pass, up from 8. No regression to existing LlamaModel /
LlamaModelInner consumers; LlamaModelTP behavior unchanged for fp16
checkpoints (only the diagnostic precondition is new). LlamaModelTPQ is
additive.
anupsv added a commit to Layr-Labs/d-inference that referenced this pull request May 21, 2026
Picks up the quantized TP variant + matrix-level sharding tests on
Layr-Labs/mlx-swift-lm#25 (commit 56a5ca6). This makes makeLlamaTP()
available for routing the dispatcher to the right TP variant when
a 4-bit / 8-bit quantized checkpoint is loaded.
LoRA: LLMModel protocol requires LoRAModel conformance, but TP+LoRA
isn't actually supported. MLXLLM's LoRA training code walks loraLayers
looking for Linear sub-modules to wrap with adapters; our sharded
variants (AllToShardedLinear, QuantizedAllToShardedLinear, etc.) are
NOT Linear subclasses so adapter insertion either no-ops silently or
hits surprising shape errors at train time.

Both LlamaModelTP and LlamaModelTPQ now return `[]` from loraLayers
with a docstring pointing at LlamaModel (single-rank) for LoRA
workflows. Sharded LoRA is its own research problem.

Test: testLlamaTPRejectsNonDivisibleHeads was misleading — it actually
asserted XCTAssertNoThrow because 7 % 1 == 0 on a singleton group.
Renamed to testLlamaTPAcceptsAnyHeadCountOnSingletonGroup so the name
matches what's tested. Real world>=2 rejection paths are exercised
end-to-end by the 2-Mac smoke test in d-inference, not here.
anupsv added a commit to Layr-Labs/d-inference that referenced this pull request May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#25 commit 735c28a — drops the
misleading LoRA conformance on LlamaModelTP/Q (returns [] with a
docstring pointing at LlamaModel for LoRA workflows) and renames the
testLlamaTPRejectsNonDivisibleHeads test to match what it actually
exercises.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant