diff --git a/Project.toml b/Project.toml index 29123b5..d578dcb 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.4.7" [deps] LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] @@ -13,6 +14,7 @@ Distributions = "0.25" ExplicitImports = "1.14.0" JET = "0.9, 0.10, 0.11" LogExpFunctions = "0.3" +PrecompileTools = "1" Random = "1.10" Statistics = "1" Test = "1" diff --git a/src/PoissonRandom.jl b/src/PoissonRandom.jl index 8c4509b..94c9418 100644 --- a/src/PoissonRandom.jl +++ b/src/PoissonRandom.jl @@ -2,6 +2,7 @@ module PoissonRandom using Random: Random, AbstractRNG, randexp using LogExpFunctions: log1pmx +using PrecompileTools: @compile_workload export pois_rand, PassthroughRNG @@ -136,4 +137,14 @@ pois_rand(PoissonRandom.PassthroughRNG(), λ) pois_rand(λ::Real) = pois_rand(Random.GLOBAL_RNG, λ) pois_rand(rng::AbstractRNG, λ::Real) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) +@compile_workload begin + # Precompile the most common code paths + # Small λ uses count_rand, large λ uses ad_rand + pois_rand(3.0) # count_rand path (λ < 6) + pois_rand(50.0) # ad_rand path (λ >= 6) + # PassthroughRNG for GPU compatibility + pois_rand(PassthroughRNG(), 3.0) + pois_rand(PassthroughRNG(), 50.0) +end + end # module