diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index eafa53904..6be0e29e0 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -665,7 +665,6 @@ def _make_interpret_callable() -> _PallasInterpretCallable: kernel_name = getattr(pallas_kernel, "__name__", "pallas_kernel") - jax.config.update("jax_export_ignore_forward_compatibility", True) jax_callable = JaxCallable( name=kernel_name, jit_fn=jax.jit(jit_fn),