diff --git a/Source/MLX/Distributed.swift b/Source/MLX/Distributed.swift new file mode 100644 index 00000000..1423c67e --- /dev/null +++ b/Source/MLX/Distributed.swift @@ -0,0 +1,410 @@ +// Copyright © 2024 Apple Inc. + +import Cmlx +import Foundation + +/// Error type for synchronous distributed API failures. +/// +/// Distributed collectives and layers are often lazy. These errors only +/// describe failures that can be detected at call time; execution-time backend +/// failures may still surface later when the returned value is evaluated. +public enum DistributedError: LocalizedError, Sendable, Equatable { + case initializationFailed(backend: DistributedBackend) + case initializationError(backend: DistributedBackend, message: String) + case runtime(String) + case invalidConfiguration(String) + case unsupportedModuleType(String) + + public var errorDescription: String? { + switch self { + case .initializationFailed(let backend): + "Failed to initialize a distributed group for backend '\(backend.rawValue)'." + case .initializationError(let backend, let message): + "Failed to initialize distributed backend '\(backend.rawValue)': \(message)" + case .runtime(let message): + "Distributed runtime error: \(message)" + case .invalidConfiguration(let message): + "Invalid distributed configuration: \(message)" + case .unsupportedModuleType(let typeName): + "Unsupported distributed module type: \(typeName)" + } + } +} + +private func withDistributedRuntimeError(_ body: () throws -> R) throws -> R { + do { + return try withError(body) + } catch let MLXError.caught(message) { + throw DistributedError.runtime(message) + } +} + +private func withDistributedInitializationError( + backend: DistributedBackend, _ body: () throws -> R +) throws -> R { + do { + return try withError(body) + } catch let MLXError.caught(message) { + if backend == .any, message.contains("Couldn't initialize any backend") { + throw DistributedError.initializationFailed(backend: backend) + } + throw DistributedError.initializationError(backend: backend, message: message) + } +} + +private func requireDistributedGroup( + _ group: mlx_distributed_group, operation: String +) throws -> DistributedGroup { + guard group.ctx != nil else { + throw DistributedError.runtime("\(operation) returned an empty distributed group.") + } + return DistributedGroup(group) +} + +private func requireDistributedArray(_ array: mlx_array, operation: String) throws -> MLXArray { + guard array.ctx != nil else { + throw DistributedError.runtime("\(operation) returned an empty MLXArray.") + } + return MLXArray(array) +} + +/// The distributed communication backend to use. +/// +/// When ``DistributedBackend/any`` is specified, MLX chooses the best available +/// backend automatically. Use a specific case to force a particular backend. +public enum DistributedBackend: String, CaseIterable, Sendable { + /// Let MLX choose the best available backend automatically. + case any + /// TCP socket-based ring backend. + case ring + /// Joint Accelerator Communication Library (Thunderbolt 5 RDMA). + case jaccl + /// Message Passing Interface backend. + case mpi + /// NVIDIA Collective Communications Library backend. + case nccl + + /// Whether this backend can be initialized on the current runtime. + public var isAvailable: Bool { + rawValue.withCString { mlx_distributed_is_available($0) } + } +} + +/// Wrapper around the MLX C distributed group handle. +/// +/// A `DistributedGroup` represents a group of independent MLX processes that +/// can communicate using distributed operations. Create the initial group with +/// ``init()``, ``init(backend:)``, or ``init(strict:)``, then use +/// ``split(color:key:)`` to create sub-groups. +/// +/// `DistributedGroup()` preserves MLX's size-1 fallback behavior: if no real +/// distributed backend can be formed, MLX returns a singleton group (rank 0, +/// size 1). On that singleton group, collective operations such as `allSum`, +/// `allGather`, `allMax`, `allMin`, and `sumScatter` behave as no-ops. +/// +/// `DistributedGroup` is an opaque runtime handle and is intentionally not +/// `Sendable`. +public final class DistributedGroup { + + let ctx: mlx_distributed_group + + init(_ ctx: mlx_distributed_group) { + self.ctx = ctx + } + + private static func initialize(strict: Bool, backend: DistributedBackend) + -> mlx_distributed_group + { + backend.rawValue.withCString { mlx_distributed_init(strict, $0) } + } + + /// Initialize the distributed backend and return the group containing all + /// discoverable processes. + /// + /// When the backend cannot form a real distributed group, this initializer + /// preserves MLX's fallback behavior and returns a singleton group (rank 0, + /// size 1). This is equivalent to calling ``init(backend:)`` with + /// ``DistributedBackend/any``. + /// + public convenience init() { + self.init(backend: .any) + } + + /// Initialize the distributed backend and return the group containing all + /// discoverable processes. + /// + /// Unlike ``init(strict:)``, this initializer preserves MLX's fallback + /// behavior and returns a singleton group (rank 0, size 1) when the chosen + /// backend cannot form a real distributed group. + /// + /// - Parameter backend: the backend to use + public convenience init(backend: DistributedBackend) { + let group = Self.initialize(strict: false, backend: backend) + precondition( + group.ctx != nil, + "MLX unexpectedly failed to create a distributed group for backend '\(backend.rawValue)'." + ) + self.init(group) + } + + /// Initialize the distributed backend and return a real distributed group. + /// + /// Unlike ``init(backend:)``, this initializer does not fall back to a + /// singleton group. It succeeds only when the chosen backend can form a + /// real distributed group at runtime, and throws when strict initialization + /// reports a backend-specific configuration error. + /// + /// - Parameter backend: the backend to use + public convenience init(strict backend: DistributedBackend) throws { + let group = try withDistributedInitializationError(backend: backend) { + Self.initialize(strict: true, backend: backend) + } + guard group.ctx != nil else { + throw DistributedError.initializationFailed(backend: backend) + } + self.init(group) + } + + deinit { + // UPSTREAM GAP: mlx_distributed_group is a value type wrapping a + // heap-allocated C++ Group object (void* ctx). Other MLX-C handle + // types (mlx_device, mlx_stream, mlx_array, etc.) expose a public + // free function (e.g., mlx_device_free), but MLX-C v0.5.0 does NOT + // expose mlx_distributed_group_free(). The private C++ header + // (mlx/c/private/distributed_group.h) has mlx_distributed_group_free_() + // but it is an inline C++ function, inaccessible from Swift/C. + // + // Calling C free() on ctx is NOT safe because the underlying object + // is allocated with C++ new and may have a non-trivial destructor. + // + // Practical impact is minimal: groups are typically singleton-like and + // long-lived (one per distributed init, occasionally split). The C++ + // Group internally holds a shared_ptr to the backend, so the leaked + // memory per group is small. + // + // TODO: File upstream issue to add mlx_distributed_group_free() to + // the public MLX-C API, then call it here like Device.deinit calls + // mlx_device_free(ctx). + } + + /// The rank of this process in the group. + public var rank: Int { + Int(mlx_distributed_group_rank(ctx)) + } + + /// The number of processes in the group. + public var size: Int { + Int(mlx_distributed_group_size(ctx)) + } + + /// Split this group into sub-groups based on the provided color. + /// + /// Processes that use the same color will be placed in the same sub-group. + /// The key defines the rank of the process in the new group; the smaller + /// the key, the smaller the rank. If the key is negative, the rank in the + /// current group is used. + /// + /// This method throws only for failures that are detectable when the split + /// is requested. It does not force later communication on the returned + /// group to evaluate. + /// + /// - Parameters: + /// - color: processes with the same color go to the same sub-group + /// - key: determines rank ordering in the new group (negative = use current rank) + /// - Returns: a new ``DistributedGroup`` for the sub-group + public func split(color: Int, key: Int = -1) throws -> DistributedGroup { + let result = try withDistributedRuntimeError { + mlx_distributed_group_split(ctx, Int32(color), Int32(key)) + } + return try requireDistributedGroup(result, operation: "split(color:key:)") + } + + /// Sum-reduce the array across all processes in the group. + /// + /// Each process contributes its local array and all processes receive + /// the element-wise sum. + /// + /// On a singleton group, this behaves as identity. + /// This method is lazy and non-throwing: backend failures may still + /// surface only when the returned array is evaluated. Use + /// ``withError(_:)-6g4wn`` or ``checkedEval(_:)`` around the operation plus + /// its evaluation boundary if you need a Swift error. + /// + /// - Parameters: + /// - array: the local array to sum + /// - stream: stream or device to evaluate on + /// - Returns: the element-wise sum across all processes + public func allSum(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_distributed_all_sum(&result, array.ctx, ctx, stream.ctx) + return MLXArray(result) + } + + /// Gather arrays from all processes in the group. + /// + /// Each process contributes its local array and all processes receive + /// the concatenated result. + /// + /// On a singleton group, this behaves as identity. + /// This method is lazy and non-throwing: backend failures may still + /// surface only when the returned array is evaluated. Use + /// ``withError(_:)-6g4wn`` or ``checkedEval(_:)`` around the operation plus + /// its evaluation boundary if you need a Swift error. + /// + /// - Parameters: + /// - array: the local array to gather + /// - stream: stream or device to evaluate on + /// - Returns: the concatenation of arrays from all processes + public func allGather(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_distributed_all_gather(&result, array.ctx, ctx, stream.ctx) + return MLXArray(result) + } + + /// Max-reduce the array across all processes in the group. + /// + /// Each process contributes its local array and all processes receive + /// the element-wise maximum. + /// + /// On a singleton group, this behaves as identity. + /// This method is lazy and non-throwing: backend failures may still + /// surface only when the returned array is evaluated. Use + /// ``withError(_:)-6g4wn`` or ``checkedEval(_:)`` around the operation plus + /// its evaluation boundary if you need a Swift error. + /// + /// - Parameters: + /// - array: the local array to max-reduce + /// - stream: stream or device to evaluate on + /// - Returns: the element-wise maximum across all processes + public func allMax(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_distributed_all_max(&result, array.ctx, ctx, stream.ctx) + return MLXArray(result) + } + + /// Min-reduce the array across all processes in the group. + /// + /// Each process contributes its local array and all processes receive + /// the element-wise minimum. + /// + /// On a singleton group, this behaves as identity. + /// This method is lazy and non-throwing: backend failures may still + /// surface only when the returned array is evaluated. Use + /// ``withError(_:)-6g4wn`` or ``checkedEval(_:)`` around the operation plus + /// its evaluation boundary if you need a Swift error. + /// + /// - Parameters: + /// - array: the local array to min-reduce + /// - stream: stream or device to evaluate on + /// - Returns: the element-wise minimum across all processes + public func allMin(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray { + var result = mlx_array_new() + mlx_distributed_all_min(&result, array.ctx, ctx, stream.ctx) + return MLXArray(result) + } + + /// Sum-reduce and scatter the array across all processes in the group. + /// + /// The array is sum-reduced and the result is scattered (split) across + /// processes so each process receives its portion. + /// + /// On a singleton group, this behaves as identity. + /// This method throws only for immediate validation or setup failures such + /// as an invalid input shape. Backend support and execution failures may + /// still surface later when the returned array is evaluated. Wrap the + /// operation plus its evaluation boundary in ``withError(_:)-6g4wn`` or + /// use ``checkedEval(_:)`` when you need a Swift error. + /// + /// - Parameters: + /// - array: the local array to sum-scatter + /// - stream: stream or device to evaluate on + /// - Returns: this process's portion of the sum-scattered result + public func sumScatter(_ array: MLXArray, stream: StreamOrDevice = .default) throws -> MLXArray + { + var result = mlx_array_new() + _ = try withDistributedRuntimeError { + mlx_distributed_sum_scatter(&result, array.ctx, ctx, stream.ctx) + } + return try requireDistributedArray(result, operation: "sumScatter(_:stream:)") + } + + /// Send an array to another process in the group. + /// + /// Returns a dependency token (an ``MLXArray``) that can be used to + /// sequence operations. + /// + /// Requires a group size of at least 2. + /// This method throws only for immediate validation or setup failures such + /// as an invalid destination rank. Transport and backend failures may + /// still surface later when the returned dependency token is evaluated. + /// Wrap the operation plus its evaluation boundary in + /// ``withError(_:)-6g4wn`` or use ``checkedEval(_:)`` when you need a + /// Swift error. + /// + /// - Parameters: + /// - array: the array to send + /// - dst: the destination rank + /// - stream: stream or device to evaluate on + /// - Returns: a dependency token + public func send(_ array: MLXArray, to dst: Int, stream: StreamOrDevice = .default) throws + -> MLXArray + { + var result = mlx_array_new() + _ = try withDistributedRuntimeError { + mlx_distributed_send(&result, array.ctx, Int32(dst), ctx, stream.ctx) + } + return try requireDistributedArray(result, operation: "send(_:to:stream:)") + } + + /// Receive an array from another process in the group. + /// + /// Requires a group size of at least 2. + /// This method throws only for immediate validation or setup failures such + /// as an invalid source rank. Transport and backend failures may still + /// surface later when the returned array is evaluated. Wrap the operation + /// plus its evaluation boundary in ``withError(_:)-6g4wn`` or use + /// ``checkedEval(_:)`` when you need a Swift error. + /// + /// - Parameters: + /// - shape: the shape of the expected array + /// - dtype: the data type of the expected array + /// - src: the source rank + /// - stream: stream or device to evaluate on + /// - Returns: the received array + public func recv( + shape: [Int], dtype: DType, from src: Int, stream: StreamOrDevice = .default + ) throws -> MLXArray { + var result = mlx_array_new() + let cShape = shape.map { Int32($0) } + _ = try withDistributedRuntimeError { + mlx_distributed_recv( + &result, cShape, cShape.count, dtype.cmlxDtype, Int32(src), ctx, stream.ctx) + } + return try requireDistributedArray(result, operation: "recv(shape:dtype:from:stream:)") + } + + /// Receive an array from another process, using a template array for + /// shape and dtype. + /// + /// Requires a group size of at least 2. + /// This method throws only for immediate validation or setup failures. + /// Transport and backend failures may still surface later when the returned + /// array is evaluated. Wrap the operation plus its evaluation boundary in + /// ``withError(_:)-6g4wn`` or use ``checkedEval(_:)`` when you need a + /// Swift error. + /// + /// - Parameters: + /// - array: template array whose shape and dtype define the expected result + /// - src: the source rank + /// - stream: stream or device to evaluate on + /// - Returns: the received array with the same shape and dtype as the template + public func recvLike( + _ array: MLXArray, from src: Int, stream: StreamOrDevice = .default + ) throws -> MLXArray { + var result = mlx_array_new() + _ = try withDistributedRuntimeError { + mlx_distributed_recv_like(&result, array.ctx, Int32(src), ctx, stream.ctx) + } + return try requireDistributedArray(result, operation: "recvLike(_:from:stream:)") + } +} diff --git a/Source/MLXNN/Distributed.swift b/Source/MLXNN/Distributed.swift new file mode 100644 index 00000000..ad9f593b --- /dev/null +++ b/Source/MLXNN/Distributed.swift @@ -0,0 +1,990 @@ +// Copyright © 2024 Apple Inc. + +import Foundation +import MLX + +// MARK: - sumGradients Helper + +/// Each closure uses `CustomFunction` with an identity forward pass and an +/// `allSum` VJP so that gradients are aggregated across the distributed group +/// during backpropagation. +/// Returns a closure that is the identity in the forward pass but performs +/// `allSum` on the cotangents during the backward pass. +/// +/// This helper is internal. Callers that reuse it on a hot path should retain +/// the returned closure themselves. On a singleton group, the returned closure +/// is just identity. +/// +/// - Parameter group: the distributed group to aggregate gradients over +/// - Returns: a closure `(MLXArray) -> MLXArray` that is identity forward, +/// allSum backward +func sumGradients(group: DistributedGroup) -> (MLXArray) -> MLXArray { + if group.size == 1 { + // Optimization: on a size-1 group, just return identity + return { x in x } + } + + // Build a CustomFunction with identity forward and allSum VJP + let cf = CustomFunction { + Forward { inputs in inputs } + VJP { _, cotangents in + cotangents.map { group.allSum($0) } + } + } + + return { x in + cf([x])[0] + } +} + +private func validateShardedDimension( + _ dimension: Int, across groupSize: Int, description: String +) throws { + guard dimension % groupSize == 0 else { + throw DistributedError.invalidConfiguration(description) + } +} + +private func validatePositiveSegments(_ segments: Int) throws { + guard segments > 0 else { + throw DistributedError.invalidConfiguration( + "segments must be positive and non-zero but got \(segments).") + } +} + +private func normalizeShardAxis(path: String, value: MLXArray, axis: Int) throws -> Int { + let normalizedAxis = axis < 0 ? value.ndim + axis : axis + guard normalizedAxis >= 0, normalizedAxis < value.ndim else { + throw DistributedError.invalidConfiguration( + "Cannot shard parameter '\(path)' with axis \(axis) for shape \(value.shape).") + } + return normalizedAxis +} + +private func applyShardedParameters(_ module: Module, parameters: ModuleParameters) throws { + do { + try module.update(parameters: parameters, verify: .none) + } catch { + throw DistributedError.invalidConfiguration( + "Failed to apply sharded parameters: \(error.localizedDescription)") + } +} + +// MARK: - AllToShardedLinear + +/// Each member of the group applies part of the affine transformation such +/// that the result is sharded across the group. +/// +/// The gradients are automatically aggregated from each member of the group +/// via an internal gradient reducer for the distributed group. +/// +/// ### See Also +/// - ``ShardedToAllLinear`` +open class AllToShardedLinear: Module, UnaryLayer { + + public let weight: MLXArray + public let bias: MLXArray? + private let gradientReducer: (MLXArray) -> MLXArray + + /// The distributed group. Stored as a plain property so it is excluded + /// from `parameters()` and `children()`. + public let group: DistributedGroup + + /// Initialize an ``AllToShardedLinear`` layer. + /// + /// Validates that `outputDimensions` is divisible by the group size and + /// throws instead of trapping when the requested sharding is invalid. + /// + /// - Parameters: + /// - inputDimensions: number of input dimensions + /// - outputDimensions: number of output dimensions (must be divisible by group size) + /// - bias: if `true`, apply a bias + /// - group: the distributed group (defaults to `DistributedGroup()`) + public init( + inputDimensions: Int, outputDimensions: Int, bias: Bool = true, + group: DistributedGroup? = nil + ) throws { + let group = group ?? DistributedGroup() + let N = group.size + + try validateShardedDimension( + outputDimensions, across: N, + description: + "Cannot shard the output of size \(outputDimensions) across \(N) devices." + ) + + self.group = group + self.gradientReducer = sumGradients(group: group) + let scale = sqrt(1.0 / Float(inputDimensions)) + self.weight = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions / N, inputDimensions]) + if bias { + self.bias = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions / N]) + } else { + self.bias = nil + } + super.init() + } + + /// Internal initializer for providing weight and bias directly (used by `fromLinear`). + init(weight: MLXArray, bias: MLXArray?, group: DistributedGroup) { + self.weight = weight + self.bias = bias + self.group = group + self.gradientReducer = sumGradients(group: group) + super.init() + } + + open override func describeExtra(_ indent: Int) -> String { + let (outDims, inDims) = weight.shape2 + let N = group.size + return + "(inputDimensions=\(inDims), outputDimensions=\(outDims * N), bias=\(bias != nil))" + } + + /// This forward pass remains lazy and non-throwing. Distributed backend + /// failures may still surface only when the returned array is evaluated. + open func callAsFunction(_ x: MLXArray) -> MLXArray { + // Aggregate the gradients coming from each shard + var x = gradientReducer(x) + + // Compute the affine projection + if let bias { + x = addMM(bias, x, weight.T) + } else { + x = matmul(x, weight.T) + } + return x + } + + /// Create an ``AllToShardedLinear`` from an existing ``Linear`` layer. + /// + /// For a size-1 group, the sharded weights are identical to the original. + /// + /// - Parameters: + /// - linear: the linear layer to convert + /// - segments: number of segments for fused weights (e.g. 3 for QKV). Default is 1. + /// - group: the distributed group + /// - Returns: a new ``AllToShardedLinear`` layer with sharded weights + public class func fromLinear( + _ linear: Linear, segments: Int = 1, group: DistributedGroup? = nil + ) throws -> AllToShardedLinear { + let group = group ?? DistributedGroup() + let (outputDimensions, inputDimensions) = linear.weight.shape2 + + let layer = try AllToShardedLinear( + inputDimensions: inputDimensions, outputDimensions: outputDimensions, + bias: linear.bias != nil, group: group) + + // Shard the parameters from the original linear layer + let shardedParams = try shardParameterTree( + linear.parameters(), predicate: allToShardedPredicate(segments: segments), + group: group) + try applyShardedParameters(layer, parameters: shardedParams) + + return layer + } +} + +// MARK: - ShardedToAllLinear + +/// Each rank applies part of the affine transformation and then aggregates the +/// partial results via ``DistributedGroup/allSum(_:stream:)``. +/// +/// All ranks receive the same result after this layer. +/// +/// ### See Also +/// - ``AllToShardedLinear`` +open class ShardedToAllLinear: Module, UnaryLayer { + + public let weight: MLXArray + public let bias: MLXArray? + + /// The distributed group. Stored as a plain property so it is excluded + /// from `parameters()` and `children()`. + public let group: DistributedGroup + + /// Initialize a ``ShardedToAllLinear`` layer. + /// + /// Validates that `inputDimensions` is divisible by the group size and + /// throws instead of trapping when the requested sharding is invalid. + /// + /// - Parameters: + /// - inputDimensions: number of input dimensions (must be divisible by group size) + /// - outputDimensions: number of output dimensions + /// - bias: if `true`, apply a bias + /// - group: the distributed group (defaults to `DistributedGroup()`) + public init( + inputDimensions: Int, outputDimensions: Int, bias: Bool = true, + group: DistributedGroup? = nil + ) throws { + let group = group ?? DistributedGroup() + let N = group.size + + try validateShardedDimension( + inputDimensions, across: N, + description: + "The input of size \(inputDimensions) cannot be sharded across \(N) devices." + ) + + self.group = group + let scale = sqrt(1.0 / Float(inputDimensions)) + self.weight = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions, inputDimensions / N]) + if bias { + self.bias = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions]) + } else { + self.bias = nil + } + super.init() + } + + /// Internal initializer for providing weight and bias directly (used by `fromLinear`). + init(weight: MLXArray, bias: MLXArray?, group: DistributedGroup) { + self.weight = weight + self.bias = bias + self.group = group + super.init() + } + + open override func describeExtra(_ indent: Int) -> String { + let (outDims, inDims) = weight.shape2 + let N = group.size + return + "(inputDimensions=\(inDims * N), outputDimensions=\(outDims), bias=\(bias != nil))" + } + + /// This forward pass remains lazy and non-throwing. Distributed backend + /// failures may still surface only when the returned array is evaluated. + open func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = matmul(x, weight.T) + + x = group.allSum(x) + + if let bias { + x = x + bias + } + return x + } + + /// Create a ``ShardedToAllLinear`` from an existing ``Linear`` layer. + /// + /// For a size-1 group, the sharded weights are identical to the original. + /// + /// - Parameters: + /// - linear: the linear layer to convert + /// - segments: number of segments for fused weights (e.g. 3 for QKV). Default is 1. + /// - group: the distributed group + /// - Returns: a new ``ShardedToAllLinear`` layer with sharded weights + public class func fromLinear( + _ linear: Linear, segments: Int = 1, group: DistributedGroup? = nil + ) throws -> ShardedToAllLinear { + let group = group ?? DistributedGroup() + let (outputDimensions, inputDimensions) = linear.weight.shape2 + + let layer = try ShardedToAllLinear( + inputDimensions: inputDimensions, outputDimensions: outputDimensions, + bias: linear.bias != nil, group: group) + + // Shard the parameters from the original linear layer + let shardedParams = try shardParameterTree( + linear.parameters(), predicate: shardedToAllPredicate(segments: segments), + group: group) + try applyShardedParameters(layer, parameters: shardedParams) + + return layer + } +} + +// MARK: - QuantizedAllToShardedLinear + +/// Each member of the group applies part of the affine transformation with +/// a quantized matrix such that the result is sharded across the group. +/// +/// It is the quantized equivalent of ``AllToShardedLinear``. +/// Similar to ``QuantizedLinear``, its parameters are frozen and will not be +/// included in any gradient computation. +/// +/// ### See Also +/// - ``AllToShardedLinear`` +/// - ``QuantizedShardedToAllLinear`` +open class QuantizedAllToShardedLinear: Module, UnaryLayer, Quantized { + + public let groupSize: Int + public let bits: Int + public let mode: QuantizationMode + + public let weight: MLXArray + public let scales: MLXArray + public let biases: MLXArray? + public let bias: MLXArray? + private let gradientReducer: (MLXArray) -> MLXArray + + /// The distributed group. Stored as a plain property so it is excluded + /// from `parameters()` and `children()`. + public let group: DistributedGroup + + /// Initialize a ``QuantizedAllToShardedLinear`` layer. + /// + /// Validates that `outputDimensions` is divisible by the group size and + /// throws instead of trapping when the requested sharding is invalid. + /// + /// - Parameters: + /// - inputDimensions: number of input dimensions + /// - outputDimensions: number of output dimensions (must be divisible by group size) + /// - bias: if `true`, apply a bias + /// - groupSize: the group size used for quantization. Default is 64. + /// - bits: the bit width used for quantization. Default is 4. + /// - mode: the quantization mode. Default is `.affine`. + /// - group: the distributed group (defaults to `DistributedGroup()`) + public init( + inputDimensions: Int, outputDimensions: Int, bias: Bool = true, + groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine, + group: DistributedGroup? = nil + ) throws { + let group = group ?? DistributedGroup() + let N = group.size + + try validateShardedDimension( + outputDimensions, across: N, + description: + "Cannot shard the output of size \(outputDimensions) across \(N) devices." + ) + + self.group = group + self.gradientReducer = sumGradients(group: group) + self.groupSize = groupSize + self.bits = bits + self.mode = mode + let scale = sqrt(1.0 / Float(inputDimensions)) + let w = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions / N, inputDimensions]) + let (quantizedWeight, scales, biases) = MLX.quantized( + w, groupSize: groupSize, bits: bits, mode: mode) + self.weight = quantizedWeight + self.scales = scales + self.biases = biases + + if bias { + self.bias = MLXArray.zeros([outputDimensions / N]) + } else { + self.bias = nil + } + super.init() + + self.freeze() + } + + /// Internal initializer for providing arrays directly (used by `fromQuantizedLinear`). + init( + weight: MLXArray, bias: MLXArray?, scales: MLXArray, biases: MLXArray?, + groupSize: Int, bits: Int, mode: QuantizationMode, + group: DistributedGroup + ) { + self.weight = weight + self.bias = bias + self.scales = scales + self.biases = biases + self.groupSize = groupSize + self.bits = bits + self.mode = mode + self.group = group + self.gradientReducer = sumGradients(group: group) + super.init() + + self.freeze() + } + + public override func unfreeze( + recursive: Bool = true, keys: [String]? = nil, strict: Bool = false + ) throws { + try super.unfreeze(recursive: recursive, keys: keys, strict: strict) + self.freeze(recursive: false) + } + + open override func describeExtra(_ indent: Int) -> String { + let (outDims, inDims) = weight.shape2 + let inDimsReal = (inDims * 32) / bits + let outDimsReal = outDims * group.size + return + "(inputDimensions=\(inDimsReal), outputDimensions=\(outDimsReal), bias=\(bias != nil), groupSize=\(groupSize), bits=\(bits))" + } + + /// This forward pass remains lazy and non-throwing. Distributed backend + /// failures may still surface only when the returned array is evaluated. + open func callAsFunction(_ x: MLXArray) -> MLXArray { + // Aggregate the gradients coming from each shard + var x = gradientReducer(x) + + x = quantizedMM( + x, + weight, + scales: scales, + biases: biases, + transpose: true, + groupSize: groupSize, + bits: bits, + mode: mode + ) + if let bias { + x = x + bias + } + return x + } + + /// Create a ``QuantizedAllToShardedLinear`` from an existing ``QuantizedLinear`` layer. + /// + /// For a size-1 group, the sharded weights are identical to the original. + /// + /// - Parameters: + /// - quantizedLinear: the quantized linear layer to convert + /// - segments: number of segments for fused weights (e.g. 3 for QKV). Default is 1. + /// - group: the distributed group + /// - Returns: a new ``QuantizedAllToShardedLinear`` layer with sharded weights + public class func fromQuantizedLinear( + _ quantizedLinear: QuantizedLinear, segments: Int = 1, + group: DistributedGroup? = nil + ) throws -> QuantizedAllToShardedLinear { + let group = group ?? DistributedGroup() + let (outputDimensions, inputDimensions) = quantizedLinear.weight.shape2 + let inputDimsReal = (inputDimensions * 32) / quantizedLinear.bits + + let layer = try QuantizedAllToShardedLinear( + inputDimensions: inputDimsReal, outputDimensions: outputDimensions, + bias: quantizedLinear.bias != nil, + groupSize: quantizedLinear.groupSize, + bits: quantizedLinear.bits, + mode: quantizedLinear.mode, + group: group) + + // Shard the parameters from the original quantized linear layer + let shardedParams = try shardParameterTree( + quantizedLinear.parameters(), predicate: allToShardedPredicate(segments: segments), + group: group) + try applyShardedParameters(layer, parameters: shardedParams) + + return layer + } +} + +// MARK: - QuantizedShardedToAllLinear + +/// Each rank applies part of the affine transformation using the quantized +/// matrix and then aggregates the partial results. +/// +/// All ranks receive the same result after this layer. +/// +/// It is the quantized equivalent of ``ShardedToAllLinear``. +/// Similar to ``QuantizedLinear``, its parameters are frozen and will not be +/// included in any gradient computation. +/// +/// ### See Also +/// - ``ShardedToAllLinear`` +/// - ``QuantizedAllToShardedLinear`` +open class QuantizedShardedToAllLinear: Module, UnaryLayer, Quantized { + + public let groupSize: Int + public let bits: Int + public let mode: QuantizationMode + + public let weight: MLXArray + public let scales: MLXArray + public let biases: MLXArray? + public let bias: MLXArray? + + /// The distributed group. Stored as a plain property so it is excluded + /// from `parameters()` and `children()`. + public let group: DistributedGroup + + /// Initialize a ``QuantizedShardedToAllLinear`` layer. + /// + /// Validates that `inputDimensions` is divisible by the group size and + /// throws instead of trapping when the requested sharding is invalid. + /// + /// - Parameters: + /// - inputDimensions: number of input dimensions (must be divisible by group size) + /// - outputDimensions: number of output dimensions + /// - bias: if `true`, apply a bias + /// - groupSize: the group size used for quantization. Default is 64. + /// - bits: the bit width used for quantization. Default is 4. + /// - mode: the quantization mode. Default is `.affine`. + /// - group: the distributed group (defaults to `DistributedGroup()`) + public init( + inputDimensions: Int, outputDimensions: Int, bias: Bool = true, + groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine, + group: DistributedGroup? = nil + ) throws { + let group = group ?? DistributedGroup() + let N = group.size + + try validateShardedDimension( + inputDimensions, across: N, + description: + "The input of size \(inputDimensions) cannot be sharded across \(N) devices." + ) + + self.group = group + self.groupSize = groupSize + self.bits = bits + self.mode = mode + let scale = sqrt(1.0 / Float(inputDimensions)) + let w = MLXRandom.uniform( + low: -scale, high: scale, [outputDimensions, inputDimensions / N]) + let (quantizedWeight, scales, biases) = MLX.quantized( + w, groupSize: groupSize, bits: bits, mode: mode) + self.weight = quantizedWeight + self.scales = scales + self.biases = biases + + if bias { + self.bias = MLXArray.zeros([outputDimensions]) + } else { + self.bias = nil + } + super.init() + + self.freeze() + } + + /// Internal initializer for providing arrays directly (used by `fromQuantizedLinear`). + init( + weight: MLXArray, bias: MLXArray?, scales: MLXArray, biases: MLXArray?, + groupSize: Int, bits: Int, mode: QuantizationMode, + group: DistributedGroup + ) { + self.weight = weight + self.bias = bias + self.scales = scales + self.biases = biases + self.groupSize = groupSize + self.bits = bits + self.mode = mode + self.group = group + super.init() + + self.freeze() + } + + public override func unfreeze( + recursive: Bool = true, keys: [String]? = nil, strict: Bool = false + ) throws { + try super.unfreeze(recursive: recursive, keys: keys, strict: strict) + self.freeze(recursive: false) + } + + open override func describeExtra(_ indent: Int) -> String { + let (outDims, inDims) = weight.shape2 + let inDimsReal = (inDims * 32) / bits * group.size + return + "(inputDimensions=\(inDimsReal), outputDimensions=\(outDims), bias=\(bias != nil), groupSize=\(groupSize), bits=\(bits))" + } + + /// This forward pass remains lazy and non-throwing. Distributed backend + /// failures may still surface only when the returned array is evaluated. + open func callAsFunction(_ x: MLXArray) -> MLXArray { + var x = quantizedMM( + x, + weight, + scales: scales, + biases: biases, + transpose: true, + groupSize: groupSize, + bits: bits, + mode: mode + ) + + x = group.allSum(x) + + if let bias { + x = x + bias + } + return x + } + + /// Create a ``QuantizedShardedToAllLinear`` from an existing ``QuantizedLinear`` layer. + /// + /// For a size-1 group, the sharded weights are identical to the original. + /// + /// - Parameters: + /// - quantizedLinear: the quantized linear layer to convert + /// - segments: number of segments for fused weights (e.g. 3 for QKV). Default is 1. + /// - group: the distributed group + /// - Returns: a new ``QuantizedShardedToAllLinear`` layer with sharded weights + public class func fromQuantizedLinear( + _ quantizedLinear: QuantizedLinear, segments: Int = 1, + group: DistributedGroup? = nil + ) throws -> QuantizedShardedToAllLinear { + let group = group ?? DistributedGroup() + let (outputDimensions, inputDimensions) = quantizedLinear.weight.shape2 + let inputDimsReal = (inputDimensions * 32) / quantizedLinear.bits + + let layer = try QuantizedShardedToAllLinear( + inputDimensions: inputDimsReal, outputDimensions: outputDimensions, + bias: quantizedLinear.bias != nil, + groupSize: quantizedLinear.groupSize, + bits: quantizedLinear.bits, + mode: quantizedLinear.mode, + group: group) + + // Shard the parameters from the original quantized linear layer + let shardedParams = try shardParameterTree( + quantizedLinear.parameters(), predicate: shardedToAllPredicate(segments: segments), + group: group) + try applyShardedParameters(layer, parameters: shardedParams) + + return layer + } +} + +// MARK: - Internal Sharding Helpers + +/// Sharding predicate result: axis to shard on, and number of segments. +/// Returns `nil` if the parameter should not be sharded. +private typealias ShardInfo = (axis: Int, segments: Int) + +/// Returns a sharding predicate for "all-to-sharded" conversion. +/// +/// For bias: shard along last axis (-1). For weight: shard along axis 0 +/// (max(ndim - 2, 0) in Python, which is axis 0 for 2D weights). +private func allToShardedPredicate(segments: Int) -> (String, MLXArray) -> ShardInfo? { + return { path, weight in + if path.hasSuffix("bias") { + return (axis: -1, segments: segments) + } + // For 2D weight [outDims, inDims], max(ndim - 2, 0) = 0 + return (axis: max(weight.ndim - 2, 0), segments: segments) + } +} + +/// Returns a sharding predicate for "sharded-to-all" conversion. +/// +/// For bias: don't shard (return nil). For weight: shard along last axis (-1). +private func shardedToAllPredicate(segments: Int) -> (String, MLXArray) -> ShardInfo? { + return { path, weight in + if path.hasSuffix("bias") { + return nil + } + return (axis: -1, segments: segments) + } +} + +/// Shard a flat parameter tree according to the given predicate and group. +/// +/// This mirrors the Python `_shard` function using `tree_map_with_path`. +/// For each parameter, the predicate determines the sharding axis and segments. +/// The weight is split into segments, each segment is split across the group, +/// and the rank-local shard is taken and concatenated. +private func shardParameterTree( + _ parameters: ModuleParameters, + predicate: (String, MLXArray) -> ShardInfo?, + group: DistributedGroup +) throws -> ModuleParameters { + let N = group.size + let r = group.rank + + // Flatten to get (path, MLXArray) pairs + let flat = parameters.flattened() + + // Shard each parameter + let sharded = try flat.map { (path, value) -> (String, MLXArray) in + guard let info = predicate(path, value) else { + return (path, value) + } + + try validatePositiveSegments(info.segments) + let axis = try normalizeShardAxis(path: path, value: value, axis: info.axis) + let segments = info.segments + + if segments > 1 { + try validateShardedDimension( + value.shape[axis], across: segments, + description: + "Parameter '\(path)' with shape \(value.shape) cannot be split into \(segments) segments along axis \(axis)." + ) + } + + // Split into segments, then split each segment across group, take rank-th part + let segmentParts: [MLXArray] + if segments > 1 { + segmentParts = value.split(parts: segments, axis: axis) + } else { + segmentParts = [value] + } + + let shardedParts = try segmentParts.map { part -> MLXArray in + try validateShardedDimension( + part.shape[axis], across: N, + description: + "Parameter '\(path)' with shape \(part.shape) cannot be sharded across \(N) devices along axis \(axis)." + ) + let groupParts = part.split(parts: N, axis: axis) + return groupParts[r] + } + + let result: MLXArray + if shardedParts.count > 1 { + result = concatenated(shardedParts, axis: axis).contiguous() + } else { + result = shardedParts[0].contiguous() + } + + return (path, result) + } + + return ModuleParameters.unflattened(sharded) +} + +// MARK: - ShardingType + +/// Describes the type of sharding for distributed linear layers. +/// +/// - ``allToSharded``: Common (replicated) input is projected into a sharded +/// representation. Each rank holds a slice of the output features. +/// - ``shardedToAll``: Sharded input is projected and then aggregated so that +/// every rank obtains the full (common) output. +/// +/// ### See Also +/// - ``shardLinear(module:sharding:segments:group:)`` +/// - ``shardInPlace(module:sharding:segments:group:)`` +public enum ShardingType { + case allToSharded + case shardedToAll +} + +// MARK: - shardLinear + +/// Create a new distributed linear layer from an existing ``Linear`` or +/// ``QuantizedLinear``. +/// +/// The returned layer has its parameters sharded across the group and +/// performs distributed communication in either the forward or backward pass +/// depending on the sharding type. +/// +/// > Note: The `segments` parameter accepts an integer count (e.g. 3 for fused QKV). +/// > Python's upstream `_shard`/`_split` helpers also support list-based and fractional +/// > segment boundaries; these can be added here if upstream use cases require them. +/// +/// - Parameters: +/// - module: the ``Linear`` or ``QuantizedLinear`` layer to shard +/// - sharding: the type of sharding (``ShardingType/allToSharded`` or +/// ``ShardingType/shardedToAll``) +/// - segments: number of segments for fused weights (e.g. 3 for QKV). +/// Default is 1. +/// - group: the distributed group. If `nil`, uses `DistributedGroup()`. +/// - Returns: a new distributed ``Module`` with sharded parameters +/// - Throws: ``DistributedError/invalidConfiguration(_:)`` for invalid +/// segment or divisibility requests, or +/// ``DistributedError/unsupportedModuleType(_:)`` when the module cannot be +/// sharded by this helper. +/// +/// ### See Also +/// - ``shardInPlace(module:sharding:segments:group:)`` +/// - ``AllToShardedLinear`` +/// - ``ShardedToAllLinear`` +public func shardLinear( + module: Module, sharding: ShardingType, segments: Int = 1, + group: DistributedGroup? = nil +) throws -> Module { + // QuantizedLinear must be checked before Linear because QuantizedLinear + // is a subclass of Linear and would otherwise match the Linear case. + switch (sharding, module) { + case (.allToSharded, let quantized as QuantizedLinear): + return try QuantizedAllToShardedLinear.fromQuantizedLinear( + quantized, segments: segments, group: group) + case (.allToSharded, let linear as Linear): + return try AllToShardedLinear.fromLinear(linear, segments: segments, group: group) + case (.shardedToAll, let quantized as QuantizedLinear): + return try QuantizedShardedToAllLinear.fromQuantizedLinear( + quantized, segments: segments, group: group) + case (.shardedToAll, let linear as Linear): + return try ShardedToAllLinear.fromLinear(linear, segments: segments, group: group) + default: + throw DistributedError.unsupportedModuleType(String(describing: type(of: module))) + } +} + +// MARK: - shardInPlace + +/// Shard a module's parameters in-place using ``Module/update(parameters:)``. +/// +/// Unlike ``shardLinear(module:sharding:segments:group:)`` which returns a new +/// distributed layer type, this function modifies the parameters of the +/// existing module without changing its type. The module itself must +/// natively support distributed communication for the collective ops to +/// take effect. +/// +/// - Parameters: +/// - module: the module whose parameters will be sharded in-place +/// - sharding: the type of sharding (``ShardingType/allToSharded`` or +/// ``ShardingType/shardedToAll``), or a custom predicate +/// - segments: number of segments for fused weights (e.g. 3 for QKV). +/// Default is 1. +/// - group: the distributed group. If `nil`, uses `DistributedGroup()`. +/// - Throws: ``DistributedError/invalidConfiguration(_:)`` when the parameter +/// tree cannot be sharded with the requested configuration. +/// +/// ### See Also +/// - ``shardLinear(module:sharding:segments:group:)`` +public func shardInPlace( + module: Module, sharding: ShardingType, segments: Int = 1, + group: DistributedGroup? = nil +) throws { + let group = group ?? DistributedGroup() + let predicate: (String, MLXArray) -> ShardInfo? + + switch sharding { + case .allToSharded: + predicate = allToShardedPredicate(segments: segments) + case .shardedToAll: + predicate = shardedToAllPredicate(segments: segments) + } + + let shardedParams = try shardParameterTree( + module.parameters(), predicate: predicate, group: group) + try applyShardedParameters(module, parameters: shardedParams) +} + +// MARK: - averageGradients + +/// Average a gradient tree across the processes in the distributed group. +/// +/// When the group has a single member the gradients are returned unchanged. +/// Otherwise each gradient array is sum-reduced across the group and divided +/// by the group size. +/// +/// This helper supports batching small gradient arrays into larger +/// concatenated chunks before performing the all-reduce, which can improve +/// communication performance. +/// This API is lazy and non-throwing: runtime communication failures may still +/// surface only when the returned arrays are evaluated. +/// +/// - Parameters: +/// - gradients: the gradient tree (typically from ``Module/parameters()`` +/// or ``Module/trainableParameters()``) +/// - group: the distributed group. If `nil`, uses `DistributedGroup()`. +/// - allReduceSize: maximum byte size for batching gradient arrays into a +/// single all-reduce call. Set to 0 or negative to disable batching. +/// Default is 32 MiB. +/// - communicationType: if provided, cast each gradient to this type before +/// communication and cast back to the original type after. Typically used +/// to cast to a smaller float (e.g. `.float16`) to reduce communication +/// size. Default is `nil`. +/// - communicationStream: optional stream for the communication. If `nil`, +/// the default stream is used. +/// - Returns: the averaged gradient tree with the same structure as the input +/// +/// ### See Also +/// - ``shardLinear(module:sharding:segments:group:)`` +/// - ``shardInPlace(module:sharding:segments:group:)`` +public func averageGradients( + gradients: ModuleParameters, + group: DistributedGroup? = nil, + allReduceSize: Int = 32 * 1024 * 1024, + communicationType: DType? = nil, + communicationStream: StreamOrDevice? = nil +) -> ModuleParameters { + let group = group ?? DistributedGroup() + let N = group.size + + if N == 1 { + return gradients + } + + let stream: StreamOrDevice = communicationStream ?? .default + + // Helper to average a single gradient array, optionally casting to + // communicationType before the all-reduce and back after. + func average(_ x: MLXArray) -> MLXArray { + let dt = x.dtype + let y = communicationType != nil ? x.asType(communicationType!) : x + return group.allSum(y, stream: stream).asType(dt) / Float(N) + } + + if allReduceSize <= 0 { + // No batching: average each gradient independently + return gradients.mapValues(transform: { array in + average(array) + }) + } + + // Batched mode: concatenate small gradients, reduce, split back + let flat = gradients.flattened() + if flat.isEmpty { + return gradients + } + + // Collect metadata + let keys = flat.map { $0.0 } + let values = flat.map { $0.1 } + let shapes = values.map { $0.shape } + let sizes = values.map { $0.size } + let dtypes = values.map { $0.dtype } + + // Check for mixed types -- if mixed, fall back to non-batched + let firstDtype = dtypes[0] + if !dtypes.allSatisfy({ $0 == firstDtype }) { + return averageGradients( + gradients: gradients, group: group, allReduceSize: 0, + communicationType: communicationType, + communicationStream: communicationStream) + } + + // Use communicationType size for batching threshold if provided, + // matching Python's behavior + let itemSize = communicationType?.size ?? firstDtype.size + + // Group gradients into batches that are at least allReduceSize bytes + var gradGroups = [[Int]]() + var currentGroup = [Int]() + var currentSize = 0 + + for i in 0 ..< keys.count { + currentGroup.append(i) + currentSize += sizes[i] * itemSize + if currentSize >= allReduceSize { + gradGroups.append(currentGroup) + currentGroup = [] + currentSize = 0 + } + } + if !currentGroup.isEmpty { + gradGroups.append(currentGroup) + } + + // Concatenate-reduce-split for each group + var newFlat = [(String, MLXArray)]() + for group in gradGroups { + // Flatten each gradient to 1D and concatenate + let flatArrays = group.map { values[$0].reshaped(-1) } + let bigGrad = concatenated(flatArrays, axis: 0) + + // Average the concatenated gradient + let averaged = average(bigGrad) + + // Split back using cumulative sizes as indices + var indices = [Int]() + var cumulative = 0 + for (i, idx) in group.enumerated() { + cumulative += sizes[idx] + if i < group.count - 1 { + indices.append(cumulative) + } + } + + let splitGrads: [MLXArray] + if indices.isEmpty { + splitGrads = [averaged] + } else { + splitGrads = split(averaged, indices: indices, axis: 0) + } + + for (i, idx) in group.enumerated() { + let reshaped = splitGrads[i].reshaped(shapes[idx]) + newFlat.append((keys[idx], reshaped)) + } + } + + return ModuleParameters.unflattened(newFlat) +}