Skip to content

Commit b49320c

Browse files
committed
implemented slice reduction for iter_contiguous
1 parent 6038cca commit b49320c

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

src/_arraykit.c

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,12 +3356,14 @@ AK_build_slice(Py_ssize_t start, Py_ssize_t stop, Py_ssize_t step)
33563356
return new;
33573357
}
33583358

3359-
// Given inclusive start, end indices, returns a new reference to a slice. Returns NULL on error.
3359+
// Given inclusive start, end indices, returns a new reference to a slice. Returns NULL on error. If `reduce` is True, single width slices return an integer.
33603360
static inline PyObject*
3361-
AK_build_slice_inclusive(Py_ssize_t start, Py_ssize_t end)
3361+
AK_build_slice_inclusive(Py_ssize_t start, Py_ssize_t end, bool reduce)
33623362
{
3363-
// TODO: if start and end are the same, return an integer; this can be configured with a parameter
3364-
assert(start >= 0);
3363+
if (reduce && start == end) {
3364+
return PyLong_FromSsize_t(start); // new ref
3365+
}
3366+
// assert(start >= 0);
33653367
if (start <= end) {
33663368
return AK_build_slice(start, end + 1, 1);
33673369
}
@@ -4654,10 +4656,15 @@ typedef struct BIIterContiguousObject {
46544656
Py_ssize_t last_column;
46554657
Py_ssize_t next_block;
46564658
Py_ssize_t next_column;
4659+
bool reduce; // optionally reduce slices to integers
46574660
} BIIterContiguousObject;
46584661

46594662
static PyObject *
4660-
BIIterContiguous_new(BlockIndexObject *bi, int8_t reversed, PyObject* iter) {
4663+
BIIterContiguous_new(BlockIndexObject *bi,
4664+
int8_t reversed,
4665+
PyObject* iter,
4666+
bool reduce)
4667+
{
46614668
BIIterContiguousObject *bii = PyObject_New(BIIterContiguousObject, &BIIterContiguousType);
46624669
if (!bii) {
46634670
return NULL;
@@ -4672,6 +4679,7 @@ BIIterContiguous_new(BlockIndexObject *bi, int8_t reversed, PyObject* iter) {
46724679
bii->last_column = -1;
46734680
bii->next_block = -1;
46744681
bii->next_column = -1;
4682+
bii->reduce = reduce;
46754683

46764684
return (PyObject *)bii;
46774685
}
@@ -4726,7 +4734,9 @@ BIIterContiguous_iternext(BIIterContiguousObject *self) {
47264734
self->next_block = -2;
47274735
return Py_BuildValue("nN", // N steals ref
47284736
self->last_block,
4729-
AK_build_slice_inclusive(slice_start, self->last_column));
4737+
AK_build_slice_inclusive(slice_start,
4738+
self->last_column,
4739+
self->reduce));
47304740
}
47314741
// i is gauranteed to be within the range of self->bit_count at this point; the only source of arbitrary indices is in BIIterSeq_iternext_core, and that function validates the range
47324742
BlockIndexRecord* biri = &self->bi->bir[i];
@@ -4749,7 +4759,9 @@ BIIterContiguous_iternext(BIIterContiguousObject *self) {
47494759
self->next_column = column;
47504760
return Py_BuildValue("nN", // N steals ref
47514761
self->last_block,
4752-
AK_build_slice_inclusive(slice_start, self->last_column));
4762+
AK_build_slice_inclusive(slice_start,
4763+
self->last_column,
4764+
self->reduce));
47534765
}
47544766
return NULL;
47554767
}
@@ -4778,7 +4790,10 @@ BIIterContiguous_reversed(BIIterContiguousObject *self) {
47784790
!self->reversed,
47794791
BIIS_UNKNOWN, // let type be determined by selector
47804792
0);
4781-
PyObject* biiter = BIIterContiguous_new(self->bi, !self->reversed, self->iter);
4793+
PyObject* biiter = BIIterContiguous_new(self->bi,
4794+
!self->reversed,
4795+
self->iter,
4796+
self->reduce);
47824797
Py_DECREF(iter);
47834798
return biiter;
47844799
}
@@ -5389,6 +5404,7 @@ BlockIndex_iter_select(BlockIndexObject *self, PyObject *selector){
53895404
static char *iter_contiguous_kargs_names[] = {
53905405
"selector",
53915406
"ascending",
5407+
"reduce",
53925408
NULL
53935409
};
53945410

@@ -5398,19 +5414,21 @@ BlockIndex_iter_contiguous(BlockIndexObject *self, PyObject *args, PyObject *kwa
53985414
{
53995415
PyObject* selector;
54005416
int ascending = 0;
5417+
int reduce = 0;
54015418

54025419
if (!PyArg_ParseTupleAndKeywords(args, kwargs,
5403-
"O|$p:iter_contiguous",
5420+
"O|$pp:iter_contiguous",
54045421
iter_contiguous_kargs_names,
54055422
&selector,
5406-
&ascending
5423+
&ascending,
5424+
&reduce
54075425
)) {
54085426
return NULL;
54095427
}
54105428

54115429
// might need to store enum type for branching
54125430
PyObject* iter = BIIterSelector_new(self, selector, 0, BIIS_UNKNOWN, ascending);
5413-
PyObject* biiter = BIIterContiguous_new(self, 0, iter); // will incref iter
5431+
PyObject* biiter = BIIterContiguous_new(self, 0, iter, reduce); // will incref iter
54145432
Py_DECREF(iter);
54155433

54165434
return biiter;

test/test_block_index.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def test_block_index_iter_contiguous_d(self) -> None:
645645
[(0, slice(0, 4)), (1, slice(0, 4))]
646646
)
647647

648-
def test_block_index_iter_contiguous_e(self) -> None:
648+
def test_block_index_iter_contiguous_e1(self) -> None:
649649
bi1 = BlockIndex()
650650
bi1.register(np.arange(2))
651651
bi1.register(np.arange(2))
@@ -673,3 +673,32 @@ def test_block_index_iter_contiguous_e(self) -> None:
673673
list(bi1.iter_contiguous(np.array([6, 0, 7]), ascending=True)),
674674
[(0, slice(0, 1)), (6, slice(0, 1)), (7, slice(0, 1))]
675675
)
676+
677+
def test_block_index_iter_contiguous_e2(self) -> None:
678+
bi1 = BlockIndex()
679+
bi1.register(np.arange(2))
680+
bi1.register(np.arange(2))
681+
bi1.register(np.arange(2))
682+
bi1.register(np.arange(2))
683+
bi1.register(np.arange(2))
684+
bi1.register(np.arange(2))
685+
bi1.register(np.arange(2))
686+
bi1.register(np.arange(2))
687+
688+
self.assertEqual(
689+
list(bi1.iter_contiguous([6, 0, 7], reduce=True)),
690+
[(6, 0), (0, 0), (7, 0)]
691+
)
692+
self.assertEqual(
693+
list(bi1.iter_contiguous([6, 0, 7], ascending=True, reduce=True)),
694+
[(0, 0), (6, 0), (7, 0)]
695+
)
696+
697+
self.assertEqual(
698+
list(bi1.iter_contiguous(np.array([6, 0, 7]), reduce=True)),
699+
[(6, 0), (0, 0), (7, 0)]
700+
)
701+
self.assertEqual(
702+
list(bi1.iter_contiguous(np.array([6, 0, 7]), ascending=True, reduce=True)),
703+
[(0, 0), (6, 0), (7, 0)]
704+
)

0 commit comments

Comments
 (0)