AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice' What is the version of jax chex optax here?
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
What is the version of jax chex optax here?