Skip to content

Commit 5ea3ca5

Browse files
committed
Optimizes out need for argsort. Credit to @ForeverWintr. Frees GIL in main loop.
1 parent 9a013f8 commit 5ea3ca5

File tree

2 files changed

+71
-52
lines changed

2 files changed

+71
-52
lines changed

performance/reference/util.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,4 @@ def get_new_indexers_and_screen_ak(
247247
if len(positions) > len(indexers):
248248
return np.unique(indexers, return_inverse=True)
249249

250-
# Will return same *objects* back if it was able to finish early.
251-
new_indexers, index_screen = ak_routine(indexers, positions)
252-
if new_indexers is indexers and index_screen is positions:
253-
return positions, indexers
254-
255-
# Use a more helpful alias!
256-
element_locations = index_screen
257-
258-
found_mask = element_locations != len(positions)
259-
260-
found_element_locations = element_locations[found_mask]
261-
order_found = np.argsort(found_element_locations)
262-
263-
found_positions = positions[found_mask]
264-
return found_positions[order_found], new_indexers
250+
return ak_routine(indexers, positions)

src/_arraykit.c

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -576,30 +576,23 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
576576
577577
num_unique = len(positions)
578578
element_locations = np.full(num_unique, num_unique, dtype=np.int64)
579+
order_found = np.full(num_unique, num_unique, dtype=np.int64)
579580
new_indexers = np.empty(len(indexers), dtype=np.int64)
580581
581582
num_found = 0
582583
583584
for i, element in enumerate(indexers):
584585
if element_locations[element] == num_unique:
585586
element_locations[element] = num_found
587+
order_found[num_found] = element
586588
num_found += 1
587589
588590
if num_found == num_unique:
589591
return positions, indexers
590592
591593
new_indexers[i] = element_locations[element]
592594
593-
return element_locations, new_indexers
594-
# ...
595-
# NOTE: These return values will be used in a Python wrapper like this:
596-
found_mask = element_locations != num_unique
597-
598-
found_element_locations = element_locations[found_mask]
599-
order_found = np.argsort(found_element_locations)
600-
601-
found_positions = positions[found_mask]
602-
return found_positions[order_found], new_indexers
595+
return order_found[:num_found], new_indexers
603596
*/
604597
PyArrayObject *indexers;
605598
PyArrayObject *positions;
@@ -644,46 +637,61 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
644637
}
645638

646639
npy_intp dims = {num_unique};
647-
PyArrayObject *element_locations = (PyArrayObject*)PyArray_Empty(
648-
1, // ndim
649-
&dims, // shape
650-
PyArray_DescrFromType(NPY_INT64), // dtype
651-
0 // fortran
640+
PyArrayObject *element_locations = (PyArrayObject*)PyArray_EMPTY(
641+
1, // ndim
642+
&dims, // shape
643+
NPY_INT64, // dtype
644+
0 // fortran
652645
);
653646
if (element_locations == NULL) {
654647
return NULL;
655648
}
656649

657-
PyObject *num_unique_pyint = PyLong_FromLong(num_unique);
658-
if (num_unique_pyint == NULL) {
650+
PyArrayObject *order_found = (PyArrayObject*)PyArray_EMPTY(
651+
1, // ndim
652+
&dims, // shape
653+
NPY_INT64, // dtype
654+
0 // fortran
655+
);
656+
if (order_found == NULL) {
659657
Py_DECREF(element_locations);
660658
return NULL;
661659
}
662660

663-
// We use ``num_unique`` to signal that we haven't found the element yet
661+
PyObject *num_unique_pyint = PyLong_FromLong(num_unique);
662+
if (num_unique_pyint == NULL) {
663+
goto fail;
664+
}
665+
666+
// We use ``num_unique`` here to signal that we haven't found the element yet
664667
// This works, because each element must be 0 < num_unique.
665668
int fill_success = PyArray_FillWithScalar(element_locations, num_unique_pyint);
669+
if (fill_success != 0) {
670+
Py_DECREF(num_unique_pyint);
671+
goto fail;
672+
}
673+
674+
fill_success = PyArray_FillWithScalar(order_found, num_unique_pyint);
666675
Py_DECREF(num_unique_pyint);
667676
if (fill_success != 0) {
668-
Py_DECREF(element_locations);
669-
return NULL;
677+
goto fail;
670678
}
671679

672-
PyArrayObject *new_indexers = (PyArrayObject*)PyArray_Empty(
673-
1, // ndim
674-
PyArray_DIMS(indexers), // shape
675-
PyArray_DescrFromType(NPY_INT64), // dtype
676-
0 // fortran
680+
PyArrayObject *new_indexers = (PyArrayObject*)PyArray_EMPTY(
681+
1, // ndim
682+
PyArray_DIMS(indexers), // shape
683+
NPY_INT64, // dtype
684+
0 // fortran
677685
);
678686
if (new_indexers == NULL) {
679-
Py_DECREF(element_locations);
680-
return NULL;
687+
goto fail;
681688
}
682689

683690
// We know that our incoming dtypes are all int64! This is a safe cast.
684691
// Plus, it's easier (and less error prone) to work with native C-arrays
685692
// over using numpy's iteration APIs.
686693
npy_int64 *element_location_values = (npy_int64*)PyArray_DATA(element_locations);
694+
npy_int64 *order_found_values = (npy_int64*)PyArray_DATA(order_found);
687695
npy_int64 *new_indexers_values = (npy_int64*)PyArray_DATA(new_indexers);
688696

689697
// Now, implement the core algorithm by looping over the ``indexers``.
@@ -698,25 +706,27 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
698706
NULL // dtype
699707
);
700708
if (indexer_iter == NULL) {
701-
Py_DECREF(element_locations);
702709
Py_DECREF(new_indexers);
703-
return NULL;
710+
goto fail;
704711
}
705712

706713
// The iternext function gets stored in a local variable so it can be called repeatedly in an efficient manner.
707714
NpyIter_IterNextFunc *indexer_iternext = NpyIter_GetIterNext(indexer_iter, NULL);
708715
if (indexer_iternext == NULL) {
709716
NpyIter_Deallocate(indexer_iter);
710-
Py_DECREF(element_locations);
711717
Py_DECREF(new_indexers);
712-
return NULL;
718+
goto fail;
713719
}
714720

715721
// All of these will be updated by the iterator
716722
char **dataptr = NpyIter_GetDataPtrArray(indexer_iter);
717723
npy_intp *strideptr = NpyIter_GetInnerStrideArray(indexer_iter);
718724
npy_intp *innersizeptr = NpyIter_GetInnerLoopSizePtr(indexer_iter);
719725

726+
// No gil is required from here on!
727+
NPY_BEGIN_THREADS_DEF;
728+
NPY_BEGIN_THREADS;
729+
720730
size_t i = 0;
721731
npy_int64 num_found = 0;
722732
do {
@@ -731,17 +741,15 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
731741

732742
if (element_location_values[element] == num_unique) {
733743
element_location_values[element] = num_found;
744+
order_found_values[num_found] = element;
734745
++num_found;
735746

736747
if (num_found == num_unique) {
737748
// This insight is core to the performance of the algorithm.
738749
// If we have found every possible indexer, we can simply return
739750
// back the inputs! Essentially, we can observe on <= single pass
740751
// that we have the opportunity for re-use
741-
NpyIter_Deallocate(indexer_iter);
742-
Py_DECREF(element_locations);
743-
Py_DECREF(new_indexers);
744-
return PyTuple_Pack(2, indexers, positions);
752+
goto finish_early;
745753
}
746754
}
747755

@@ -754,12 +762,37 @@ get_new_indexers_and_screen(PyObject *Py_UNUSED(m), PyObject *args, PyObject *kw
754762
// Increment the iterator to the next inner loop
755763
} while(indexer_iternext(indexer_iter));
756764

757-
NpyIter_Deallocate(indexer_iter);
765+
NPY_END_THREADS;
758766

759-
PyObject *result = PyTuple_Pack(2, new_indexers, element_locations);
767+
NpyIter_Deallocate(indexer_iter);
760768
Py_DECREF(element_locations);
769+
770+
// new_positions = order_found[:num_unique]
771+
PyObject *new_positions = PySequence_GetSlice(order_found, 0, num_found);
772+
Py_DECREF(order_found);
773+
if (new_positions == NULL) {
774+
return NULL;
775+
}
776+
777+
// return new_positions, new_indexers
778+
PyObject *result = PyTuple_Pack(2, new_positions, new_indexers);
761779
Py_DECREF(new_indexers);
780+
Py_DECREF(new_positions);
762781
return result;
782+
783+
finish_early:
784+
NPY_END_THREADS;
785+
786+
NpyIter_Deallocate(indexer_iter);
787+
Py_DECREF(element_locations);
788+
Py_DECREF(order_found);
789+
Py_DECREF(new_indexers);
790+
return PyTuple_Pack(2, positions, indexers);
791+
792+
fail:
793+
Py_DECREF(element_locations);
794+
Py_DECREF(order_found);
795+
return NULL;
763796
}
764797

765798
//------------------------------------------------------------------------------

0 commit comments

Comments
 (0)