diff --git a/CMakeLists.txt b/CMakeLists.txt index d018a982..9f50a973 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,6 +64,12 @@ if(NOT MLX_BUILD_METAL) ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXArray+Metal.swift) endif() +if(MLX_BUILD_METAL OR MLX_BUILD_CUDA) + list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXFast+CPU.swift) +else() + list(REMOVE_ITEM MLX-src ${CMAKE_CURRENT_LIST_DIR}/Source/MLX/MLXFast+GPU.swift) +endif() + add_library(MLX STATIC ${MLX-src}) target_include_directories(MLX PUBLIC ${CMAKE_CURRENT_LIST_DIR}/Source/Cmlx/include) diff --git a/Package.swift b/Package.swift index 99cf6fba..3968f2fd 100644 --- a/Package.swift +++ b/Package.swift @@ -68,7 +68,7 @@ import PackageDescription let mlxSwiftExcludes: [String] = [ "GPU+Metal.swift", "MLXArray+Metal.swift", - "MLXFast.swift", + "MLXFast+GPU.swift", "MLXFastKernel.swift", ] #else @@ -101,7 +101,9 @@ import PackageDescription .linkedFramework("Accelerate"), ] - let mlxSwiftExcludes: [String] = [] + let mlxSwiftExcludes: [String] = [ + "MLXFast+CPU.swift" + ] #endif let cmlx = Target.target( diff --git a/Source/MLX/MLXFast+CPU.swift b/Source/MLX/MLXFast+CPU.swift new file mode 100644 index 00000000..c64e3091 --- /dev/null +++ b/Source/MLX/MLXFast+CPU.swift @@ -0,0 +1,230 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx + +extension MLXFast { + + /// Core RoPE implementation using pure MLX operations. + /// Matches the C++ fallback in fast.cpp (lines 417-501). + private static func _ropeImpl( + _ x: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float, + scale: Float, + offset: MLXArray, + freqs: MLXArray? + ) -> MLXArray { + let shape = x.shape + var x = x + + // Reshape to 4D [B, N, T, D] + if x.ndim == 3 { + x = x.expandedDimensions(axis: 1) + } else if x.ndim > 4 { + x = x.flattened(start: 1, end: 1 + (x.ndim - 4)) + } + + let B = x.dim(0) + let N = x.dim(1) + let T = x.dim(2) + let t = x.dtype + let halfDims = dimensions / 2 + + // Expand batch offsets [B] -> [B, 1, 1] for broadcasting + var off = offset + if off.size > 1 { + off = off.expandedDimensions(axes: [-1, -2]) + } + + // positions = (arange(T) + offset) * scale + let positions = (arange(T, dtype: .float32) + off) * MLXArray(scale) + + // Compute inverse frequencies + let invFreqs: MLXArray + if let freqs { + invFreqs = reciprocal(freqs) + } else { + // inv_freqs = exp(arange(0, -halfDims, -1) * log(base) / halfDims) + // = [base^0, base^(-1/halfDims), base^(-2/halfDims), ...] + let logBasePerHalfDim = log(MLXArray(base)) / MLXArray(Float(halfDims)) + invFreqs = exp( + arange(0.0, Double(-halfDims), step: -1.0, dtype: .float32) * logBasePerHalfDim + ) + } + + // theta: [T, halfDims] or [B, 1, T, halfDims] + let theta = positions.expandedDimensions(axis: -1) * invFreqs + let coss = cos(theta).asType(t) + let sins = sin(theta).asType(t) + + if traditional { + // Traditional: rotate consecutive pairs (even/odd interleaved) + let x1 = x[.ellipsis, .stride(from: 0, to: dimensions, by: 2)] + let x2 = x[.ellipsis, .stride(from: 1, to: dimensions, by: 2)] + let out1 = (x1 * coss - x2 * sins).expandedDimensions(axis: -1) + let out2 = (x1 * sins + x2 * coss).expandedDimensions(axis: -1) + // Interleave back: [.., halfDims, 2] -> reshape [.., dims] + var out = concatenated([out1, out2], axis: -1).reshaped(B, N, T, dimensions) + if dimensions < x.dim(-1) { + out = concatenated([out, x[.ellipsis, dimensions...]], axis: -1) + } + return out.reshaped(shape) + } else { + // Modern: split at halfDims boundary (more efficient) + let x1 = x[.ellipsis, .. MLXArray { + _ropeImpl( + x, dimensions: dimensions, traditional: traditional, + base: base ?? 10000.0, scale: scale, + offset: MLXArray(Int32(offset)), freqs: freqs) + } + + public static func RoPE( + _ x: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float?, + scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + _ropeImpl( + x, dimensions: dimensions, traditional: traditional, + base: base ?? 10000.0, scale: scale, + offset: offset, freqs: freqs) + } + + // Fallback rmsNorm implementation + public static func rmsNorm( + _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default + ) -> MLXArray { + // RMS norm: weight * x * rsqrt(mean(x^2) + eps) + let meanSquare = mean(x * x, axis: -1, keepDims: true) + return weight * x * rsqrt(meanSquare + eps) + } + + // Fallback layerNorm implementation + public static func layerNorm( + _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, + stream: StreamOrDevice = .default + ) -> MLXArray { + let mean = MLX.mean(x, axis: -1, keepDims: true) + let variance = MLX.variance(x, axis: -1, keepDims: true) + var normalized = (x - mean) * rsqrt(variance + eps) + if let weight { + normalized = normalized * weight + } + if let bias { + normalized = normalized + bias + } + return normalized + } + + // Fallback scaledDotProductAttention implementation + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: MLXArray?, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + Self.scaledDotProductAttention( + queries: queries, keys: keys, values: values, scale: scale, + mask: mask.map { .array($0) } ?? .none, + sinks: sinks, memoryEfficientThreshold: memoryEfficientThreshold, stream: stream + ) + } + + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: ScaledDotProductAttentionMaskMode, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + // Handle GQA (Grouped Query Attention) where nHeads > nKVHeads + let nHeads = queries.dim(1) + let nKVHeads = keys.dim(1) + + var expandedKeys = keys + var expandedValues = values + + if nHeads != nKVHeads { + // Repeat KV heads to match query heads + // e.g., if nHeads=32, nKVHeads=8, each KV head is repeated 4 times + let repeats = nHeads / nKVHeads + let B = keys.dim(0) + let L = keys.dim(2) + let D = keys.dim(3) + + // Expand and repeat: [B, nKVHeads, L, D] -> [B, nHeads, L, D] + // Use repeated() free function which is the public API for tiling along an axis + expandedKeys = repeated( + keys.reshaped(B, nKVHeads, 1, L, D), + count: repeats, + axis: 2 + ).reshaped(B, nHeads, L, D) + expandedValues = repeated( + values.reshaped(B, nKVHeads, 1, L, D), + count: repeats, + axis: 2 + ).reshaped(B, nHeads, L, D) + } + + var scores = (queries * scale).matmul(expandedKeys.transposed(0, 1, 3, 2)) + + switch mask { + case .none: + break + case .causal: + let L = queries.dim(2) + let S = keys.dim(2) + let indices_q = MLXArray(0 ..< L) + let indices_k = MLXArray(0 ..< S) + let causalMask = + indices_q.expandedDimensions(axis: 1) .>= (indices_k - MLXArray(S - L)) + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(causalMask, scores, maskValues) + case .array(let maskArray): + if maskArray.dtype == .bool { + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(maskArray, scores, maskValues) + } else { + scores = scores + maskArray + } + case .arrays(let maskArrays): + if let maskArray = maskArrays.first { + if maskArray.dtype == .bool { + let maskValues = MLXArray(Float(-1e9)) + scores = MLX.where(maskArray, scores, maskValues) + } else { + scores = scores + maskArray + } + } + } + + scores = softmax(scores.asType(.float32), axis: -1).asType(scores.dtype) + return matmul(scores, expandedValues) + } +} diff --git a/Source/MLX/MLXFast+GPU.swift b/Source/MLX/MLXFast+GPU.swift new file mode 100644 index 00000000..436f8632 --- /dev/null +++ b/Source/MLX/MLXFast+GPU.swift @@ -0,0 +1,231 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx + +extension MLXFast { + /// Optimized implementation of `NN.RoPE`. + /// + /// Used like this: + /// + /// ```swift + /// let x: MLXArray + /// let dimensions: Int + /// let traditional: Bool + /// let base: Float + /// let scale: Float + /// let offset: Int + /// + /// let shape = x.shape + /// var x = x.reshaped(-1, x.dim(-2), x.dim(-1)) + /// x = MLXFast.RoPE(x, dimensions: dimensions, traditional: traditional, base: base, scale: scale, offset: offset) + /// return x.reshaped(shape) + /// ``` + /// + /// > Note: `MLXNN.RoPE` uses this implementation internally. + public static func RoPE( + _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, + offset: Int, + freqs: MLXArray? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + mlx_fast_rope( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) + } + + /// Optimized implementation of `NN.RoPE` with array offset for batched inference. + /// + /// This overload accepts an array offset, allowing different position offsets for each + /// sequence in a batch. The offset can be a scalar array or a vector with length + /// matching the batch size. + /// + /// - Parameters: + /// - array: input array + /// - dimensions: The feature dimensions to be rotated. If the input feature is larger + /// than dims then the rest is left unchanged. + /// - traditional: If `true` choose the traditional implementation which is slightly less efficient. + /// - base: The base used to compute angular frequency for each dimension in the positional encodings. + /// - scale: The scale used to scale the positions. + /// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element. + /// - freqs: Optional frequencies to use with RoPE. + /// - stream: stream or device to evaluate on + /// - Returns: The input with rotary positional encoding applied. + public static func RoPE( + _ array: MLXArray, + dimensions: Int, + traditional: Bool, + base: Float?, + scale: Float, + offset: MLXArray, + freqs: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) + let offset = offset + mlx_fast_rope_dynamic( + &result, + array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx, + (freqs ?? .mlxNone).ctx, stream.ctx) + return MLXArray(result) + } + + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` + /// + /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + /// + /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + /// + /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). + /// + /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + /// + /// Specifically this implements: + /// + /// ```swift + /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + /// if let mask { + /// scores = scores + mask + /// } + /// + /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + /// + /// return matmul(scores, values).transposed(0, 2, 1, 3) + /// ``` + /// + /// In the following the dimensions are given by: + /// + /// * `B`: The batch size. + /// * `N_q`: The number of query heads. + /// * `N_kv`: The number of key and value heads. + /// * `T_q`: The number of queries per example. + /// * `T_kv`: The number of keys and values per example. + /// * `D`: The per-head dimension. + /// + /// - Parameters: + /// - queries: queries with shape `[B, N_q, T_q, D]` + /// - keys: keys with shape `[B, N_kv, T_kv, D]` + /// - values: values with shape `[B, N_kv, T_kv, D]` + /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` + /// - mask: mask array + /// - sinks: optional array of attention sinks + /// - memoryEfficientThreshold: unused + /// - stream: stream to evaluate on + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, + sinks: MLXArray? = nil, + memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + + mlx_fast_scaled_dot_product_attention( + &result, + queries.ctx, keys.ctx, values.ctx, scale, + "", mask?.ctx ?? MLXArray.mlxNone.ctx, + (sinks ?? .mlxNone).ctx, + stream.ctx) + return MLXArray(result) + } + + /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` + /// + /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). + /// + /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. + /// + /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). + /// + /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. + /// + /// Specifically this implements: + /// + /// ```swift + /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) + /// if let mask { + /// scores = scores + mask + /// } + /// + /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) + /// + /// return matmul(scores, values).transposed(0, 2, 1, 3) + /// ``` + /// + /// In the following the dimensions are given by: + /// + /// * `B`: The batch size. + /// * `N_q`: The number of query heads. + /// * `N_kv`: The number of key and value heads. + /// * `T_q`: The number of queries per example. + /// * `T_kv`: The number of keys and values per example. + /// * `D`: The per-head dimension. + /// + /// - Parameters: + /// - queries: queries with shape `[B, N_q, T_q, D]` + /// - keys: keys with shape `[B, N_kv, T_kv, D]` + /// - values: values with shape `[B, N_kv, T_kv, D]` + /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` + /// - mask: a ``ScaledDotProductAttentionMaskMode`` + /// - sinks: optional array of attention sinks + /// - stream: stream to evaluate on + public static func scaledDotProductAttention( + queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, + mask: ScaledDotProductAttentionMaskMode, + sinks: MLXArray? = nil, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + + mlx_fast_scaled_dot_product_attention( + &result, + queries.ctx, keys.ctx, values.ctx, scale, + mask.mode, mask.mask?.ctx ?? MLXArray.mlxNone.ctx, + (sinks ?? .mlxNone).ctx, + stream.ctx) + return MLXArray(result) + } + + /// Root Mean Square normalization (RMS norm). + /// + /// The normalization is with respect to the last axis of the input `x`. + /// + /// - Parameters: + /// - x: input array + /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional + /// with the same size as the last axis of `x`. + /// - eps: A small additive constant for numerical stability + /// - stream: stream or device to evaluate on + public static func rmsNorm( + _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default + ) + -> MLXArray + { + var result = mlx_array_new() + mlx_fast_rms_norm(&result, x.ctx, weight.ctx, eps, stream.ctx) + return MLXArray(result) + } + + /// Layer normalization. + /// + /// The normalization is with respect to the last axis of the input `x`. + /// + /// - Parameters: + /// - x: input array + /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional + /// with the same size as the last axis of `x`. If not given no scaling will occur. + /// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional + /// with the same size as the last axis of `x`. It not given no offset will occur. + /// - eps: A small additive constant for numerical stability + /// - stream: stream or device to evaluate on + public static func layerNorm( + _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, + stream: StreamOrDevice = .default + ) -> MLXArray { + var result = mlx_array_new() + mlx_fast_layer_norm( + &result, x.ctx, (weight ?? .mlxNone).ctx, (bias ?? .mlxNone).ctx, eps, stream.ctx) + return MLXArray(result) + } +} diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 92c96da8..d56d93d6 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -3,134 +3,6 @@ import Cmlx public enum MLXFast { - - /// Optimized implementation of `NN.RoPE`. - /// - /// Used like this: - /// - /// ```swift - /// let x: MLXArray - /// let dimensions: Int - /// let traditional: Bool - /// let base: Float - /// let scale: Float - /// let offset: Int - /// - /// let shape = x.shape - /// var x = x.reshaped(-1, x.dim(-2), x.dim(-1)) - /// x = MLXFast.RoPE(x, dimensions: dimensions, traditional: traditional, base: base, scale: scale, offset: offset) - /// return x.reshaped(shape) - /// ``` - /// - /// > Note: `MLXNN.RoPE` uses this implementation internally. - public static func RoPE( - _ array: MLXArray, dimensions: Int, traditional: Bool, base: Float?, scale: Float, - offset: Int, - freqs: MLXArray? = nil, stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) - mlx_fast_rope( - &result, - array.ctx, Int32(dimensions), traditional, base, scale, Int32(offset), - (freqs ?? .mlxNone).ctx, stream.ctx) - return MLXArray(result) - } - - /// Optimized implementation of `NN.RoPE` with array offset for batched inference. - /// - /// This overload accepts an array offset, allowing different position offsets for each - /// sequence in a batch. The offset can be a scalar array or a vector with length - /// matching the batch size. - /// - /// - Parameters: - /// - array: input array - /// - dimensions: The feature dimensions to be rotated. If the input feature is larger - /// than dims then the rest is left unchanged. - /// - traditional: If `true` choose the traditional implementation which is slightly less efficient. - /// - base: The base used to compute angular frequency for each dimension in the positional encodings. - /// - scale: The scale used to scale the positions. - /// - offset: The position offset as an array. Can be a scalar or a vector of offsets for each batch element. - /// - freqs: Optional frequencies to use with RoPE. - /// - stream: stream or device to evaluate on - /// - Returns: The input with rotary positional encoding applied. - public static func RoPE( - _ array: MLXArray, - dimensions: Int, - traditional: Bool, - base: Float?, - scale: Float, - offset: MLXArray, - freqs: MLXArray? = nil, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - let base = mlx_optional_float(value: base ?? 0, has_value: base != nil) - let offset = offset - mlx_fast_rope_dynamic( - &result, - array.ctx, Int32(dimensions), traditional, base, scale, offset.ctx, - (freqs ?? .mlxNone).ctx, stream.ctx) - return MLXArray(result) - } - - /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` - /// - /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). - /// - /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. - /// - /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). - /// - /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. - /// - /// Specifically this implements: - /// - /// ```swift - /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - /// if let mask { - /// scores = scores + mask - /// } - /// - /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - /// - /// return matmul(scores, values).transposed(0, 2, 1, 3) - /// ``` - /// - /// In the following the dimensions are given by: - /// - /// * `B`: The batch size. - /// * `N_q`: The number of query heads. - /// * `N_kv`: The number of key and value heads. - /// * `T_q`: The number of queries per example. - /// * `T_kv`: The number of keys and values per example. - /// * `D`: The per-head dimension. - /// - /// - Parameters: - /// - queries: queries with shape `[B, N_q, T_q, D]` - /// - keys: keys with shape `[B, N_kv, T_kv, D]` - /// - values: values with shape `[B, N_kv, T_kv, D]` - /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` - /// - mask: mask array - /// - sinks: optional array of attention sinks - /// - memoryEfficientThreshold: unused - /// - stream: stream to evaluate on - public static func scaledDotProductAttention( - queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, - sinks: MLXArray? = nil, - memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - - mlx_fast_scaled_dot_product_attention( - &result, - queries.ctx, keys.ctx, values.ctx, scale, - "", mask?.ctx ?? MLXArray.mlxNone.ctx, - (sinks ?? .mlxNone).ctx, - stream.ctx) - return MLXArray(result) - } - public enum ScaledDotProductAttentionMaskMode { case none case array(MLXArray) @@ -159,106 +31,6 @@ public enum MLXFast { } } } - - /// A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V` - /// - /// Supports [Multi-Head Attention](https://arxiv.org/abs/1706.03762), [Grouped Query Attention](https://arxiv.org/abs/2305.13245), and [Multi-Query Attention](https://arxiv.org/abs/1911.02150). - /// - /// This function will dispatch to an optimized Metal kernel when the query sequence length is 1. It handles other cases with regular MLX operations. - /// - /// > Note: The softmax operation is performed in float32 precision regardless of input precision (float16 or float32). - /// - /// > Note: For Grouped Query Attention and Multi-Query Attention, the input arrays for `key` and `value` should not be pre-tiled to match the `query` array. - /// - /// Specifically this implements: - /// - /// ```swift - /// var scores = (queries * self.scale).matmul(keys.transposed(0, 1, 3, 2)) - /// if let mask { - /// scores = scores + mask - /// } - /// - /// scores = softMax(scores.asType(.float32), axis: -1).asType(scores.dtype) - /// - /// return matmul(scores, values).transposed(0, 2, 1, 3) - /// ``` - /// - /// In the following the dimensions are given by: - /// - /// * `B`: The batch size. - /// * `N_q`: The number of query heads. - /// * `N_kv`: The number of key and value heads. - /// * `T_q`: The number of queries per example. - /// * `T_kv`: The number of keys and values per example. - /// * `D`: The per-head dimension. - /// - /// - Parameters: - /// - queries: queries with shape `[B, N_q, T_q, D]` - /// - keys: keys with shape `[B, N_kv, T_kv, D]` - /// - values: values with shape `[B, N_kv, T_kv, D]` - /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` - /// - mask: a ``ScaledDotProductAttentionMaskMode`` - /// - sinks: optional array of attention sinks - /// - stream: stream to evaluate on - public static func scaledDotProductAttention( - queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, - mask: ScaledDotProductAttentionMaskMode, - sinks: MLXArray? = nil, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - - mlx_fast_scaled_dot_product_attention( - &result, - queries.ctx, keys.ctx, values.ctx, scale, - mask.mode, mask.mask?.ctx ?? MLXArray.mlxNone.ctx, - (sinks ?? .mlxNone).ctx, - stream.ctx) - return MLXArray(result) - } - - /// Root Mean Square normalization (RMS norm). - /// - /// The normalization is with respect to the last axis of the input `x`. - /// - /// - Parameters: - /// - x: input array - /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional - /// with the same size as the last axis of `x`. - /// - eps: A small additive constant for numerical stability - /// - stream: stream or device to evaluate on - public static func rmsNorm( - _ x: MLXArray, weight: MLXArray, eps: Float, stream: StreamOrDevice = .default - ) - -> MLXArray - { - var result = mlx_array_new() - mlx_fast_rms_norm(&result, x.ctx, weight.ctx, eps, stream.ctx) - return MLXArray(result) - } - - /// Layer normalization. - /// - /// The normalization is with respect to the last axis of the input `x`. - /// - /// - Parameters: - /// - x: input array - /// - weight: A multiplicative weight to scale the result by. The `weight` should be one-dimensional - /// with the same size as the last axis of `x`. If not given no scaling will occur. - /// - bias: An additive offset to be added to the result. The `bias` should be one-dimensional - /// with the same size as the last axis of `x`. It not given no offset will occur. - /// - eps: A small additive constant for numerical stability - /// - stream: stream or device to evaluate on - public static func layerNorm( - _ x: MLXArray, weight: MLXArray? = nil, bias: MLXArray? = nil, eps: Float, - stream: StreamOrDevice = .default - ) -> MLXArray { - var result = mlx_array_new() - mlx_fast_layer_norm( - &result, x.ctx, (weight ?? .mlxNone).ctx, (bias ?? .mlxNone).ctx, eps, stream.ctx) - return MLXArray(result) - } - } /// Optimized implementation of `NN.RoPE`. diff --git a/Source/MLXFast/MLXFastKernel.swift b/Source/MLXFast/MLXFastKernel.swift index cb95bf10..09070074 100644 --- a/Source/MLXFast/MLXFastKernel.swift +++ b/Source/MLXFast/MLXFastKernel.swift @@ -1,57 +1,59 @@ -// Copyright © 2024 Apple Inc. +#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(visionOS) + // Copyright © 2024 Apple Inc. -import Cmlx -import MLX + import Cmlx + import MLX -/// Container for a kernel created by -/// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` -/// -/// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs: -/// -/// ```swift -/// let a = normal([2, 2]) -/// let kernel = MLXFast.metalKernel( -/// name: "basic", -/// inputNames: ["a"], -/// outputNames: ["out1"], -/// source: """ -/// uint elem = thread_position_in_grid.x; -/// out1[elem] = a[elem]; -/// """, -/// grid: (4, 1, 1), -/// threadGroup: (2, 1, 1), -/// outputShapes: [[2, 2]], -/// outputDTypes: [.float32]) -/// -/// let out = kernel([a]) -/// ``` -@available(*, deprecated, renamed: "MLXFast.MLXFastKernel") -public typealias MLXFastKernel = MLXFast.MLXFastKernel + /// Container for a kernel created by + /// ``metalKernel(name:inputNames:outputNames:source:header:ensureRowContiguous:atomicOutputs:template:grid:threadGroup:outputShapes:outputDTypes:initValue:verbose:)`` + /// + /// The ``MLXFast/MLXFastKernel`` can be used to evaluate the kernel with inputs: + /// + /// ```swift + /// let a = normal([2, 2]) + /// let kernel = MLXFast.metalKernel( + /// name: "basic", + /// inputNames: ["a"], + /// outputNames: ["out1"], + /// source: """ + /// uint elem = thread_position_in_grid.x; + /// out1[elem] = a[elem]; + /// """, + /// grid: (4, 1, 1), + /// threadGroup: (2, 1, 1), + /// outputShapes: [[2, 2]], + /// outputDTypes: [.float32]) + /// + /// let out = kernel([a]) + /// ``` + @available(*, deprecated, renamed: "MLXFast.MLXFastKernel") + public typealias MLXFastKernel = MLXFast.MLXFastKernel -/// A jit-compiled custom Metal kernel defined from a source string. -/// -/// - Parameters: -/// - name: name for the kernel -/// - inputNames: parameter names of the inputs in the function signature -/// - outputNames: parameter names of the outputs in the function signature -/// - source: source code -- this is the body of a function in Metal, -/// the function signature will be automatically generated. -/// - header: header source code to include before the main function. Useful -/// for helper functions or includes that should live outside of the main function body. -/// - ensureRowContiguous: whether to ensure the inputs are row contiguous -/// before the kernel runs (at a performance cost) -/// - atomicOutputs: whether to use atomic outputs in the function signature, -/// e.g. `device atomic` -/// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it -public func metalKernel( - name: String, inputNames: [String], outputNames: [String], - source: String, header: String = "", - ensureRowContiguous: Bool = true, - atomicOutputs: Bool = false -) -> MLXFast.MLXFastKernel { - return MLX.MLXFast.metalKernel( - name: name, inputNames: inputNames, outputNames: outputNames, - source: source, header: header, - ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs - ) -} + /// A jit-compiled custom Metal kernel defined from a source string. + /// + /// - Parameters: + /// - name: name for the kernel + /// - inputNames: parameter names of the inputs in the function signature + /// - outputNames: parameter names of the outputs in the function signature + /// - source: source code -- this is the body of a function in Metal, + /// the function signature will be automatically generated. + /// - header: header source code to include before the main function. Useful + /// for helper functions or includes that should live outside of the main function body. + /// - ensureRowContiguous: whether to ensure the inputs are row contiguous + /// before the kernel runs (at a performance cost) + /// - atomicOutputs: whether to use atomic outputs in the function signature, + /// e.g. `device atomic` + /// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it + public func metalKernel( + name: String, inputNames: [String], outputNames: [String], + source: String, header: String = "", + ensureRowContiguous: Bool = true, + atomicOutputs: Bool = false + ) -> MLXFast.MLXFastKernel { + return MLX.MLXFast.metalKernel( + name: name, inputNames: inputNames, outputNames: outputNames, + source: source, header: header, + ensureRowContiguous: ensureRowContiguous, atomicOutputs: atomicOutputs + ) + } +#endif diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index ae0c0491..76b4d019 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -1395,10 +1395,12 @@ public enum ModuleValue { get { // note: this gives a warning but it does in fact do something // in the case where this is e.g. ParameterInfo - if let value = value as? T { + if let value { return value } else { - return value! + preconditionFailure( + "`value` should have been set in init -- this is a bug in the property wrapper implementation" + ) } } set {