diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..a393180 --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,17 @@ +coverage: + status: + project: + default: + target: 95% + threshold: 1% + patch: + default: + target: 95% + +ignore: + - "ext/**/*" + +comment: + layout: "reach,diff,flags,files" + behavior: default + require_changes: true diff --git a/docs/src/architecture/macro-internals.md b/docs/src/architecture/macro-internals.md index d8e5f93..4c2c920 100644 --- a/docs/src/architecture/macro-internals.md +++ b/docs/src/architecture/macro-internals.md @@ -115,50 +115,56 @@ end # If only checkpoint!(pool, Int64), Float64 arrays won't be rewound! ``` -### The Solution: `_untracked_flags` +### The Solution: Bitmask-Based Untracked Tracking -Every `acquire!` call (and convenience functions) marks itself as "untracked": +Every `acquire!` call (and convenience functions) marks itself as "untracked" with type-specific bitmask information: ```julia # Public API (called from user code outside macro) @inline function acquire!(pool, ::Type{T}, n::Int) where {T} - _mark_untracked!(pool) # ← Sets flag! + _mark_untracked!(pool, T) # ← Sets type-specific bitmask! _acquire_impl!(pool, T, n) end # Macro-transformed calls skip the marking # (because macro already knows about them) -_acquire_impl!(pool, T, n) # ← No flag +_acquire_impl!(pool, T, n) # ← No marking ``` +Each fixed-slot type maps to a bit in a `UInt16` bitmask via `_fixed_slot_bit(T)`. +Non-fixed-slot types set a separate `_untracked_has_others` flag. + ### Flow Diagram ``` -@with_pool pool begin State of pool._untracked_flags - │ ───────────────────────────────── - ├─► checkpoint!(pool, Int64) depth=2, flag[2]=false +@with_pool pool begin Bitmask state at depth 2 + │ ───────────────────────────── + ├─► checkpoint!(pool, Int64) masks[2]=0x0000, others[2]=false │ - │ A = _acquire_impl!(...) (macro-transformed, no flag set) + │ A = _acquire_impl!(...) (macro-transformed, no mark) │ B = helper!(pool) │ └─► zeros!(pool, Float64, N) - │ └─► _mark_untracked!(pool) flag[2]=TRUE ←──┐ - │ │ - │ ... more code ... │ - │ │ - └─► rewind! check: │ - if pool._untracked_flags[2] ─────────────────────────┘ - rewind!(pool) # Full rewind (safe) - else + │ └─► _mark_untracked!(pool, Float64) + │ masks[2] |= 0x0001 (Float64 bit) ←───┐ + │ │ + │ ... more code ... │ + │ │ + └─► rewind! check: │ + tracked_mask = _tracked_mask_for_types(Int64) │ + if _can_use_typed_path(pool, tracked_mask) ────────┘ rewind!(pool, Int64) # Typed rewind (fast) + else # Float64 not in {Int64} → full + rewind!(pool) # Full rewind (safe) end end ``` ### Why This Works -1. **Macro-tracked calls**: Transformed to `_acquire_impl!` → no flag → typed rewind -2. **Untracked calls**: Use public API → sets flag → triggers full rewind -3. **Result**: Always safe, with optimization when possible +1. **Macro-tracked calls**: Transformed to `_acquire_impl!` → no bitmask mark → typed path +2. **Untracked calls**: Use public API → sets type-specific bitmask → subset check at rewind +3. **Subset optimization**: If untracked types are a subset of tracked types, the typed path is still safe +4. **Result**: Always safe, with finer-grained optimization than a single boolean flag ## Nested `@with_pool` Handling @@ -170,13 +176,13 @@ Each `@with_pool` maintains its own checkpoint depth: │ ├─► @with_pool p2 begin depth: 2 → 3 │ v2 = acquire!(p2, Int64, 5) - │ helper!(p2) # sets flag[3]=true + │ helper!(p2) # marks bitmask at depth 3 │ sum(v2) - │ end depth: 3 → 2, flag[3] checked + │ end depth: 3 → 2, bitmask checked │ │ # v1 still valid here! sum(v1) -end depth: 2 → 1, flag[2] checked +end depth: 2 → 1, bitmask checked ``` ### Depth Tracking Data Structures @@ -184,13 +190,15 @@ end depth: 2 → 1, flag[2] checked ```julia struct AdaptiveArrayPool # ... type pools ... - _current_depth::Int # Current scope depth (1 = global) - _untracked_flags::Vector{Bool} # Per-depth flag array + _current_depth::Int # Current scope depth (1 = global) + _untracked_fixed_masks::Vector{UInt16} # Per-depth: which fixed slots untracked + _untracked_has_others::Vector{Bool} # Per-depth: any non-fixed-slot untracked end # Initialized with sentinel: -_current_depth = 1 # Global scope -_untracked_flags = [false] # Sentinel for depth=1 +_current_depth = 1 # Global scope +_untracked_fixed_masks = [UInt16(0)] # Sentinel for depth=1 +_untracked_has_others = [false] # Sentinel for depth=1 ``` ## Performance Impact @@ -199,9 +207,12 @@ _untracked_flags = [false] # Sentinel for depth=1 |----------|-------------------|----------------| | 1 type, no untracked | `checkpoint!(pool, T)` | **~77% faster** | | Multiple types, no untracked | `checkpoint!(pool, T1, T2, ...)` | **~50% faster** | -| Any untracked acquire | `checkpoint!(pool)` | Baseline | +| Untracked subset of tracked | `checkpoint!(pool, T...)` | **~77% faster** | +| Unknown untracked types | `checkpoint!(pool)` | Baseline | -The optimization matters most in tight loops with many iterations. +The optimization matters most in tight loops with many iterations. The bitmask subset +check allows the typed path even when untracked acquires occur, as long as those types +are already covered by the macro's tracked set. ## Code Generation Summary @@ -217,11 +228,11 @@ end function compute(data) pool = get_task_local_pool() - # Check if parent scope had untracked (for nested pools) - if pool._untracked_flags[pool._current_depth] - checkpoint!(pool) # Full checkpoint + # Bitmask subset check: can typed path handle any untracked acquires? + if _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) + checkpoint!(pool, Float64) # Typed checkpoint (fast) else - checkpoint!(pool, Float64) # Typed checkpoint + checkpoint!(pool) # Full checkpoint (safe) end try @@ -229,11 +240,10 @@ function compute(data) result = helper!(pool, A) return result finally - # Check if untracked acquires occurred in this scope - if pool._untracked_flags[pool._current_depth] - rewind!(pool) # Full rewind + if _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) + rewind!(pool, Float64) # Typed rewind (fast) else - rewind!(pool, Float64) # Typed rewind + rewind!(pool) # Full rewind (safe) end end end @@ -246,8 +256,10 @@ end | `_extract_acquire_types(expr, pool_name)` | AST walk to find types | | `_filter_static_types(types, local_vars)` | Filter out locally-defined types | | `_transform_acquire_calls(expr, pool_name)` | Replace `acquire!` → `_acquire_impl!` | -| `_mark_untracked!(pool)` | Set untracked flag for current depth | -| `_generate_typed_checkpoint_call(pool, types)` | Generate `checkpoint!(pool, T...)` | +| `_mark_untracked!(pool, T)` | Set type-specific bitmask for current depth | +| `_can_use_typed_path(pool, mask)` | Bitmask subset check for typed vs full path | +| `_tracked_mask_for_types(T...)` | Compile-time bitmask for tracked types | +| `_generate_typed_checkpoint_call(pool, types)` | Generate bitmask-aware checkpoint | ## See Also diff --git a/ext/AdaptiveArrayPoolsCUDAExt/state.jl b/ext/AdaptiveArrayPoolsCUDAExt/state.jl index a7ccd03..e4e6354 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/state.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/state.jl @@ -31,9 +31,10 @@ end # ============================================================================== function AdaptiveArrayPools.checkpoint!(pool::CuAdaptiveArrayPool) - # Increment depth and initialize untracked flag + # Increment depth and initialize untracked bitmask state pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) depth = pool._current_depth # Fixed slots - zero allocation via @generated iteration @@ -52,17 +53,27 @@ end # Type-specific checkpoint (single type) @inline function AdaptiveArrayPools.checkpoint!(pool::CuAdaptiveArrayPool, ::Type{T}) where {T} pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) _checkpoint_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, T), pool._current_depth) nothing end # Type-specific checkpoint (multiple types) @generated function AdaptiveArrayPools.checkpoint!(pool::CuAdaptiveArrayPool, types::Type...) - checkpoint_exprs = [:(_checkpoint_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in 1:length(types)] + seen = Set{Any}() + unique_indices = Int[] + for i in eachindex(types) + if !(types[i] in seen) + push!(seen, types[i]) + push!(unique_indices, i) + end + end + checkpoint_exprs = [:(_checkpoint_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in unique_indices] quote pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) $(checkpoint_exprs...) nothing end @@ -91,7 +102,8 @@ function AdaptiveArrayPools.rewind!(pool::CuAdaptiveArrayPool) _rewind_typed_pool!(tp, cur_depth) end - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 return nothing @@ -104,22 +116,32 @@ end return nothing end _rewind_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, T), pool._current_depth) - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 nothing end # Type-specific rewind (multiple types) @generated function AdaptiveArrayPools.rewind!(pool::CuAdaptiveArrayPool, types::Type...) - rewind_exprs = [:(_rewind_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in length(types):-1:1] - reset_exprs = [:(reset!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]))) for i in 1:length(types)] + seen = Set{Any}() + unique_indices = Int[] + for i in eachindex(types) + if !(types[i] in seen) + push!(seen, types[i]) + push!(unique_indices, i) + end + end + rewind_exprs = [:(_rewind_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in reverse(unique_indices)] + reset_exprs = [:(reset!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]))) for i in unique_indices] quote if pool._current_depth == 1 $(reset_exprs...) return nothing end $(rewind_exprs...) - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 nothing end @@ -140,10 +162,12 @@ function AdaptiveArrayPools.reset!(pool::CuAdaptiveArrayPool) reset!(tp) end - # Reset untracked detection state + # Reset depth and bitmask sentinel state pool._current_depth = 1 - empty!(pool._untracked_flags) - push!(pool._untracked_flags, false) + empty!(pool._untracked_fixed_masks) + push!(pool._untracked_fixed_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._untracked_has_others) + push!(pool._untracked_has_others, false) # Sentinel: no others return pool end @@ -197,10 +221,12 @@ function Base.empty!(pool::CuAdaptiveArrayPool) end empty!(pool.others) - # Reset state + # Reset depth and bitmask sentinel state pool._current_depth = 1 - empty!(pool._untracked_flags) - push!(pool._untracked_flags, false) + empty!(pool._untracked_fixed_masks) + push!(pool._untracked_fixed_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._untracked_has_others) + push!(pool._untracked_has_others, false) # Sentinel: no others return pool end diff --git a/ext/AdaptiveArrayPoolsCUDAExt/types.jl b/ext/AdaptiveArrayPoolsCUDAExt/types.jl index 096984b..056bd18 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/types.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/types.jl @@ -112,7 +112,8 @@ mutable struct CuAdaptiveArrayPool <: AbstractArrayPool # State management (same as CPU) _current_depth::Int - _untracked_flags::Vector{Bool} + _untracked_fixed_masks::Vector{UInt16} # Per-depth: which fixed slots had untracked acquires + _untracked_has_others::Vector{Bool} # Per-depth: any non-fixed-slot untracked acquire? # Device tracking (safety) device_id::Int @@ -131,7 +132,8 @@ function CuAdaptiveArrayPool() CuTypedPool{Bool}(), IdDict{DataType, Any}(), 1, # _current_depth (1 = global scope) - [false], # _untracked_flags sentinel + [UInt16(0)], # _untracked_fixed_masks: sentinel (no bits set) + [false], # _untracked_has_others: sentinel (no others) CUDA.deviceid(dev) # Use public API ) end diff --git a/src/acquire.jl b/src/acquire.jl index 6113202..25038cc 100644 --- a/src/acquire.jl +++ b/src/acquire.jl @@ -164,17 +164,23 @@ end # ============================================================================== """ - _mark_untracked!(pool::AbstractArrayPool) + _mark_untracked!(pool::AbstractArrayPool, ::Type{T}) -Mark that an untracked acquire has occurred at the current checkpoint depth. +Mark that an untracked acquire of type `T` has occurred at the current checkpoint depth. Called by `acquire!` wrapper; macro-transformed calls use `_acquire_impl!` directly. -With 1-indexed _current_depth (starting at 1 for global scope), this always marks -the current scope's _untracked_flags. +For fixed-slot types, sets the corresponding bit in `_untracked_fixed_masks`. +For non-fixed-slot types, sets `_untracked_has_others` flag. """ -@inline function _mark_untracked!(pool::AbstractArrayPool) - # Always mark (_current_depth >= 1 guaranteed by sentinel) - @inbounds pool._untracked_flags[pool._current_depth] = true +@inline function _mark_untracked!(pool::AbstractArrayPool, ::Type{T}) where {T} + depth = pool._current_depth + b = _fixed_slot_bit(T) + if b == UInt16(0) + @inbounds pool._untracked_has_others[depth] = true + else + @inbounds pool._untracked_fixed_masks[depth] |= b + end + nothing end # ============================================================================== @@ -265,19 +271,19 @@ end See also: [`unsafe_acquire!`](@ref) for native array access. """ @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _acquire_impl!(pool, T, n) end # Multi-dimensional support (zero-allocation with N-D cache) @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _acquire_impl!(pool, T, dims...) end # Tuple support: allows acquire!(pool, T, size(A)) where size(A) returns NTuple{N,Int} @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _acquire_impl!(pool, T, dims...) end @@ -297,7 +303,7 @@ end ``` """ @inline function acquire!(pool::AbstractArrayPool, x::AbstractArray) - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _acquire_impl!(pool, eltype(x), size(x)) end @@ -352,18 +358,18 @@ end See also: [`acquire!`](@ref) for view-based access. """ @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_acquire_impl!(pool, T, n) end @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_acquire_impl!(pool, T, dims...) end # Tuple support @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_acquire_impl!(pool, T, dims) end @@ -383,7 +389,7 @@ end ``` """ @inline function unsafe_acquire!(pool::AbstractArrayPool, x::AbstractArray) - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _unsafe_acquire_impl!(pool, eltype(x), size(x)) end diff --git a/src/convenience.jl b/src/convenience.jl index 4d5d8f7..62d919f 100644 --- a/src/convenience.jl +++ b/src/convenience.jl @@ -43,22 +43,22 @@ end See also: [`ones!`](@ref), [`similar!`](@ref), [`acquire!`](@ref) """ @inline function zeros!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _zeros_impl!(pool, T, dims...) end @inline function zeros!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _zeros_impl!(pool, default_eltype(pool), dims...) end @inline function zeros!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N,Int}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _zeros_impl!(pool, T, dims...) end @inline function zeros!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _zeros_impl!(pool, default_eltype(pool), dims...) end @@ -116,22 +116,22 @@ end See also: [`zeros!`](@ref), [`similar!`](@ref), [`acquire!`](@ref) """ @inline function ones!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _ones_impl!(pool, T, dims...) end @inline function ones!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _ones_impl!(pool, default_eltype(pool), dims...) end @inline function ones!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N,Int}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _ones_impl!(pool, T, dims...) end @inline function ones!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _ones_impl!(pool, default_eltype(pool), dims...) end @@ -186,11 +186,11 @@ end See also: [`falses!`](@ref), [`ones!`](@ref), [`acquire!`](@ref) """ @inline function trues!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, Bit) _trues_impl!(pool, dims...) end @inline function trues!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, Bit) _trues_impl!(pool, dims...) end @@ -226,11 +226,11 @@ end See also: [`trues!`](@ref), [`zeros!`](@ref), [`acquire!`](@ref) """ @inline function falses!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, Bit) _falses_impl!(pool, dims...) end @inline function falses!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, Bit) _falses_impl!(pool, dims...) end @@ -273,22 +273,22 @@ end See also: [`zeros!`](@ref), [`ones!`](@ref), [`acquire!`](@ref) """ @inline function similar!(pool::AbstractArrayPool, x::AbstractArray) - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _similar_impl!(pool, x) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}) where {T} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _similar_impl!(pool, x, T) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _similar_impl!(pool, x, dims...) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _similar_impl!(pool, x, T, dims...) end @@ -336,22 +336,22 @@ end See also: [`unsafe_ones!`](@ref), [`zeros!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_zeros!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_zeros_impl!(pool, T, dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _unsafe_zeros_impl!(pool, default_eltype(pool), dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N,Int}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_zeros_impl!(pool, T, dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _unsafe_zeros_impl!(pool, default_eltype(pool), dims...) end @@ -403,22 +403,22 @@ end See also: [`unsafe_zeros!`](@ref), [`ones!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_ones!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_ones_impl!(pool, T, dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _unsafe_ones_impl!(pool, default_eltype(pool), dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N,Int}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_ones_impl!(pool, T, dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, dims::NTuple{N,Int}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, default_eltype(pool)) _unsafe_ones_impl!(pool, default_eltype(pool), dims...) end @@ -473,22 +473,22 @@ end See also: [`similar!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray) - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _unsafe_similar_impl!(pool, x) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}) where {T} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_similar_impl!(pool, x, T) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, dims::Vararg{Int,N}) where {N} - _mark_untracked!(pool) + _mark_untracked!(pool, eltype(x)) _unsafe_similar_impl!(pool, x, dims...) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int,N}) where {T,N} - _mark_untracked!(pool) + _mark_untracked!(pool, T) _unsafe_similar_impl!(pool, x, T, dims...) end diff --git a/src/macros.jl b/src/macros.jl index d79a65e..c0011f6 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -340,31 +340,14 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu # Transform acquire! calls to _acquire_impl! (bypasses untracked marking) transformed_expr = _transform_acquire_calls(expr, pool_name) - # For typed checkpoint, add _untracked_flags check for fallback to full checkpoint - # This protects parent scope arrays when entering nested @with_pool if use_typed - typed_checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - checkpoint_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $checkpoint!($(esc(pool_name))) # Full checkpoint (parent had untracked) - else - $typed_checkpoint_call # Fast typed checkpoint - end - end + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) else checkpoint_call = :($checkpoint!($(esc(pool_name)))) end - # For typed checkpoint, add _untracked_flags check for fallback to full rewind if use_typed - typed_rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - rewind_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $rewind!($(esc(pool_name))) # Full rewind (untracked detected) - else - $typed_rewind_call # Fast typed rewind - end - end + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else rewind_call = :($rewind!($(esc(pool_name)))) end @@ -449,22 +432,8 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc pool_getter = :($_get_pool_for_backend($(Val{backend}()))) if use_typed - typed_checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - checkpoint_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $checkpoint!($(esc(pool_name))) - else - $typed_checkpoint_call - end - end - typed_rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - rewind_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $rewind!($(esc(pool_name))) - else - $typed_rewind_call - end - end + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else checkpoint_call = :($checkpoint!($(esc(pool_name)))) rewind_call = :($rewind!($(esc(pool_name)))) @@ -509,30 +478,14 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc # Use Val{backend}() for compile-time dispatch - fully inlinable pool_getter = :($_get_pool_for_backend($(Val{backend}()))) - # Generate checkpoint call (typed or full) if use_typed - typed_checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - checkpoint_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $checkpoint!($(esc(pool_name))) # Full checkpoint (parent had untracked) - else - $typed_checkpoint_call # Fast typed checkpoint - end - end + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) else checkpoint_call = :($checkpoint!($(esc(pool_name)))) end - # Generate rewind call (typed or full) if use_typed - typed_rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - rewind_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $rewind!($(esc(pool_name))) # Full rewind (untracked detected) - else - $typed_rewind_call # Fast typed rewind - end - end + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else rewind_call = :($rewind!($(esc(pool_name)))) end @@ -586,30 +539,14 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f # Use Val{backend}() for compile-time dispatch pool_getter = :($_get_pool_for_backend($(Val{backend}()))) - # Generate checkpoint call (typed or full) if use_typed - typed_checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - checkpoint_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $checkpoint!($(esc(pool_name))) - else - $typed_checkpoint_call - end - end + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) else checkpoint_call = :($checkpoint!($(esc(pool_name)))) end - # Generate rewind call (typed or full) if use_typed - typed_rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - rewind_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $rewind!($(esc(pool_name))) - else - $typed_rewind_call - end - end + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else rewind_call = :($rewind!($(esc(pool_name)))) end @@ -655,31 +592,14 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable # Transform acquire! calls to _acquire_impl! (bypasses untracked marking) transformed_body = _transform_acquire_calls(body, pool_name) - # For typed checkpoint, add _untracked_flags check for fallback to full checkpoint - # This protects parent scope arrays when entering nested @with_pool if use_typed - typed_checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - checkpoint_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $checkpoint!($(esc(pool_name))) # Full checkpoint (parent had untracked) - else - $typed_checkpoint_call # Fast typed checkpoint - end - end + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) else checkpoint_call = :($checkpoint!($(esc(pool_name)))) end - # For typed checkpoint, add _untracked_flags check for fallback to full rewind if use_typed - typed_rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - rewind_call = quote - if @inbounds $(esc(pool_name))._untracked_flags[$(esc(pool_name))._current_depth] - $rewind!($(esc(pool_name))) # Full rewind (untracked detected) - else - $typed_rewind_call # Fast typed rewind - end - end + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else rewind_call = :($rewind!($(esc(pool_name)))) end @@ -982,30 +902,48 @@ end """ _generate_typed_checkpoint_call(pool_expr, types) -Generate checkpoint!(pool, T1, T2, ...) call expression. +Generate bitmask-aware checkpoint call. When types are known at compile time, +emits a conditional: if untracked types ⊆ tracked types → typed checkpoint, +otherwise → full checkpoint. """ function _generate_typed_checkpoint_call(pool_expr, types) if isempty(types) return :($checkpoint!($pool_expr)) else - # esc types so they resolve in caller's namespace (Float64, not AdaptiveArrayPools.Float64) escaped_types = [esc(t) for t in types] - return :($checkpoint!($pool_expr, $(escaped_types...))) + typed_call = :($checkpoint!($pool_expr, $(escaped_types...))) + full_call = :($checkpoint!($pool_expr)) + return quote + if $_can_use_typed_path($pool_expr, $_tracked_mask_for_types($(escaped_types...))) + $typed_call + else + $full_call + end + end end end """ _generate_typed_rewind_call(pool_expr, types) -Generate rewind!(pool, T1, T2, ...) call expression. -Types are passed in original order; rewind! handles reversal internally. +Generate bitmask-aware rewind call. When types are known at compile time, +emits a conditional: if untracked types ⊆ tracked types → typed rewind, +otherwise → full rewind. """ function _generate_typed_rewind_call(pool_expr, types) if isempty(types) return :($rewind!($pool_expr)) else escaped_types = [esc(t) for t in types] - return :($rewind!($pool_expr, $(escaped_types...))) + typed_call = :($rewind!($pool_expr, $(escaped_types...))) + full_call = :($rewind!($pool_expr)) + return quote + if $_can_use_typed_path($pool_expr, $_tracked_mask_for_types($(escaped_types...))) + $typed_call + else + $full_call + end + end end end diff --git a/src/state.jl b/src/state.jl index c9b2a66..9a831d5 100644 --- a/src/state.jl +++ b/src/state.jl @@ -13,9 +13,10 @@ After warmup, this function has **zero allocation**. See also: [`rewind!`](@ref), [`@with_pool`](@ref) """ function checkpoint!(pool::AdaptiveArrayPool) - # Increment depth and initialize untracked flag + # Increment depth and initialize untracked bitmask state pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) depth = pool._current_depth # Fixed slots - zero allocation via @generated iteration @@ -37,13 +38,14 @@ end Save state for a specific type only. Used by optimized macros that know which types will be used at compile time. -Also updates _current_depth and _untracked_flags for untracked acquire detection. +Also updates _current_depth and bitmask state for untracked acquire detection. ~77% faster than full checkpoint! when only one type is used. """ @inline function checkpoint!(pool::AdaptiveArrayPool, ::Type{T}) where T pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) _checkpoint_typed_pool!(get_typed_pool!(pool, T), pool._current_depth) nothing end @@ -67,7 +69,8 @@ compile-time unrolling. Increments _current_depth once for all types. checkpoint_exprs = [:(_checkpoint_typed_pool!(get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in unique_indices] quote pool._current_depth += 1 - push!(pool._untracked_flags, false) + push!(pool._untracked_fixed_masks, UInt16(0)) + push!(pool._untracked_has_others, false) $(checkpoint_exprs...) nothing end @@ -118,7 +121,8 @@ function rewind!(pool::AdaptiveArrayPool) _rewind_typed_pool!(tp, cur_depth) end - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 return nothing @@ -128,7 +132,7 @@ end rewind!(pool::AdaptiveArrayPool, ::Type{T}) Restore state for a specific type only. -Also updates _current_depth and _untracked_flags. +Also updates _current_depth and bitmask state. """ @inline function rewind!(pool::AdaptiveArrayPool, ::Type{T}) where T # Safety guard: at global scope (depth=1), delegate to reset! @@ -137,7 +141,8 @@ Also updates _current_depth and _untracked_flags. return nothing end _rewind_typed_pool!(get_typed_pool!(pool, T), pool._current_depth) - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 nothing end @@ -167,7 +172,8 @@ Decrements _current_depth once after all types are rewound. return nothing end $(rewind_exprs...) - pop!(pool._untracked_flags) + pop!(pool._untracked_fixed_masks) + pop!(pool._untracked_has_others) pool._current_depth -= 1 nothing end @@ -284,8 +290,10 @@ function Base.empty!(pool::AdaptiveArrayPool) # Reset untracked detection state (1-based sentinel pattern) pool._current_depth = 1 # 1 = global scope (sentinel) - empty!(pool._untracked_flags) - push!(pool._untracked_flags, false) # Sentinel: global scope starts with false + empty!(pool._untracked_fixed_masks) + push!(pool._untracked_fixed_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._untracked_has_others) + push!(pool._untracked_has_others, false) # Sentinel: no others return pool end @@ -318,7 +326,7 @@ Reset pool state without clearing allocated storage. This function: - Resets all `n_active` counters to 0 - Restores all checkpoint stacks to sentinel state -- Resets `_current_depth` and `_untracked_flags` +- Resets `_current_depth` and untracked bitmask state Unlike `empty!`, this **preserves** all allocated vectors, views, and N-D arrays for reuse, avoiding reallocation costs. @@ -363,8 +371,10 @@ function reset!(pool::AdaptiveArrayPool) # Reset untracked detection state (1-based sentinel pattern) pool._current_depth = 1 # 1 = global scope (sentinel) - empty!(pool._untracked_flags) - push!(pool._untracked_flags, false) # Sentinel: global scope starts with false + empty!(pool._untracked_fixed_masks) + push!(pool._untracked_fixed_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._untracked_has_others) + push!(pool._untracked_has_others, false) # Sentinel: no others return pool end @@ -398,6 +408,46 @@ See also: [`reset!(::AdaptiveArrayPool)`](@ref), [`rewind!`](@ref) end end +# ============================================================================== +# Bitmask Helpers for Typed Path Decisions +# ============================================================================== + +""" + _tracked_mask_for_types(types::Type...) -> UInt16 + +Compute compile-time bitmask for the types tracked by a typed checkpoint/rewind. +Uses `@generated` for zero-overhead constant folding. + +Returns `UInt16(0)` when called with no arguments. +Non-fixed-slot types contribute `UInt16(0)` (their bit is 0). +""" +@generated function _tracked_mask_for_types(types::Type...) + mask = UInt16(0) + for i in 1:length(types) + T = types[i].parameters[1] + mask |= _fixed_slot_bit(T) + end + return :(UInt16($mask)) +end + +""" + _can_use_typed_path(pool::AbstractArrayPool, tracked_mask::UInt16) -> Bool + +Check if the typed (fast) checkpoint/rewind path is safe to use. + +Returns `true` when all untracked acquires at the current depth are a subset +of the tracked types (bitmask subset check) AND no non-fixed-slot types were used. + +The subset check: `(untracked_mask & ~tracked_mask) == 0` means every bit set +in `untracked_mask` is also set in `tracked_mask`. +""" +@inline function _can_use_typed_path(pool::AbstractArrayPool, tracked_mask::UInt16) + depth = pool._current_depth + untracked_mask = @inbounds pool._untracked_fixed_masks[depth] + has_others = @inbounds pool._untracked_has_others[depth] + return (untracked_mask & ~tracked_mask) == UInt16(0) && !has_others +end + # ============================================================================== # DisabledPool State Management (no-ops) # ============================================================================== diff --git a/src/types.jl b/src/types.jl index 0b6f62f..e6adb4c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -366,6 +366,22 @@ Tests verify synchronization automatically. """ const FIXED_SLOT_FIELDS = (:float64, :float32, :int64, :int32, :complexf64, :complexf32, :bool, :bits) +# ============================================================================== +# Fixed-Slot Bit Mapping (for typed untracked tracking) +# ============================================================================== +# Maps each fixed-slot type to a unique bit in a UInt16 bitmask. +# Bit ordering matches FIXED_SLOT_FIELDS. Non-fixed types return UInt16(0). + +@inline _fixed_slot_bit(::Type{Float64}) = UInt16(1) << 0 +@inline _fixed_slot_bit(::Type{Float32}) = UInt16(1) << 1 +@inline _fixed_slot_bit(::Type{Int64}) = UInt16(1) << 2 +@inline _fixed_slot_bit(::Type{Int32}) = UInt16(1) << 3 +@inline _fixed_slot_bit(::Type{ComplexF64}) = UInt16(1) << 4 +@inline _fixed_slot_bit(::Type{ComplexF32}) = UInt16(1) << 5 +@inline _fixed_slot_bit(::Type{Bool}) = UInt16(1) << 6 +@inline _fixed_slot_bit(::Type{Bit}) = UInt16(1) << 7 +@inline _fixed_slot_bit(::Type) = UInt16(0) # non-fixed-slot → triggers has_others + # ============================================================================== # AdaptiveArrayPool # ============================================================================== @@ -392,7 +408,8 @@ mutable struct AdaptiveArrayPool <: AbstractArrayPool # Untracked acquire detection (1-based sentinel pattern) _current_depth::Int # Current scope depth (1 = global scope) - _untracked_flags::Vector{Bool} # Per-depth flag: true if untracked acquire occurred + _untracked_fixed_masks::Vector{UInt16} # Per-depth: which fixed slots had untracked acquires + _untracked_has_others::Vector{Bool} # Per-depth: any non-fixed-slot untracked acquire? end function AdaptiveArrayPool() @@ -407,7 +424,8 @@ function AdaptiveArrayPool() BitTypedPool(), IdDict{DataType, Any}(), 1, # _current_depth: 1 = global scope (sentinel) - [false] # _untracked_flags: sentinel for global scope + [UInt16(0)], # _untracked_fixed_masks: sentinel (no bits set) + [false] # _untracked_has_others: sentinel (no others) ) end diff --git a/test/test_macro_expansion.jl b/test/test_macro_expansion.jl index 76638d6..67c0dae 100644 --- a/test/test_macro_expansion.jl +++ b/test/test_macro_expansion.jl @@ -708,3 +708,76 @@ end end end # Source Location Preservation + +# ============================================================================== +# Phase 3: Bitmask-aware checkpoint/rewind in macro expansion +# ============================================================================== + +@testset "Bitmask-aware typed path in expansion" begin + @testset "@with_pool typed expansion uses _can_use_typed_path" begin + expr = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + expr_str = string(expr) + + # Should use _can_use_typed_path instead of _untracked_flags + @test occursin("_can_use_typed_path", expr_str) + @test !occursin("_untracked_flags", expr_str) + + # Should use _tracked_mask_for_types + @test occursin("_tracked_mask_for_types", expr_str) + end + + @testset "@with_pool full path (dynamic types) has no bitmask" begin + expr = @macroexpand @with_pool pool begin + local_arr = rand(10) + v = acquire!(pool, local_arr) + sum(v) + end + + expr_str = string(expr) + + # Full path (dynamic type) → no _can_use_typed_path + @test !occursin("_can_use_typed_path", expr_str) + @test !occursin("_tracked_mask_for_types", expr_str) + end + + @testset "@maybe_with_pool typed expansion uses _can_use_typed_path" begin + expr = @macroexpand @maybe_with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + expr_str = string(expr) + + @test occursin("_can_use_typed_path", expr_str) + @test !occursin("_untracked_flags", expr_str) + end + + @testset "@with_pool :cpu backend uses _can_use_typed_path" begin + expr = @macroexpand @with_pool :cpu pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + expr_str = string(expr) + + @test occursin("_can_use_typed_path", expr_str) + @test !occursin("_untracked_flags", expr_str) + end + + @testset "@with_pool function def uses _can_use_typed_path" begin + expr = @macroexpand @with_pool pool function test_fn(x) + v = acquire!(pool, Float64, length(x)) + v .= x + sum(v) + end + + expr_str = string(expr) + + @test occursin("_can_use_typed_path", expr_str) + @test !occursin("_untracked_flags", expr_str) + end +end diff --git a/test/test_state.jl b/test/test_state.jl index 9881f41..52d25bd 100644 --- a/test/test_state.jl +++ b/test/test_state.jl @@ -306,7 +306,8 @@ reset!(pool) @test pool._current_depth == 1 - @test pool._untracked_flags == [false] + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] @test pool.float64._checkpoint_n_active == [0] # Sentinel only @test pool.float64._checkpoint_depths == [0] # Sentinel only end @@ -493,7 +494,7 @@ rewind!(pool) @test pool.float64.n_active == 0 @test pool._current_depth == 1 - @test pool._untracked_flags == [false] + @test pool._untracked_fixed_masks == [UInt16(0)] end @testset "rewind! after reset!" begin @@ -734,25 +735,24 @@ @testset "Parent scope protection via full checkpoint" begin # Test: Parent scope arrays are protected by automatic full checkpoint - # when entering @with_pool with _untracked_flags[_current_depth] = true + # when untracked acquire occurred at current depth # Helper function that acquires Int64 (called from inside @with_pool) # Since it's defined outside the macro, acquire! won't be transformed function untracked_helper(p) - acquire!(p, Int64, 5) # This will mark _untracked_flags = true + acquire!(p, Int64, 5) # This will mark untracked bitmask end pool = get_task_local_pool() empty!(pool) # Start fresh - # Acquire Int64 array OUTSIDE @with_pool - marks global _untracked_flags + # Acquire Int64 array OUTSIDE @with_pool - marks untracked bitmask v_parent = acquire!(pool, Int64, 10) v_parent .= 42 # Initialize @test pool.int64.n_active == 1 - @test pool._untracked_flags[1] == true # Global scope marked + @test pool._untracked_fixed_masks[1] == AdaptiveArrayPools._fixed_slot_bit(Int64) - # Enter @with_pool - should do FULL checkpoint (because _untracked_flags[1] = true) - # This protects the parent's Int64 arrays + # Enter @with_pool - full checkpoint protects parent's Int64 arrays @with_pool pool begin v_float = acquire!(pool, Float64, 100) # Tracked untracked_helper(pool) # Untracked Int64 acquire! @@ -776,7 +776,7 @@ v_parent = acquire!(pool, Int32, 7) v_parent .= Int32(123) @test pool.int32.n_active == 1 - @test pool._untracked_flags[1] == true + @test pool._untracked_fixed_masks[1] == AdaptiveArrayPools._fixed_slot_bit(Int32) # Helper for Int32 function int32_helper(p) @@ -827,7 +827,7 @@ pool = AdaptiveArrayPool() # No global untracked acquire - @test pool._untracked_flags[1] == false + @test pool._untracked_fixed_masks[1] == UInt16(0) # Checkpoint/rewind with typed - should work normally checkpoint!(pool) @@ -1002,7 +1002,7 @@ @test pool.bool.n_active == 0 @test pool.complexf64.n_active == 0 @test pool._current_depth == 1 - @test pool._untracked_flags == [false] + @test pool._untracked_fixed_masks == [UInt16(0)] empty!(pool) end @@ -1318,4 +1318,613 @@ empty!(pool) end + # ========================================================================== + # Typed-Aware Untracked Tracking — Phase 1: Bitmask Metadata Lifecycle + # ========================================================================== + + @testset "Bitmask metadata: constructor sentinel" begin + pool = AdaptiveArrayPool() + + # New fields exist + @test hasfield(AdaptiveArrayPool, :_untracked_fixed_masks) + @test hasfield(AdaptiveArrayPool, :_untracked_has_others) + + # Sentinel values at depth=1 (global scope) + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + @test length(pool._untracked_fixed_masks) == 1 + @test length(pool._untracked_has_others) == 1 + end + + @testset "Bitmask metadata: checkpoint! pushes sentinels" begin + pool = AdaptiveArrayPool() + + # Full checkpoint + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 2 + @test length(pool._untracked_has_others) == 2 + @test pool._untracked_fixed_masks[2] == UInt16(0) + @test pool._untracked_has_others[2] == false + + # Another checkpoint + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 3 + @test length(pool._untracked_has_others) == 3 + + # Cleanup + rewind!(pool) + rewind!(pool) + end + + @testset "Bitmask metadata: typed checkpoint! pushes sentinels" begin + pool = AdaptiveArrayPool() + + # Single-type checkpoint + checkpoint!(pool, Float64) + @test length(pool._untracked_fixed_masks) == 2 + @test length(pool._untracked_has_others) == 2 + @test pool._untracked_fixed_masks[2] == UInt16(0) + @test pool._untracked_has_others[2] == false + rewind!(pool, Float64) + + # Multi-type checkpoint + checkpoint!(pool, Float64, Float32) + @test length(pool._untracked_fixed_masks) == 2 + @test length(pool._untracked_has_others) == 2 + @test pool._untracked_fixed_masks[2] == UInt16(0) + @test pool._untracked_has_others[2] == false + rewind!(pool, Float64, Float32) + end + + @testset "Bitmask metadata: rewind! pops" begin + pool = AdaptiveArrayPool() + + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 2 + @test length(pool._untracked_has_others) == 2 + + rewind!(pool) + @test length(pool._untracked_fixed_masks) == 1 + @test length(pool._untracked_has_others) == 1 + # Sentinel preserved + @test pool._untracked_fixed_masks[1] == UInt16(0) + @test pool._untracked_has_others[1] == false + end + + @testset "Bitmask metadata: typed rewind! pops" begin + pool = AdaptiveArrayPool() + + checkpoint!(pool, Float64) + @test length(pool._untracked_fixed_masks) == 2 + + rewind!(pool, Float64) + @test length(pool._untracked_fixed_masks) == 1 + @test length(pool._untracked_has_others) == 1 + + # Multi-type + checkpoint!(pool, Float64, Int64) + @test length(pool._untracked_fixed_masks) == 2 + + rewind!(pool, Float64, Int64) + @test length(pool._untracked_fixed_masks) == 1 + end + + @testset "Bitmask metadata: reset! restores sentinel" begin + pool = AdaptiveArrayPool() + + # Build up state + checkpoint!(pool) + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 3 + + reset!(pool) + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + @test pool._current_depth == 1 + end + + @testset "Bitmask metadata: empty! restores sentinel" begin + pool = AdaptiveArrayPool() + + # Build up state + checkpoint!(pool) + acquire!(pool, Float64, 10) + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 3 + + empty!(pool) + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + @test pool._current_depth == 1 + end + + @testset "Bitmask metadata: multiple checkpoint/rewind cycles" begin + pool = AdaptiveArrayPool() + + for _ in 1:5 + checkpoint!(pool) + rewind!(pool) + end + + # No stack leaks — should be back to sentinel only + @test length(pool._untracked_fixed_masks) == 1 + @test length(pool._untracked_has_others) == 1 + @test pool._current_depth == 1 + end + + @testset "Bitmask metadata: nested depth tracking" begin + pool = AdaptiveArrayPool() + + # Depth 2 + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 2 + + # Depth 3 + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 3 + + # Depth 4 + checkpoint!(pool) + @test length(pool._untracked_fixed_masks) == 4 + + # Pop back + rewind!(pool) + @test length(pool._untracked_fixed_masks) == 3 + + rewind!(pool) + @test length(pool._untracked_fixed_masks) == 2 + + rewind!(pool) + @test length(pool._untracked_fixed_masks) == 1 + end + + # ========================================================================== + # Typed-Aware Untracked Tracking — Phase 2: _fixed_slot_bit + typed marking + # ========================================================================== + + @testset "_fixed_slot_bit dispatch" begin + using AdaptiveArrayPools: _fixed_slot_bit, Bit + + # Each fixed slot returns a unique nonzero bit + @test _fixed_slot_bit(Float64) == UInt16(1) << 0 + @test _fixed_slot_bit(Float32) == UInt16(1) << 1 + @test _fixed_slot_bit(Int64) == UInt16(1) << 2 + @test _fixed_slot_bit(Int32) == UInt16(1) << 3 + @test _fixed_slot_bit(ComplexF64) == UInt16(1) << 4 + @test _fixed_slot_bit(ComplexF32) == UInt16(1) << 5 + @test _fixed_slot_bit(Bool) == UInt16(1) << 6 + @test _fixed_slot_bit(Bit) == UInt16(1) << 7 + + # Non-fixed-slot types return 0 + @test _fixed_slot_bit(UInt8) == UInt16(0) + @test _fixed_slot_bit(UInt16) == UInt16(0) + @test _fixed_slot_bit(Float16) == UInt16(0) + @test _fixed_slot_bit(String) == UInt16(0) + + # All 8 bits are unique (no collisions) + bits = [_fixed_slot_bit(T) for T in (Float64, Float32, Int64, Int32, ComplexF64, ComplexF32, Bool, Bit)] + @test length(unique(bits)) == 8 + @test all(b -> b != UInt16(0), bits) + end + + @testset "Typed _mark_untracked!: fixed-slot types set mask bits" begin + using AdaptiveArrayPools: _mark_untracked!, _fixed_slot_bit + + pool = AdaptiveArrayPool() + checkpoint!(pool) # depth=2 + + # Mark Float64 untracked + _mark_untracked!(pool, Float64) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + @test pool._untracked_has_others[2] == false + + # Mark Float32 additionally — bits accumulate + _mark_untracked!(pool, Float32) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) | _fixed_slot_bit(Float32) + @test pool._untracked_has_others[2] == false + + # Mark Float64 again — idempotent + _mark_untracked!(pool, Float64) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) | _fixed_slot_bit(Float32) + + rewind!(pool) + end + + @testset "Typed _mark_untracked!: non-fixed-slot types set has_others" begin + using AdaptiveArrayPools: _mark_untracked!, _fixed_slot_bit + + pool = AdaptiveArrayPool() + checkpoint!(pool) # depth=2 + + # Mark UInt8 (not a fixed slot) + _mark_untracked!(pool, UInt8) + @test pool._untracked_fixed_masks[2] == UInt16(0) # mask unchanged + @test pool._untracked_has_others[2] == true + + rewind!(pool) + end + + @testset "Typed _mark_untracked!: mixed fixed + others" begin + using AdaptiveArrayPools: _mark_untracked!, _fixed_slot_bit + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + _mark_untracked!(pool, Float64) + _mark_untracked!(pool, UInt8) # others + _mark_untracked!(pool, Int64) + + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) | _fixed_slot_bit(Int64) + @test pool._untracked_has_others[2] == true + + rewind!(pool) + end + + @testset "Typed _mark_untracked!: nested depth isolation" begin + using AdaptiveArrayPools: _mark_untracked!, _fixed_slot_bit + + pool = AdaptiveArrayPool() + + # Depth 2 + checkpoint!(pool) + _mark_untracked!(pool, Float64) + + # Depth 3 + checkpoint!(pool) + _mark_untracked!(pool, Int32) + + # Depth 3 has only Int32 + @test pool._untracked_fixed_masks[3] == _fixed_slot_bit(Int32) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + + # Depth 1 (sentinel) untouched + @test pool._untracked_fixed_masks[1] == UInt16(0) + + rewind!(pool) + rewind!(pool) + end + + @testset "Public acquire! sets typed bitmask" begin + using AdaptiveArrayPools: _fixed_slot_bit + + pool = AdaptiveArrayPool() + checkpoint!(pool) # depth=2 + + # acquire! outside @with_pool calls _mark_untracked!(pool, T) + acquire!(pool, Float64, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + + acquire!(pool, Int64, 5) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) | _fixed_slot_bit(Int64) + + rewind!(pool) + end + + @testset "Public unsafe_acquire! sets typed bitmask" begin + using AdaptiveArrayPools: _fixed_slot_bit + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + unsafe_acquire!(pool, Float32, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float32) + + rewind!(pool) + end + + @testset "Convenience functions set typed bitmask" begin + using AdaptiveArrayPools: _fixed_slot_bit, Bit + + pool = AdaptiveArrayPool() + + # zeros! with explicit type + checkpoint!(pool) + zeros!(pool, Float64, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + rewind!(pool) + + # zeros! without type → default_eltype → Float64 + checkpoint!(pool) + zeros!(pool, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + rewind!(pool) + + # ones! with type + checkpoint!(pool) + ones!(pool, Int32, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Int32) + rewind!(pool) + + # trues! → Bit type + checkpoint!(pool) + trues!(pool, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Bit) + rewind!(pool) + + # falses! → Bit type + checkpoint!(pool) + falses!(pool, 10) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Bit) + rewind!(pool) + + # similar! with template array + checkpoint!(pool) + similar!(pool, rand(Float32, 5)) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float32) + rewind!(pool) + end + + @testset "Convenience functions: non-fixed-slot type sets has_others" begin + pool = AdaptiveArrayPool() + + checkpoint!(pool) + zeros!(pool, UInt8, 10) + @test pool._untracked_has_others[2] == true + @test pool._untracked_fixed_masks[2] == UInt16(0) + rewind!(pool) + end + + # ================================================================== + # Phase 3: _tracked_mask_for_types and _can_use_typed_path + # ================================================================== + + @testset "_tracked_mask_for_types: computes correct mask" begin + using AdaptiveArrayPools: _tracked_mask_for_types + + # No args → zero mask + @test _tracked_mask_for_types() == UInt16(0) + + # Single types + @test _tracked_mask_for_types(Float64) == _fixed_slot_bit(Float64) + @test _tracked_mask_for_types(Float32) == _fixed_slot_bit(Float32) + @test _tracked_mask_for_types(Bit) == _fixed_slot_bit(Bit) + + # Multiple types → OR combination + @test _tracked_mask_for_types(Float64, Float32) == (_fixed_slot_bit(Float64) | _fixed_slot_bit(Float32)) + @test _tracked_mask_for_types(Float64, Int32) == (_fixed_slot_bit(Float64) | _fixed_slot_bit(Int32)) + + # All 8 fixed types + all_mask = _tracked_mask_for_types(Float64, Float32, Int64, Int32, ComplexF64, ComplexF32, Bool, Bit) + expected = UInt16(0) + for T in (Float64, Float32, Int64, Int32, ComplexF64, ComplexF32, Bool, Bit) + expected |= _fixed_slot_bit(T) + end + @test all_mask == expected + + # Non-fixed-slot types contribute UInt16(0) + @test _tracked_mask_for_types(UInt8) == UInt16(0) + @test _tracked_mask_for_types(Float64, UInt8) == _fixed_slot_bit(Float64) + + # Duplicate types → idempotent + @test _tracked_mask_for_types(Float64, Float64) == _fixed_slot_bit(Float64) + end + + @testset "_can_use_typed_path: truth table" begin + using AdaptiveArrayPools: _can_use_typed_path, _tracked_mask_for_types + + pool = AdaptiveArrayPool() + checkpoint!(pool) # depth = 2 + + # Case 1: no untracked at all → typed path OK + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) == true + + # Case 2: untracked Float64, tracked includes Float64 → subset → OK + pool._untracked_fixed_masks[2] = _fixed_slot_bit(Float64) + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) == true + + # Case 3: untracked Float64, tracked is Float32 only → NOT subset → full + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float32)) == false + + # Case 4: untracked Float64|Float32, tracked Float64 only → partial → full + pool._untracked_fixed_masks[2] = _fixed_slot_bit(Float64) | _fixed_slot_bit(Float32) + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) == false + + # Case 5: untracked Float64|Float32, tracked Float64|Float32 → exact match → OK + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64, Float32)) == true + + # Case 6: untracked Float64 + has_others → always full + pool._untracked_fixed_masks[2] = _fixed_slot_bit(Float64) + pool._untracked_has_others[2] = true + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) == false + + # Case 7: no fixed untracked but has_others → always full + pool._untracked_fixed_masks[2] = UInt16(0) + pool._untracked_has_others[2] = true + @test _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) == false + + rewind!(pool) + end + + # ================================================================== + # Phase 3: End-to-end runtime scenarios + # ================================================================== + + @testset "Scenario A: typed rewind preserved when untracked ⊆ tracked" begin + # Helper function acquires Float64 OUTSIDE @with_pool → untracked + function _scenario_a_helper!(pool) + acquire!(pool, Float64, 5) + end + + pool = AdaptiveArrayPool() + checkpoint!(pool) # outer scope + + # Macro scope uses Float64 → tracked set = {Float64} + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + _scenario_a_helper!(pool) # untracked Float64 → subset of tracked + end + + # Pool state should be correct after rewind + @test pool.float64.n_active == 0 # all rewound (outer had no acquires) + rewind!(pool) + end + + @testset "Scenario B: full rewind when untracked NOT ⊆ tracked" begin + # Helper acquires Float32 while @with_pool only tracks Float64 + function _scenario_b_helper!(pool) + acquire!(pool, Float32, 5) + end + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + _scenario_b_helper!(pool) # untracked Float32 → NOT subset of {Float64} + end + + # Both types should be correctly rewound + @test pool.float64.n_active == 0 + @test pool.float32.n_active == 0 + rewind!(pool) + end + + @testset "Scenario C: others triggers full rewind" begin + # Helper acquires UInt8 (non-fixed-slot) → has_others = true + function _scenario_c_helper!(pool) + acquire!(pool, UInt8, 5) + end + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + _scenario_c_helper!(pool) # untracked UInt8 → has_others → full + end + + @test pool.float64.n_active == 0 + @test get_typed_pool!(pool, UInt8).n_active == 0 + rewind!(pool) + end + + @testset "Scenario D: nested checkpoint with parent untracked ⊆ child tracked" begin + # Helper acquires Float64 OUTSIDE @with_pool → untracked + function _scenario_d_helper!(pool) + acquire!(pool, Float64, 3) + end + + # Inner scope as function to avoid Julia scoping conflict: + # nested `local pool` in same try-block soft scope causes UndefVarError + function _scenario_d_inner!() + @with_pool pool begin + v2 = acquire!(pool, Float64, 5) + v2 .= 2.0 + @test pool.float64.n_active >= 3 # v1 + helper + v2 + end + end + + pool = AdaptiveArrayPool() + + # Parent @with_pool uses Float64 + @with_pool pool begin + v1 = acquire!(pool, Float64, 10) + v1 .= 1.0 + _scenario_d_helper!(pool) # marks untracked Float64 in parent scope + + # Nested @with_pool also uses Float64 → can use typed checkpoint + _scenario_d_inner!() + + # After nested rewind, v1 and helper should still be active + @test v1[1] == 1.0 # v1 still valid + end + + @test pool.float64.n_active == 0 + end + + @testset "Scenario E: nested checkpoint with parent untracked has_others → full" begin + # Helper acquires UInt8 (non-fixed-slot) → has_others = true + function _scenario_e_helper!(pool) + acquire!(pool, UInt8, 3) + end + + # Inner scope as function (avoids Julia scoping conflict with same-name local) + function _scenario_e_inner!() + @with_pool pool begin + v2 = acquire!(pool, Float64, 5) + v2 .= 2.0 + end + end + + pool = AdaptiveArrayPool() + + @with_pool pool begin + v1 = acquire!(pool, Float64, 10) + v1 .= 1.0 + _scenario_e_helper!(pool) # marks has_others in parent scope + + # Nested @with_pool: parent had has_others → must use full checkpoint + _scenario_e_inner!() + + @test v1[1] == 1.0 # v1 still valid after nested rewind + end + + @test pool.float64.n_active == 0 + @test get_typed_pool!(pool, UInt8).n_active == 0 + end + + # ================================================================== + # Phase 4: Legacy _untracked_flags removal verification + # ================================================================== + @testset "Phase 4: _untracked_flags field removed from AdaptiveArrayPool" begin + # The legacy boolean _untracked_flags field has been replaced by + # bitmask-based tracking (_untracked_fixed_masks + _untracked_has_others). + # Verify it no longer exists as a struct field. + @test !(:_untracked_flags in fieldnames(AdaptiveArrayPool)) + + # Verify the bitmask fields ARE present (they are the replacement) + @test :_untracked_fixed_masks in fieldnames(AdaptiveArrayPool) + @test :_untracked_has_others in fieldnames(AdaptiveArrayPool) + end + + @testset "Phase 4: bitmask stacks have no stale state after lifecycle ops" begin + pool = AdaptiveArrayPool() + + # Initial sentinel state + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + + # Checkpoint → mark → rewind cycle leaves no stale bits + checkpoint!(pool) + _mark_untracked!(pool, Float64) + @test pool._untracked_fixed_masks[2] == _fixed_slot_bit(Float64) + rewind!(pool) + @test pool._untracked_fixed_masks == [UInt16(0)] # back to sentinel + @test pool._untracked_has_others == [false] + + # Nested checkpoint → mark others → rewind cleans up + checkpoint!(pool) # depth 2 + checkpoint!(pool) # depth 3 + _mark_untracked!(pool, UInt8) # others at depth 3 + @test pool._untracked_has_others[3] == true + rewind!(pool) # back to depth 2 + @test length(pool._untracked_has_others) == 2 + @test pool._untracked_has_others[2] == false # depth 2 clean + rewind!(pool) # back to depth 1 + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + + # reset! restores sentinel state after deep nesting + checkpoint!(pool) + checkpoint!(pool) + _mark_untracked!(pool, Float32) + _mark_untracked!(pool, Int64) + reset!(pool) + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + @test pool._current_depth == 1 + + # empty! also restores sentinel state + checkpoint!(pool) + _mark_untracked!(pool, ComplexF64) + _mark_untracked!(pool, UInt16) + empty!(pool) + @test pool._untracked_fixed_masks == [UInt16(0)] + @test pool._untracked_has_others == [false] + @test pool._current_depth == 1 + end + end # State Management \ No newline at end of file