Share the fused gated-delta kernel between MLXLLM and MLXVLM#257
Share the fused gated-delta kernel between MLXLLM and MLXVLM#257john-rocky wants to merge 1 commit into
Conversation
| // Gated-delta helpers (`gatedDeltaUpdate`, `gatedDeltaOps`, `gatedDeltaKernel`, | ||
| // `computeGatedDeltaG`) are shared with MLXLLM via `MLXLMCommon/GatedDelta.swift`, | ||
| // which selects the fused Metal kernel when available and falls back to the ops | ||
| // path otherwise. Keeping a single implementation also keeps both paths in sync | ||
| // when the upstream Python kernel evolves. |
There was a problem hiding this comment.
We don't need the comment for removed code.
There was a problem hiding this comment.
Done — the duplicate helpers are just deleted now, nothing left in their place.
| import MLXLMCommon | ||
| import XCTest | ||
|
|
||
| final class GatedDeltaBenchTests: XCTestCase { |
There was a problem hiding this comment.
I don't think we need benchmark tests, though if you did want to add it, look in IntegrationTesting -- that is a more appropriate place.
There was a problem hiding this comment.
Dropped the benchmark test file. The microbench and end-to-end numbers stay in the PR description as the supporting evidence.
davidkoski
left a comment
There was a problem hiding this comment.
See comments. This needs a rebase to pull the updated GatedDelta code.
Thanks!
…ore#124) `GatedDelta.swift` defines a fused `gatedDeltaKernel` (custom Metal kernel, single threadgroup over the time loop) plus a `gatedDeltaOps` fallback. `gatedDeltaUpdate` prefers the kernel when MLX Metal is available, which is the only path the LLM-side Qwen 3.5 / Qwen 3 Next models take on Apple Silicon. The VLM-side `MLXVLM/Models/Qwen35.swift` was a copy/paste of the same code with the kernel branch omitted, so every VLM Qwen 3.5 checkpoint with `linear_attn` ran the unfused per-step expression-graph version, generating far more intermediates per token. Move the shared file to `MLXLMCommon` and mark `gatedDeltaUpdate` public; its helpers stay internal since nothing outside the file calls them. Delete the duplicate helpers from `MLXVLM/Models/Qwen35.swift` so the VLM call site goes through the same kernel-preferred dispatch as the LLM side. `Tests/MLXLMTests/GatedDeltaTests.swift` imports `MLXLMCommon` instead of `@testable import MLXLLM`, following the move. No behavior change for MLXLLM (same code in a new module); MLXVLM Qwen 3.5 with `linear_attn` now uses the fused kernel.
90ba4ce to
46f46f8
Compare
|
Thanks for the review! Addressed all three points:
Two things worth flagging since they're new in this diff:
Net is +2 / −132 across 3 files ( |
Summary
Libraries/MLXLLM/Models/GatedDelta.swiftdefines a fusedgatedDeltaKernel(custom Metal kernel, single threadgroup walking the time loop) and agatedDeltaOpsexpression-graph fallback.gatedDeltaUpdateprefers the kernel when MLX Metal is available, which is the only path the LLM-side Qwen 3.5 / Qwen 3 Next models take on Apple Silicon.MLXVLM/Models/Qwen35.swiftwas a copy/paste of the same code with the kernel branch removed:So every VLM Qwen 3.5 checkpoint with
linear_attnwas running the unfused per-step expression-graph version, generating far more intermediates per token. This is a stand-alone VLM-side missed-branch fix; for context on the broader #124 perf discussion, see my profiling comment there showing the 35B-A3B-4bit text-only MoE no longer reproduces a gap on currentmlx-swift 0.31.3(mlx-c 0.31.1) builds — the VLM-side branch removal addressed in this PR is independent of that finding.Change
Libraries/MLXLLM/Models/GatedDelta.swift→Libraries/MLXLMCommon/GatedDelta.swift.computeGatedDeltaG,gatedDeltaKernel,gatedDeltaOps,gatedDeltaUpdatepublicso both LLM and VLM call throughMLXLMCommon.computeGatedDeltaG/gatedDeltaStepOps/gatedDeltaOps/gatedDeltaUpdatehelpers fromMLXVLM/Models/Qwen35.swift. The single VLM call site now resolves to the shared public function with the kernel-preferred dispatch.Net (excluding the bench test): +10 / −132. No behavior change for MLXLLM (same code, new home); MLXVLM Qwen 3.5 with
linear_attnnow hits the fused Metal kernel.Microbench (M4 Max, 128 GB, bfloat16)
Tests/MLXLMTests/GatedDeltaBenchTests.swiftconstructs realistic Qwen 3.5 linear-attn input shapes (B=1, Hk=16, Hv=64, Dk=192, Dv=128) at four representative T values and timesgatedDeltaKernelvsgatedDeltaOpsover 20 iterations after warmup.gatedDeltaOpsmedian (ms)gatedDeltaKernelmedian (ms)Reproduce:
(SPM
swift testis unsupported because the metal kernel needs the .metallib bundle that only xcodebuild produces.)End-to-end (M4 Max, 128 GB,
mlx-community/Qwen3.5-0.8B-MLX-4bit)The 0.8B unified VLM checkpoint has 24 transformer blocks; 18 of them are
linear_attention(gated-delta) per itstext_config.layer_typesand 6 arefull_attention. Withmax-tokens 64,temperature 0,seed 42, prompt"What is the capital of Japan?"(24 tokens), 6 trials per branch with the first dropped as warmup:mlx_lm0.31.3 (kernel path)Decode improvement (1.54× / 2.16×) is smaller than the microbench (1.92× – 59.69×) because the 0.8B model still spends a meaningful fraction of each forward in non-
linear_attnwork (full-attention blocks, MoE / MLP, RMS norm, embed). The decode-per-token cost onlinear_attnblocks specifically is what the kernel fuses; end-to-end gain scales with the share of those blocks. Models with a higherlinear_attnratio (the larger 35B-A3B / 122B-A10B Qwen 3.5 unified VLMs) should see proportionally larger wins.Greedy output is identical between the two Swift builds:
(Python emits the same text minus the bold markers because it doesn't go through Swift's markdown wrapper.)