@@ -40,20 +40,47 @@ class PrioritySearchTree:
4040 """
4141
4242 def _push_down (self , node : Node , value : _V ) -> None :
43- while node != Node .NULL_NODE :
44- if node .heap_value is None :
45- node .heap_value = value
46- return
43+ if node == Node .NULL_NODE :
44+ return
45+ if node .heap_value is None :
46+ node .heap_value = value
47+ return
4748
48- if self .heap_key ( value ) > self .heap_key (node .heap_value ):
49- node .heap_value , value = value , node .heap_value
50- elif self . heap_key ( value ) == self . heap_key ( node . heap_value ) :
51- node .placeholder = False
49+ if self .tree_key ( node . heap_value ) < self .tree_key (node .tree_value ):
50+ self . _push_down ( node .left , node .heap_value )
51+ else :
52+ self . _push_down ( node .right , node . heap_value )
5253
53- if self .tree_key (value ) < self .tree_key (node .tree_value ):
54- node = node .left
55- else :
56- node = node .right
54+ if node .placeholder :
55+ node .placeholder = False
56+ else :
57+ node .heap_value = value
58+
59+ def _sieve_down (self , node : Node , value : _V ) -> None :
60+
61+ if node .heap_value is None :
62+ node .heap_value = value
63+ return
64+
65+ if node .placeholder :
66+ node .placeholder = False
67+ return
68+
69+ heap_key_value = self .heap_key (value )
70+ tree_key_value = self .tree_key (value )
71+
72+ if heap_key_value > self .heap_key (node .heap_value ):
73+ self ._push_down (node , value )
74+ return
75+
76+ if heap_key_value == self .heap_key (node .heap_value ) and tree_key_value < self .tree_key (node .heap_value ):
77+ self ._push_down (node , value )
78+ return
79+
80+ if tree_key_value < self .tree_key (node .tree_value ):
81+ self ._sieve_down (node .left , value )
82+ else :
83+ self ._sieve_down (node .right , value )
5784
5885 def _push_up (self , node : Node ) -> None :
5986 vl , vr = None , None
@@ -65,7 +92,7 @@ def _push_up(self, node: Node) -> None:
6592 vr = node .right .heap_value
6693
6794 if vl and vr :
68- if self .heap_key (vl ) > self .heap_key (vr ):
95+ if self .heap_key (vl ) >= self .heap_key (vr ):
6996 node .heap_value = vl
7097 self ._push_up (node .left )
7198 else :
@@ -230,7 +257,7 @@ def add(self, value: _V) -> None:
230257 new_internal_node .heap_value = prev .heap_value
231258 prev .placeholder = True
232259
233- self ._push_down (self ._root , value )
260+ self ._sieve_down (self ._root , value )
234261 self ._fix_insert (new_leaf_node )
235262 self ._len += 1
236263
@@ -246,17 +273,19 @@ def remove(self, value: _V) -> None:
246273
247274 Complexity:
248275 O(log(N)) where **N** is number of items in PST
276+
277+ Note:
278+ this function is using ``tree_key(value)`` to compare the items
249279 """
250280 node = self ._root
251281 value_tree_key = self .tree_key (value )
252- value_heap_key = self .heap_key (value )
253282 heap_node = None
254283 tree_node = None
255284 leaf_node = None
256285
257286 while node != Node .NULL_NODE :
258287 leaf_node = node
259- if heap_node is None and self .tree_key (node .heap_value ) == value_tree_key and self . heap_key ( node . heap_value ) == value_heap_key :
288+ if heap_node is None and self .tree_key (node .heap_value ) == value_tree_key :
260289 heap_node = node
261290 if tree_node is None and self .tree_key (node .tree_value ) == value_tree_key :
262291 tree_node = node
@@ -286,7 +315,7 @@ def remove(self, value: _V) -> None:
286315 cut_node = leaf_node .parent
287316 fix_node = leaf_node .parent .right
288317
289- self ._push_down (cut_node , cut_node . heap_value )
318+ self ._push_down (cut_node , None )
290319 self ._transplant (cut_node , fix_node )
291320
292321 if cut_node .color == 0 :
@@ -535,5 +564,53 @@ def _rotate_left(self, x: Node) -> None:
535564 y .set_left (x )
536565 self ._push_up (x )
537566
538- def __len__ (self ):
567+ def __len__ (self ) -> int :
568+ """
569+ Implements the built-in function len()
570+
571+ Returns:
572+ int: Number of items in PST.
573+
574+ Complexity:
575+ O(1)
576+ """
539577 return self ._len
578+
579+ def __contains__ (self , value ) -> bool :
580+ """
581+ Implements membership test operator.
582+
583+ Args:
584+ value: Value to test for membership
585+
586+ Returns:
587+ bool: ``True`` if value is in ``self``, ``False`` otherwise.
588+
589+ Complexity:
590+ O(log(N)) where **N** is number of items in PST
591+
592+ Note:
593+ this function is using ``tree_key(value)`` to compare the items
594+ """
595+ value_tree_key = self .tree_key (value )
596+
597+ node = self ._root
598+ while node != Node .NULL_NODE :
599+ if value_tree_key < self .tree_key (node .tree_value ):
600+ node = node .left
601+ elif value_tree_key == self .tree_key (node .tree_value ):
602+ return True
603+ else :
604+ node = node .right
605+
606+ return False
607+
608+ def clear (self ) -> None :
609+ """
610+ Removes **all** items from PST.
611+
612+ Complexity:
613+ O(1)
614+ """
615+ self ._root = Node .NULL_NODE
616+ self ._len = 0
0 commit comments