diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b1f3a343..d276b663 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,7 @@ X.X.X (unreleased) * Implement basic retrieval of bitinformation in python as alternative to julia implementation (:pr:`156`, :issue:`155`, :pr:`126`, :issue:`125`) `Hauke Schulz`_ with helpful comments from `Milan Klöwer`_. * Make julia binding to BitInformation.jl optional (:pr:`153`, :issue:`151`) `Aaron Spring`_. * Add support for python 3.13 (:pr:`303`) and use uv for GitHub Actions `Hauke Schulz`_. +* Simplify get_bitinformation function (:pr:`262`, :issue:`261`) `Hauke Schulz`_. 0.0.3 (2022-07-11) ------------------ diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index 15d6d67c..2a7fb54e 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -108,7 +108,30 @@ def dict_to_dataset(info_per_bit): return dsb -def get_bitinformation( # noqa: C901 +def _check_bitinfo_kwargs(implementation=None, axis=None, dim=None, kwargs=None): + if kwargs is None: + kwargs = {} + # check keywords + if implementation == "julia" and not julia_installed: + raise ImportError('Please install julia or use implementation="python".') + if axis is not None and dim is not None: + raise ValueError("Please provide either `axis` or `dim` but not both.") + if axis: + if not isinstance(axis, int): + raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.") + if dim: + if not isinstance(dim, str) and not isinstance(dim, list): + raise ValueError( + f"Please provide `dim` as `str` or `list`, found {type(dim)}." + ) + if "mask" in kwargs: + raise ValueError( + "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead." + ) + return + + +def get_bitinformation( ds, dim=None, axis=None, @@ -123,7 +146,7 @@ def get_bitinformation( # noqa: C901 ---------- ds : :py:class:`xarray.Dataset` Input dataset to analyse - dim : str + dim : str or list Dimension over which to apply mean. Only one of the ``dim`` and ``axis`` arguments can be supplied. If no ``dim`` or ``axis`` is given (default), the bitinformation is retrieved along all dimensions. axis : int @@ -185,8 +208,38 @@ def get_bitinformation( # noqa: C901 xbitinfo_version: ... BitInformation.jl_version: ... """ - if implementation == "julia" and not julia_installed: - raise ImportError('Please install julia or use implementation="python".') + if overwrite is False and label is not None: + try: + info_per_bit = load_bitinformation(label) + except FileNotFoundError: + logging.info( + f"No bitinformation could be found for {label}. Please set `overwrite=True` for recalculation..." + ) + else: + return info_per_bit + else: + _check_bitinfo_kwargs(implementation, axis, dim, kwargs) + + return _get_bitinformation( + ds, + dim=dim, + axis=axis, + label=label, + overwrite=overwrite, + implementation=implementation, + **kwargs, + ) + + +def _get_bitinformation( + ds, + dim=None, + axis=None, + label=None, + overwrite=False, + implementation="julia", + **kwargs, +): if dim is None and axis is None: # gather bitinformation on all axis return _get_bitinformation_along_dims( @@ -209,63 +262,17 @@ def get_bitinformation( # noqa: C901 ) else: # gather bitinformation along one axis - if overwrite is False and label is not None: - try: - info_per_bit = load_bitinformation(label) - return info_per_bit - except FileNotFoundError: - logging.info( - f"No bitinformation could be found for {label}. Recalculating..." - ) + info_per_bit = _get_bitinformation_along_axis( + ds, implementation, axis, dim, **kwargs + ) - # check keywords - if axis is not None and dim is not None: - raise ValueError("Please provide either `axis` or `dim` but not both.") - if axis: - if not isinstance(axis, int): - raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.") - if dim: - if not isinstance(dim, str): - raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.") - if "mask" in kwargs: - raise ValueError( - "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead." - ) + if label is not None: + out_fn = label + ".json" + if not os.path.exists(out_fn) or overwrite: + save_bitinformation(info_per_bit, out_fn) - info_per_bit = {} - pbar = tqdm(ds.data_vars) - for var in pbar: - pbar.set_description(f"Processing var: {var} for dim: {dim}") - - if _quantized_variable_is_scaled(ds, var): - loaded_dtype = ds[var].dtype - quantized_storage_dtype = ds[var].encoding["dtype"] - warnings.warn( - f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results", - category=UserWarning, - ) + info_per_bit = dict_to_dataset(info_per_bit) - if implementation == "julia": - info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) - if info_per_bit_var is None: - continue - else: - info_per_bit[var] = info_per_bit_var - elif implementation == "python": - info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs) - if info_per_bit_var is None: - continue - else: - info_per_bit[var] = info_per_bit_var - else: - raise ValueError( - f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one." - ) - if label is not None: - with open(label + ".json", "w") as f: - logging.debug(f"Save bitinformation to {label + '.json'}") - json.dump(info_per_bit, f, cls=JsonCustomEncoder) - info_per_bit = dict_to_dataset(info_per_bit) for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix for a in ds[var].attrs.keys(): info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a] @@ -377,7 +384,7 @@ def _get_bitinformation_along_dims( logging.info(f"Get bitinformation along dimension {d}") if label is not None: label = "_".join([label, d]) - info_per_bit_per_dim[d] = get_bitinformation( + info_per_bit_per_dim[d] = _get_bitinformation( ds, dim=d, axis=None, @@ -390,6 +397,41 @@ def _get_bitinformation_along_dims( return info_per_bit +def _get_bitinformation_along_axis(ds, implementation, axis, dim, **kwargs): + """ + Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis. + """ + info_per_bit = {} + pbar = tqdm(ds.data_vars) + for var in pbar: + pbar.set_description(f"Processing var: {var} for dim: {dim}") + if _quantized_variable_is_scaled(ds, var): + loaded_dtype = ds[var].dtype + quantized_storage_dtype = ds[var].encoding["dtype"] + warnings.warn( + f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results", + category=UserWarning, + ) + if implementation == "julia": + info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: + continue + else: + info_per_bit[var] = info_per_bit_var + elif implementation == "python": + info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: + continue + else: + info_per_bit[var] = info_per_bit_var + else: + raise ValueError( + f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one." + ) + + return info_per_bit + + def _get_bitinformation_kwargs_handler(da, kwargs): """Helper function to preprocess kwargs args of :py:func:`xbitinfo.xbitinfo.get_bitinformation`.""" kwargs_var = kwargs.copy() @@ -539,6 +581,14 @@ def get_cdf_without_artificial_information( return cdf +def save_bitinformation(info_per_bit, out_fn, overwrite=False): + """Save bitinformation to JSON file""" + with open(out_fn, "w") as f: + logging.debug(f"Save bitinformation to {out_fn}") + json.dump(info_per_bit, f, cls=JsonCustomEncoder) + return + + def get_keepbits(info_per_bit, inflevel=0.99, information_filter=None, **kwargs): """Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.