From dba8ba32a0851730b058d7a1b9ebe8523e669be0 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 22:16:53 +0100 Subject: [PATCH 01/11] Faster algorithm for ordered sampling with replacement --- src/sampling.jl | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 610dde99b..9b854a128 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -101,6 +101,29 @@ sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, sampler!(rng, a, wv, x) end +""" +This algorithm generates a sorted sample with replacement by +adapting the classic result that the cumulative sum of n+1 +exponentially-distributed random numbers divided by the overall sum +(dropping the last) is a sorted sample from a uniform[0,1] +""" +function uniform_orderstat_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) + 1 == firstindex(a) == firstindex(x) || + throw(ArgumentError("non 1-based arrays are not supported")) + Base.mightalias(a, x) && + throw(ArgumentError("output array x must not share memory with input array a")) + n = length(a) + k = length(x) + exp_rands = randexp(rng, k) + sorted_rands = accumulate!(+, exp_rands, exp_rands) + cum_step = (sorted_rands[end] + randexp(rng)) / n + @inbounds for i in eachindex(x) + j = ceil(Int, sorted_rands[i] / cum_step) + x[i] = a[j] + end + return x +end + ### draw a pair of distinct integers in [1:n] """ @@ -499,7 +522,11 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if replace # with replacement if ordered - sample_ordered!(direct_sample!, rng, a, x) + if k < 10 + sample_ordered!(direct_sample!, rng, a, x) + else + uniform_orderstat_sample!(rng, a, x) + end else direct_sample!(rng, a, x) end From c922c90f9257cd1fee8d8ca74225fd6bb3896532 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:00:25 +0100 Subject: [PATCH 02/11] Update sampling.jl --- src/sampling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index 9b854a128..b71181789 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -522,10 +522,10 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if replace # with replacement if ordered - if k < 10 - sample_ordered!(direct_sample!, rng, a, x) - else + if k > 10 uniform_orderstat_sample!(rng, a, x) + else + sample_ordered!(direct_sample!, rng, a, x) end else direct_sample!(rng, a, x) From a7ef9726654e2d6486f4415f3a9244c2108b5a3b Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:11:09 +0100 Subject: [PATCH 03/11] let's keep only uniform_orderstat_sample! to see if it is numerically stable --- src/sampling.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/sampling.jl b/src/sampling.jl index b71181789..8160a1652 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -522,11 +522,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if replace # with replacement if ordered - if k > 10 - uniform_orderstat_sample!(rng, a, x) - else - sample_ordered!(direct_sample!, rng, a, x) - end + uniform_orderstat_sample!(rng, a, x) else direct_sample!(rng, a, x) end From b424eaf76df5a13c8c972eadfa90dc989c48bae2 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:37:12 +0100 Subject: [PATCH 04/11] previous methodology then --- src/sampling.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 8160a1652..987f51a4b 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -522,7 +522,11 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if replace # with replacement if ordered - uniform_orderstat_sample!(rng, a, x) + if k <= 10 + sample_ordered!(direct_sample!, rng, a, x) + else + uniform_orderdist_sample!(rng, a, x) + end else direct_sample!(rng, a, x) end From 3ae19bfe3fc32e859d0ba37c81c3e0cf2e30d7be Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:46:04 +0100 Subject: [PATCH 05/11] Update sampling.jl --- test/sampling.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sampling.jl b/test/sampling.jl index 34b339dec..a30051eaa 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -77,11 +77,11 @@ check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false) for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(3:12) : (3:12) r = T===Int ? r : T.(r) - aa = Int.(sample(r, n; ordered=true)) - check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev) - aa = Int.(sample(r, 10; ordered=true)) check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev) + + aa = Int.(sample(r, n; ordered=true)) + check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev) end @test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false From c5d90b38f74196d30f08d9be7eb2b6a7ca036cb7 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Mon, 1 Jan 2024 23:56:44 +0100 Subject: [PATCH 06/11] use better test for small ordered sampling --- test/sampling.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/sampling.jl b/test/sampling.jl index a30051eaa..50fcc48ee 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -77,11 +77,15 @@ check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false) for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(3:12) : (3:12) r = T===Int ? r : T.(r) - aa = Int.(sample(r, 10; ordered=true)) - check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev) - aa = Int.(sample(r, n; ordered=true)) check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev) + + aa = Int[] + for i in 1:Int(n/10) + bb = Int.(sample(r, 10; ordered=true)) + append!(aa, bb) + end + check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=false, rev=rev) end @test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false From e032965e416c4dbd9e7d23c45290f7bfdd2cb087 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:01:10 +0100 Subject: [PATCH 07/11] Update sampling.jl --- src/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index 987f51a4b..f2c5e3338 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -525,7 +525,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if k <= 10 sample_ordered!(direct_sample!, rng, a, x) else - uniform_orderdist_sample!(rng, a, x) + uniform_orderstat_sample!(rng, a, x) end else direct_sample!(rng, a, x) From 397b8b37ff5624bb508086950d25319862cc6640 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:01:35 +0100 Subject: [PATCH 08/11] Update sampling.jl --- test/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampling.jl b/test/sampling.jl index 50fcc48ee..841893b4e 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -85,7 +85,7 @@ for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64 bb = Int.(sample(r, 10; ordered=true)) append!(aa, bb) end - check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=false, rev=rev) + check_sample_wrep(sort!(aa), (3, 12), 5.0e-3; ordered=true, rev=rev) end @test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false From 1e9c9ecd9cce820ab24b7bc9a307b4e87ce3fb2a Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:03:35 +0100 Subject: [PATCH 09/11] Update sampling.jl --- test/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampling.jl b/test/sampling.jl index 841893b4e..f2c3003b6 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -85,7 +85,7 @@ for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64 bb = Int.(sample(r, 10; ordered=true)) append!(aa, bb) end - check_sample_wrep(sort!(aa), (3, 12), 5.0e-3; ordered=true, rev=rev) + check_sample_wrep(sort!(aa, rev=rev), (3, 12), 5.0e-3; ordered=true, rev=rev) end @test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false From 2a4683cead6a50b553a2798838b42d7d83f5b906 Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Tue, 2 Jan 2024 00:45:21 +0100 Subject: [PATCH 10/11] try stablerng(1) --- test/sampling.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/sampling.jl b/test/sampling.jl index f2c3003b6..5630bdcc1 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -74,15 +74,16 @@ test_rng_use(direct_sample!, 1:10, zeros(Int, 6)) a = sample(3:12, n) check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false) +rng = StableRNG(1) for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(3:12) : (3:12) r = T===Int ? r : T.(r) - aa = Int.(sample(r, n; ordered=true)) + aa = Int.(sample(rng, r, n; ordered=true)) check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev) aa = Int[] for i in 1:Int(n/10) - bb = Int.(sample(r, 10; ordered=true)) + bb = Int.(sample(rng, r, 10; ordered=true)) append!(aa, bb) end check_sample_wrep(sort!(aa, rev=rev), (3, 12), 5.0e-3; ordered=true, rev=rev) From 5367e69cd4fa47cfc0e87a85b65d9e7d95c55ebb Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Sat, 13 Apr 2024 12:06:55 +0200 Subject: [PATCH 11/11] use cumsum --- src/sampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampling.jl b/src/sampling.jl index f2c5e3338..c19b40ecd 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -115,7 +115,7 @@ function uniform_orderstat_sample!(rng::AbstractRNG, a::AbstractArray, x::Abstra n = length(a) k = length(x) exp_rands = randexp(rng, k) - sorted_rands = accumulate!(+, exp_rands, exp_rands) + sorted_rands = cumsum(exp_rands) cum_step = (sorted_rands[end] + randexp(rng)) / n @inbounds for i in eachindex(x) j = ceil(Int, sorted_rands[i] / cum_step)