Skip to content

Commit 1b323c8

Browse files
simplify get_bitinformation
1 parent 6fd035b commit 1b323c8

File tree

1 file changed

+95
-77
lines changed

1 file changed

+95
-77
lines changed

xbitinfo/xbitinfo.py

+95-77
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,28 @@ def dict_to_dataset(info_per_bit):
105105
return dsb
106106

107107

108-
def get_bitinformation( # noqa: C901
108+
def _check_bitinfo_kwargs(implementation=None, axis=None, dim=None, kwargs=None):
109+
if kwargs is None:
110+
kwargs = {}
111+
# check keywords
112+
if implementation == "julia" and not julia_installed:
113+
raise ImportError('Please install julia or use implementation="python".')
114+
if axis is not None and dim is not None:
115+
raise ValueError("Please provide either `axis` or `dim` but not both.")
116+
if axis:
117+
if not isinstance(axis, int):
118+
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
119+
if dim:
120+
if not isinstance(dim, str):
121+
raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.")
122+
if "mask" in kwargs:
123+
raise ValueError(
124+
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
125+
)
126+
return
127+
128+
129+
def get_bitinformation(
109130
ds,
110131
dim=None,
111132
axis=None,
@@ -182,81 +203,46 @@ def get_bitinformation( # noqa: C901
182203
xbitinfo_version: ...
183204
BitInformation.jl_version: ...
184205
"""
185-
if implementation == "julia" and not julia_installed:
186-
raise ImportError('Please install julia or use implementation="python".')
187-
if dim is None and axis is None:
188-
# gather bitinformation on all axis
189-
return _get_bitinformation_along_dims(
190-
ds,
191-
dim=dim,
192-
label=label,
193-
overwrite=overwrite,
194-
implementation=implementation,
195-
**kwargs,
196-
)
197-
if isinstance(dim, list) and axis is None:
198-
# gather bitinformation on dims specified
199-
return _get_bitinformation_along_dims(
200-
ds,
201-
dim=dim,
202-
label=label,
203-
overwrite=overwrite,
204-
implementation=implementation,
205-
**kwargs,
206-
)
206+
if overwrite is False and label is not None:
207+
try:
208+
info_per_bit = load_bitinformation(label)
209+
except FileNotFoundError:
210+
logging.info(
211+
f"No bitinformation could be found for {label}. Please set `overwrite=True` for recalculation..."
212+
)
207213
else:
208-
# gather bitinformation along one axis
209-
if overwrite is False and label is not None:
210-
try:
211-
info_per_bit = load_bitinformation(label)
212-
return info_per_bit
213-
except FileNotFoundError:
214-
logging.info(
215-
f"No bitinformation could be found for {label}. Recalculating..."
216-
)
217-
218-
# check keywords
219-
if axis is not None and dim is not None:
220-
raise ValueError("Please provide either `axis` or `dim` but not both.")
221-
if axis:
222-
if not isinstance(axis, int):
223-
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
224-
if dim:
225-
if not isinstance(dim, str):
226-
raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.")
227-
if "mask" in kwargs:
228-
raise ValueError(
229-
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
214+
_check_bitinfo_kwargs(implementation, axis, dim, kwargs)
215+
if dim is None and axis is None:
216+
# gather bitinformation on all axis
217+
info_per_bit, label = _get_bitinformation_along_dims(
218+
ds,
219+
dim=dim,
220+
label=label,
221+
implementation=implementation,
222+
**kwargs,
223+
)
224+
elif isinstance(dim, list) and axis is None:
225+
# gather bitinformation on dims specified
226+
info_per_bit, label = _get_bitinformation_along_dims(
227+
ds,
228+
dim=dim,
229+
label=label,
230+
implementation=implementation,
231+
**kwargs,
232+
)
233+
else:
234+
# gather bitinformation along one axis
235+
info_per_bit = _get_bitinformation_along_axis(
236+
ds, implementation, axis, dim, kwargs
230237
)
231-
232-
info_per_bit = {}
233-
pbar = tqdm(ds.data_vars)
234-
for var in pbar:
235-
pbar.set_description(f"Processing var: {var} for dim: {dim}")
236-
if implementation == "julia":
237-
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
238-
if info_per_bit_var is None:
239-
continue
240-
else:
241-
info_per_bit[var] = info_per_bit_var
242-
elif implementation == "python":
243-
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
244-
if info_per_bit_var is None:
245-
continue
246-
else:
247-
info_per_bit[var] = info_per_bit_var
248-
else:
249-
raise ValueError(
250-
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
251-
)
252238
if label is not None:
253-
with open(label + ".json", "w") as f:
254-
logging.debug(f"Save bitinformation to {label + '.json'}")
255-
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
256-
info_per_bit = dict_to_dataset(info_per_bit)
257-
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
258-
for a in ds[var].attrs.keys():
259-
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
239+
out_fn = label + ".json"
240+
if not os.path.exists(out_fn) or overwrite:
241+
save_bitinformation(info_per_bit, out_fn)
242+
info_per_bit = dict_to_dataset(info_per_bit)
243+
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
244+
for a in ds[var].attrs.keys():
245+
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
260246
return info_per_bit
261247

262248

@@ -328,7 +314,6 @@ def _get_bitinformation_along_dims(
328314
ds,
329315
dim=None,
330316
label=None,
331-
overwrite=False,
332317
implementation="julia",
333318
**kwargs,
334319
):
@@ -345,16 +330,41 @@ def _get_bitinformation_along_dims(
345330
logging.info(f"Get bitinformation along dimension {d}")
346331
if label is not None:
347332
label = "_".join([label, d])
348-
info_per_bit_per_dim[d] = get_bitinformation(
333+
info_per_bit_per_dim[d] = _get_bitinformation_along_axis(
349334
ds,
350335
dim=d,
351336
axis=None,
352-
label=label,
353-
overwrite=overwrite,
354337
implementation=implementation,
355338
**kwargs,
356339
).expand_dims("dim", axis=0)
357340
info_per_bit = xr.merge(info_per_bit_per_dim.values()).squeeze()
341+
return info_per_bit, label
342+
343+
344+
def _get_bitinformation_along_axis(ds, implementation, axis, dim, kwargs):
345+
"""
346+
Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis.
347+
"""
348+
info_per_bit = {}
349+
pbar = tqdm(ds.data_vars)
350+
for var in pbar:
351+
pbar.set_description(f"Processing var: {var} for dim: {dim}")
352+
if implementation == "julia":
353+
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
354+
if info_per_bit_var is None:
355+
continue
356+
else:
357+
info_per_bit[var] = info_per_bit_var
358+
elif implementation == "python":
359+
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
360+
if info_per_bit_var is None:
361+
continue
362+
else:
363+
info_per_bit[var] = info_per_bit_var
364+
else:
365+
raise ValueError(
366+
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
367+
)
358368
return info_per_bit
359369

360370

@@ -385,6 +395,14 @@ def load_bitinformation(label):
385395
raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}")
386396

387397

398+
def save_bitinformation(info_per_bit, out_fn, overwrite=False):
399+
"""Save bitinformation to JSON file"""
400+
with open(out_fn, "w") as f:
401+
logging.debug(f"Save bitinformation to {out_fn}")
402+
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
403+
return
404+
405+
388406
def get_keepbits(info_per_bit, inflevel=0.99):
389407
"""Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.
390408

0 commit comments

Comments
 (0)