From 46f46f82ce9f9229c147bfee917086d3bf33466c Mon Sep 17 00:00:00 2001 From: john-rocky Date: Fri, 15 May 2026 04:09:51 +0900 Subject: [PATCH] Share the fused gated-delta kernel between MLXLLM and MLXVLM (#124) `GatedDelta.swift` defines a fused `gatedDeltaKernel` (custom Metal kernel, single threadgroup over the time loop) plus a `gatedDeltaOps` fallback. `gatedDeltaUpdate` prefers the kernel when MLX Metal is available, which is the only path the LLM-side Qwen 3.5 / Qwen 3 Next models take on Apple Silicon. The VLM-side `MLXVLM/Models/Qwen35.swift` was a copy/paste of the same code with the kernel branch omitted, so every VLM Qwen 3.5 checkpoint with `linear_attn` ran the unfused per-step expression-graph version, generating far more intermediates per token. Move the shared file to `MLXLMCommon` and mark `gatedDeltaUpdate` public; its helpers stay internal since nothing outside the file calls them. Delete the duplicate helpers from `MLXVLM/Models/Qwen35.swift` so the VLM call site goes through the same kernel-preferred dispatch as the LLM side. `Tests/MLXLMTests/GatedDeltaTests.swift` imports `MLXLMCommon` instead of `@testable import MLXLLM`, following the move. No behavior change for MLXLLM (same code in a new module); MLXVLM Qwen 3.5 with `linear_attn` now uses the fused kernel. --- .../Models => MLXLMCommon}/GatedDelta.swift | 2 +- Libraries/MLXVLM/Models/Qwen35.swift | 129 ------------------ Tests/MLXLMTests/GatedDeltaTests.swift | 3 +- 3 files changed, 2 insertions(+), 132 deletions(-) rename Libraries/{MLXLLM/Models => MLXLMCommon}/GatedDelta.swift (99%) diff --git a/Libraries/MLXLLM/Models/GatedDelta.swift b/Libraries/MLXLMCommon/GatedDelta.swift similarity index 99% rename from Libraries/MLXLLM/Models/GatedDelta.swift rename to Libraries/MLXLMCommon/GatedDelta.swift index cf34c2107..0b5e346eb 100644 --- a/Libraries/MLXLLM/Models/GatedDelta.swift +++ b/Libraries/MLXLMCommon/GatedDelta.swift @@ -270,7 +270,7 @@ func gatedDeltaOps( // MARK: - Public API -func gatedDeltaUpdate( +public func gatedDeltaUpdate( q: MLXArray, k: MLXArray, v: MLXArray, diff --git a/Libraries/MLXVLM/Models/Qwen35.swift b/Libraries/MLXVLM/Models/Qwen35.swift index d0909da4e..2a83dfbc0 100644 --- a/Libraries/MLXVLM/Models/Qwen35.swift +++ b/Libraries/MLXVLM/Models/Qwen35.swift @@ -16,135 +16,6 @@ private enum Qwen35VLError: Error { case featureTokenMismatch(expected: Int, actual: Int) } -// MARK: - Gated Delta Helpers - -private func computeGatedDeltaG(_ aLog: MLXArray, _ a: MLXArray, _ dtBias: MLXArray) - -> MLXArray -{ - let decay = exp(-exp(aLog.asType(.float32)) * softplus(a + dtBias)) - return decay.asType(a.dtype) -} - -private func gatedDeltaStepOps( - q: MLXArray, - k: MLXArray, - v: MLXArray, - g: MLXArray, - beta: MLXArray, - state: MLXArray, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let oldState = state - let decay: MLXArray - if g.ndim == 2 { - decay = expandedDimensions(g, axes: [2, 3]) - } else if g.ndim == 3 { - decay = expandedDimensions(g, axis: -2) - } else { - fatalError("Unsupported gating shape \(g.shape)") - } - - var state = state * decay - let kvMem = (state * expandedDimensions(k, axis: -2)).sum(axis: -1) - let delta = (v - kvMem) * expandedDimensions(beta, axis: -1) - state = state + expandedDimensions(k, axis: -2) * expandedDimensions(delta, axis: -1) - let y = (state * expandedDimensions(q, axis: -2)).sum(axis: -1) - - if let mask { - let expandedMask: MLXArray - if mask.ndim == 1 { - expandedMask = expandedDimensions(mask, axes: [1, 2, 3]) - } else if mask.ndim == 2 { - expandedMask = expandedDimensions(mask, axes: [2, 3]) - } else if mask.ndim == 3 { - expandedMask = expandedDimensions(mask, axis: -1) - } else { - fatalError("Unsupported mask shape \(mask.shape)") - } - state = MLX.where(expandedMask, state, oldState) - } - - return (y, state) -} - -private func gatedDeltaOps( - q: MLXArray, - k: MLXArray, - v: MLXArray, - g: MLXArray, - beta: MLXArray, - state: MLXArray? = nil, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let B = q.dim(0) - let T = q.dim(1) - let Hk = q.dim(2) - let Dk = q.dim(3) - let Hv = v.dim(2) - let Dv = v.dim(3) - - var q = q - var k = k - - let repeatFactor = Hv / Hk - if repeatFactor > 1 { - q = repeated(q, count: repeatFactor, axis: -2) - k = repeated(k, count: repeatFactor, axis: -2) - } - - var state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype) - - var ys = [MLXArray]() - ys.reserveCapacity(T) - - for t in 0 ..< T { - let qT = q[0..., t] - let kT = k[0..., t] - let vT = v[0..., t] - let gT = g[0..., t] - let betaT = beta[0..., t] - let maskT = mask == nil ? nil : mask![0..., t] - - let (y, newState) = gatedDeltaStepOps( - q: qT, - k: kT, - v: vT, - g: gT, - beta: betaT, - state: state, - mask: maskT - ) - ys.append(y) - state = newState - } - - let y = MLX.stacked(ys, axis: 1) - return (y, state) -} - -private func gatedDeltaUpdate( - q: MLXArray, - k: MLXArray, - v: MLXArray, - a: MLXArray, - b: MLXArray, - aLog: MLXArray, - dtBias: MLXArray, - state: MLXArray? = nil, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let beta = sigmoid(b) - let g = computeGatedDeltaG(aLog, a, dtBias) - - let B = q.dim(0) - let Dk = q.dim(3) - let Hv = v.dim(2) - let Dv = v.dim(3) - - let state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype) - return gatedDeltaOps(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) -} - // MARK: - Configuration public struct Qwen35Configuration: Codable, Sendable { diff --git a/Tests/MLXLMTests/GatedDeltaTests.swift b/Tests/MLXLMTests/GatedDeltaTests.swift index 6457bcd23..203b06a42 100644 --- a/Tests/MLXLMTests/GatedDeltaTests.swift +++ b/Tests/MLXLMTests/GatedDeltaTests.swift @@ -2,10 +2,9 @@ import Foundation import MLX +import MLXLMCommon import XCTest -@testable import MLXLLM - public class GatedDeltaTests: XCTestCase { private struct Inputs {