diff --git a/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java b/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java index 8a8fc72ab4c3..c665904d8e05 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java +++ b/lucene/core/src/java/org/apache/lucene/index/ConcurrentApproximatePriorityQueue.java @@ -116,20 +116,28 @@ T poll(Predicate predicate) { } } } - for (int i = 0; i < concurrency; ++i) { - final int index = (threadHash + i) % concurrency; - final Lock lock = locks[index]; - final ApproximatePriorityQueue queue = queues[index]; + + // We want to make sure we return a non-null entry if this queue is not empty. This requires us + // to take all the locks at once, otherwise if there is a single non-empty sub queue, as we + // iterate through all sub queues, there is a chance that an entry gets added to a queue we just + // checked and that the existing entry gets removed from a queue we haven't checked yet. This + // would make this method return `null` even though the queue was empty at no point in time. + for (Lock lock : locks) { lock.lock(); - try { + } + try { + for (ApproximatePriorityQueue queue : queues) { T entry = queue.poll(predicate); if (entry != null) { return entry; } - } finally { + } + } finally { + for (Lock lock : locks) { lock.unlock(); } } + return null; } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java index 2656e4a38855..45543dd247f0 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestConcurrentApproximatePriorityQueue.java @@ -100,4 +100,41 @@ public void run() { assertEquals(Integer.valueOf(3), pq.poll(x -> true)); assertNull(pq.poll(x -> true)); } + + public void testNeverReturnNullOnNonEmptyQueue() throws Exception { + final int iters = atLeast(10); + for (int iter = 0; iter < iters; ++iter) { + final int concurrency = TestUtil.nextInt(random(), 1, 16); + final ConcurrentApproximatePriorityQueue queue = + new ConcurrentApproximatePriorityQueue<>(concurrency); + final int numThreads = TestUtil.nextInt(random(), 2, 16); + final Thread[] threads = new Thread[numThreads]; + final CountDownLatch startingGun = new CountDownLatch(1); + for (int t = 0; t < threads.length; ++t) { + threads[t] = + new Thread( + () -> { + try { + startingGun.await(); + } catch (InterruptedException e) { + throw new ThreadInterruptedException(e); + } + Integer v = TestUtil.nextInt(random(), 0, 100); + queue.add(v, v); + for (int i = 0; i < 1_000; ++i) { + v = queue.poll(x -> true); + assertNotNull(v); + queue.add(v, v); + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + startingGun.countDown(); + for (Thread thread : threads) { + thread.join(); + } + } + } }