Skip to content

Add Llama callPartial for pipeline-parallel inference#24

Open
anupsv wants to merge 1 commit into
mainfrom
feat/llama-pipeline-parallel
Open

Add Llama callPartial for pipeline-parallel inference#24
anupsv wants to merge 1 commit into
mainfrom
feat/llama-pipeline-parallel

Conversation

@anupsv
Copy link
Copy Markdown

@anupsv anupsv commented May 21, 2026

Summary

Adds callPartial entry points on LlamaModelInner and LlamaModel so a single rank can run a contiguous sub-range of transformer blocks, with control over whether the rank applies token embedding, the final RMS norm, and the LM head.

Used by downstream pipeline-parallel inference (e.g. d-inference EncryptedPipelineInference.swift / PipelineInference.swift) where two Macs split a Llama-family model across Thunderbolt 5 and exchange activation tensors between ranks.

Shape of the change

Method Where Purpose
LlamaModelInner.callPartial(_:layerRange:applyEmbedding:applyNorm:cache:) Libraries/MLXLLM/Models/Llama.swift Run blocks layerRange.lowerBound ..< layerRange.upperBound. Rank 0 sets applyEmbedding=true; last rank sets applyNorm=true. KV cache indices are relative to layerRange.
LlamaModel.callPartial(_:layerRange:applyEmbedding:applyNorm:applyHead:cache:) Libraries/MLXLLM/Models/Llama.swift Wraps LlamaModelInner.callPartial and optionally projects to vocab logits via the LM head (last rank only).

Net: 1 file, +52 lines, additive. No existing forward path changes.

Test plan

  • swift build succeeds on macOS 14+ (Apple Silicon)
  • Existing Llama / Mistral consumers still compile and produce identical outputs (verified via d-inference provider-swift tests)
  • Two-rank pipeline split on a Llama-class model produces logits equal to a single-rank run (within numerical tolerance)

Why upstream

Pinning this on our fork unblocks d-inference CI (the d-inference encrypted-pipeline-inference stack imports callPartial directly). Once landed we'll bump the d-inference submodule pointer. Related: Layr-Labs/mlx-swift#2 (the matching change on the underlying mlx-swift fork to expose Cmlx + enable jaccl).

anupsv added a commit to Layr-Labs/d-inference that referenced this pull request May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#24 which adds the callPartial methods
on LlamaModelInner and LlamaModel that EncryptedPipelineInference and
PipelineInference need for two-rank pipeline-parallel inference.

Without this bump, CI's `swift build` fails with:
  error: value of type 'LlamaModel' has no member 'callPartial'

Related: #193 (upstream mlx-swift distributed deviation tracker).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds callPartial entry points on LlamaModelInner and LlamaModel for
running a contiguous sub-range of transformer blocks on a single rank.
The caller specifies which layers to run, whether to apply token
embedding (rank 0 only), the final RMS norm (last rank only), and the
LM head (last rank only).

Used by downstream pipeline-parallel inference (e.g. d-inference's
EncryptedPipelineInference.swift / PipelineInference.swift) where two
Macs split a Llama-family model across Thunderbolt 5 and exchange
activation tensors between ranks.

KV cache indices are relative to the rank's layerRange, so each rank
maintains a cache slice sized to its own layer count rather than the
full model.

No change to the existing forward path; callPartial is additive.
anupsv added a commit to Layr-Labs/d-inference that referenced this pull request May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#24 which adds the callPartial methods
on LlamaModelInner and LlamaModel that EncryptedPipelineInference and
PipelineInference need for two-rank pipeline-parallel inference.

Without this bump, CI's `swift build` fails with:
  error: value of type 'LlamaModel' has no member 'callPartial'

Related: #193 (upstream mlx-swift distributed deviation tracker).
@anupsv anupsv force-pushed the feat/llama-pipeline-parallel branch from fb999f1 to c2fbbdc Compare May 21, 2026 04:22
anupsv added a commit to Layr-Labs/d-inference that referenced this pull request May 21, 2026
Force-pushes on Layr-Labs/mlx-swift#2 and Layr-Labs/mlx-swift-lm#24
landed new SHAs (fa6a4e8, c2fbbdc) — bump the submodule pointers to
match.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant