@@ -23,11 +23,9 @@ class BitInfo(BitRound):
23
23
Parameters
24
24
----------
25
25
26
- inflevel: float
27
- The number of bits of the mantissa to keep. The range allowed
28
- depends on the dtype input data. If keepbits is
29
- equal to the maximum allowed for the data type, this is equivalent
30
- to no transform.
26
+ info_level: float
27
+ The level of information to preserve in the data. The value should be
28
+ between 0. and 1.0. Higher values preserve more information.
31
29
32
30
axes: int or list of int, optional
33
31
Axes along which to calculate the bit information. If None, all axes
@@ -36,13 +34,22 @@ class BitInfo(BitRound):
36
34
37
35
codec_id = 'bitinfo'
38
36
39
- def __init__ (self , inflevel : float , axes = None ):
40
- if (inflevel < 0 ) or (inflevel > 1.0 ):
41
- raise ValueError ("Please provide `inflevel ` from interval [0.,1.]" )
37
+ def __init__ (self , info_level : float , axes = None ):
38
+ if (info_level < 0 ) or (info_level > 1.0 ):
39
+ raise ValueError ("Please provide `info_level ` from interval [0.,1.]" )
42
40
43
- self .inflevel = inflevel
41
+ elif axes is not None and not isinstance (axes , list ):
42
+ if int (axes ) != axes :
43
+ raise ValueError ("axis must be an integer or a list of integers." )
44
+ axes = [axes ]
45
+
46
+ elif isinstance (axes , list ) and not all (int (ax ) == ax for ax in axes ):
47
+ raise ValueError ("axis must be an integer or a list of integers." )
48
+
49
+ self .info_level = info_level
44
50
self .axes = axes
45
51
52
+
46
53
def encode (self , buf ):
47
54
"""Create int array by rounding floating-point data
48
55
@@ -68,11 +75,11 @@ def encode(self, buf):
68
75
69
76
for ax in self .axes :
70
77
info_per_bit = bitinformation (a , axis = ax )
71
- keepbits .append (get_keepbits (info_per_bit , self .inflevel ))
78
+ keepbits .append (get_keepbits (info_per_bit , self .info_level ))
72
79
73
80
keepbits = max (keepbits )
74
81
75
- return BitRound ._bitround ( a , keepbits , dtype )
82
+ return BitRound .bitround ( buf , keepbits , dtype )
76
83
77
84
78
85
def exponent_bias (dtype ):
@@ -117,12 +124,12 @@ def signed_exponent(A):
117
124
118
125
Parameters
119
126
----------
120
- A : :py:class:`numpy. array`
127
+ a : array
121
128
Array to transform
122
129
123
130
Returns
124
131
-------
125
- B : :py:class:`numpy. array`
132
+ array
126
133
127
134
Example
128
135
-------
@@ -162,8 +169,7 @@ def signed_exponent(A):
162
169
eabs = np .uint64 (eabs )
163
170
esign = np .uint64 (esign )
164
171
esigned = esign | (eabs << sbits )
165
- B = (sf | esigned ).view (np .int64 )
166
- return B
172
+ return (sf | esigned ).view (np .int64 )
167
173
168
174
169
175
def bitpaircount_u1 (a , b ):
@@ -260,7 +266,8 @@ def get_keepbits(info_per_bit, inflevel=0.99):
260
266
261
267
def _cdf_from_info_per_bit (info_per_bit ):
262
268
"""Convert info_per_bit to cumulative distribution function"""
263
- tol = info_per_bit [- 4 :].max () * 1.5
264
- info_per_bit [info_per_bit < tol ] = 0
269
+ # TODO this threshold isn't working yet
270
+ #tol = info_per_bit[-4:].max() * 1.5
271
+ #info_per_bit[info_per_bit < tol] = 0
265
272
cdf = info_per_bit .cumsum ()
266
273
return cdf / cdf [- 1 ]
0 commit comments