diff --git a/Project.toml b/Project.toml index 1f45cd5..1da995e 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] Aqua = "0.8" Distributions = "0.25" +JET = "0.9, 0.10, 0.11" LogExpFunctions = "0.3" Random = "1.10" Statistics = "1" @@ -18,8 +19,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Statistics", "Test", "Distributions"] +test = ["Aqua", "Statistics", "Test", "Distributions", "JET"] diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 2a0a200..6db8f0c 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -12,8 +12,8 @@ Random.rand(rng::PassthroughRNG) = rand() Random.randexp(rng::PassthroughRNG) = randexp() Random.randn(rng::PassthroughRNG) = randn() -count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) -function count_rand(rng::AbstractRNG, λ) +count_rand(λ::Real) = count_rand(Random.GLOBAL_RNG, λ) +function count_rand(rng::AbstractRNG, λ::Real) n = 0 c = randexp(rng) while c < λ @@ -31,8 +31,8 @@ end # # For μ sufficiently large, (i.e. >= 10.0) # -ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) -function ad_rand(rng::AbstractRNG, λ) +ad_rand(λ::Real) = ad_rand(Random.GLOBAL_RNG, λ) +function ad_rand(rng::AbstractRNG, λ::Real) s = sqrt(λ) d = 6 * λ^2 L = floor(Int, λ - 1.1484) @@ -82,7 +82,7 @@ function ad_rand(rng::AbstractRNG, λ) end # Procedure F -function procf(λ, K::Int, s::Float64) +function procf(λ::Real, K::Int, s::Float64) # can be pre-computed, but does not seem to affect performance INV_SQRT_2PI = inv(sqrt(2pi)) ω = INV_SQRT_2PI / s @@ -133,7 +133,7 @@ pois_rand(rng, λ) pois_rand(PoissonRandom.PassthroughRNG(), λ) ``` """ -pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) -pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) +pois_rand(λ::Real) = pois_rand(Random.GLOBAL_RNG, λ) +pois_rand(rng::AbstractRNG, λ::Real) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) end # module diff --git a/test/qa.jl b/test/qa.jl index 4a633e1..f0456be 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -1,4 +1,6 @@ -using PoissonRandom, Aqua +using PoissonRandom, Aqua, JET +using Random + @testset "Aqua" begin Aqua.find_persistent_tasks_deps(PoissonRandom) Aqua.test_ambiguities(PoissonRandom, recursive = false) @@ -10,3 +12,17 @@ using PoissonRandom, Aqua Aqua.test_unbound_args(PoissonRandom) Aqua.test_undefined_exports(PoissonRandom) end + +@testset "JET static analysis" begin + @testset "Type stability" begin + JET.@test_opt target_modules = (PoissonRandom,) pois_rand(10.0) + JET.@test_opt target_modules = (PoissonRandom,) pois_rand(Random.default_rng(), 10.0) + JET.@test_opt target_modules = (PoissonRandom,) pois_rand(PassthroughRNG(), 10.0) + end + + @testset "Error analysis" begin + JET.@test_call target_modules = (PoissonRandom,) pois_rand(10.0) + JET.@test_call target_modules = (PoissonRandom,) pois_rand(Random.default_rng(), 10.0) + JET.@test_call target_modules = (PoissonRandom,) pois_rand(PassthroughRNG(), 10.0) + end +end