Skip to content

Commit 497a50b

Browse files
committed
Enable adaptive bitround
1 parent b429e4b commit 497a50b

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

numcodecs/bitround.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22

3+
from typing import Callable
34

45
from .abc import Codec
56
from .compat import ensure_ndarray_like, ndarray_copy
@@ -29,19 +30,30 @@ class BitRound(Codec):
2930
Parameters
3031
----------
3132
32-
keepbits: int
33+
keepbits: int or function
3334
The number of bits of the mantissa to keep. The range allowed
3435
depends on the dtype input data. If keepbits is
3536
equal to the maximum allowed for the data type, this is equivalent
36-
to no transform.
37+
to no transform. Alternatively, pass a function to determine the
38+
number of bits to keep from the input data. The function should
39+
take a single argument, the input data, and return an integer
40+
specifying the number of bits to keep.
3741
"""
3842

3943
codec_id = 'bitround'
4044

41-
def __init__(self, keepbits: int):
42-
if keepbits < 0:
45+
def __init__(self, keepbits: [int, Callable]):
46+
if isinstance(keepbits, int) and keepbits < 0:
4347
raise ValueError("keepbits must be zero or positive")
44-
self.keepbits = keepbits
48+
49+
elif isinstance(keepbits, int):
50+
self.keepbits = [lambda x: keepbits]
51+
52+
elif isinstance(keepbits, Callable):
53+
self.keepbits = keepbits
54+
55+
else:
56+
raise TypeError("keepbits must be an integer or function")
4557

4658
def encode(self, buf):
4759
"""Create int array by rounding floating-point data
@@ -56,12 +68,13 @@ def encode(self, buf):
5668
# cast float to int type of same width (preserve endianness)
5769
a_int_dtype = np.dtype(a.dtype.str.replace("f", "i"))
5870
all_set = np.array(-1, dtype=a_int_dtype)
59-
if self.keepbits == bits:
71+
buf_keepbits = self.keepbits(buf)
72+
if buf_keepbits == bits:
6073
return a
61-
if self.keepbits > bits:
74+
if buf_keepbits > bits:
6275
raise ValueError("Keepbits too large for given dtype")
6376
b = a.view(a_int_dtype)
64-
maskbits = bits - self.keepbits
77+
maskbits = bits - buf_keepbits
6578
mask = (all_set >> maskbits) << maskbits
6679
half_quantum1 = (1 << (maskbits - 1)) - 1
6780
b += ((b >> maskbits) & 1) + half_quantum1

0 commit comments

Comments
 (0)