diff --git a/Libraries/MLXLLM/Models/GatedDelta.swift b/Libraries/MLXLMCommon/GatedDelta.swift similarity index 99% rename from Libraries/MLXLLM/Models/GatedDelta.swift rename to Libraries/MLXLMCommon/GatedDelta.swift index cf34c2107..0b5e346eb 100644 --- a/Libraries/MLXLLM/Models/GatedDelta.swift +++ b/Libraries/MLXLMCommon/GatedDelta.swift @@ -270,7 +270,7 @@ func gatedDeltaOps( // MARK: - Public API -func gatedDeltaUpdate( +public func gatedDeltaUpdate( q: MLXArray, k: MLXArray, v: MLXArray, diff --git a/Libraries/MLXVLM/Models/Qwen35.swift b/Libraries/MLXVLM/Models/Qwen35.swift index efa6f5248..184c78943 100644 --- a/Libraries/MLXVLM/Models/Qwen35.swift +++ b/Libraries/MLXVLM/Models/Qwen35.swift @@ -21,135 +21,6 @@ private let precomputedPositionIdsKey = LMOutput.Key( private let ropeDeltasKey = LMOutput.Key( "qwen35.ropeDeltas") -// MARK: - Gated Delta Helpers - -private func computeGatedDeltaG(_ aLog: MLXArray, _ a: MLXArray, _ dtBias: MLXArray) - -> MLXArray -{ - let decay = exp(-exp(aLog.asType(.float32)) * softplus(a + dtBias)) - return decay.asType(a.dtype) -} - -private func gatedDeltaStepOps( - q: MLXArray, - k: MLXArray, - v: MLXArray, - g: MLXArray, - beta: MLXArray, - state: MLXArray, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let oldState = state - let decay: MLXArray - if g.ndim == 2 { - decay = expandedDimensions(g, axes: [2, 3]) - } else if g.ndim == 3 { - decay = expandedDimensions(g, axis: -2) - } else { - fatalError("Unsupported gating shape \(g.shape)") - } - - var state = state * decay - let kvMem = (state * expandedDimensions(k, axis: -2)).sum(axis: -1) - let delta = (v - kvMem) * expandedDimensions(beta, axis: -1) - state = state + expandedDimensions(k, axis: -2) * expandedDimensions(delta, axis: -1) - let y = (state * expandedDimensions(q, axis: -2)).sum(axis: -1) - - if let mask { - let expandedMask: MLXArray - if mask.ndim == 1 { - expandedMask = expandedDimensions(mask, axes: [1, 2, 3]) - } else if mask.ndim == 2 { - expandedMask = expandedDimensions(mask, axes: [2, 3]) - } else if mask.ndim == 3 { - expandedMask = expandedDimensions(mask, axis: -1) - } else { - fatalError("Unsupported mask shape \(mask.shape)") - } - state = MLX.where(expandedMask, state, oldState) - } - - return (y, state) -} - -private func gatedDeltaOps( - q: MLXArray, - k: MLXArray, - v: MLXArray, - g: MLXArray, - beta: MLXArray, - state: MLXArray? = nil, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let B = q.dim(0) - let T = q.dim(1) - let Hk = q.dim(2) - let Dk = q.dim(3) - let Hv = v.dim(2) - let Dv = v.dim(3) - - var q = q - var k = k - - let repeatFactor = Hv / Hk - if repeatFactor > 1 { - q = repeated(q, count: repeatFactor, axis: -2) - k = repeated(k, count: repeatFactor, axis: -2) - } - - var state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype) - - var ys = [MLXArray]() - ys.reserveCapacity(T) - - for t in 0 ..< T { - let qT = q[0..., t] - let kT = k[0..., t] - let vT = v[0..., t] - let gT = g[0..., t] - let betaT = beta[0..., t] - let maskT = mask == nil ? nil : mask![0..., t] - - let (y, newState) = gatedDeltaStepOps( - q: qT, - k: kT, - v: vT, - g: gT, - beta: betaT, - state: state, - mask: maskT - ) - ys.append(y) - state = newState - } - - let y = MLX.stacked(ys, axis: 1) - return (y, state) -} - -private func gatedDeltaUpdate( - q: MLXArray, - k: MLXArray, - v: MLXArray, - a: MLXArray, - b: MLXArray, - aLog: MLXArray, - dtBias: MLXArray, - state: MLXArray? = nil, - mask: MLXArray? = nil -) -> (MLXArray, MLXArray) { - let beta = sigmoid(b) - let g = computeGatedDeltaG(aLog, a, dtBias) - - let B = q.dim(0) - let Dk = q.dim(3) - let Hv = v.dim(2) - let Dv = v.dim(3) - - let state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype) - return gatedDeltaOps(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask) -} - // MARK: - Configuration public struct Qwen35Configuration: Codable, Sendable { diff --git a/Tests/MLXLMTests/GatedDeltaTests.swift b/Tests/MLXLMTests/GatedDeltaTests.swift index 6457bcd23..203b06a42 100644 --- a/Tests/MLXLMTests/GatedDeltaTests.swift +++ b/Tests/MLXLMTests/GatedDeltaTests.swift @@ -2,10 +2,9 @@ import Foundation import MLX +import MLXLMCommon import XCTest -@testable import MLXLLM - public class GatedDeltaTests: XCTestCase { private struct Inputs {