Skip to content

Commit 1ad5259

Browse files
committed
Cleans up new numpy iteration code.
1 parent 88efa30 commit 1ad5259

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

src/_arraykit.c

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -692,39 +692,47 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
692692

693693
npy_int64 num_found = 0;
694694

695-
NpyIter* iter;
696-
NpyIter_IterNextFunc *iternext;
697-
char** dataptr;
698-
npy_intp* strideptr,* innersizeptr;
699-
npy_int64 element;
695+
// Now, implement the core algorithm by looping over the ``indexers``.
696+
// We need to use numpy's iteration API, as the ``indexers`` could be
697+
// C-contiguous, F-contiguous, both, or neither.
698+
// See https://numpy.org/doc/stable/reference/c-api/iterator.html#simple-iteration-example
700699

701-
iter = NpyIter_New(
700+
NpyIter *indexer_iter = NpyIter_New(
702701
indexers,
703-
NPY_ITER_READONLY| NPY_ITER_EXTERNAL_LOOP| NPY_ITER_REFS_OK,
702+
NPY_ITER_READONLY| NPY_ITER_EXTERNAL_LOOP,
704703
NPY_KEEPORDER,
705704
NPY_NO_CASTING,
706705
NULL
707706
);
708-
if (iter == NULL) {
709-
return -1;
707+
if (indexer_iter == NULL) {
708+
Py_DECREF(element_locations);
709+
Py_DECREF(new_indexers);
710+
return NULL;
710711
}
711712

712-
iternext = NpyIter_GetIterNext(iter, NULL);
713-
if (iternext == NULL) {
714-
NpyIter_Deallocate(iter);
715-
return -1;
713+
// The iternext function gets stored in a local variable so it can be called repeatedly in an efficient manner.
714+
NpyIter_IterNextFunc *indexer_iternext = NpyIter_GetIterNext(indexer_iter, NULL);
715+
if (indexer_iternext == NULL) {
716+
NpyIter_Deallocate(indexer_iter);
717+
Py_DECREF(element_locations);
718+
Py_DECREF(new_indexers);
719+
return NULL;
716720
}
717-
dataptr = NpyIter_GetDataPtrArray(iter);
718-
strideptr = NpyIter_GetInnerStrideArray(iter);
719-
innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
721+
722+
// All of these will be updated by the iterator
723+
char **dataptr = NpyIter_GetDataPtrArray(indexer_iter);
724+
npy_intp *strideptr = NpyIter_GetInnerStrideArray(indexer_iter);
725+
npy_intp *innersizeptr = NpyIter_GetInnerLoopSizePtr(indexer_iter);
720726

721727
size_t i = 0;
722728
do {
729+
// Get the inner loop data/stride/inner_size values
723730
char* data = *dataptr;
724731
npy_intp stride = *strideptr;
725-
npy_intp count = *innersizeptr;
732+
npy_intp inner_size = *innersizeptr;
733+
npy_int64 element;
726734

727-
while (count--) {
735+
while (inner_size--) {
728736
memcpy (&element, data, sizeof (long));
729737

730738
if (element_location_values[element] == num_unique)
@@ -738,7 +746,7 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
738746
// If we have found every possible indexer, we can simply return
739747
// back the inputs! Essentially, we can observe on <= single pass
740748
// that we have the opportunity for re-use
741-
NpyIter_Deallocate(iter);
749+
NpyIter_Deallocate(indexer_iter);
742750
Py_DECREF(element_locations);
743751
Py_DECREF(new_indexers);
744752
return PyTuple_Pack(2, indexers, positions);
@@ -751,9 +759,10 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
751759
++i;
752760
}
753761

754-
} while(iternext(iter));
762+
// Increment the iterator to the next inner loop
763+
} while(indexer_iternext(indexer_iter));
755764

756-
NpyIter_Deallocate(iter);
765+
NpyIter_Deallocate(indexer_iter);
757766

758767
PyObject *result = PyTuple_Pack(2, new_indexers, element_locations);
759768
Py_DECREF(element_locations);

0 commit comments

Comments
 (0)