Add MLX backend support for Nutpie compilation#254
Add MLX backend support for Nutpie compilation#254cetagostini wants to merge 11 commits intopymc-devs:mainfrom
Conversation
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
|
Thanks, that looks great! |
Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt |
|
@aseyboldt solve the test issue to work only on macs with intel chips. |
|
@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig. |
aseyboldt
left a comment
There was a problem hiding this comment.
That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?
| updated.update(**updates) | ||
|
|
||
| # Convert to MLX arrays if using MLX backend (indicated by force_single_core) | ||
| if self._force_single_core: |
There was a problem hiding this comment.
We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.
Resolve conflicts to keep the MLX backend goals while adopting upstream fixes and CI restructure: - compiled_pyfunc.py: combine upstream's new extra_callback / extra_callback_rate parameters in PyFuncModel._make_sampler with the MLX force_single_core guard. - .github/workflows/ci.yml: adopt upstream's split build/test jobs with the suite matrix (stan/pymc/flow), test_pymc_dev, docs, deploy-docs. Install mlx only in the macOS pymc suite on aarch64 (x86_64 macOS pymc is excluded per upstream's matrix). - compile_pymc.py auto-merge: keep _compile_pymc_model_mlx, MLX backend in compile_pymc_model and _make_functions, and adopt upstream's PyTensor compat refactor (pt.grad, allow_xtensor_conversion, pytensor imports). - tests/test_pymc.py auto-merge: keep dynamic backend_params with MLX guarded by find_spec, plus upstream's test_progress_callback. Made-with: Cursor
Address review comments and CI failures on PR pymc-devs#254: * MLX is not thread-safe (Metal command-buffer race, ml-explore/mlx#2133), so always set ``force_single_core=True`` for ``backend="mlx"`` regardless of ``gradient_backend``. This unblocks the default config (``gradient_backend="pytensor"``) with ``chains>=2``, which previously aborted with "A command encoder is already encoding to this command buffer". * Decouple shared-data type conversion from ``_force_single_core``: ``PyFuncModel`` now carries an optional ``_shared_data_converter`` callable (set to ``mx.array`` for MLX) used by ``with_data``. The old code was abusing ``_force_single_core`` as a "is MLX backend" proxy. * Drop the stale ``raw_logp_fn`` plumbing in the MLX backend. The transform adapter is flowjax-based (JAX-only), so MLX could never expose a usable raw logp. * Pin ``mlx<0.31`` in ``pyproject.toml`` extras and the CI install. mlx 0.31.x crashes inside ``Compiled::eval_gpu`` when the sampler worker thread evaluates auto-fused element-wise kernels (ml-explore/mlx#3329), which is the root cause of the macOS aarch64 pymc test segfaults. mlx 0.29.x and 0.30.x are unaffected. * Add a runtime guard in ``_compile_pymc_model_mlx`` that raises a helpful ``RuntimeError`` if mlx>=0.31 is installed anyway, instead of segfaulting partway through sampling. * Make ``test_deterministic_sampling_mlx`` a smoke test (drop ``array_compare`` + reference file). MLX sampling is not bit-identical across machines/MLX versions, mirroring the situation already documented for ``test_normalizing_flow``. * Mark ``test_dims_model[mlx-*]`` as ``xfail`` while pymc-devs/pytensor#1350 is open (PyTensor's MLX linker has no ``XTensorFromTensor`` op yet). * Mirror logp's ``mx.array(_x)`` conversion in the expand closure for consistency. * Ignore local AI tooling folders (``.ai/``, ``.cursor/``, ``.claude/``). Made-with: Cursor
Keep only the upstream-issue references (mlx#2133, mlx#3329) since those are non-obvious context; drop the explanatory comments whose content is already conveyed by the surrounding code. Made-with: Cursor
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.