@@ -105,7 +105,28 @@ def dict_to_dataset(info_per_bit):
105
105
return dsb
106
106
107
107
108
- def get_bitinformation ( # noqa: C901
108
+ def _check_bitinfo_kwargs (implementation = None , axis = None , dim = None , kwargs = None ):
109
+ if kwargs is None :
110
+ kwargs = {}
111
+ # check keywords
112
+ if implementation == "julia" and not julia_installed :
113
+ raise ImportError ('Please install julia or use implementation="python".' )
114
+ if axis is not None and dim is not None :
115
+ raise ValueError ("Please provide either `axis` or `dim` but not both." )
116
+ if axis :
117
+ if not isinstance (axis , int ):
118
+ raise ValueError (f"Please provide `axis` as `int`, found { type (axis )} ." )
119
+ if dim :
120
+ if not isinstance (dim , str ):
121
+ raise ValueError (f"Please provide `dim` as `str`, found { type (dim )} ." )
122
+ if "mask" in kwargs :
123
+ raise ValueError (
124
+ "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
125
+ )
126
+ return
127
+
128
+
129
+ def get_bitinformation (
109
130
ds ,
110
131
dim = None ,
111
132
axis = None ,
@@ -182,81 +203,46 @@ def get_bitinformation( # noqa: C901
182
203
xbitinfo_version: ...
183
204
BitInformation.jl_version: ...
184
205
"""
185
- if implementation == "julia" and not julia_installed :
186
- raise ImportError ('Please install julia or use implementation="python".' )
187
- if dim is None and axis is None :
188
- # gather bitinformation on all axis
189
- return _get_bitinformation_along_dims (
190
- ds ,
191
- dim = dim ,
192
- label = label ,
193
- overwrite = overwrite ,
194
- implementation = implementation ,
195
- ** kwargs ,
196
- )
197
- if isinstance (dim , list ) and axis is None :
198
- # gather bitinformation on dims specified
199
- return _get_bitinformation_along_dims (
200
- ds ,
201
- dim = dim ,
202
- label = label ,
203
- overwrite = overwrite ,
204
- implementation = implementation ,
205
- ** kwargs ,
206
- )
206
+ if overwrite is False and label is not None :
207
+ try :
208
+ info_per_bit = load_bitinformation (label )
209
+ except FileNotFoundError :
210
+ logging .info (
211
+ f"No bitinformation could be found for { label } . Please set `overwrite=True` for recalculation..."
212
+ )
207
213
else :
208
- # gather bitinformation along one axis
209
- if overwrite is False and label is not None :
210
- try :
211
- info_per_bit = load_bitinformation (label )
212
- return info_per_bit
213
- except FileNotFoundError :
214
- logging .info (
215
- f"No bitinformation could be found for { label } . Recalculating..."
216
- )
217
-
218
- # check keywords
219
- if axis is not None and dim is not None :
220
- raise ValueError ("Please provide either `axis` or `dim` but not both." )
221
- if axis :
222
- if not isinstance (axis , int ):
223
- raise ValueError (f"Please provide `axis` as `int`, found { type (axis )} ." )
224
- if dim :
225
- if not isinstance (dim , str ):
226
- raise ValueError (f"Please provide `dim` as `str`, found { type (dim )} ." )
227
- if "mask" in kwargs :
228
- raise ValueError (
229
- "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
214
+ _check_bitinfo_kwargs (implementation , axis , dim , kwargs )
215
+ if dim is None and axis is None :
216
+ # gather bitinformation on all axis
217
+ info_per_bit , label = _get_bitinformation_along_dims (
218
+ ds ,
219
+ dim = dim ,
220
+ label = label ,
221
+ implementation = implementation ,
222
+ ** kwargs ,
223
+ )
224
+ elif isinstance (dim , list ) and axis is None :
225
+ # gather bitinformation on dims specified
226
+ info_per_bit , label = _get_bitinformation_along_dims (
227
+ ds ,
228
+ dim = dim ,
229
+ label = label ,
230
+ implementation = implementation ,
231
+ ** kwargs ,
232
+ )
233
+ else :
234
+ # gather bitinformation along one axis
235
+ info_per_bit = _get_bitinformation_along_axis (
236
+ ds , implementation , axis , dim , kwargs
230
237
)
231
-
232
- info_per_bit = {}
233
- pbar = tqdm (ds .data_vars )
234
- for var in pbar :
235
- pbar .set_description (f"Processing var: { var } for dim: { dim } " )
236
- if implementation == "julia" :
237
- info_per_bit_var = _jl_get_bitinformation (ds , var , axis , dim , kwargs )
238
- if info_per_bit_var is None :
239
- continue
240
- else :
241
- info_per_bit [var ] = info_per_bit_var
242
- elif implementation == "python" :
243
- info_per_bit_var = _py_get_bitinformation (ds , var , axis , dim , kwargs )
244
- if info_per_bit_var is None :
245
- continue
246
- else :
247
- info_per_bit [var ] = info_per_bit_var
248
- else :
249
- raise ValueError (
250
- f"Implementation of bitinformation algorithm { implementation } is unknown. Please choose a different one."
251
- )
252
238
if label is not None :
253
- with open ( label + ".json" , "w" ) as f :
254
- logging . debug ( f"Save bitinformation to { label + '.json' } " )
255
- json . dump (info_per_bit , f , cls = JsonCustomEncoder )
256
- info_per_bit = dict_to_dataset (info_per_bit )
257
- for var in info_per_bit .data_vars : # keep attrs from input with source_ prefix
258
- for a in ds [var ].attrs .keys ():
259
- info_per_bit [var ].attrs ["source_" + a ] = ds [var ].attrs [a ]
239
+ out_fn = label + ".json"
240
+ if not os . path . exists ( out_fn ) or overwrite :
241
+ save_bitinformation (info_per_bit , out_fn )
242
+ info_per_bit = dict_to_dataset (info_per_bit )
243
+ for var in info_per_bit .data_vars : # keep attrs from input with source_ prefix
244
+ for a in ds [var ].attrs .keys ():
245
+ info_per_bit [var ].attrs ["source_" + a ] = ds [var ].attrs [a ]
260
246
return info_per_bit
261
247
262
248
@@ -328,7 +314,6 @@ def _get_bitinformation_along_dims(
328
314
ds ,
329
315
dim = None ,
330
316
label = None ,
331
- overwrite = False ,
332
317
implementation = "julia" ,
333
318
** kwargs ,
334
319
):
@@ -345,16 +330,41 @@ def _get_bitinformation_along_dims(
345
330
logging .info (f"Get bitinformation along dimension { d } " )
346
331
if label is not None :
347
332
label = "_" .join ([label , d ])
348
- info_per_bit_per_dim [d ] = get_bitinformation (
333
+ info_per_bit_per_dim [d ] = _get_bitinformation_along_axis (
349
334
ds ,
350
335
dim = d ,
351
336
axis = None ,
352
- label = label ,
353
- overwrite = overwrite ,
354
337
implementation = implementation ,
355
338
** kwargs ,
356
339
).expand_dims ("dim" , axis = 0 )
357
340
info_per_bit = xr .merge (info_per_bit_per_dim .values ()).squeeze ()
341
+ return info_per_bit , label
342
+
343
+
344
+ def _get_bitinformation_along_axis (ds , implementation , axis , dim , kwargs ):
345
+ """
346
+ Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis.
347
+ """
348
+ info_per_bit = {}
349
+ pbar = tqdm (ds .data_vars )
350
+ for var in pbar :
351
+ pbar .set_description (f"Processing var: { var } for dim: { dim } " )
352
+ if implementation == "julia" :
353
+ info_per_bit_var = _jl_get_bitinformation (ds , var , axis , dim , kwargs )
354
+ if info_per_bit_var is None :
355
+ continue
356
+ else :
357
+ info_per_bit [var ] = info_per_bit_var
358
+ elif implementation == "python" :
359
+ info_per_bit_var = _py_get_bitinformation (ds , var , axis , dim , kwargs )
360
+ if info_per_bit_var is None :
361
+ continue
362
+ else :
363
+ info_per_bit [var ] = info_per_bit_var
364
+ else :
365
+ raise ValueError (
366
+ f"Implementation of bitinformation algorithm { implementation } is unknown. Please choose a different one."
367
+ )
358
368
return info_per_bit
359
369
360
370
@@ -385,6 +395,14 @@ def load_bitinformation(label):
385
395
raise FileNotFoundError (f"No bitinformation could be found at { label + '.json' } " )
386
396
387
397
398
+ def save_bitinformation (info_per_bit , out_fn , overwrite = False ):
399
+ """Save bitinformation to JSON file"""
400
+ with open (out_fn , "w" ) as f :
401
+ logging .debug (f"Save bitinformation to { out_fn } " )
402
+ json .dump (info_per_bit , f , cls = JsonCustomEncoder )
403
+ return
404
+
405
+
388
406
def get_keepbits (info_per_bit , inflevel = 0.99 ):
389
407
"""Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.
390
408
0 commit comments