diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 392668b..adb05d6 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -3,7 +3,18 @@ module PoissonRandom using Random using LogExpFunctions: log1pmx -export pois_rand +export pois_rand, PassthroughRNG + +# GPU-compatible Poisson sampling via PassthroughRNG +struct PassthroughRNG <: AbstractRNG end + +rand(rng::PassthroughRNG) = Random.rand() +randexp(rng::PassthroughRNG) = Random.randexp() +randn(rng::PassthroughRNG) = Random.randn() + +rand(rng::AbstractRNG) = Random.rand(rng) +randexp(rng::AbstractRNG) = Random.randexp(rng) +randn(rng::AbstractRNG) = Random.randn(rng) count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) function count_rand(rng::AbstractRNG, λ) @@ -88,7 +99,7 @@ function procf(λ, K::Int, s::Float64) if K < 10 px = -float(λ) - py = λ^K / factorial(K) + py = λ^K / prod(2:K) else δ = inv(12) / K δ -= 4.8 * δ^3 @@ -121,6 +132,9 @@ pois_rand(λ) using RandomNumbers rng = Xorshifts.Xoroshiro128Plus() pois_rand(rng, λ) + +# Simple Poisson random on GPU +pois_rand(PoissonRandom.PassthroughRNG(), λ) ``` """ pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)