Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify get_bitinformation #262

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 99 additions & 78 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,30 @@ def dict_to_dataset(info_per_bit):
return dsb


def get_bitinformation( # noqa: C901
def _check_bitinfo_kwargs(implementation=None, axis=None, dim=None, kwargs=None):
if kwargs is None:
kwargs = {}
# check keywords
if implementation == "julia" and not julia_installed:
raise ImportError('Please install julia or use implementation="python".')
if axis is not None and dim is not None:
raise ValueError("Please provide either `axis` or `dim` but not both.")
if axis:
if not isinstance(axis, int):
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
if dim:
if not isinstance(dim, str) and not isinstance(dim, list):
raise ValueError(
f"Please provide `dim` as `str` or `list`, found {type(dim)}."
)
if "mask" in kwargs:
raise ValueError(
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
)
return


def get_bitinformation(
ds,
dim=None,
axis=None,
Expand All @@ -120,7 +143,7 @@ def get_bitinformation( # noqa: C901
----------
ds : :py:class:`xarray.Dataset`
Input dataset to analyse
dim : str
dim : str or list
Dimension over which to apply mean. Only one of the ``dim`` and ``axis`` arguments can be supplied.
If no ``dim`` or ``axis`` is given (default), the bitinformation is retrieved along all dimensions.
axis : int
Expand Down Expand Up @@ -182,81 +205,46 @@ def get_bitinformation( # noqa: C901
xbitinfo_version: ...
BitInformation.jl_version: ...
"""
if implementation == "julia" and not julia_installed:
raise ImportError('Please install julia or use implementation="python".')
if dim is None and axis is None:
# gather bitinformation on all axis
return _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
if isinstance(dim, list) and axis is None:
# gather bitinformation on dims specified
return _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
if overwrite is False and label is not None:
try:
info_per_bit = load_bitinformation(label)
info_per_bit = dict_to_dataset(info_per_bit)
except FileNotFoundError:
logging.info(
f"No bitinformation could be found for {label}. Please set `overwrite=True` for recalculation..."
)
else:
# gather bitinformation along one axis
if overwrite is False and label is not None:
try:
info_per_bit = load_bitinformation(label)
return info_per_bit
except FileNotFoundError:
logging.info(
f"No bitinformation could be found for {label}. Recalculating..."
)

# check keywords
if axis is not None and dim is not None:
raise ValueError("Please provide either `axis` or `dim` but not both.")
if axis:
if not isinstance(axis, int):
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
if dim:
if not isinstance(dim, str):
raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.")
if "mask" in kwargs:
raise ValueError(
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
_check_bitinfo_kwargs(implementation, axis, dim, kwargs)
if dim is None and axis is None:
# gather bitinformation on all axis
info_per_bit, label = _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
implementation=implementation,
**kwargs,
)
elif isinstance(dim, list) and axis is None:
# gather bitinformation on dims specified
info_per_bit, label = _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
implementation=implementation,
**kwargs,
)
else:
# gather bitinformation along one axis
info_per_bit = _get_bitinformation_along_axis(
ds, implementation, axis, dim, **kwargs
)

info_per_bit = {}
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")
if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
elif implementation == "python":
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
else:
raise ValueError(
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
)
if label is not None:
with open(label + ".json", "w") as f:
logging.debug(f"Save bitinformation to {label + '.json'}")
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
info_per_bit = dict_to_dataset(info_per_bit)
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
for a in ds[var].attrs.keys():
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
out_fn = label + ".json"
if not os.path.exists(out_fn) or overwrite:
save_bitinformation(info_per_bit.to_dict(), out_fn)
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
for a in ds[var].attrs.keys():
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
return info_per_bit


Expand Down Expand Up @@ -328,7 +316,6 @@ def _get_bitinformation_along_dims(
ds,
dim=None,
label=None,
overwrite=False,
implementation="julia",
**kwargs,
):
Expand All @@ -345,16 +332,42 @@ def _get_bitinformation_along_dims(
logging.info(f"Get bitinformation along dimension {d}")
if label is not None:
label = "_".join([label, d])
info_per_bit_per_dim[d] = get_bitinformation(
info_per_bit_per_dim[d] = _get_bitinformation_along_axis(
ds,
dim=d,
axis=None,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
).expand_dims("dim", axis=0)
info_per_bit = xr.merge(info_per_bit_per_dim.values()).squeeze()
return info_per_bit, label


def _get_bitinformation_along_axis(ds, implementation, axis, dim, **kwargs):
"""
Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis.
"""
info_per_bit = {}
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")
if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
elif implementation == "python":
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
else:
raise ValueError(
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
)
info_per_bit = dict_to_dataset(info_per_bit)
return info_per_bit


Expand Down Expand Up @@ -385,6 +398,14 @@ def load_bitinformation(label):
raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}")


def save_bitinformation(info_per_bit, out_fn, overwrite=False):
"""Save bitinformation to JSON file"""
with open(out_fn, "w") as f:
logging.debug(f"Save bitinformation to {out_fn}")
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
return


def get_keepbits(info_per_bit, inflevel=0.99):
"""Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.

Expand Down
Loading