Skip to content

Commit 26a70de

Browse files
committed
Added __iter__ method
1 parent f3670b1 commit 26a70de

File tree

7 files changed

+84
-42
lines changed

7 files changed

+84
-42
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ Changelog
2828

2929
* Refactoring
3030
* Added PrioritySearchSet class
31+
* Added __iter__ method

docs/reference/priority_search_tree_print.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
priority_search_tree.print_helpers
2-
====================
2+
==================================
33

44
.. testsetup::
55

src/priority_search_tree/ps_set.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import collections
21
from typing import Callable
32
from typing import Iterable
3+
from typing import Iterator
4+
from typing import MutableSet
45
from typing import Optional
56
from typing import TypeVar
67

@@ -11,13 +12,15 @@
1112
_V = TypeVar("_V")
1213

1314

14-
class PrioritySearchSet(collections.abc.MutableSet):
15+
class PrioritySearchSet(MutableSet):
1516
"""Mutable Set that maintains priority search tree properties.
1617
1718
PrioritySearchSet can be used to store any type of objects.
1819
2 functions should be passed to PrioritySearchSet constructor:
19-
* ``key_func`` to extract **key** for the object
20-
* ``priority_func`` to extract **priority** for the object
20+
21+
* ``key_func`` to extract **key** for the object
22+
* ``priority_func`` to extract **priority** for the object
23+
2124
extracted **key**, **priority** values will be used in underlying PrioritySearchTree
2225
2326
Example::
@@ -138,7 +141,7 @@ def remove(self, value: _V) -> None:
138141
del self._pst[key]
139142
del self._values[key]
140143

141-
def query(self, left: _V, right: _V, bottom: _V) -> list[_V]:
144+
def query(self, left: _V, right: _V, bottom: _V) -> list:
142145
"""Performs 3 sided query on PSS.
143146
144147
This function returns list of items that meet the following criteria:
@@ -162,7 +165,7 @@ def query(self, left: _V, right: _V, bottom: _V) -> list[_V]:
162165
priority_bottom = self.priority_func(bottom)
163166
return [self._values[x] for x in self._pst.query(key_left, key_right, priority_bottom)]
164167

165-
def sorted_query(self, left: _V, right: _V, bottom: _V, items_limit: int = 0) -> [_V]:
168+
def sorted_query(self, left: _V, right: _V, bottom: _V, items_limit: int = 0) -> list:
166169
"""Performs sorted 3 sided query on PSS.
167170
168171
This function returns list of items that meet the following criteria:
@@ -242,5 +245,11 @@ def discard(self, value) -> None:
242245
del self._pst[key]
243246
del self._values[key]
244247

245-
def __iter__(self):
246-
raise NotImplementedError
248+
def __iter__(self) -> Iterator:
249+
"""Create an iterator that iterates values in sorted by **key** order
250+
251+
Returns:
252+
Iterator: in order iterator
253+
"""
254+
for key in self._pst:
255+
yield self._values[key]

src/priority_search_tree/ps_tree.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import collections.abc
21
from typing import Iterable
2+
from typing import Iterator
3+
from typing import MutableMapping
34
from typing import Optional
45
from typing import Tuple
56
from typing import TypeVar
@@ -10,7 +11,7 @@
1011
_PRIORITY = TypeVar("_PRIORITY")
1112

1213

13-
class PrioritySearchTree(collections.abc.MutableMapping):
14+
class PrioritySearchTree(MutableMapping):
1415
"""Class that represents Priority search tree.
1516
1617
PrioritySearchTree is a mutable mapping that stores **keys** and corresponding **priorities**.
@@ -241,7 +242,7 @@ def _transplant(self, u: Node, v: Node) -> None:
241242
else:
242243
u.parent.set_right(v)
243244

244-
def query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY) -> list[_KEY]:
245+
def query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY) -> list:
245246
"""Performs 3 sided query on PST.
246247
247248
This function returns list of items that meet the following criteria:
@@ -283,7 +284,7 @@ def _query_node(node) -> None:
283284
_query_node(self._root)
284285
return result
285286

286-
def sorted_query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY, items_limit: int = 0) -> [_KEY]:
287+
def sorted_query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY, items_limit: int = 0) -> list:
287288
"""Performs 3 sided query on PST.
288289
289290
This function returns list of items that meet the following criteria:
@@ -613,5 +614,24 @@ def __getitem__(self, key: _KEY) -> _PRIORITY:
613614

614615
return heap_node.heap_key[0]
615616

616-
def __iter__(self):
617-
raise NotImplementedError
617+
def __iter__(self) -> Iterator:
618+
"""Create an iterator that iterates **keys** in sorted order
619+
620+
Returns:
621+
Iterator: in order iterator
622+
"""
623+
stack = []
624+
current = self._root
625+
yielded_key = None
626+
while True:
627+
if current != Node.NULL_NODE:
628+
stack.append(current)
629+
current = current.left
630+
elif stack:
631+
current = stack.pop()
632+
if current.tree_key != yielded_key:
633+
yielded_key = current.tree_key
634+
yield yielded_key
635+
current = current.right
636+
else:
637+
break

tests/manual/performance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from priority_search_tree import PrioritySearchSet
1111

1212
NUMBER_OF_ITEMS = [1, 100, 10000, 1000000]
13-
REPEAT_COUNT = 1
13+
REPEAT_COUNT = 10
1414

1515
TIMEOUT_SEC = 60
1616

tests/test_priority_search_set.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,44 @@ def __init__(self, x: int, y: int):
1212

1313
items = [Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4), Point(5, 6), Point(6, 6)]
1414

15-
pst = PrioritySearchSet(key_func=lambda v: v.x, priority_func=lambda v: v.y, iterable=items)
16-
assert_rb_tree(pst._pst._root)
15+
pss = PrioritySearchSet(key_func=lambda v: v.x, priority_func=lambda v: v.y, iterable=items)
16+
assert_rb_tree(pss._pst._root)
1717

18-
assert pst.query(Point(1, 1), Point(2, 2), Point(2, 2)) == [items[1]]
19-
assert pst.query(Point(1, 1), Point(5, 1), Point(1, 6)) == [items[4]]
20-
assert pst.sorted_query(Point(1, 1), Point(6, 1), Point(1, 6)) == [items[5], items[4]]
21-
assert pst.sorted_query(Point(1, 1), Point(4, 1), Point(1, 1), items_limit=1) == [items[3]]
18+
assert pss.query(Point(1, 1), Point(2, 2), Point(2, 2)) == [items[1]]
19+
assert pss.query(Point(1, 1), Point(5, 1), Point(1, 6)) == [items[4]]
20+
assert pss.sorted_query(Point(1, 1), Point(6, 1), Point(1, 6)) == [items[5], items[4]]
21+
assert pss.sorted_query(Point(1, 1), Point(4, 1), Point(1, 1), items_limit=1) == [items[3]]
2222

23-
assert items[2] in pst
24-
assert Point(10, 10) not in pst
25-
assert Point(2, 10) in pst
23+
assert items[2] in pss
24+
assert Point(10, 10) not in pss
25+
assert Point(2, 10) in pss
2626

27-
pst.remove(items[2])
28-
pst.remove(Point(2, 10))
29-
pst.discard(Point(4, 1))
30-
pst.discard(Point(4, 1))
27+
pss.remove(items[2])
28+
pss.remove(Point(2, 10))
29+
pss.discard(Point(4, 1))
30+
pss.discard(Point(4, 1))
3131

3232
with pytest.raises(KeyError, match="Key not found:"):
33-
pst.remove(Point(7, 1))
33+
pss.remove(Point(7, 1))
3434

3535
with pytest.raises(KeyError, match="Key not found:"):
36-
pst.remove(Point(7, 7))
36+
pss.remove(Point(7, 7))
3737

38-
assert_rb_tree(pst._pst._root)
38+
assert_rb_tree(pss._pst._root)
3939

40-
assert pst.get_with_max_priority().y == 6
40+
assert pss.get_with_max_priority().y == 6
4141

42-
assert pst.pop().y == 6
43-
assert pst.pop().y == 6
44-
assert pst.pop().y == 1
42+
assert pss.pop().y == 6
43+
assert pss.pop().y == 6
44+
assert pss.pop().y == 1
4545

46-
assert not pst
46+
assert not pss
4747

48-
pst = PrioritySearchSet(key_func=lambda v: v.x, priority_func=lambda v: v.y)
49-
pst.add(Point(1, 1))
50-
assert len(pst) == 1
51-
pst.clear()
52-
assert not pst
48+
pss = PrioritySearchSet(key_func=lambda v: v.x, priority_func=lambda v: v.y)
49+
pss.add(Point(1, 4))
50+
assert len(pss) == 1
51+
for p in pss:
52+
assert p.x == 1
53+
assert p.y == 4
54+
pss.clear()
55+
assert not pss

tests/test_priority_search_tree.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def test_empty_pst():
3737
assert len(pst) == 0
3838
with pytest.raises(KeyError):
3939
pst.get_with_max_priority()
40+
for x in pst:
41+
raise AssertionError(x)
4042
pst[1] = 5
4143
assert pst.pop(1) == 5
4244
with pytest.raises(KeyError, match="Key not found:"):
@@ -113,6 +115,13 @@ def test_query():
113115
assert pst.sorted_query(x_min, x_max, y_min) == query_expected
114116

115117

118+
def test_iterator():
119+
items = [(0, 0), (1, 6), (6, 3), (7, 5), (8, 8), (2, 1), (3, 7), (4, 4), (5, 2)]
120+
pst = PrioritySearchTree(items)
121+
for e, k in enumerate(pst):
122+
assert e == k
123+
124+
116125
def test_sorted_query_limit():
117126
items = [(0, 0), (1, 6), (2, 2), (3, 7), (4, 4), (5, 2), (6, 3), (7, 5), (8, 8)]
118127
pst = PrioritySearchTree(items)

0 commit comments

Comments
 (0)