Skip to content

Commit 87cd098

Browse files
committed
Add basic tests
1 parent ffde755 commit 87cd098

File tree

6 files changed

+112
-1
lines changed

6 files changed

+112
-1
lines changed

docs/bitinfo.rst

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
PCodec
2+
======
3+
4+
.. automodule:: numcodecs.bitinfo
5+
6+
.. autoclass:: BitInfo
7+
8+
.. autoattribute:: codec_id
9+
.. automethod:: encode
10+
.. automethod:: decode

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Contents
7777
delta
7878
fixedscaleoffset
7979
quantize
80+
bitinfo
8081
bitround
8182
packbits
8283
categorize

docs/release.rst

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Unreleased
1414
Enhancements
1515
~~~~~~~~~~~~
1616

17+
* Add BitInfo codec
18+
By :user:`Tim Hodson <thodson-usgs>`.
1719
* Use PyData theme for docs
1820
By :user:`John Kirkham <jakirkham>`, :issue:`485`.
1921

numcodecs/bitinfo.py

+9
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def exponent_mask(dtype):
115115
mask = 0x7F80_0000
116116
elif dtype == np.float64:
117117
mask = 0x7FF0_0000_0000_0000
118+
else:
119+
raise ValueError(f"Unsupported dtype {dtype}")
118120
return mask
119121

120122

@@ -175,6 +177,7 @@ def signed_exponent(A):
175177
def bitpaircount_u1(a, b):
176178
assert a.dtype == "u1"
177179
assert b.dtype == "u1"
180+
178181
unpack_a = np.unpackbits(a.flatten()).astype("u1")
179182
unpack_b = np.unpackbits(b.flatten()).astype("u1")
180183

@@ -188,6 +191,7 @@ def bitpaircount_u1(a, b):
188191
def bitpaircount(a, b):
189192
assert a.dtype.kind == "u"
190193
assert b.dtype.kind == "u"
194+
191195
nbytes = max(a.dtype.itemsize, b.dtype.itemsize)
192196

193197
a, b = np.broadcast_arrays(a, b)
@@ -203,6 +207,9 @@ def bitpaircount(a, b):
203207
def mutual_information(a, b, base=2):
204208
"""Calculate the mutual information between two arrays.
205209
"""
210+
assert a.dtype == b.dtype
211+
assert a.dtype.kind == "u"
212+
206213
size = np.prod(np.broadcast_shapes(a.shape, b.shape))
207214
counts = bitpaircount(a, b)
208215

@@ -228,6 +235,8 @@ def bitinformation(a, axis=0):
228235
-------
229236
info_per_bit : array
230237
"""
238+
assert a.dtype.kind == "u"
239+
231240
sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape)))
232241
sb = tuple(
233242
slice(1, None) if i == axis else slice(None) for i in range(len(a.shape))

numcodecs/bitround.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,23 @@ def decode(self, buf, out=None):
7373
return ndarray_copy(data, out)
7474

7575
@staticmethod
76-
def bitround(buf, keepbits, dtype):
76+
def bitround(buf, keepbits: int, dtype):
77+
"""Drop bits from the mantissa of a floating point array
78+
79+
Parameters
80+
----------
81+
buf: ndarray
82+
The input array
83+
keepbits: int
84+
The number of bits to keep
85+
dtype: dtype
86+
The dtype of the input array
87+
88+
Returns
89+
-------
90+
ndarray
91+
The bitrounded array transformed to an integer type
92+
"""
7793
bits = max_bits[str(dtype)]
7894
a_int_dtype = np.dtype(buf.dtype.str.replace("f", "i"))
7995
all_set = np.array(-1, dtype=a_int_dtype)

numcodecs/tests/test_bitinfo.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
3+
import pytest
4+
5+
from numcodecs.bitinfo import BitInfo, exponent_bias, mutual_information
6+
7+
def test_bitinfo_initialization():
8+
bitinfo = BitInfo(0.5)
9+
assert bitinfo.info_level == 0.5
10+
assert bitinfo.axes is None
11+
12+
bitinfo = BitInfo(0.5, axes=1)
13+
assert bitinfo.axes == [1]
14+
15+
bitinfo = BitInfo(0.5, axes=[1, 2])
16+
assert bitinfo.axes == [1, 2]
17+
18+
with pytest.raises(ValueError):
19+
BitInfo(-0.1)
20+
21+
with pytest.raises(ValueError):
22+
BitInfo(1.1)
23+
24+
with pytest.raises(ValueError):
25+
BitInfo(0.5, axes=1.5)
26+
27+
with pytest.raises(ValueError):
28+
BitInfo(0.5, axes=[1, 1.5])
29+
30+
31+
def test_bitinfo_encode():
32+
bitinfo = BitInfo(info_level=0.5)
33+
a = np.array([1.0, 2.0, 3.0], dtype="float32")
34+
encoded = bitinfo.encode(a)
35+
decoded = bitinfo.decode(encoded)
36+
assert decoded.dtype == a.dtype
37+
38+
39+
def test_bitinfo_encode_errors():
40+
bitinfo = BitInfo(0.5)
41+
a = np.array([1, 2, 3], dtype="int32")
42+
with pytest.raises(TypeError):
43+
bitinfo.encode(a)
44+
45+
a = np.array([1.0, 2.0, 3.0], dtype="float128")
46+
with pytest.raises(TypeError):
47+
bitinfo.encode(a)
48+
49+
50+
def test_exponent_bias():
51+
assert exponent_bias("f2") == 15
52+
assert exponent_bias("f4") == 127
53+
assert exponent_bias("f8") == 1023
54+
55+
with pytest.raises(ValueError):
56+
exponent_bias("int32")
57+
58+
59+
def test_mutual_information():
60+
""" Test mutual information calculation
61+
62+
Tests for changes to the mutual_information
63+
but not the correcteness of the original.
64+
"""
65+
a = np.arange(10.0, dtype='float32')
66+
b = a + 1000
67+
c = a[::-1].copy()
68+
dt = np.dtype('uint32')
69+
a,b,c = map(lambda x: x.view(dt), [a,b,c])
70+
71+
assert mutual_information(a, a).sum() == 7.020411549771797
72+
assert mutual_information(a, b).sum() == 0.0
73+
assert mutual_information(a, c).sum() == 0.6545015579460758

0 commit comments

Comments
 (0)