Skip to content

Add MLX backend support for Nutpie compilation#254

Open
cetagostini wants to merge 11 commits intopymc-devs:mainfrom
cetagostini:cetagostini/adding_mlx_backend
Open

Add MLX backend support for Nutpie compilation#254
cetagostini wants to merge 11 commits intopymc-devs:mainfrom
cetagostini:cetagostini/adding_mlx_backend

Conversation

@cetagostini
Copy link
Copy Markdown

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
@aseyboldt
Copy link
Copy Markdown
Member

Thanks, that looks great!
I think we probably should call mlx.compile on the final functions though?

Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
@cetagostini
Copy link
Copy Markdown
Author

cetagostini commented Oct 27, 2025

Thanks, that looks great! I think we probably should call mlx.compile on the final functions though?

Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt

@cetagostini
Copy link
Copy Markdown
Author

@aseyboldt solve the test issue to work only on macs with intel chips.

@cetagostini
Copy link
Copy Markdown
Author

@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig.

@cetagostini cetagostini requested review from aseyboldt and jessegrabowski and removed request for aseyboldt October 30, 2025 12:46
Copy link
Copy Markdown
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?

Comment thread python/nutpie/compiled_pyfunc.py Outdated
updated.update(**updates)

# Convert to MLX arrays if using MLX backend (indicated by force_single_core)
if self._force_single_core:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.

Resolve conflicts to keep the MLX backend goals while adopting upstream
fixes and CI restructure:

- compiled_pyfunc.py: combine upstream's new extra_callback / extra_callback_rate
  parameters in PyFuncModel._make_sampler with the MLX force_single_core
  guard.
- .github/workflows/ci.yml: adopt upstream's split build/test jobs with
  the suite matrix (stan/pymc/flow), test_pymc_dev, docs, deploy-docs.
  Install mlx only in the macOS pymc suite on aarch64 (x86_64 macOS
  pymc is excluded per upstream's matrix).
- compile_pymc.py auto-merge: keep _compile_pymc_model_mlx, MLX backend
  in compile_pymc_model and _make_functions, and adopt upstream's
  PyTensor compat refactor (pt.grad, allow_xtensor_conversion, pytensor
  imports).
- tests/test_pymc.py auto-merge: keep dynamic backend_params with MLX
  guarded by find_spec, plus upstream's test_progress_callback.

Made-with: Cursor
Address review comments and CI failures on PR pymc-devs#254:

* MLX is not thread-safe (Metal command-buffer race,
  ml-explore/mlx#2133), so always set ``force_single_core=True`` for
  ``backend="mlx"`` regardless of ``gradient_backend``. This unblocks the
  default config (``gradient_backend="pytensor"``) with ``chains>=2``,
  which previously aborted with
  "A command encoder is already encoding to this command buffer".

* Decouple shared-data type conversion from ``_force_single_core``:
  ``PyFuncModel`` now carries an optional ``_shared_data_converter``
  callable (set to ``mx.array`` for MLX) used by ``with_data``. The old
  code was abusing ``_force_single_core`` as a "is MLX backend" proxy.

* Drop the stale ``raw_logp_fn`` plumbing in the MLX backend. The
  transform adapter is flowjax-based (JAX-only), so MLX could never
  expose a usable raw logp.

* Pin ``mlx<0.31`` in ``pyproject.toml`` extras and the CI install.
  mlx 0.31.x crashes inside ``Compiled::eval_gpu`` when the sampler
  worker thread evaluates auto-fused element-wise kernels
  (ml-explore/mlx#3329), which is the root cause of the macOS aarch64
  pymc test segfaults. mlx 0.29.x and 0.30.x are unaffected.

* Add a runtime guard in ``_compile_pymc_model_mlx`` that raises a
  helpful ``RuntimeError`` if mlx>=0.31 is installed anyway, instead of
  segfaulting partway through sampling.

* Make ``test_deterministic_sampling_mlx`` a smoke test (drop
  ``array_compare`` + reference file). MLX sampling is not bit-identical
  across machines/MLX versions, mirroring the situation already
  documented for ``test_normalizing_flow``.

* Mark ``test_dims_model[mlx-*]`` as ``xfail`` while
  pymc-devs/pytensor#1350 is open (PyTensor's MLX linker has no
  ``XTensorFromTensor`` op yet).

* Mirror logp's ``mx.array(_x)`` conversion in the expand closure for
  consistency.

* Ignore local AI tooling folders (``.ai/``, ``.cursor/``, ``.claude/``).

Made-with: Cursor
Keep only the upstream-issue references (mlx#2133, mlx#3329) since
those are non-obvious context; drop the explanatory comments whose
content is already conveyed by the surrounding code.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants