@@ -23,11 +23,9 @@ class BitInfo(BitRound):
2323    Parameters 
2424    ---------- 
2525
26-     inflevel: float 
27-         The number of bits of the mantissa to keep. The range allowed 
28-         depends on the dtype input data. If keepbits is 
29-         equal to the maximum allowed for the data type, this is equivalent 
30-         to no transform. 
26+     info_level: float 
27+         The level of information to preserve in the data. The value should be 
28+         between 0. and 1.0. Higher values preserve more information. 
3129
3230    axes: int or list of int, optional 
3331        Axes along which to calculate the bit information. If None, all axes 
@@ -36,13 +34,22 @@ class BitInfo(BitRound):
3634
3735    codec_id  =  'bitinfo' 
3836
39-     def  __init__ (self , inflevel : float , axes = None ):
40-         if  (inflevel  <  0 ) or  (inflevel  >  1.0 ):
41-             raise  ValueError ("Please provide `inflevel ` from interval [0.,1.]" )
37+     def  __init__ (self , info_level : float , axes = None ):
38+         if  (info_level  <  0 ) or  (info_level  >  1.0 ):
39+             raise  ValueError ("Please provide `info_level ` from interval [0.,1.]" )
4240
43-         self .inflevel  =  inflevel 
41+         elif  axes  is  not   None  and  not  isinstance (axes , list ):
42+             if  int (axes ) !=  axes :
43+                 raise  ValueError ("axis must be an integer or a list of integers." )
44+             axes  =  [axes ]
45+ 
46+         elif  isinstance (axes , list ) and  not  all (int (ax ) ==  ax  for  ax  in  axes ):
47+             raise  ValueError ("axis must be an integer or a list of integers." )
48+ 
49+         self .info_level  =  info_level 
4450        self .axes  =  axes 
4551
52+ 
4653    def  encode (self , buf ):
4754        """Create int array by rounding floating-point data 
4855
@@ -68,11 +75,11 @@ def encode(self, buf):
6875
6976        for  ax  in  self .axes :
7077            info_per_bit  =  bitinformation (a , axis = ax )
71-             keepbits .append (get_keepbits (info_per_bit , self .inflevel ))
78+             keepbits .append (get_keepbits (info_per_bit , self .info_level ))
7279
7380        keepbits  =  max (keepbits )
7481
75-         return  BitRound ._bitround ( a , keepbits , dtype )
82+         return  BitRound .bitround ( buf , keepbits , dtype )
7683
7784
7885def  exponent_bias (dtype ):
@@ -117,12 +124,12 @@ def signed_exponent(A):
117124
118125    Parameters 
119126    ---------- 
120-     A  : :py:class:`numpy. array`  
127+     a  : array 
121128        Array to transform 
122129
123130    Returns 
124131    ------- 
125-     B : :py:class:`numpy. array`  
132+     array 
126133
127134    Example 
128135    ------- 
@@ -162,8 +169,7 @@ def signed_exponent(A):
162169        eabs  =  np .uint64 (eabs )
163170        esign  =  np .uint64 (esign )
164171    esigned  =  esign  |  (eabs  <<  sbits )
165-     B  =  (sf  |  esigned ).view (np .int64 )
166-     return  B 
172+     return  (sf  |  esigned ).view (np .int64 )
167173
168174
169175def  bitpaircount_u1 (a , b ):
@@ -260,7 +266,8 @@ def get_keepbits(info_per_bit, inflevel=0.99):
260266
261267def  _cdf_from_info_per_bit (info_per_bit ):
262268    """Convert info_per_bit to cumulative distribution function""" 
263-     tol  =  info_per_bit [- 4 :].max () *  1.5 
264-     info_per_bit [info_per_bit  <  tol ] =  0 
269+     # TODO this threshold isn't working yet 
270+     #tol = info_per_bit[-4:].max() * 1.5 
271+     #info_per_bit[info_per_bit < tol] = 0 
265272    cdf  =  info_per_bit .cumsum ()
266273    return  cdf  /  cdf [- 1 ]
0 commit comments