Skip to content

Commit cbc0087

Browse files
committed
normalize reported rows in BlockIndex, updated caching
1 parent ff184de commit cbc0087

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/_arraykit.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5304,6 +5304,7 @@ BlockIndex_register(BlockIndexObject *self, PyObject *value) {
53045304
Py_ssize_t alignment = PyArray_DIM(a, 0);
53055305
if (self->row_count == -1) {
53065306
self->row_count = alignment;
5307+
self->shape_recache = true; // setting rows, must recache shape
53075308
}
53085309
else if (self->row_count != alignment) {
53095310
PyErr_Format(ErrorInitTypeBlocks,
@@ -5422,12 +5423,17 @@ BlockIndex_setstate(BlockIndexObject *self, PyObject *state)
54225423
//------------------------------------------------------------------------------
54235424
// getters
54245425

5426+
// Never expose a negative row value to the caller
5427+
#define AK_BI_ROWS(rows) ((rows) < 0 ? 0 : (rows))
5428+
54255429
static PyObject *
54265430
BlockIndex_shape_getter(BlockIndexObject *self, void* Py_UNUSED(closure))
54275431
{
54285432
if (self->shape == NULL || self->shape_recache) {
54295433
Py_XDECREF(self->shape); // get rid of old if it exists
5430-
self->shape = AK_build_pair_ssize_t(self->row_count, self->bir_count);
5434+
self->shape = AK_build_pair_ssize_t(
5435+
AK_BI_ROWS(self->row_count),
5436+
self->bir_count);
54315437
}
54325438
// shape is not null and shape_recache is false
54335439
Py_INCREF(self->shape); // for caller
@@ -5437,7 +5443,7 @@ BlockIndex_shape_getter(BlockIndexObject *self, void* Py_UNUSED(closure))
54375443

54385444
static PyObject *
54395445
BlockIndex_rows_getter(BlockIndexObject *self, void* Py_UNUSED(closure)){
5440-
return PyLong_FromSsize_t(self->row_count);
5446+
return PyLong_FromSsize_t(AK_BI_ROWS(self->row_count));
54415447
}
54425448

54435449
static PyObject *

test/test_block_index.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_block_index_init_a(self) -> None:
1616
bi1 = BlockIndex()
1717
self.assertEqual(bi1.dtype, np.dtype(float))
1818
s = bi1.shape
19-
self.assertEqual(s, (-1, 0))
19+
self.assertEqual(s, (0, 0))
2020
del bi1
21-
self.assertEqual(s, (-1, 0))
21+
self.assertEqual(s, (0, 0))
2222
del s
2323

2424
def test_block_index_init_b1(self) -> None:
@@ -850,3 +850,17 @@ def test_block_index_iter_block_c(self) -> None:
850850

851851
slc = slice(None)
852852
self.assertEqual(list(bi1.iter_block()), [(i, slc) for i in range(8)])
853+
854+
#---------------------------------------------------------------------------
855+
856+
def test_block_index_shape_a(self) -> None:
857+
bi1 = BlockIndex()
858+
self.assertEqual(bi1.shape, (0, 0))
859+
self.assertEqual(bi1.rows, 0)
860+
861+
bi1.register(np.array(()).reshape(2,0))
862+
self.assertEqual(bi1.shape, (2, 0))
863+
self.assertEqual(bi1.rows, 2)
864+
865+
with self.assertRaises(ErrorInitTypeBlocks):
866+
bi1.register(np.array(()).reshape(3,0))

0 commit comments

Comments
 (0)