Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Source/Examples/Tutorial.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct Tutorial {
print(x[1])

// make an array of shape [2, 2] filled with ones
let y = MLXArray.ones([2, 2])
let y = MLXArray.ones([2, 2], dtype: .float32)

// pointwise add x and y
let z = x + y
Expand Down
186 changes: 167 additions & 19 deletions Source/MLX/Factory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,29 @@ extension MLXArray {
/// - ``zeros(like:stream:)``
/// - ``ones(_:type:stream:)``
static public func zeros(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.zeros(shape, type: type, stream: stream)
}

/// Construct an array of zeros, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``zeros(_:dtype:stream:)``
/// or ``zeros(_:type:stream:)``. Allowing the dtype to default silently
/// promotes computations to float32, which can mask precision issues in
/// half-precision (bfloat16 / float16) pipelines. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: zeros(shape, dtype: .float32) or zeros(shape, type: Float.self). See ml-explore/mlx-swift#390."
)
static public func zeros(
_ shape: some Collection<Int>, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.zeros(shape, dtype: .float32, stream: stream)
}

/// Construct an array of zeros with a given ``DType``
///
/// Example:
Expand Down Expand Up @@ -91,12 +108,27 @@ extension MLXArray {
/// - ``ones(like:stream:)``
/// - ``zeros(_:type:stream:)``
static public func ones(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.ones(shape, type: type, stream: stream)
}

/// Construct an array of ones, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``ones(_:dtype:stream:)``
/// or ``ones(_:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: ones(shape, dtype: .float32) or ones(shape, type: Float.self). See ml-explore/mlx-swift#390."
)
static public func ones(
_ shape: some Collection<Int>, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.ones(shape, dtype: .float32, stream: stream)
}

/// Construct an array of ones with a given ``DType``
///
/// Example:
Expand Down Expand Up @@ -147,7 +179,7 @@ extension MLXArray {
///
/// ```swift
/// // create [10, 10] array with 1's on the diagonal.
/// let r = MLXArray.eye(10)
/// let r = MLXArray.eye(10, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -161,12 +193,27 @@ extension MLXArray {
/// - <doc:initialization>
/// - ``identity(_:type:stream:)``
static public func eye(
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self,
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.eye(n, m: m, k: k, type: type, stream: stream)
}

/// Create an identity matrix or a general diagonal matrix, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``eye(_:m:k:dtype:stream:)``
/// or ``eye(_:m:k:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: eye(n, m:, k:, dtype: .float32) or eye(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390."
)
static public func eye(
_ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.eye(n, m: m, k: k, dtype: .float32, stream: stream)
}

/// Create an identity matrix or a general diagonal matrix given a ``DType``.
///
/// Example:
Expand Down Expand Up @@ -215,7 +262,7 @@ extension MLXArray {
/// - ``full(_:values:stream:)``
/// - ``repeated(_:count:axis:stream:)``
static public func full(
_ shape: some Collection<Int>, values: MLXArray, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, values: MLXArray, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.full(shape, values: values, type: type, stream: stream)
Expand Down Expand Up @@ -285,7 +332,7 @@ extension MLXArray {
///
/// ```swift
/// // create [10, 10] array with 1's on the diagonal.
/// let r = MLXArray.identity(10)
/// let r = MLXArray.identity(10, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -297,11 +344,24 @@ extension MLXArray {
/// - <doc:initialization>
/// - ``eye(_:m:k:type:stream:)``
static public func identity(
_ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default
_ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.identity(n, type: type, stream: stream)
}

/// Create a square identity matrix, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``identity(_:dtype:stream:)``
/// or ``identity(_:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: identity(n, dtype: .float32) or identity(n, type: Float.self). See ml-explore/mlx-swift#390."
)
static public func identity(_ n: Int, stream: StreamOrDevice = .default) -> MLXArray {
MLX.identity(n, dtype: .float32, stream: stream)
}

/// Create a square identity matrix with a given ``DType``.
///
/// Example:
Expand Down Expand Up @@ -599,7 +659,7 @@ extension MLXArray {
///
/// ```swift
/// // [5, 5] array with the lower triangle filled with 1s
/// let r = MLXArray.triangle(5)
/// let r = MLXArray.tri(5, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -612,12 +672,27 @@ extension MLXArray {
/// ### See Also
/// - <doc:initialization>
static public func tri(
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self,
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
MLX.tri(n, m: m, k: k, type: type, stream: stream)
}

/// An array with ones at and below the given diagonal and zeros elsewhere, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``tri(_:m:k:dtype:stream:)``
/// or ``tri(_:m:k:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: tri(n, m:, k:, dtype: .float32) or tri(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390."
)
static public func tri(
_ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default
) -> MLXArray {
MLX.tri(n, m: m, k: k, dtype: .float32, stream: stream)
}

/// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``.
///
/// Example:
Expand Down Expand Up @@ -649,7 +724,7 @@ extension MLXArray {
/// Example:
///
/// ```swift
/// let z = MLXArray.zeros([5, 10], type: Int.self)
/// let z = zeros([5, 10], type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -662,14 +737,29 @@ extension MLXArray {
/// - ``zeros(like:stream:)``
/// - ``ones(_:type:stream:)``
public func zeros(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_zeros(&result, shape.map { Int32($0) }, shape.count, type.dtype.cmlxDtype, stream.ctx)
return MLXArray(result)
}

/// Construct an array of zeros, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``zeros(_:dtype:stream:)``
/// or ``zeros(_:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: zeros(shape, dtype: .float32) or zeros(shape, type: Float.self). See ml-explore/mlx-swift#390."
)
public func zeros(
_ shape: some Collection<Int>, stream: StreamOrDevice = .default
) -> MLXArray {
zeros(shape, dtype: .float32, stream: stream)
}

/// Construct an array of zeros with a given ``DType``
///
/// Example:
Expand Down Expand Up @@ -723,7 +813,7 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML
/// Example:
///
/// ```swift
/// let r = MLXArray.ones([5, 10], type: Int.self)
/// let r = ones([5, 10], type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -736,14 +826,29 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML
/// - ``ones(like:stream:)``
/// - ``zeros(_:type:stream:)``
public func ones(
_ shape: some Collection<Int>, type: (some HasDType).Type = Float.self,
_ shape: some Collection<Int>, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_ones(&result, shape.map { Int32($0) }, shape.count, type.dtype.cmlxDtype, stream.ctx)
return MLXArray(result)
}

/// Construct an array of ones, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``ones(_:dtype:stream:)``
/// or ``ones(_:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: ones(shape, dtype: .float32) or ones(shape, type: Float.self). See ml-explore/mlx-swift#390."
)
public func ones(
_ shape: some Collection<Int>, stream: StreamOrDevice = .default
) -> MLXArray {
ones(shape, dtype: .float32, stream: stream)
}

/// Construct an array of ones with a given ``DType``
///
/// Example:
Expand Down Expand Up @@ -798,7 +903,7 @@ public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLX
///
/// ```swift
/// // create [10, 10] array with 1's on the diagonal.
/// let r = MLXArray.eye(10)
/// let r = eye(10, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -812,14 +917,29 @@ public func ones(like array: MLXArray, stream: StreamOrDevice = .default) -> MLX
/// - <doc:initialization>
/// - ``identity(_:type:stream:)``
public func eye(
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self,
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_eye(&result, n.int32, (m ?? n).int32, k.int32, type.dtype.cmlxDtype, stream.ctx)
return MLXArray(result)
}

/// Create an identity matrix or a general diagonal matrix, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``eye(_:m:k:dtype:stream:)``
/// or ``eye(_:m:k:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: eye(n, m:, k:, dtype: .float32) or eye(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390."
)
public func eye(
_ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default
) -> MLXArray {
eye(n, m: m, k: k, dtype: .float32, stream: stream)
}

/// Create an identity matrix or a general diagonal matrix given a ``DType``.
///
/// Example:
Expand Down Expand Up @@ -945,7 +1065,7 @@ public func full(
///
/// ```swift
/// // create [10, 10] array with 1's on the diagonal.
/// let r = MLXArray.identity(10)
/// let r = identity(10, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -957,13 +1077,26 @@ public func full(
/// - <doc:initialization>
/// - ``eye(_:m:k:type:stream:)``
public func identity(
_ n: Int, type: (some HasDType).Type = Float.self, stream: StreamOrDevice = .default
_ n: Int, type: (some HasDType).Type, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_identity(&result, n.int32, type.dtype.cmlxDtype, stream.ctx)
return MLXArray(result)
}

/// Create a square identity matrix, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``identity(_:dtype:stream:)``
/// or ``identity(_:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: identity(n, dtype: .float32) or identity(n, type: Float.self). See ml-explore/mlx-swift#390."
)
public func identity(_ n: Int, stream: StreamOrDevice = .default) -> MLXArray {
identity(n, dtype: .float32, stream: stream)
}

/// Create a square identity matrix with a given ``DType``.
///
/// Example:
Expand Down Expand Up @@ -1280,7 +1413,7 @@ public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de
///
/// ```swift
/// // [5, 5] array with the lower triangle filled with 1s
/// let r = MLXArray.triangle(5)
/// let r = tri(5, type: Int.self)
/// ```
///
/// - Parameters:
Expand All @@ -1293,14 +1426,29 @@ public func repeated(_ array: MLXArray, count: Int, stream: StreamOrDevice = .de
/// ### See Also
/// - <doc:initialization>
public func tri(
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type = Float.self,
_ n: Int, m: Int? = nil, k: Int = 0, type: (some HasDType).Type,
stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
mlx_tri(&result, n.int32, (m ?? n).int32, k.int32, type.dtype.cmlxDtype, stream.ctx)
return MLXArray(result)
}

/// An array with ones at and below the given diagonal and zeros elsewhere, defaulting to `Float` (float32).
///
/// > Deprecated: pass the dtype explicitly via ``tri(_:m:k:dtype:stream:)``
/// or ``tri(_:m:k:type:stream:)``. See
/// [ml-explore/mlx-swift#390](https://github.com/ml-explore/mlx-swift/issues/390).
@available(
*, deprecated,
message: "Pass dtype explicitly: tri(n, m:, k:, dtype: .float32) or tri(n, m:, k:, type: Float.self). See ml-explore/mlx-swift#390."
)
public func tri(
_ n: Int, m: Int? = nil, k: Int = 0, stream: StreamOrDevice = .default
) -> MLXArray {
tri(n, m: m, k: k, dtype: .float32, stream: stream)
}

/// An array with ones at and below the given diagonal and zeros elsewhere and a given ``DType``.
///
/// Example:
Expand Down
Loading