Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 66 additions & 37 deletions lucene/core/src/java/org/apache/lucene/util/PriorityQueue.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.IntFunction;
import java.util.function.Supplier;

Expand Down Expand Up @@ -72,6 +76,7 @@ public static <T> PriorityQueue<T> usingComparator(
private final int maxSize;
private final T[] heap;
private final LessThan<? super T> lessThan;
private final Map<T, Set<Integer>> indexMap = new HashMap<>();

/** Create an empty priority queue of the configured size using the specified {@link LessThan}. */
public PriorityQueue(int maxSize, LessThan<? super T> lessThan) {
Expand Down Expand Up @@ -182,9 +187,9 @@ public void addAll(Collection<T> elements) {
* @return the new 'top' element in the queue.
*/
public final T add(T element) {
// don't modify size until we know heap access didn't throw AIOOB.
int index = size + 1;
heap[index] = element;
addIndex(element, index);
size = index;
upHeap(index);
return heap[1];
Expand Down Expand Up @@ -272,6 +277,7 @@ public final int size() {
public final void clear() {
Arrays.fill(heap, 0, size + 1, null);
size = 0;
indexMap.clear();
}

/**
Expand All @@ -280,20 +286,25 @@ public final void clear() {
* constant remove time but the trade-off would be extra cost to all additions/insertions)
*/
public final boolean remove(T element) {
for (int i = 1; i <= size; i++) {
if (heap[i] == element) {
heap[i] = heap[size];
heap[size] = null; // permit GC of objects
size--;
if (i <= size) {
if (!upHeap(i)) {
downHeap(i);
}
}
return true;
}
Set<Integer> indices = indexMap.get(element);
if (indices == null || indices.isEmpty()) return false;
Integer idx = indices.iterator().next();
removeIndex(element, idx);
T last = heap[size];
if (idx == size) {
heap[size] = null;
size--;
return true;
}
return false;
removeIndex(last, size);
heap[idx] = last;
addIndex(last, idx);
heap[size] = null;
size--;
if (!upHeap(idx)) {
downHeap(idx);
}
return true;
}

/**
Expand All @@ -320,36 +331,54 @@ public T[] drainToArrayHighestFirst(IntFunction<T[]> newArray) {
return array;
}

private boolean upHeap(int origPos) {
int i = origPos;
T node = heap[i]; // save bottom node
int j = i >>> 1;
while (j > 0 && lessThan.lessThan(node, heap[j])) {
heap[i] = heap[j]; // shift parents down
i = j;
j = j >>> 1;
private void addIndex(T element, int idx) {
indexMap.computeIfAbsent(element, k -> new HashSet<>()).add(idx);
}

private void removeIndex(T element, int idx) {
Set<Integer> indices = indexMap.get(element);
if (indices != null) {
indices.remove(idx);
if (indices.isEmpty()) indexMap.remove(element);
}
heap[i] = node; // install saved node
return i != origPos;
}

private void downHeap(int i) {
T node = heap[i]; // save top node
int j = i << 1; // find smaller child
int k = j + 1;
if (k <= size && lessThan.lessThan(heap[k], heap[j])) {
j = k;
protected boolean upHeap(int i) {
T node = heap[i];
int j = i;
while (j > 1 && lessThan.lessThan(node, heap[j >> 1])) {
heap[j] = heap[j >> 1];
removeIndex(heap[j], j >> 1);
addIndex(heap[j], j);
j >>= 1;
}
while (j <= size && lessThan.lessThan(heap[j], node)) {
heap[i] = heap[j]; // shift up child
i = j;
j = i << 1;
k = j + 1;
if (k <= size && lessThan.lessThan(heap[k], heap[j])) {
heap[j] = node;
removeIndex(node, i);
addIndex(node, j);
return j < i;
}

protected boolean downHeap(int i) {
T node = heap[i];
int j = i;
int k;
while ((k = j << 1) <= size) {
if (k < size && lessThan.lessThan(heap[k + 1], heap[k])) {
k++;
}
if (lessThan.lessThan(heap[k], node)) {
heap[j] = heap[k];
removeIndex(heap[j], k);
addIndex(heap[j], j);
j = k;
} else {
break;
}
}
heap[i] = node; // install saved node
heap[j] = node;
removeIndex(node, i);
addIndex(node, j);
return j > i;
}

/**
Expand Down
89 changes: 61 additions & 28 deletions lucene/core/src/test/org/apache/lucene/util/TestPriorityQueue.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,47 +208,80 @@ public void testAddAllDoesNotFitIntoQueue() {
() -> pq.addAll(list));
}

/** Randomly add and remove elements, comparing against the reference java.util.PriorityQueue. */
public void testRemovalsAndInsertions() {
/** Randomly remove elements, comparing against the reference java.util.PriorityQueue by value. */
public void testRemovals() {
int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000);
int size = maxElement / 2 + 1;

var reference = new java.util.PriorityQueue<Integer>();
var pq = new IntegerQueue(size);

Random localRandom = nonAssertingRandom(random());

// Lucene's PriorityQueue.remove uses reference equality, not .equals to determine which
// elements
// to remove (!).
HashMap<Integer, Integer> ints = new HashMap<>();
// Fill both queues with up to maxSize elements
for (int i = 0; i < size; i++) {
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);
pq.add(element);
reference.add(element);
}
// Perform random removals and compare by value
for (int i = 0; i < size; i++) {
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);
int pqCount = 0, refCount = 0;
for (Integer val : pq) if (val.equals(element)) pqCount++;
for (Integer val : reference) if (val.equals(element)) refCount++;
boolean pqRemoved = pq.remove(element);
boolean refRemoved = reference.remove(element);
assertEquals("remove() should return true if value was present", refCount > 0, pqRemoved);
assertEquals("remove() should return true if value was present", refCount > 0, refRemoved);
int pqCountAfter = 0, refCountAfter = 0;
for (Integer val : pq) if (val.equals(element)) pqCountAfter++;
for (Integer val : reference) if (val.equals(element)) refCountAfter++;
assertEquals("Should remove only one instance (value)", Math.max(0, refCount - 1), refCountAfter);
assertEquals("Should remove only one instance (value)", Math.max(0, pqCount - 1), pqCountAfter);
assertEquals("pq and reference should match counts after removal", refCountAfter, pqCountAfter);
assertEquals("size after removal should match", reference.size(), pq.size());
Integer pqTop = pq.top();
Integer refTop = reference.peek();
if (pqTop != null && refTop != null) {
assertEquals("top() value difference after removal?", refTop.intValue(), pqTop.intValue());
} else {
assertEquals("top() value difference after removal?", refTop, pqTop);
}
}
pq.checkValidity();
}

/** Randomly add elements, comparing against the reference java.util.PriorityQueue by value. */
public void testInsertions() {
int maxElement = RandomNumbers.randomIntBetween(random(), 1, 10_000);
int size = maxElement / 2 + 1;
var reference = new java.util.PriorityQueue<Integer>();
var pq = new IntegerQueue(size);
Random localRandom = nonAssertingRandom(random());
HashMap<Integer, Integer> ints = new HashMap<>();
for (int i = 0, iters = size * 2; i < iters; i++) {
Integer element = ints.computeIfAbsent(localRandom.nextInt(maxElement), k -> k);

var action = localRandom.nextInt(100);
if (action < 25) {
// removals, possibly misses.
assertEquals("remove() difference: " + i, reference.remove(element), pq.remove(element));
var dropped = pq.insertWithOverflow(element);
reference.add(element);
Integer droppedReference;
if (reference.size() > size) {
droppedReference = reference.remove();
} else {
// additions.
var dropped = pq.insertWithOverflow(element);

reference.add(element);
Integer droppedReference;
if (reference.size() > size) {
droppedReference = reference.remove();
} else {
droppedReference = null;
}

assertEquals("insertWithOverflow() difference.", dropped, droppedReference);
droppedReference = null;
}
if (dropped != null && droppedReference != null) {
assertEquals("insertWithOverflow() dropped value difference.", dropped.intValue(), droppedReference.intValue());
} else {
assertEquals("insertWithOverflow() dropped value difference.", droppedReference, dropped);
}

assertEquals("insertWithOverflow() size difference?", reference.size(), pq.size());
assertEquals("top() difference?", reference.peek(), pq.top());
Integer pqTop = pq.top();
Integer refTop = reference.peek();
if (pqTop != null && refTop != null) {
assertEquals("top() value difference?", refTop.intValue(), pqTop.intValue());
} else {
assertEquals("top() value difference?", refTop, pqTop);
}
}

pq.checkValidity();
}

Expand Down
Loading