Skip to content

Commit a248965

Browse files
committed
Added sorted_query method.
1 parent 78785e0 commit a248965

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ Changelog
1212
------------------
1313

1414
* Initial implementation.
15+
16+
17+
0.0.2 (2024-03-24)
18+
------------------
19+
20+
* Added sorted_query method.

src/priority_search_tree/__init__.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ def query(self, tree_left: _V, tree_right: _V, heap_bottom: _V) -> [_V]:
378378
while queue:
379379
node = queue.popleft()
380380

381-
if node == Node.NULL_NODE:
381+
if node == Node.NULL_NODE or node.placeholder:
382382
continue
383383

384384
if node.heap_value:
385385
if self.heap_key(node.heap_value) >= self.heap_key(heap_bottom):
386-
if self.tree_key(tree_left) <= self.tree_key(node.heap_value) <= self.tree_key(tree_right) and not node.placeholder:
386+
if self.tree_key(tree_left) <= self.tree_key(node.heap_value) <= self.tree_key(tree_right):
387387
result.append(node.heap_value)
388388
else:
389389
continue
@@ -398,6 +398,73 @@ def query(self, tree_left: _V, tree_right: _V, heap_bottom: _V) -> [_V]:
398398

399399
return result
400400

401+
def sorted_query(self, tree_left: _V, tree_right: _V, heap_bottom: _V, items_limit: int = 0) -> [_V]:
402+
"""Performs 3 sided query on PST.
403+
404+
This function returns list of items that meet the following criteria:
405+
1. items have **tree_key** grater or equal to **tree_key** of tree_left argument
406+
2. items have **tree_key** smaller or equal to **tree_key** of tree_right argument
407+
3. items have **heap_key** grater or equal to **heap_key** of heap_bottom argument
408+
409+
Args:
410+
tree_left: Left bound for query (**tree_key** is used).
411+
tree_right: Right bound for query (**tree_key** is used).
412+
heap_bottom: Bottom bound for query (**heap_key** is used).
413+
items_limit (int): Number of items to return. Default value is ``0`` - no limit.
414+
415+
Returns:
416+
List: list of items that satisfy criteria and sorted by **heap_key**
417+
(in case of limit, items with largest **heap_key** will be returned), or empty list if no items found
418+
419+
Complexity:
420+
O(log(N)+K*log(K)) where **N** is number of items in PST and **K** is number of returned items
421+
"""
422+
tree_left_key = self.tree_key(tree_left)
423+
tree_right_key = self.tree_key(tree_right)
424+
heap_bottom_key = self.heap_key(heap_bottom)
425+
if items_limit <= 0:
426+
items_limit = self._len
427+
428+
def _query_node(node, limit):
429+
result = []
430+
if node == Node.NULL_NODE or node.placeholder or limit == 0:
431+
return result
432+
433+
if node.heap_value:
434+
if self.heap_key(node.heap_value) >= heap_bottom_key:
435+
if tree_left_key <= self.tree_key(node.heap_value) <= tree_right_key:
436+
result.append(node.heap_value)
437+
limit -= 1
438+
else:
439+
return result
440+
441+
if tree_right_key < self.tree_key(node.tree_value):
442+
result.extend(_query_node(node.left, limit))
443+
elif tree_left_key >= self.tree_key(node.tree_value):
444+
result.extend(_query_node(node.right, limit))
445+
else:
446+
left = _query_node(node.left, limit)
447+
right = _query_node(node.right, limit)
448+
# merge
449+
i, j = 0, 0
450+
while i < len(left) and j < len(right) and len(result) < items_limit:
451+
if self.heap_key(left[i]) >= self.heap_key(right[j]):
452+
result.append(left[i])
453+
i += 1
454+
else:
455+
result.append(right[j])
456+
j += 1
457+
while i < len(left) and len(result) < items_limit:
458+
result.append(left[i])
459+
i += 1
460+
while j < len(right) and len(result) < items_limit:
461+
result.append(right[j])
462+
j += 1
463+
464+
return result
465+
466+
return _query_node(self._root, items_limit)
467+
401468
def _fix_insert(self, node: Node) -> None:
402469
while node.parent.color == 1:
403470
if node.parent.parent.right == node.parent:

tests/test_priority_search_tree.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,31 @@ def test_query():
9494
if (x_min, 0) <= item <= (x_max, 8) and item[1:] >= (y_min,):
9595
query_expected.add(item)
9696
assert set(pst.query((x_min, 0), (x_max, 8), (0, y_min))) == query_expected
97+
sqr = pst.sorted_query((x_min, 0), (x_max, 8), (0, y_min))
98+
assert sqr == sorted(query_expected, key=lambda x: x[1:], reverse=True)
99+
100+
101+
def test_sorted_query_limit():
102+
items = [(0, 0), (1, 6), (2, 2), (3, 7), (4, 4), (5, 2), (6, 3), (7, 5), (8, 8)]
103+
pst = PrioritySearchTree(items)
104+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=1)
105+
assert result == [(8, 8)]
106+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=2)
107+
assert result == [(8, 8), (3, 7)]
108+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=3)
109+
assert result == [(8, 8), (3, 7), (1, 6)]
110+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=4)
111+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5)]
112+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=5)
113+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5), (4, 4)]
114+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=6)
115+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5), (4, 4), (6, 3)]
116+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=7)
117+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5), (4, 4), (6, 3), (2, 2)]
118+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=8)
119+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5), (4, 4), (6, 3), (2, 2), (5, 2)]
120+
result = pst.sorted_query((0, 0), (8, 8), (8, 0), items_limit=0)
121+
assert result == [(8, 8), (3, 7), (1, 6), (7, 5), (4, 4), (6, 3), (2, 2), (5, 2), (0, 0)]
97122

98123

99124
def test_stress_tester():

0 commit comments

Comments
 (0)