1
1
import numpy as np
2
2
3
+ from typing import Callable
3
4
4
5
from .abc import Codec
5
6
from .compat import ensure_ndarray_like , ndarray_copy
@@ -29,19 +30,30 @@ class BitRound(Codec):
29
30
Parameters
30
31
----------
31
32
32
- keepbits: int
33
+ keepbits: int or function
33
34
The number of bits of the mantissa to keep. The range allowed
34
35
depends on the dtype input data. If keepbits is
35
36
equal to the maximum allowed for the data type, this is equivalent
36
- to no transform.
37
+ to no transform. Alternatively, pass a function to determine the
38
+ number of bits to keep from the input data. The function should
39
+ take a single argument, the input data, and return an integer
40
+ specifying the number of bits to keep.
37
41
"""
38
42
39
43
codec_id = 'bitround'
40
44
41
- def __init__ (self , keepbits : int ):
42
- if keepbits < 0 :
45
+ def __init__ (self , keepbits : [ int , Callable ] ):
46
+ if isinstance ( keepbits , int ) and keepbits < 0 :
43
47
raise ValueError ("keepbits must be zero or positive" )
44
- self .keepbits = keepbits
48
+
49
+ elif isinstance (keepbits , int ):
50
+ self .keepbits = [lambda x : keepbits ]
51
+
52
+ elif isinstance (keepbits , Callable ):
53
+ self .keepbits = keepbits
54
+
55
+ else :
56
+ raise TypeError ("keepbits must be an integer or function" )
45
57
46
58
def encode (self , buf ):
47
59
"""Create int array by rounding floating-point data
@@ -56,12 +68,13 @@ def encode(self, buf):
56
68
# cast float to int type of same width (preserve endianness)
57
69
a_int_dtype = np .dtype (a .dtype .str .replace ("f" , "i" ))
58
70
all_set = np .array (- 1 , dtype = a_int_dtype )
59
- if self .keepbits == bits :
71
+ buf_keepbits = self .keepbits (buf )
72
+ if buf_keepbits == bits :
60
73
return a
61
- if self . keepbits > bits :
74
+ if buf_keepbits > bits :
62
75
raise ValueError ("Keepbits too large for given dtype" )
63
76
b = a .view (a_int_dtype )
64
- maskbits = bits - self . keepbits
77
+ maskbits = bits - buf_keepbits
65
78
mask = (all_set >> maskbits ) << maskbits
66
79
half_quantum1 = (1 << (maskbits - 1 )) - 1
67
80
b += ((b >> maskbits ) & 1 ) + half_quantum1
0 commit comments