Skip to content

Deprecate Float-defaulted Factory functions (zeros/ones/eye/identity/tri)#408

Closed
john-rocky wants to merge 1 commit into
ml-explore:mainfrom
john-rocky:fix/issue-390-explicit-dtype
Closed

Deprecate Float-defaulted Factory functions (zeros/ones/eye/identity/tri)#408
john-rocky wants to merge 1 commit into
ml-explore:mainfrom
john-rocky:fix/issue-390-explicit-dtype

Conversation

@john-rocky
Copy link
Copy Markdown

Closes #390.

Implements the path called out in the issue: factory functions whose dtype silently defaults to Float.self are now deprecated, while the existing dtype: / type: overloads are kept as the recommended replacements.

API change

For each of zeros, ones, eye, identity, tri (both `MLXArray.` and free `MLX.` forms):

```swift
// Already existed — keep as the recommended explicit-dtype paths:
zeros(shape, dtype: .float32)
zeros(shape, type: Float.self)

// Existed but silently returned float32 — now deprecated:
zeros(shape) // ⚠️ #DeprecatedDeclaration
```

The deprecated overload still returns float32, so the change is source-compatible. Callers that omitted the type argument get a Swift warning pointing them at the explicit-dtype form and at this issue.

A small mechanical cleanup also runs alongside:

  • `type: (some HasDType).Type = Float.self` had its `= Float.self` default removed. The new no-type deprecated overload is more specific than the type-explicit overload under Swift overload resolution, so the default was already unreachable; the cleanup just removes dead syntax. Explicit `type: T.self` calls continue to compile silently.
  • `full(shape, values, type: T = Float.self)` had its default removed for the same reason — `full(shape, values)` already exists and inherits `values.dtype`, so the default was never selected.

In-tree callers tightened to keep the build warning-free

The deprecation surfaced silent float32 promotion at six existing internal sites in MLXNN and the tutorial example:

  • `Convolution.bias = MLXArray.zeros([outputChannels])` (1D / 2D / 3D)
  • `ConvolutionTransposed.bias = MLXArray.zeros([outputChannels])` (1D / 2D / 3D)
  • `LayerNorm` / `RMSNorm` / `GroupNorm` / `BatchNorm` affine + running-stats parameters
  • `QuantizedLinear.bias`
  • `Source/Examples/Tutorial.swift`

All become `...zeros([d], dtype: .float32)` / `...ones([d], dtype: .float32)`. No behavior change — the original implicit dtype was `.float32` — but the dtype is now visible at the layer-construction site, which is also a stepping-stone toward the broader "thread dtype through layer initializers" follow-up that #390 alludes to via ml-explore/mlx-swift-lm#124.

Intentional non-goals (separate PRs welcome)

  • `arange(_ stop: Double, dtype: DType = .float32, ...)`: matches Python (`mx.arange(3.0)` returns float32), so the default is parity, not a footgun. Leaving in place.
  • `MLXRandom.{uniform, normal, multivariateNormal, truncatedNormal, gumbel, laplace}`: these have the same `dtype: DType = .float32` shape. Worth tackling, but the right design depends on whether the dtype should be inferred from `low`/`high` when those are `MLXArray` (Python-parity behavior). I left them out so the scope of this PR stays focused and reviewable, and so the inference-vs-explicit question can be discussed independently.
  • Layer-level dtype threading (`Linear(dtype:)`, `Conv2d(dtype:)`, `*Norm(dtype:)`, etc.). This is the real fix for the issue's underlying motivation (avoiding implicit float32 weights in bf16 pipelines). It's a much larger API surface — happy to follow up with a draft once the deprecation direction here is settled.

Verification

```
swift build # Build complete!
# 0 #390-related warnings remain in-tree
```

`swift test` could not be run end-to-end on this machine: the SPM test target hits the pre-existing "Failed to load the default metallib" issue (#349) before any of my changes get exercised. Plain library / example builds succeed without warnings, and the changes are purely additive overloads plus default-argument removals, so there is no runtime path affected by this PR that isn't already covered by the build.

I'm happy to:

  • Strip the in-tree call-site updates if you'd rather see them in a separate "clean up internal float32 callers" PR.
  • Convert the no-type overloads to a hard removal (i.e., remove `= Float.self` and skip adding the deprecated overload) if you prefer a breaking change with a clearer deprecation cycle in a release-note section.
  • Extend the same pattern to `MLXRandom.*` once the inference-from-inputs question is decided.

Whichever direction you prefer, this PR can be reshaped to match.

Closes ml-explore#390.

Several `MLX.<factory>(...)` and `MLXArray.<factory>(...)` overloads
silently default `type:` to `Float.self`, returning float32 arrays
even when they participate in bfloat16 / float16 graphs (each implicit
float32 value then triggers an AsType cast in MLX's C++ engine, with
the bandwidth penalty described in the issue body).

This change:

1. **Adds a deprecated no-`type:` overload** for each affected factory
   (`zeros`, `ones`, `eye`, `identity`, `tri`) on both `MLXArray` and
   the free function. The deprecated overload still returns float32, so
   source compatibility is preserved, but callers that elided the
   `type:` argument now see a `#DeprecatedDeclaration` warning pointing
   to the explicit `dtype:` / `type:` variants and to issue ml-explore#390.

2. **Removes `= Float.self` from the existing `type:` overloads.**
   Because the new no-type overload is more specific than the
   default-eliding variant, Swift's overload resolution prefers the
   deprecated one for `zeros(shape)` / `ones(shape)` etc. The
   `= Float.self` default on the type-explicit overload became
   unreachable, so the cleanup removes it without behavior change.
   Callers passing `type:` explicitly continue to compile silently.

3. **Removes the unreachable `= Float.self` default from
   `full(shape, values, type:)`.** `full(shape, values)` already exists
   and inherits the dtype of `values`, so the defaulted variant was
   never selected by overload resolution.

4. **Threads `dtype: .float32` through internal call sites in
   `MLXNN` (Convolution / ConvolutionTransposed / Normalization /
   Quantized) and `Examples/Tutorial`** so the in-tree build stays
   warning-free. These call sites previously relied on the silent
   default; tightening them is a no-op behaviorally but documents the
   dtype choice at the layer-construction site.

What's intentionally *not* in this PR:

- `arange` (`Double` overloads with `dtype: DType = .float32`):
  matches Python's behavior (`mx.arange(3.0) -> float32`) so leaving
  the default in place keeps parity.
- `MLXRandom` factories (`uniform` / `normal` / `multivariateNormal` /
  `truncatedNormal` / `gumbel` / `laplace`): these need a follow-up
  PR; the existing call sites would otherwise produce a wave of
  warnings before the inference-from-low/high direction is decided.
- Threading a dtype through `Linear` / `Conv*` / `*Norm` initialisers
  so that constructed weights / biases inherit a layer-wide dtype.
  That is the next logical step (the bandwidth issue surfaced in
  ml-explore/mlx-swift-lm#124), but it is a much larger surface change
  and is best discussed separately.

Build verification:

    swift build      # Build complete!, no ml-explore#390 warnings
@john-rocky
Copy link
Copy Markdown
Author

Apologies — I missed your existing draft #391 when I surveyed the issue tracker, and opened this on the same target. Closing in favor of #391.

For what it's worth, the only piece in here that wasn't in #391 was the in-tree call-site cleanup in MLXNN (Convolution / ConvolutionTransposed / LayerNorm / RMSNorm / GroupNorm / BatchNorm / QuantizedLinear / Tutorial) — those would silence the wave of internal deprecation warnings that #391 will surface once it leaves draft. Happy to send that as a tiny follow-up after #391 lands; no attribution needed and please feel free to copy the diff into #391 if it's easier.

Sorry for the noise.

@john-rocky john-rocky closed this May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] many mlx-swift functions incorrectly promote to float32, e.g. zeros, ones, etc.

1 participant