Add Llama callPartial for pipeline-parallel inference#24
Open
anupsv wants to merge 1 commit into
Open
Conversation
anupsv
added a commit
to Layr-Labs/d-inference
that referenced
this pull request
May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#24 which adds the callPartial methods on LlamaModelInner and LlamaModel that EncryptedPipelineInference and PipelineInference need for two-rank pipeline-parallel inference. Without this bump, CI's `swift build` fails with: error: value of type 'LlamaModel' has no member 'callPartial' Related: #193 (upstream mlx-swift distributed deviation tracker). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds callPartial entry points on LlamaModelInner and LlamaModel for running a contiguous sub-range of transformer blocks on a single rank. The caller specifies which layers to run, whether to apply token embedding (rank 0 only), the final RMS norm (last rank only), and the LM head (last rank only). Used by downstream pipeline-parallel inference (e.g. d-inference's EncryptedPipelineInference.swift / PipelineInference.swift) where two Macs split a Llama-family model across Thunderbolt 5 and exchange activation tensors between ranks. KV cache indices are relative to the rank's layerRange, so each rank maintains a cache slice sized to its own layer count rather than the full model. No change to the existing forward path; callPartial is additive.
anupsv
added a commit
to Layr-Labs/d-inference
that referenced
this pull request
May 21, 2026
Picks up Layr-Labs/mlx-swift-lm#24 which adds the callPartial methods on LlamaModelInner and LlamaModel that EncryptedPipelineInference and PipelineInference need for two-rank pipeline-parallel inference. Without this bump, CI's `swift build` fails with: error: value of type 'LlamaModel' has no member 'callPartial' Related: #193 (upstream mlx-swift distributed deviation tracker).
fb999f1 to
c2fbbdc
Compare
anupsv
added a commit
to Layr-Labs/d-inference
that referenced
this pull request
May 21, 2026
Force-pushes on Layr-Labs/mlx-swift#2 and Layr-Labs/mlx-swift-lm#24 landed new SHAs (fa6a4e8, c2fbbdc) — bump the submodule pointers to match.
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
callPartialentry points onLlamaModelInnerandLlamaModelso a single rank can run a contiguous sub-range of transformer blocks, with control over whether the rank applies token embedding, the final RMS norm, and the LM head.Used by downstream pipeline-parallel inference (e.g. d-inference
EncryptedPipelineInference.swift/PipelineInference.swift) where two Macs split a Llama-family model across Thunderbolt 5 and exchange activation tensors between ranks.Shape of the change
LlamaModelInner.callPartial(_:layerRange:applyEmbedding:applyNorm:cache:)Libraries/MLXLLM/Models/Llama.swiftlayerRange.lowerBound ..< layerRange.upperBound. Rank 0 setsapplyEmbedding=true; last rank setsapplyNorm=true. KV cache indices are relative tolayerRange.LlamaModel.callPartial(_:layerRange:applyEmbedding:applyNorm:applyHead:cache:)Libraries/MLXLLM/Models/Llama.swiftLlamaModelInner.callPartialand optionally projects to vocab logits via the LM head (last rank only).Net: 1 file, +52 lines, additive. No existing forward path changes.
Test plan
swift buildsucceeds on macOS 14+ (Apple Silicon)provider-swifttests)Why upstream
Pinning this on our fork unblocks d-inference CI (the d-inference encrypted-pipeline-inference stack imports
callPartialdirectly). Once landed we'll bump the d-inference submodule pointer. Related: Layr-Labs/mlx-swift#2 (the matching change on the underlying mlx-swift fork to exposeCmlx+ enable jaccl).