Skip to content
18 changes: 16 additions & 2 deletions src/PoissonRandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ module PoissonRandom
using Random
using LogExpFunctions: log1pmx

export pois_rand
export pois_rand, PassthroughRNG

# GPU-compatible Poisson sampling via PassthroughRNG
struct PassthroughRNG <: AbstractRNG end

rand(rng::PassthroughRNG) = Random.rand()
randexp(rng::PassthroughRNG) = Random.randexp()
randn(rng::PassthroughRNG) = Random.randn()

rand(rng::AbstractRNG) = Random.rand(rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas aren't these three here pirating from base?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No they are shadowing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait there is a direct using. ehhh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably delete the last of these.

randexp(rng::AbstractRNG) = Random.randexp(rng)
randn(rng::AbstractRNG) = Random.randn(rng)

count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
function count_rand(rng::AbstractRNG, λ)
Expand Down Expand Up @@ -88,7 +99,7 @@ function procf(λ, K::Int, s::Float64)

if K < 10
px = -float(λ)
py = λ^K / factorial(K)
py = λ^K / prod(2:K)
else
δ = inv(12) / K
δ -= 4.8 * δ^3
Expand Down Expand Up @@ -121,6 +132,9 @@ pois_rand(λ)
using RandomNumbers
rng = Xorshifts.Xoroshiro128Plus()
pois_rand(rng, λ)

# Simple Poisson random on GPU
pois_rand(PoissonRandom.PassthroughRNG(), λ)
```
"""
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)
Expand Down
Loading