diff --git a/Source/Examples/Tutorial.swift b/Source/Examples/Tutorial.swift index be00f83e..7cdbfa5b 100644 --- a/Source/Examples/Tutorial.swift +++ b/Source/Examples/Tutorial.swift @@ -53,7 +53,7 @@ struct Tutorial { print(x[1]) // make an array of shape [2, 2] filled with ones - let y = MLXArray.ones([2, 2]) + let y = MLXArray.ones([2, 2], dtype: .float32) // pointwise add x and y let z = x + y diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index 91354051..ccac35f4 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -23,12 +23,29 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` static public func zeros( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.zeros(shape, type: type, stream: stream) } + /// Construct an array of zeros, defaulting to `Float` (float32). + /// + /// > Deprecated: pass the dtype explicitly via ``zeros(_:dtype:stream:)`` + /// or ``zeros(_:type:stream:)``. Allowing the dtype to default silently + /// promotes computations to float32, which can mask precision issues in + /// half-precision (bfloat16 / float16) pipelines. See + /// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). + @available( + *, deprecated, + message: "Pass dtype explicitly: zeros(shape, dtype: .float32) or zeros(shape, type: Float.self). See ml-explore/mlx-swift#390." + ) + static public func zeros( + _ shape: some Collection, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.zeros(shape, dtype: .float32, stream: stream) + } + /// Construct an array of zeros with a given ``DType`` /// /// Example: @@ -91,12 +108,27 @@ extension MLXArray { /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` static public func ones( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.ones(shape, type: type, stream: stream) } + /// Construct an array of ones, defaulting to `Float` (float32). + /// + /// > Deprecated: pass the dtype explicitly via ``ones(_:dtype:stream:)`` + /// or ``ones(_:type:stream:)``. See + /// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). + @available( + *, deprecated, + message: "Pass dtype explicitly: ones(shape, dtype: .float32) or ones(shape, type: Float.self). See ml-explore/mlx-swift#390." + ) + static public func ones( + _ shape: some Collection, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.ones(shape, dtype: .float32, stream: stream) + } + /// Construct an array of ones with a given ``DType`` /// /// Example: @@ -147,7 +179,7 @@ extension MLXArray { /// /// ```swift /// // create [10, 10] array with 1's on the diagonal. - /// let r = MLXArray.eye(10) + /// let r = MLXArray.eye(10, type: Int.self) /// ``` /// /// - Parameters: @@ -161,12 +193,27 @@ extension MLXArray { /// - /// - ``identity(_:type:stream:)`` static public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.eye(n, m: m, k: k, type: type, stream: stream) } + /// Create an identity matrix or a general diagonal matrix, defaulting to `Float` (float32). + /// + /// > Deprecated: pass the dtype explicitly via ``eye(_:m:k:dtype:stream:)`` + /// or ``eye(_:m:k:type:stream:)``. See + /// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). + @available( + *, deprecated, + message: "Pass dtype explicitly: eye(n, m:, k:, dtype: .float32) or eye(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390." + ) + static public func eye( + _ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.eye(n, m: m, k: k, dtype: .float32, stream: stream) + } + /// Create an identity matrix or a general diagonal matrix given a ``DType``. /// /// Example: @@ -215,7 +262,7 @@ extension MLXArray { /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` static public func full( - _ shape: some Collection, values: MLXArray, type: (some HasDType).Type = Float.self, + _ shape: some Collection, values: MLXArray, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.full(shape, values: values, type: type, stream: stream) @@ -285,7 +332,7 @@ extension MLXArray { /// /// ```swift /// // create [10, 10] array with 1's on the diagonal. - /// let r = MLXArray.identity(10) + /// let r = MLXArray.identity(10, type: Int.self) /// ``` /// /// - Parameters: @@ -297,11 +344,24 @@ extension MLXArray { /// - /// - ``eye(_:m:k:type:stream:)`` static public func identity( - _ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default + _ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.identity(n, type: type, stream: stream) } + /// Create a square identity matrix, defaulting to `Float` (float32). + /// + /// > Deprecated: pass the dtype explicitly via ``identity(_:dtype:stream:)`` + /// or ``identity(_:type:stream:)``. See + /// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). + @available( + *, deprecated, + message: "Pass dtype explicitly: identity(n, dtype: .float32) or identity(n, type: Float.self). See ml-explore/mlx-swift#390." + ) + static public func identity(_ n: Int, stream: StreamOrDevice = .default) -> MLXArray { + MLX.identity(n, dtype: .float32, stream: stream) + } + /// Create a square identity matrix with a given ``DType``. /// /// Example: @@ -599,7 +659,7 @@ extension MLXArray { /// /// ```swift /// // [5, 5] array with the lower triangle filled with 1s - /// let r = MLXArray.triangle(5) + /// let r = MLXArray.tri(5, type: Int.self) /// ``` /// /// - Parameters: @@ -612,12 +672,27 @@ extension MLXArray { /// ### See Also /// - static public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { MLX.tri(n, m: m, k: k, type: type, stream: stream) } + /// An array with ones at and below the given diagonal and zeros elsewhere, defaulting to `Float` (float32). + /// + /// > Deprecated: pass the dtype explicitly via ``tri(_:m:k:dtype:stream:)`` + /// or ``tri(_:m:k:type:stream:)``. See + /// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). + @available( + *, deprecated, + message: "Pass dtype explicitly: tri(n, m:, k:, dtype: .float32) or tri(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390." + ) + static public func tri( + _ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default + ) -> MLXArray { + MLX.tri(n, m: m, k: k, dtype: .float32, stream: stream) + } + /// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``. /// /// Example: @@ -649,7 +724,7 @@ extension MLXArray { /// Example: /// /// ```swift -/// let z = MLXArray.zeros([5, 10], type: Int.self) +/// let z = zeros([5, 10], type: Int.self) /// ``` /// /// - Parameters: @@ -662,7 +737,7 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` public func zeros( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -670,6 +745,21 @@ public func zeros( return MLXArray(result) } +/// Construct an array of zeros, defaulting to `Float` (float32). +/// +/// > Deprecated: pass the dtype explicitly via ``zeros(_:dtype:stream:)`` +/// or ``zeros(_:type:stream:)``. See +/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). +@available( + *, deprecated, + message: "Pass dtype explicitly: zeros(shape, dtype: .float32) or zeros(shape, type: Float.self). See ml-explore/mlx-swift#390." +) +public func zeros( + _ shape: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { + zeros(shape, dtype: .float32, stream: stream) +} + /// Construct an array of zeros with a given ``DType`` /// /// Example: @@ -723,7 +813,7 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML /// Example: /// /// ```swift -/// let r = MLXArray.ones([5, 10], type: Int.self) +/// let r = ones([5, 10], type: Int.self) /// ``` /// /// - Parameters: @@ -736,7 +826,7 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` public func ones( - _ shape: some Collection, type: (some HasDType).Type = Float.self, + _ shape: some Collection, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -744,6 +834,21 @@ public func ones( return MLXArray(result) } +/// Construct an array of ones, defaulting to `Float` (float32). +/// +/// > Deprecated: pass the dtype explicitly via ``ones(_:dtype:stream:)`` +/// or ``ones(_:type:stream:)``. See +/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). +@available( + *, deprecated, + message: "Pass dtype explicitly: ones(shape, dtype: .float32) or ones(shape, type: Float.self). See ml-explore/mlx-swift#390." +) +public func ones( + _ shape: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { + ones(shape, dtype: .float32, stream: stream) +} + /// Construct an array of ones with a given ``DType`` /// /// Example: @@ -798,7 +903,7 @@ public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLX /// /// ```swift /// // create [10, 10] array with 1's on the diagonal. -/// let r = MLXArray.eye(10) +/// let r = eye(10, type: Int.self) /// ``` /// /// - Parameters: @@ -812,7 +917,7 @@ public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLX /// - /// - ``identity(_:type:stream:)`` public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -820,6 +925,21 @@ public func eye( return MLXArray(result) } +/// Create an identity matrix or a general diagonal matrix, defaulting to `Float` (float32). +/// +/// > Deprecated: pass the dtype explicitly via ``eye(_:m:k:dtype:stream:)`` +/// or ``eye(_:m:k:type:stream:)``. See +/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). +@available( + *, deprecated, + message: "Pass dtype explicitly: eye(n, m:, k:, dtype: .float32) or eye(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390." +) +public func eye( + _ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default +) -> MLXArray { + eye(n, m: m, k: k, dtype: .float32, stream: stream) +} + /// Create an identity matrix or a general diagonal matrix given a ``DType``. /// /// Example: @@ -945,7 +1065,7 @@ public func full( /// /// ```swift /// // create [10, 10] array with 1's on the diagonal. -/// let r = MLXArray.identity(10) +/// let r = identity(10, type: Int.self) /// ``` /// /// - Parameters: @@ -957,13 +1077,26 @@ public func full( /// - /// - ``eye(_:m:k:type:stream:)`` public func identity( - _ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default + _ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_identity(&result, n.int32, type.dtype.cmlxDtype, stream.ctx) return MLXArray(result) } +/// Create a square identity matrix, defaulting to `Float` (float32). +/// +/// > Deprecated: pass the dtype explicitly via ``identity(_:dtype:stream:)`` +/// or ``identity(_:type:stream:)``. See +/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). +@available( + *, deprecated, + message: "Pass dtype explicitly: identity(n, dtype: .float32) or identity(n, type: Float.self). See ml-explore/mlx-swift#390." +) +public func identity(_ n: Int, stream: StreamOrDevice = .default) -> MLXArray { + identity(n, dtype: .float32, stream: stream) +} + /// Create a square identity matrix with a given ``DType``. /// /// Example: @@ -1280,7 +1413,7 @@ public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de /// /// ```swift /// // [5, 5] array with the lower triangle filled with 1s -/// let r = MLXArray.triangle(5) +/// let r = tri(5, type: Int.self) /// ``` /// /// - Parameters: @@ -1293,7 +1426,7 @@ public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de /// ### See Also /// - public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self, + _ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -1301,6 +1434,21 @@ public func tri( return MLXArray(result) } +/// An array with ones at and below the given diagonal and zeros elsewhere, defaulting to `Float` (float32). +/// +/// > Deprecated: pass the dtype explicitly via ``tri(_:m:k:dtype:stream:)`` +/// or ``tri(_:m:k:type:stream:)``. See +/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390). +@available( + *, deprecated, + message: "Pass dtype explicitly: tri(n, m:, k:, dtype: .float32) or tri(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390." +) +public func tri( + _ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default +) -> MLXArray { + tri(n, m: m, k: k, dtype: .float32, stream: stream) +} + /// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``. /// /// Example: diff --git a/Source/MLXNN/Convolution.swift b/Source/MLXNN/Convolution.swift index 824e98e8..c70ea4e2 100644 --- a/Source/MLXNN/Convolution.swift +++ b/Source/MLXNN/Convolution.swift @@ -56,7 +56,7 @@ open class Conv1d: Module, UnaryLayer { kernelSize, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding self.dilation = dilation self.stride = stride @@ -127,7 +127,7 @@ open class Conv2d: Module, UnaryLayer { kernelSize.first, kernelSize.second, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding.values self.dilation = dilation.values self.stride = stride.values @@ -199,7 +199,7 @@ open class Conv3d: Module, UnaryLayer { kernelSize.first, kernelSize.second, kernelSize.third, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding.values self.dilation = dilation.values self.stride = stride.values diff --git a/Source/MLXNN/ConvolutionTransposed.swift b/Source/MLXNN/ConvolutionTransposed.swift index dbba317f..10ebd22f 100644 --- a/Source/MLXNN/ConvolutionTransposed.swift +++ b/Source/MLXNN/ConvolutionTransposed.swift @@ -57,7 +57,7 @@ open class ConvTransposed1d: Module, UnaryLayer { kernelSize, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding self.dilation = dilation self.outputPadding = outputPadding @@ -133,7 +133,7 @@ open class ConvTransposed2d: Module, UnaryLayer { kernelSize.first, kernelSize.second, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding.values self.dilation = dilation.values self.outputPadding = outputPadding.values @@ -210,7 +210,7 @@ open class ConvTransposed3d: Module, UnaryLayer { kernelSize.first, kernelSize.second, kernelSize.third, inputChannels / groups, ]) - self.bias = bias ? MLXArray.zeros([outputChannels]) : nil + self.bias = bias ? MLXArray.zeros([outputChannels], dtype: .float32) : nil self.padding = padding.values self.dilation = dilation.values self.outputPadding = outputPadding.values diff --git a/Source/MLXNN/Normalization.swift b/Source/MLXNN/Normalization.swift index 844daedc..a444feb4 100644 --- a/Source/MLXNN/Normalization.swift +++ b/Source/MLXNN/Normalization.swift @@ -34,8 +34,8 @@ open class InstanceNorm: Module, UnaryLayer { self.eps = eps if affine { - self.weight = MLXArray.ones([dimensions]) - self.bias = MLXArray.zeros([dimensions]) + self.weight = MLXArray.ones([dimensions], dtype: .float32) + self.bias = MLXArray.zeros([dimensions], dtype: .float32) } else { self.weight = nil self.bias = nil @@ -95,8 +95,8 @@ open class LayerNorm: Module, UnaryLayer { self.eps = eps if affine { - self.weight = MLXArray.ones([dimensions]) - self.bias = bias ? MLXArray.zeros([dimensions]) : nil + self.weight = MLXArray.ones([dimensions], dtype: .float32) + self.bias = bias ? MLXArray.zeros([dimensions], dtype: .float32) : nil } else { self.weight = nil self.bias = nil @@ -130,7 +130,7 @@ open class RMSNorm: Module, UnaryLayer { public let eps: Float public init(dimensions: Int, eps: Float = 1e-5) { - self.weight = MLXArray.ones([dimensions]) + self.weight = MLXArray.ones([dimensions], dtype: .float32) self.eps = eps super.init() } @@ -189,8 +189,8 @@ open class GroupNorm: Module, UnaryLayer { self.pytorchCompatible = pytorchCompatible if affine { - self.weight = MLXArray.ones([dimensions]) - self.bias = MLXArray.zeros([dimensions]) + self.weight = MLXArray.ones([dimensions], dtype: .float32) + self.bias = MLXArray.zeros([dimensions], dtype: .float32) } else { self.weight = nil self.bias = nil @@ -294,16 +294,16 @@ open class BatchNorm: Module, UnaryLayer { self.momentum = momentum if affine { - self.weight = MLXArray.ones([featureCount]) - self.bias = MLXArray.zeros([featureCount]) + self.weight = MLXArray.ones([featureCount], dtype: .float32) + self.bias = MLXArray.zeros([featureCount], dtype: .float32) } else { self.weight = nil self.bias = nil } if trackRunningStats { - self._runningMean.wrappedValue = MLXArray.zeros([featureCount]) - self._runningVar.wrappedValue = MLXArray.ones([featureCount]) + self._runningMean.wrappedValue = MLXArray.zeros([featureCount], dtype: .float32) + self._runningVar.wrappedValue = MLXArray.ones([featureCount], dtype: .float32) } super.init() diff --git a/Source/MLXNN/Quantized.swift b/Source/MLXNN/Quantized.swift index 076e91ca..9caf482c 100644 --- a/Source/MLXNN/Quantized.swift +++ b/Source/MLXNN/Quantized.swift @@ -269,7 +269,7 @@ open class QuantizedLinear: Linear, Quantized { let weight = MLXRandom.uniform( low: -scale, high: scale, [outputDimensions, inputDimensions]) - let bias = bias ? MLXArray.zeros([outputDimensions]) : nil + let bias = bias ? MLXArray.zeros([outputDimensions], dtype: .float32) : nil self.init(weight: weight, bias: bias, groupSize: groupSize, bits: bits, mode: mode) }