From 7d1cfdfe77446824da184a488139362bbf17bff5 Mon Sep 17 00:00:00 2001 From: Marcus Better Date: Tue, 31 Dec 2024 12:27:11 -0500 Subject: [PATCH] Optimize sampling from IndexedSeq Indexed sequences allow us to skip over items without examining each one. --- .../ReservoirSamplingBenchmark.scala | 4 ++ .../algebird/mutable/ReservoirSampling.scala | 70 +++++++++++++++---- .../mutable/ReservoirSamplingTest.scala | 11 ++- 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala b/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala index 41e0bf884..4ac53a568 100644 --- a/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala +++ b/algebird-benchmark/src/main/scala/com/twitter/algebird/benchmark/ReservoirSamplingBenchmark.scala @@ -36,6 +36,10 @@ class ReservoirSamplingBenchmark { def timeAlgorithmL(state: BenchmarkState, bh: Blackhole): Unit = bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply(0 until state.collectionSize)) + @Benchmark + def timeAlgorithmLSeq(state: BenchmarkState, bh: Blackhole): Unit = + bh.consume(new ReservoirSamplingToListAggregator[Int](state.samples).apply((0 until state.collectionSize).asInstanceOf[Seq[Int]])) + @Benchmark def timePriorityQeueue(state: BenchmarkState, bh: Blackhole): Unit = bh.consume(prioQueueSampler(state.samples).apply(0 until state.collectionSize)) diff --git a/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala b/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala index a80a4c32e..6f2e3b1ba 100644 --- a/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala +++ b/algebird-core/src/main/scala/com/twitter/algebird/mutable/ReservoirSampling.scala @@ -14,7 +14,7 @@ import scala.util.Random * the element type */ sealed class Reservoir[T](val capacity: Int) { - var reservoir: mutable.Buffer[T] = mutable.Buffer() + var reservoir: mutable.ArrayBuffer[T] = new mutable.ArrayBuffer // When the reservoir is full, w is the threshold for accepting an element into the reservoir, and // the following invariant holds: The maximum score of the elements in the reservoir is w, @@ -52,6 +52,13 @@ sealed class Reservoir[T](val capacity: Int) { } } + // The number of items to skip before accepting the next item is geometrically distributed + // with probability of success w / prior. The prior will be 1 when adding to a single reservoir, + // but when merging reservoirs it will be the threshold of the reservoir being pulled from, + // and in this case we require that w < prior. + private def nextAcceptTime(rng: Random, prior: Double = 1.0): Int = + (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt + /** * Add multiple elements to the reservoir. * @param xs @@ -64,26 +71,55 @@ sealed class Reservoir[T](val capacity: Int) { * @return * this reservoir */ - def append(xs: TraversableOnce[T], rng: Random, prior: Double = 1): Reservoir[T] = { - // The number of items to skip before accepting the next item is geometrically distributed - // with probability of success w / prior. The prior will be 1 when adding to a single reservoir, - // but when merging reservoirs it will be the threshold of the reservoir being pulled from, - // and in this case we require that w < prior. - def nextAcceptTime = (-rng.self.nextExponential / Math.log1p(-w / prior)).toInt - - var skip = if (isFull) nextAcceptTime else 0 + def append(xs: TraversableOnce[T], rng: Random): Reservoir[T] = { + var skip = if (isFull) nextAcceptTime(rng) else 0 for (x <- xs) { if (!isFull) { // keep adding while reservoir is not full accept(x, rng) if (isFull) { - skip = nextAcceptTime + skip = nextAcceptTime(rng) } } else if (skip > 0) { skip -= 1 } else { accept(x, rng) - skip = nextAcceptTime + skip = nextAcceptTime(rng) + } + } + this + } + + /** + * Add multiple elements to the reservoir. This overload is optimized for indexed sequences, where we can + * skip over multiple indexes without accessing the elements. + * + * @param xs + * the elements to add + * @param rng + * the random source + * @param prior + * the threshold of the elements being added, such that the added element's value is distributed as + *
U[0, prior]
+ * @return + * this reservoir + */ + def append(xs: IndexedSeq[T], rng: Random, prior: Double): Reservoir[T] = { + var i = xs.size.min(capacity - size) + for (j <- 0 until i) { + accept(xs(j), rng) + } + assert(isFull) + + val end = xs.size + i -= 1 + while (i >= 0 && i < end) { + i += 1 + nextAcceptTime(rng, prior) + // the addition can overflow, in which case i < 0 + if (i >= 0 && i < end) { + // element enters the reservoir + reservoir(rng.nextInt(capacity)) = xs(i) + w *= Math.pow(rng.nextDouble, kInv) } } this @@ -147,7 +183,7 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono s2.reservoir(i) = s2.reservoir.head s1.append(s2.reservoir.drop(1), rng, s2.w) } else { - s1.append(s2.reservoir, rng) + s1.append(s2.reservoir, rng, 1.0) } } } @@ -157,6 +193,10 @@ class ReservoirMonoid[T](implicit val randomSupplier: () => Random) extends Mono * reservoir is mutable, it is a good idea to copy the result to an immutable view before using it, as is done * by [[ReservoirSamplingToListAggregator]]. * + * The aggregator defines operations for [[IndexedSeq]]s that allow for more efficient aggregation, however + * care must be taken with methods such as [[composePrepare()]] which return a regular [[MonoidAggregator]] + * that loses this optimized behavior. + * * @param k * the number of elements to sample * @param randomSupplier @@ -172,6 +212,7 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup override def prepare(x: T): Reservoir[T] = monoid.build(k, x) override def apply(xs: TraversableOnce[T]): C = present(agg(xs)) + def apply(xs: IndexedSeq[T]): C = present(agg(xs)) override def applyOption(inputs: TraversableOnce[T]): Option[C] = if (inputs.isEmpty) None else Some(apply(inputs)) @@ -180,11 +221,16 @@ abstract class ReservoirSamplingAggregator[T, +C](k: Int)(implicit val randomSup override def appendAll(r: Reservoir[T], xs: TraversableOnce[T]): Reservoir[T] = r.append(xs, randomSupplier()) + def appendAll(r: Reservoir[T], xs: IndexedSeq[T]): Reservoir[T] = + r.append(xs, randomSupplier(), 1.0) override def appendAll(xs: TraversableOnce[T]): Reservoir[T] = agg(xs) + def appendAll(xs: IndexedSeq[T]): Reservoir[T] = agg(xs) private def agg(xs: TraversableOnce[T]): Reservoir[T] = appendAll(monoid.zero(k), xs) + private def agg(xs: IndexedSeq[T]): Reservoir[T] = + appendAll(monoid.zero(k), xs) } class ReservoirSamplingToListAggregator[T](k: Int)(implicit randomSupplier: () => Random) diff --git a/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala index 5d29a0085..6fd0e5915 100644 --- a/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala +++ b/algebird-test/src/test/scala/com/twitter/algebird/mutable/ReservoirSamplingTest.scala @@ -1,7 +1,9 @@ package com.twitter.algebird.mutable -import com.twitter.algebird.{Aggregator, CheckProperties, Preparer} import com.twitter.algebird.RandomSamplingLaws._ +import com.twitter.algebird.scalacheck.Distribution.{forAllSampled, uniform} +import com.twitter.algebird.{Aggregator, CheckProperties, Preparer} +import org.scalacheck.Gen import scala.util.Random @@ -23,4 +25,11 @@ class ReservoirSamplingTest extends CheckProperties { property("reservoir sampling with priority queue works") { randomSamplingDistributions(prioQueueSampler) } + + property("sampling from non-indexed Seq") { + val n = 100 + "sampleList" |: forAllSampled(10000, Gen.choose(1, 20))(_ => uniform(n)) { k => + new ReservoirSamplingToListAggregator[Int](k).apply((0 until n).asInstanceOf[Seq[Int]]).head + } + } }