Skip to content

Commit 5a8935b

Browse files
committed
Added __len__, __contains__, clear methods.
Fixed issue with not unique heap keys
1 parent a248965 commit 5a8935b

File tree

4 files changed

+145
-22
lines changed

4 files changed

+145
-22
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ Changelog
1414
* Initial implementation.
1515

1616

17-
0.0.2 (2024-03-24)
17+
0.0.2 (2024-03-26)
1818
------------------
1919

2020
* Added sorted_query method.
21+
* Added __len__ and __contains__ methods.
22+
* Added clear method.
23+
* Fixed issue with not unique heap keys

docs/reference/priority_search_tree.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ priority_search_tree
77

88
.. automodule:: priority_search_tree
99
:members:
10+
:special-members:

src/priority_search_tree/__init__.py

Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,47 @@ class PrioritySearchTree:
4040
"""
4141

4242
def _push_down(self, node: Node, value: _V) -> None:
43-
while node != Node.NULL_NODE:
44-
if node.heap_value is None:
45-
node.heap_value = value
46-
return
43+
if node == Node.NULL_NODE:
44+
return
45+
if node.heap_value is None:
46+
node.heap_value = value
47+
return
4748

48-
if self.heap_key(value) > self.heap_key(node.heap_value):
49-
node.heap_value, value = value, node.heap_value
50-
elif self.heap_key(value) == self.heap_key(node.heap_value):
51-
node.placeholder = False
49+
if self.tree_key(node.heap_value) < self.tree_key(node.tree_value):
50+
self._push_down(node.left, node.heap_value)
51+
else:
52+
self._push_down(node.right, node.heap_value)
5253

53-
if self.tree_key(value) < self.tree_key(node.tree_value):
54-
node = node.left
55-
else:
56-
node = node.right
54+
if node.placeholder:
55+
node.placeholder = False
56+
else:
57+
node.heap_value = value
58+
59+
def _sieve_down(self, node: Node, value: _V) -> None:
60+
61+
if node.heap_value is None:
62+
node.heap_value = value
63+
return
64+
65+
if node.placeholder:
66+
node.placeholder = False
67+
return
68+
69+
heap_key_value = self.heap_key(value)
70+
tree_key_value = self.tree_key(value)
71+
72+
if heap_key_value > self.heap_key(node.heap_value):
73+
self._push_down(node, value)
74+
return
75+
76+
if heap_key_value == self.heap_key(node.heap_value) and tree_key_value < self.tree_key(node.heap_value):
77+
self._push_down(node, value)
78+
return
79+
80+
if tree_key_value < self.tree_key(node.tree_value):
81+
self._sieve_down(node.left, value)
82+
else:
83+
self._sieve_down(node.right, value)
5784

5885
def _push_up(self, node: Node) -> None:
5986
vl, vr = None, None
@@ -65,7 +92,7 @@ def _push_up(self, node: Node) -> None:
6592
vr = node.right.heap_value
6693

6794
if vl and vr:
68-
if self.heap_key(vl) > self.heap_key(vr):
95+
if self.heap_key(vl) >= self.heap_key(vr):
6996
node.heap_value = vl
7097
self._push_up(node.left)
7198
else:
@@ -230,7 +257,7 @@ def add(self, value: _V) -> None:
230257
new_internal_node.heap_value = prev.heap_value
231258
prev.placeholder = True
232259

233-
self._push_down(self._root, value)
260+
self._sieve_down(self._root, value)
234261
self._fix_insert(new_leaf_node)
235262
self._len += 1
236263

@@ -246,17 +273,19 @@ def remove(self, value: _V) -> None:
246273
247274
Complexity:
248275
O(log(N)) where **N** is number of items in PST
276+
277+
Note:
278+
this function is using ``tree_key(value)`` to compare the items
249279
"""
250280
node = self._root
251281
value_tree_key = self.tree_key(value)
252-
value_heap_key = self.heap_key(value)
253282
heap_node = None
254283
tree_node = None
255284
leaf_node = None
256285

257286
while node != Node.NULL_NODE:
258287
leaf_node = node
259-
if heap_node is None and self.tree_key(node.heap_value) == value_tree_key and self.heap_key(node.heap_value) == value_heap_key:
288+
if heap_node is None and self.tree_key(node.heap_value) == value_tree_key:
260289
heap_node = node
261290
if tree_node is None and self.tree_key(node.tree_value) == value_tree_key:
262291
tree_node = node
@@ -286,7 +315,7 @@ def remove(self, value: _V) -> None:
286315
cut_node = leaf_node.parent
287316
fix_node = leaf_node.parent.right
288317

289-
self._push_down(cut_node, cut_node.heap_value)
318+
self._push_down(cut_node, None)
290319
self._transplant(cut_node, fix_node)
291320

292321
if cut_node.color == 0:
@@ -535,5 +564,53 @@ def _rotate_left(self, x: Node) -> None:
535564
y.set_left(x)
536565
self._push_up(x)
537566

538-
def __len__(self):
567+
def __len__(self) -> int:
568+
"""
569+
Implements the built-in function len()
570+
571+
Returns:
572+
int: Number of items in PST.
573+
574+
Complexity:
575+
O(1)
576+
"""
539577
return self._len
578+
579+
def __contains__(self, value) -> bool:
580+
"""
581+
Implements membership test operator.
582+
583+
Args:
584+
value: Value to test for membership
585+
586+
Returns:
587+
bool: ``True`` if value is in ``self``, ``False`` otherwise.
588+
589+
Complexity:
590+
O(log(N)) where **N** is number of items in PST
591+
592+
Note:
593+
this function is using ``tree_key(value)`` to compare the items
594+
"""
595+
value_tree_key = self.tree_key(value)
596+
597+
node = self._root
598+
while node != Node.NULL_NODE:
599+
if value_tree_key < self.tree_key(node.tree_value):
600+
node = node.left
601+
elif value_tree_key == self.tree_key(node.tree_value):
602+
return True
603+
else:
604+
node = node.right
605+
606+
return False
607+
608+
def clear(self) -> None:
609+
"""
610+
Removes **all** items from PST.
611+
612+
Complexity:
613+
O(1)
614+
"""
615+
self._root = Node.NULL_NODE
616+
self._len = 0

tests/test_priority_search_tree.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
def test_empty_pst():
1313
pst = PrioritySearchTree()
1414
assert not pst
15+
assert (1, 1) not in pst
1516
result = pst.query((0, 0), (1, 2), (1, 1))
1617
assert len(result) == 0
1718
with pytest.raises(ValueError, match="Value not found:"):
@@ -67,6 +68,20 @@ def test_heap_get_max():
6768
assert pst.heap_get_max() == (5, 5)
6869

6970

71+
def test_contains():
72+
items = [(1, 1), (2, 2), (3, 4), (4, 4)]
73+
pst = PrioritySearchTree(items, tree_key=lambda x: x[0])
74+
for itm in items:
75+
assert itm in pst
76+
77+
# key_tree and heap_tree not in pst
78+
assert (5, 5) not in pst
79+
# key_tree not in pst and heap_key is
80+
assert (5, 4) not in pst
81+
# key_tree in pst heap_key is not
82+
assert (1, 5) in pst
83+
84+
7085
def test_large_pst():
7186
pst = PrioritySearchTree(LARGE_PST_INITIAL_DATA)
7287
assert len(pst) == len(LARGE_PST_INITIAL_DATA)
@@ -134,6 +149,35 @@ def test_unique_tree_key():
134149
PrioritySearchTree([(1, 1), (1, 1)])
135150

136151

152+
def test_not_unique_heap_keys():
153+
pst = PrioritySearchTree()
154+
for i in range(100, -1, -1):
155+
pst.add((i, 5))
156+
assert_rb_tree(pst._root)
157+
result = pst.sorted_query((0, 0), (1000, 0), (0, 0))
158+
assert len(result) == len(pst)
159+
for i, r in enumerate(result):
160+
assert r[0] == i
161+
162+
pst.clear()
163+
for i in range(100):
164+
pst.add((i, 5))
165+
assert_rb_tree(pst._root)
166+
result = pst.sorted_query((0, 0), (1000, 0), (0, 0))
167+
assert len(result) == len(pst)
168+
for i, r in enumerate(result):
169+
assert r[0] == i
170+
171+
pst.clear()
172+
for itm in [(0, 5), (1, 5), (2, 5), (3, 5), (4, 5), (5, 5), (6, 5), (7, 5), (8, 5)]:
173+
pst.add(itm)
174+
assert_rb_tree(pst._root)
175+
result = pst.sorted_query((0, 0), (1000, 0), (0, 0))
176+
assert len(result) == len(pst)
177+
for i, r in enumerate(result):
178+
assert r[0] == i
179+
180+
137181
def test_custom_keys():
138182
class Point:
139183
def __init__(self, x: int, y: int):
@@ -155,8 +199,7 @@ def __init__(self, x: int, y: int):
155199
pst.remove(Point(2, 2))
156200

157201
# same tree_key different hash_key
158-
with pytest.raises(ValueError, match="Value not found:"):
159-
pst.remove(Point(4, 1))
202+
pst.remove(Point(4, 1))
160203

161204
# different tree_key same hash_key
162205
with pytest.raises(ValueError, match="Value not found:"):
@@ -170,5 +213,4 @@ def __init__(self, x: int, y: int):
170213

171214
assert pst.heap_pop().y == 6
172215
assert pst.heap_pop().y == 6
173-
assert pst.heap_pop().y == 4
174216
assert pst.heap_pop().y == 1

0 commit comments

Comments
 (0)