diff --git a/mlx/c/array.cpp b/mlx/c/array.cpp index 7c7342d..8267a41 100644 --- a/mlx/c/array.cpp +++ b/mlx/c/array.cpp @@ -330,7 +330,12 @@ extern "C" const size_t* mlx_array_strides(const mlx_array arr) { } extern "C" int mlx_array_dim(const mlx_array arr, int dim) { try { - return mlx_array_get_(arr).shape(dim); + auto& a = mlx_array_get_(arr); + int ndim = a.ndim(); + if (ndim == 0) return 1; // scalar: treat as size 1 for any dim + if (dim < 0) dim += ndim; + if (dim < 0 || dim >= ndim) return 0; // out of bounds: return 0 instead of crashing + return a.shape(dim); } catch (std::exception& e) { mlx_error(e.what()); return 0;