diff --git a/CHANGELOG.md b/CHANGELOG.md index bf3ba2d6f..6fd659e5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Fixes: - Fix issues with tag false reassignment (mihran113) +- Fix issue of detecting defice arrays with newer versions of jax (alberttorosyan) ## 3.29.1 May 8, 2025: diff --git a/aim/sdk/num_utils.py b/aim/sdk/num_utils.py index ff1678278..b11b29106 100644 --- a/aim/sdk/num_utils.py +++ b/aim/sdk/num_utils.py @@ -71,6 +71,9 @@ def is_jax_device_array(inst): return True if inst_has_typename(inst, ['jaxlib', 'xla_extension', 'DeviceArray']): return True + if inst_has_typename(inst, ['jaxlib', '_jax', 'ArrayImpl']): + return True + return False