From 7f010a7668c0f52b0cd734f590fc0fb766bc2add Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Tue, 5 Aug 2025 06:34:08 +0530 Subject: [PATCH 01/12] added gpu version for pois_rand --- Project.toml | 2 + src/PoissonRandom.jl | 96 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 93 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index e8fbd31..9b72c8f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,14 @@ version = "0.4.5" [deps] LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Aqua = "0.8" Distributions = "0.25" LogExpFunctions = "0.3" Random = "1.10" +SpecialFunctions = "2" Statistics = "1" Test = "1" julia = "1.10" diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 392668b..d90122a 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -2,10 +2,25 @@ module PoissonRandom using Random using LogExpFunctions: log1pmx +using SpecialFunctions: loggamma export pois_rand -count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) +# GPU-compatible Poisson sampling +randexp(T::Type) = -log(rand(T)) +randexp() = randexp(Float64) + +function count_rand(λ) + λ = Float64(λ) + n = 0 + c = randexp(Float64) + while c < λ + n += 1 + c += randexp(Float64) + end + return n +end + function count_rand(rng::AbstractRNG, λ) n = 0 c = randexp(rng) @@ -24,7 +39,49 @@ end # # For μ sufficiently large, (i.e. >= 10.0) # -ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) +function ad_rand(λ) + λ = Float64(λ) + s = sqrt(λ) + d = 6.0 * λ^2 + L = floor(Int, λ - 1.1484) + + G = λ + s * randn() + + if G >= 0 + K = floor(Int, G) + if K >= L + return K + end + + U = rand() + if d * U >= (λ - K)^3 + return K + end + + px, py, fx, fy = procf(λ, K, s) + if fy * (1 - U) <= py * exp(px - fx) + return K + end + end + + while true + E = randexp() + U = 2 * rand() - 1 + T_val = 1.8 + copysign(E, U) + if T_val <= -0.6744 + continue + end + + K = floor(Int, λ + s * T_val) + px, py, fx, fy = procf(λ, K, s) + c = 0.1069 / λ + + @fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E) + return K + end + end +end + function ad_rand(rng::AbstractRNG, λ) s = sqrt(λ) d = 6 * λ^2 @@ -75,6 +132,35 @@ function ad_rand(rng::AbstractRNG, λ) end # Procedure F +function procf(λ, K::Int, s::Float64) + INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π) + ω = INV_SQRT_2PI / s + b1 = 1 / (24 * λ) + b2 = 0.3 * b1^2 + c3 = b1 * b2 / 7 + c2 = b2 - 15 * c3 + c1 = b1 - 6 * b2 + 45 * c3 + c0 = 1 - b1 + 3 * b2 - 15 * c3 + + if K < 10 + px = -λ + log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma + py = exp(log_py) + else + δ = 1 / (12 * K) + δ -= 4.8 * δ^3 + V = (λ - K) / K + px = K * log1pmx(V) - δ + py = INV_SQRT_2PI / sqrt(K) + end + + X = (K - λ + 0.5) / s + X2 = X^2 + fx = -X2 / 2 + fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0) + return px, py, fx, fy +end + function procf(λ, K::Int, s::Float64) # can be pre-computed, but does not seem to affect performance INV_SQRT_2PI = inv(sqrt(2pi)) @@ -114,16 +200,16 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm. ## Examples ```julia -# Simple Poisson random +# Simple Poisson random which works on GPU pois_rand(λ) -# Using another RNG +# Using RNG using RandomNumbers rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) ``` """ -pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) +pois_rand(λ) = λ < 6 ? count_rand(λ) : ad_rand(λ) pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) end # module From 56b172092e66d77504a0f10aab73e4f8bfcd4ee8 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 02:55:42 +0530 Subject: [PATCH 02/12] added PassthroughRNG implementation --- src/PoissonRandom.jl | 109 ++++--------------------------------------- 1 file changed, 10 insertions(+), 99 deletions(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index d90122a..409a211 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -6,27 +6,15 @@ using SpecialFunctions: loggamma export pois_rand -# GPU-compatible Poisson sampling -randexp(T::Type) = -log(rand(T)) -randexp() = randexp(Float64) - -function count_rand(λ) - λ = Float64(λ) - n = 0 - c = randexp(Float64) - while c < λ - n += 1 - c += randexp(Float64) - end - return n -end +# GPU-compatible Poisson sampling PassthroughRNG +struct PassthroughRNG <: AbstractRNG end function count_rand(rng::AbstractRNG, λ) n = 0 - c = randexp(rng) + c = rng isa PassthroughRNG ? randexp() : randexp(rng) while c < λ n += 1 - c += randexp(rng) + c += rng isa PassthroughRNG ? randexp() : randexp(rng) end return n end @@ -39,13 +27,13 @@ end # # For μ sufficiently large, (i.e. >= 10.0) # -function ad_rand(λ) +function ad_rand(rng::AbstractRNG, λ) λ = Float64(λ) s = sqrt(λ) d = 6.0 * λ^2 L = floor(Int, λ - 1.1484) - G = λ + s * randn() + G = λ + s * (rng isa PassthroughRNG ? randn() : randn(rng)) if G >= 0 K = floor(Int, G) @@ -53,7 +41,7 @@ function ad_rand(λ) return K end - U = rand() + U = rng isa PassthroughRNG ? rand() : rand(rng) if d * U >= (λ - K)^3 return K end @@ -65,8 +53,8 @@ function ad_rand(λ) end while true - E = randexp() - U = 2 * rand() - 1 + E = rng isa PassthroughRNG ? randexp() : randexp(rng) + U = 2 * (rng isa PassthroughRNG ? rand() : rand(rng)) - 1 T_val = 1.8 + copysign(E, U) if T_val <= -0.6744 continue @@ -82,55 +70,6 @@ function ad_rand(λ) end end -function ad_rand(rng::AbstractRNG, λ) - s = sqrt(λ) - d = 6 * λ^2 - L = floor(Int, λ - 1.1484) - # Step N - G = λ + s * randn(rng) - - if G >= 0 - K = floor(Int, G) - # Step I - if K >= L - return K - end - - # Step S - U = rand(rng) - if d * U >= (λ - K)^3 - return K - end - - # Step P - px, py, fx, fy = procf(λ, K, s) - - # Step Q - if fy * (1 - U) <= py * exp(px - fx) - return K - end - end - - while true - # Step E - E = randexp(rng) - U = 2 * rand(rng) - 1 - T = 1.8 + copysign(E, U) - if T <= -0.6744 - continue - end - - K = floor(Int, λ + s * T) - px, py, fx, fy = procf(λ, K, s) - c = 0.1069 / λ - - # Step H - @fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E) - return K - end - end -end - # Procedure F function procf(λ, K::Int, s::Float64) INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π) @@ -161,34 +100,6 @@ function procf(λ, K::Int, s::Float64) return px, py, fx, fy end -function procf(λ, K::Int, s::Float64) - # can be pre-computed, but does not seem to affect performance - INV_SQRT_2PI = inv(sqrt(2pi)) - ω = INV_SQRT_2PI / s - b1 = inv(24) / λ - b2 = 0.3 * b1 * b1 - c3 = inv(7) * b1 * b2 - c2 = b2 - 15 * c3 - c1 = b1 - 6 * b2 + 45 * c3 - c0 = 1 - b1 + 3 * b2 - 15 * c3 - - if K < 10 - px = -float(λ) - py = λ^K / factorial(K) - else - δ = inv(12) / K - δ -= 4.8 * δ^3 - V = (λ - K) / K - px = K * log1pmx(V) - δ # avoids need for table - py = INV_SQRT_2PI / sqrt(K) - end - X = (K - λ + 0.5) / s - X2 = X^2 - fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code. - fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0) - return px, py, fx, fy -end - """ ```julia pois_rand(λ) @@ -209,7 +120,7 @@ rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) ``` """ -pois_rand(λ) = λ < 6 ? count_rand(λ) : ad_rand(λ) +pois_rand(λ) = λ < 6 ? count_rand(PassthroughRNG(), λ) : ad_rand(PassthroughRNG(), λ) pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) end # module From 2a452e7c4793e08dc8a1a9ba21351a8db73bbd29 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 03:26:36 +0530 Subject: [PATCH 03/12] refactor --- src/PoissonRandom.jl | 45 +++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 409a211..4cc3fab 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -4,17 +4,22 @@ using Random using LogExpFunctions: log1pmx using SpecialFunctions: loggamma -export pois_rand +export pois_rand, PassthroughRNG # GPU-compatible Poisson sampling PassthroughRNG struct PassthroughRNG <: AbstractRNG end +rand(rng::PassthroughRNG) = Random.rand() +randexp(rng::PassthroughRNG) = Random.randexp() +randn(rng::PassthroughRNG) = Random.randn() + +count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) function count_rand(rng::AbstractRNG, λ) n = 0 - c = rng isa PassthroughRNG ? randexp() : randexp(rng) + c = randexp(rng) while c < λ n += 1 - c += rng isa PassthroughRNG ? randexp() : randexp(rng) + c += randexp(rng) end return n end @@ -27,43 +32,50 @@ end # # For μ sufficiently large, (i.e. >= 10.0) # +ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) function ad_rand(rng::AbstractRNG, λ) - λ = Float64(λ) s = sqrt(λ) - d = 6.0 * λ^2 + d = 6 * λ^2 L = floor(Int, λ - 1.1484) - - G = λ + s * (rng isa PassthroughRNG ? randn() : randn(rng)) + # Step N + G = λ + s * randn(rng) if G >= 0 K = floor(Int, G) + # Step I if K >= L return K end - U = rng isa PassthroughRNG ? rand() : rand(rng) + # Step S + U = rand(rng) if d * U >= (λ - K)^3 return K end + # Step P px, py, fx, fy = procf(λ, K, s) + + # Step Q if fy * (1 - U) <= py * exp(px - fx) return K end end while true - E = rng isa PassthroughRNG ? randexp() : randexp(rng) - U = 2 * (rng isa PassthroughRNG ? rand() : rand(rng)) - 1 - T_val = 1.8 + copysign(E, U) - if T_val <= -0.6744 + # Step E + E = randexp(rng) + U = 2 * rand(rng) - 1 + T = 1.8 + copysign(E, U) + if T <= -0.6744 continue end - K = floor(Int, λ + s * T_val) + K = floor(Int, λ + s * T) px, py, fx, fy = procf(λ, K, s) c = 0.1069 / λ + # Step H @fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E) return K end @@ -89,13 +101,12 @@ function procf(λ, K::Int, s::Float64) δ = 1 / (12 * K) δ -= 4.8 * δ^3 V = (λ - K) / K - px = K * log1pmx(V) - δ + px = K * log1pmx(V) - δ # avoids need for table py = INV_SQRT_2PI / sqrt(K) end - X = (K - λ + 0.5) / s X2 = X^2 - fx = -X2 / 2 + fx = -X2 / 2 # missing negation in pseudo-algorithm, but appears in fortran code. fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0) return px, py, fx, fy end @@ -120,7 +131,7 @@ rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) ``` """ -pois_rand(λ) = λ < 6 ? count_rand(PassthroughRNG(), λ) : ad_rand(PassthroughRNG(), λ) +pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) end # module From f94dd83d3ef109a84a3dd6a09a8ff9f20d6f9a0b Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 03:30:49 +0530 Subject: [PATCH 04/12] refactor 2 --- src/PoissonRandom.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 4cc3fab..3511511 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -84,29 +84,30 @@ end # Procedure F function procf(λ, K::Int, s::Float64) - INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π) + # can be pre-computed, but does not seem to affect performance + INV_SQRT_2PI = inv(sqrt(2pi)) ω = INV_SQRT_2PI / s - b1 = 1 / (24 * λ) - b2 = 0.3 * b1^2 - c3 = b1 * b2 / 7 + b1 = inv(24) / λ + b2 = 0.3 * b1 * b1 + c3 = inv(7) * b1 * b2 c2 = b2 - 15 * c3 c1 = b1 - 6 * b2 + 45 * c3 c0 = 1 - b1 + 3 * b2 - 15 * c3 if K < 10 px = -λ - log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma + log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma py = exp(log_py) else - δ = 1 / (12 * K) + δ = inv(12) / K δ -= 4.8 * δ^3 V = (λ - K) / K - px = K * log1pmx(V) - δ # avoids need for table + px = K * log1pmx(V) - δ # avoids need for table py = INV_SQRT_2PI / sqrt(K) end X = (K - λ + 0.5) / s X2 = X^2 - fx = -X2 / 2 # missing negation in pseudo-algorithm, but appears in fortran code. + fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code. fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0) return px, py, fx, fy end From 3a48a5cdebe129d632b5a2ddd4ab1b35d714d3f6 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 03:34:22 +0530 Subject: [PATCH 05/12] refactor 3 --- src/PoissonRandom.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 3511511..5ac9f24 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -123,13 +123,16 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm. ## Examples ```julia -# Simple Poisson random which works on GPU +# Simple Poisson random pois_rand(λ) -# Using RNG +# Using another RNG using RandomNumbers rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) + +# Simple Poisson random on GPU +pois_rand(PassthroughRNG(), λ) ``` """ pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) From 2bd65fe8aa9308f4188e4422c6f3c0f4f7cdda80 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 03:44:22 +0530 Subject: [PATCH 06/12] typo --- src/PoissonRandom.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 5ac9f24..0a4ef37 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -6,7 +6,7 @@ using SpecialFunctions: loggamma export pois_rand, PassthroughRNG -# GPU-compatible Poisson sampling PassthroughRNG +# GPU-compatible Poisson sampling via PassthroughRNG struct PassthroughRNG <: AbstractRNG end rand(rng::PassthroughRNG) = Random.rand() From 3e3e1f5718704df79ebe14514f4ffc3d9b109923 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 6 Aug 2025 03:59:28 +0530 Subject: [PATCH 07/12] Update src/PoissonRandom.jl Co-authored-by: Christopher Rackauckas --- src/PoissonRandom.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 0a4ef37..505dd0d 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -132,7 +132,7 @@ rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) # Simple Poisson random on GPU -pois_rand(PassthroughRNG(), λ) +pois_rand(PoissonRandom.PassthroughRNG(), λ) ``` """ pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) From d7733b1df269191469a8aabc44ea464cda21df27 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 04:02:37 +0530 Subject: [PATCH 08/12] minor change --- src/PoissonRandom.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 505dd0d..3387a05 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -95,7 +95,7 @@ function procf(λ, K::Int, s::Float64) c0 = 1 - b1 + 3 * b2 - 15 * c3 if K < 10 - px = -λ + px = -float(λ) log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma py = exp(log_py) else From 3b6317435ba2ac6b544d97a6aa985e0917a8bfd9 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 05:19:51 +0530 Subject: [PATCH 09/12] added rand functions to AbstractRNG --- src/PoissonRandom.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 3387a05..ef4be41 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -13,6 +13,10 @@ rand(rng::PassthroughRNG) = Random.rand() randexp(rng::PassthroughRNG) = Random.randexp() randn(rng::PassthroughRNG) = Random.randn() +rand(rng::AbstractRNG) = Random.rand(rng) +randexp(rng::AbstractRNG) = Random.randexp(rng) +randn(rng::AbstractRNG) = Random.randn(rng) + count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) function count_rand(rng::AbstractRNG, λ) n = 0 From a3dce736af8cb3c1fe3e2827ec6f16036fc47dd1 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 05:42:57 +0530 Subject: [PATCH 10/12] factorial calculation fixed --- src/PoissonRandom.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index ef4be41..cceef40 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -100,8 +100,7 @@ function procf(λ, K::Int, s::Float64) if K < 10 px = -float(λ) - log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma - py = exp(log_py) + py = λ^K / prod(1:K) else δ = inv(12) / K δ -= 4.8 * δ^3 From ae49431e4036572ef498d79b1fc92c7acc35a78c Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 05:48:31 +0530 Subject: [PATCH 11/12] optimization in factorial(K) --- src/PoissonRandom.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index cceef40..e74bc17 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -100,7 +100,7 @@ function procf(λ, K::Int, s::Float64) if K < 10 px = -float(λ) - py = λ^K / prod(1:K) + py = λ^K / prod(2:K) else δ = inv(12) / K δ -= 4.8 * δ^3 From 14e39406d4bc61dcf7bc23a9306e5edd9cc1b952 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Wed, 6 Aug 2025 05:52:02 +0530 Subject: [PATCH 12/12] removed SpecialFunctions --- Project.toml | 2 -- src/PoissonRandom.jl | 1 - 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index 9b72c8f..e8fbd31 100644 --- a/Project.toml +++ b/Project.toml @@ -5,14 +5,12 @@ version = "0.4.5" [deps] LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] Aqua = "0.8" Distributions = "0.25" LogExpFunctions = "0.3" Random = "1.10" -SpecialFunctions = "2" Statistics = "1" Test = "1" julia = "1.10" diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index e74bc17..adb05d6 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -2,7 +2,6 @@ module PoissonRandom using Random using LogExpFunctions: log1pmx -using SpecialFunctions: loggamma export pois_rand, PassthroughRNG