Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func gatedDeltaOps(

// MARK: - Public API

func gatedDeltaUpdate(
public func gatedDeltaUpdate(
q: MLXArray,
k: MLXArray,
v: MLXArray,
Expand Down
129 changes: 0 additions & 129 deletions Libraries/MLXVLM/Models/Qwen35.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,135 +16,6 @@ private enum Qwen35VLError: Error {
case featureTokenMismatch(expected: Int, actual: Int)
}

// MARK: - Gated Delta Helpers

private func computeGatedDeltaG(_ aLog: MLXArray, _ a: MLXArray, _ dtBias: MLXArray)
-> MLXArray
{
let decay = exp(-exp(aLog.asType(.float32)) * softplus(a + dtBias))
return decay.asType(a.dtype)
}

private func gatedDeltaStepOps(
q: MLXArray,
k: MLXArray,
v: MLXArray,
g: MLXArray,
beta: MLXArray,
state: MLXArray,
mask: MLXArray? = nil
) -> (MLXArray, MLXArray) {
let oldState = state
let decay: MLXArray
if g.ndim == 2 {
decay = expandedDimensions(g, axes: [2, 3])
} else if g.ndim == 3 {
decay = expandedDimensions(g, axis: -2)
} else {
fatalError("Unsupported gating shape \(g.shape)")
}

var state = state * decay
let kvMem = (state * expandedDimensions(k, axis: -2)).sum(axis: -1)
let delta = (v - kvMem) * expandedDimensions(beta, axis: -1)
state = state + expandedDimensions(k, axis: -2) * expandedDimensions(delta, axis: -1)
let y = (state * expandedDimensions(q, axis: -2)).sum(axis: -1)

if let mask {
let expandedMask: MLXArray
if mask.ndim == 1 {
expandedMask = expandedDimensions(mask, axes: [1, 2, 3])
} else if mask.ndim == 2 {
expandedMask = expandedDimensions(mask, axes: [2, 3])
} else if mask.ndim == 3 {
expandedMask = expandedDimensions(mask, axis: -1)
} else {
fatalError("Unsupported mask shape \(mask.shape)")
}
state = MLX.where(expandedMask, state, oldState)
}

return (y, state)
}

private func gatedDeltaOps(
q: MLXArray,
k: MLXArray,
v: MLXArray,
g: MLXArray,
beta: MLXArray,
state: MLXArray? = nil,
mask: MLXArray? = nil
) -> (MLXArray, MLXArray) {
let B = q.dim(0)
let T = q.dim(1)
let Hk = q.dim(2)
let Dk = q.dim(3)
let Hv = v.dim(2)
let Dv = v.dim(3)

var q = q
var k = k

let repeatFactor = Hv / Hk
if repeatFactor > 1 {
q = repeated(q, count: repeatFactor, axis: -2)
k = repeated(k, count: repeatFactor, axis: -2)
}

var state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype)

var ys = [MLXArray]()
ys.reserveCapacity(T)

for t in 0 ..< T {
let qT = q[0..., t]
let kT = k[0..., t]
let vT = v[0..., t]
let gT = g[0..., t]
let betaT = beta[0..., t]
let maskT = mask == nil ? nil : mask![0..., t]

let (y, newState) = gatedDeltaStepOps(
q: qT,
k: kT,
v: vT,
g: gT,
beta: betaT,
state: state,
mask: maskT
)
ys.append(y)
state = newState
}

let y = MLX.stacked(ys, axis: 1)
return (y, state)
}

private func gatedDeltaUpdate(
q: MLXArray,
k: MLXArray,
v: MLXArray,
a: MLXArray,
b: MLXArray,
aLog: MLXArray,
dtBias: MLXArray,
state: MLXArray? = nil,
mask: MLXArray? = nil
) -> (MLXArray, MLXArray) {
let beta = sigmoid(b)
let g = computeGatedDeltaG(aLog, a, dtBias)

let B = q.dim(0)
let Dk = q.dim(3)
let Hv = v.dim(2)
let Dv = v.dim(3)

let state = state ?? MLXArray.zeros([B, Hv, Dv, Dk], dtype: q.dtype)
return gatedDeltaOps(q: q, k: k, v: v, g: g, beta: beta, state: state, mask: mask)
}

// MARK: - Configuration

public struct Qwen35Configuration: Codable, Sendable {
Expand Down
3 changes: 1 addition & 2 deletions Tests/MLXLMTests/GatedDeltaTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import Foundation
import MLX
import MLXLMCommon
import XCTest

@testable import MLXLLM

public class GatedDeltaTests: XCTestCase {

private struct Inputs {
Expand Down