diff --git a/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java b/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java index 8a8fc72ab4c3..42b20fc39260 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java +++ b/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.index; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Predicate; @@ -116,20 +118,32 @@ T poll(Predicate predicate) { } } } - for (int i = 0; i < concurrency; ++i) { - final int index = (threadHash + i) % concurrency; - final Lock lock = locks[index]; - final ApproximatePriorityQueue queue = queues[index]; - lock.lock(); - try { + + // We want to make sure we return a non-null entry if this queue is not empty. This requires us + // to not release locks until we're done, otherwise if there is a single non-empty sub queue, as + // we iterate through all sub queues, there is a chance that an entry gets added to a queue we + // just checked and that the existing entry gets removed from a queue we haven't checked yet. + // This would make this method return `null` even though the queue was empty at no point in + // time. + + final List toUnlock = new ArrayList<>(); + try { + for (int index = 0; index < concurrency; ++index) { + final Lock lock = locks[index]; + final ApproximatePriorityQueue queue = queues[index]; + lock.lock(); + toUnlock.add(lock); T entry = queue.poll(predicate); if (entry != null) { return entry; } - } finally { + } + } finally { + for (Lock lock : toUnlock) { lock.unlock(); } } + return null; } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java index 2656e4a38855..45543dd247f0 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java @@ -100,4 +100,41 @@ public void run() { assertEquals(Integer.valueOf(3), pq.poll(x -> true)); assertNull(pq.poll(x -> true)); } + + public void testNeverReturnNullOnNonEmptyQueue() throws Exception { + final int iters = atLeast(10); + for (int iter = 0; iter < iters; ++iter) { + final int concurrency = TestUtil.nextInt(random(), 1, 16); + final ConcurrentApproximatePriorityQueue queue = + new ConcurrentApproximatePriorityQueue<>(concurrency); + final int numThreads = TestUtil.nextInt(random(), 2, 16); + final Thread[] threads = new Thread[numThreads]; + final CountDownLatch startingGun = new CountDownLatch(1); + for (int t = 0; t < threads.length; ++t) { + threads[t] = + new Thread( + () -> { + try { + startingGun.await(); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + Integer v = TestUtil.nextInt(random(), 0, 100); + queue.add(v, v); + for (int i = 0; i < 1_000; ++i) { + v = queue.poll(x -> true); + assertNotNull(v); + queue.add(v, v); + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + startingGun.countDown(); + for (Thread thread : threads) { + thread.join(); + } + } + } }