Skip to content

Commit e3d3503

Browse files
committed
implemented count_iteration
1 parent 60e0c9d commit e3d3503

File tree

6 files changed

+72
-2
lines changed

6 files changed

+72
-2
lines changed

performance/__main__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from performance.reference.util import get_new_indexers_and_screen_ak
2727
from performance.reference.util import get_new_indexers_and_screen_ref
2828
from performance.reference.util import split_after_count as split_after_count_ref
29+
from performance.reference.util import count_iteration as count_iteration_ref
2930

3031
from performance.reference.array_go import ArrayGO as ArrayGOREF
3132

@@ -43,6 +44,7 @@
4344
from arraykit import delimited_to_arrays as delimited_to_arrays_ak
4445
from arraykit import isna_element as isna_element_ak
4546
from arraykit import split_after_count as split_after_count_ak
47+
from arraykit import count_iteration as count_iteration_ak
4648

4749
from arraykit import ArrayGO as ArrayGOAK
4850

@@ -716,6 +718,24 @@ class SplitAfterCountREF(SplitAfterCount):
716718
entry = staticmethod(split_after_count_ref)
717719

718720

721+
#-------------------------------------------------------------------------------
722+
class CountIterations(Perf):
723+
NUMBER = 10_000
724+
725+
def __init__(self):
726+
self.strio = io.StringIO('\n'.join(['abcd'] * 10_000))
727+
728+
def main(self):
729+
post = self.entry(self.strio)
730+
self.strio.seek(0)
731+
732+
class CountIterationsAK(CountIterations):
733+
entry = staticmethod(count_iteration_ak)
734+
735+
class CountIterationsREF(CountIterations):
736+
entry = staticmethod(count_iteration_ref)
737+
738+
719739
#-------------------------------------------------------------------------------
720740

721741
def get_arg_parser():

performance/reference/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,12 @@ def split_after_count(string: str, delimiter: str, count: int):
254254
*left, right = string.split(delimiter, maxsplit=count)
255255
return ','.join(left), right
256256

257+
def count_iteration(iterable: tp.Iterable):
258+
count = 0
259+
for i in iterable:
260+
count += 1
261+
return count
262+
263+
264+
257265

src/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from ._arraykit import iterable_str_to_array_1d as iterable_str_to_array_1d
2121
from ._arraykit import get_new_indexers_and_screen as get_new_indexers_and_screen
2222
from ._arraykit import split_after_count as split_after_count
23+
from ._arraykit import count_iteration as count_iteration

src/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def split_after_count(
5252
count: int,
5353
) -> tp.Tuple[str, str]: ...
5454

55+
def count_iteration(__iterable: tp.Iterable) -> int: ...
56+
5557
def immutable_filter(__array: np.ndarray) -> np.ndarray: ...
5658
def mloc(__array: np.ndarray) -> int: ...
5759
def name_filter(__name: tp.Hashable) -> tp.Hashable: ...

src/_arraykit.c

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1442,10 +1442,10 @@ AK_CPL_FromIterable(PyObject* iterable, bool type_parse, Py_UCS4 tsep, Py_UCS4 d
14421442
}
14431443
Py_DECREF(field);
14441444
}
1445+
Py_DECREF(iter);
14451446
if (PyErr_Occurred()) {
14461447
return NULL;
14471448
}
1448-
Py_DECREF(iter);
14491449
return cpl;
14501450
}
14511451

@@ -3059,6 +3059,30 @@ split_after_count(PyObject *Py_UNUSED(m), PyObject *args)
30593059
}
30603060

30613061

3062+
3063+
static PyObject *
3064+
count_iteration(PyObject *Py_UNUSED(m), PyObject *iterable)
3065+
{
3066+
PyObject *iter = PyObject_GetIter(iterable);
3067+
if (iter == NULL) return NULL;
3068+
3069+
int count = 0;
3070+
PyObject *v;
3071+
3072+
while ((v = PyIter_Next(iter))) {
3073+
count++;
3074+
Py_DECREF(v);
3075+
}
3076+
Py_DECREF(iter);
3077+
if (PyErr_Occurred()) {
3078+
return NULL;
3079+
}
3080+
PyObject* result = PyLong_FromLong(count);
3081+
if (result == NULL) return NULL;
3082+
return result;
3083+
}
3084+
3085+
30623086
//------------------------------------------------------------------------------
30633087

30643088
// Return the integer version of the pointer to underlying data-buffer of array.
@@ -3922,6 +3946,7 @@ static PyMethodDef arraykit_methods[] = {
39223946
METH_VARARGS | METH_KEYWORDS,
39233947
NULL},
39243948
{"split_after_count", split_after_count, METH_VARARGS, NULL},
3949+
{"count_iteration", count_iteration, METH_O, NULL},
39253950
{"isna_element", isna_element, METH_O, NULL},
39263951
{"dtype_from_element", dtype_from_element, METH_O, NULL},
39273952
{"get_new_indexers_and_screen",

test/test_util.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
import unittest
55
import warnings
6-
6+
from io import StringIO
77
import numpy as np # type: ignore
88

99
from arraykit import resolve_dtype
@@ -18,6 +18,7 @@
1818
from arraykit import isna_element
1919
from arraykit import dtype_from_element
2020
from arraykit import split_after_count
21+
from arraykit import count_iteration
2122

2223
from performance.reference.util import get_new_indexers_and_screen_ak as get_new_indexers_and_screen_full
2324
from arraykit import get_new_indexers_and_screen
@@ -484,5 +485,18 @@ def test_split_after_count_h(self) -> None:
484485
self.assertEqual(post[0], 'a,b,c,d,e')
485486
self.assertEqual(post[1], '')
486487

488+
489+
#---------------------------------------------------------------------------
490+
def test_count_iteration_a(self) -> None:
491+
post = count_iteration(('a', 'b', 'c', 'd'))
492+
self.assertEqual(post, 4)
493+
494+
def test_count_iteration_b(self) -> None:
495+
s1 = StringIO(',1,a,b\n-,1,43,54\nX,2,1,3\nY,1,8,10\n-,2,6,20')
496+
post = count_iteration(s1)
497+
self.assertEqual(post, 5)
498+
499+
500+
487501
if __name__ == '__main__':
488502
unittest.main()

0 commit comments

Comments
 (0)