diff --git a/lindi/LindiH5pyFile/LindiH5pyDataset.py b/lindi/LindiH5pyFile/LindiH5pyDataset.py index 8e12b93..c1423f0 100644 --- a/lindi/LindiH5pyFile/LindiH5pyDataset.py +++ b/lindi/LindiH5pyFile/LindiH5pyDataset.py @@ -229,9 +229,9 @@ def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any): ) return ret else: - raise TypeError( - f"Compound dataset {self.name} does not support selection with {selection}" - ) + # Integer or slice indexing on a compound dataset - return + # rows as numpy structured array (or np.void for scalar index) + return self._get_compound_rows(zarr_array, selection) # We use zarr's slicing, except in the case of a scalar dataset if self.ndim == 0: @@ -243,6 +243,36 @@ def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any): return zarr_array[:][0] return decode_references(zarr_array[selection]) + def _get_compound_rows(self, zarr_array: zarr.Array, selection): + """Return rows from a compound dataset as a numpy structured array. + + For integer indexing, returns a single np.void. For slices, returns + a numpy structured array with the compound dtype. + """ + assert self._compound_dtype is not None + raw = zarr_array[selection] + # raw is either a single list (integer index) or list of lists (slice) + if isinstance(selection, (int, np.integer)): + # Single row - return np.void + row = raw + tup = tuple( + LindiH5pyReference(row[i]['_REFERENCE']) if isinstance(row[i], dict) and '_REFERENCE' in row[i] + else row[i] + for i in range(len(self._compound_dtype)) + ) + return np.void(tup, dtype=self._compound_dtype) + else: + # Multiple rows - return structured array + result = np.empty(len(raw), dtype=self._compound_dtype) + for row_idx, row in enumerate(raw): + tup = tuple( + LindiH5pyReference(row[i]['_REFERENCE']) if isinstance(row[i], dict) and '_REFERENCE' in row[i] + else row[i] + for i in range(len(self._compound_dtype)) + ) + result[row_idx] = tup + return result + def _get_external_hdf5_client(self, url: str) -> h5py.File: if url not in _external_hdf5_clients: if url.startswith("http://") or url.startswith("https://"): diff --git a/tests/test_numpy_array_conversion.py b/tests/test_numpy_array_conversion.py index f2afffa..69abb9e 100644 --- a/tests/test_numpy_array_conversion.py +++ b/tests/test_numpy_array_conversion.py @@ -74,3 +74,16 @@ def test_numpy_array_conversion_compound(): y_vals = np.asarray(ds['y'][:]) np.testing.assert_array_equal(y_vals, np.array([2.5, 4.5, 6.5], dtype=np.float64)) + + # Test integer indexing - returns np.void + row0 = ds[0] + assert isinstance(row0, np.void) + assert row0['x'] == 1 + assert row0['y'] == 2.5 + + # Test slice indexing - returns structured array + rows = ds[0:2] + assert rows.dtype == compound_dtype + assert len(rows) == 2 + np.testing.assert_array_equal(rows['x'], np.array([1, 3], dtype=np.int32)) + np.testing.assert_array_equal(rows['y'], np.array([2.5, 4.5], dtype=np.float64))