diff --git a/src/cell.jl b/src/cell.jl index 2cf1b0c..78be534 100644 --- a/src/cell.jl +++ b/src/cell.jl @@ -94,8 +94,8 @@ end end end -function allcells(cell::Cell) - Channel() do c +function allcells(cell::Cell{Data, N, T, L}) where {Data, N, T, L} + Channel{Cell{Data, N, T, L}}() do c queue = [cell] while !isempty(queue) current = pop!(queue) @@ -107,8 +107,8 @@ function allcells(cell::Cell) end end -function allleaves(cell::Cell) - Channel() do c +function allleaves(cell::Cell{Data, N, T, L}) where {Data, N, T, L} + Channel{Cell{Data, N, T, L}}() do c for child in allcells(cell) if isleaf(child) put!(c, child) @@ -117,8 +117,8 @@ function allleaves(cell::Cell) end end -function allparents(cell::Cell) - Channel() do c +function allparents(cell::Cell{Data, N, T, L}) where {Data, N, T, L} + Channel{Cell{Data, N, T, L}}() do c queue = [cell] while !isempty(queue) current = pop!(queue) diff --git a/test/cell.jl b/test/cell.jl index 9a7a2d5..89dd846 100644 --- a/test/cell.jl +++ b/test/cell.jl @@ -63,3 +63,41 @@ end split!(root) @test length(root.children) == 8 end + +@testset "allcells, alleaves, allparents functions" begin + struct MyRefinery <: RegionTrees.AbstractRefinery + tolerance::Float64 + end + + function RegionTrees.needs_refinement(r::MyRefinery, cell) + maximum(cell.boundary.widths) > r.tolerance + end + + function RegionTrees.refine_data(r::MyRefinery, cell::Cell, indices) + boundary = child_boundary(cell, indices) + "child with widths: $(boundary.widths)" + end + + r = MyRefinery(0.5) + root = Cell(SVector(0., 0), SVector(1., 1), "root") + adaptivesampling!(root, r) + + ac = allcells(root) + @test typeof(ac) == Channel{Cell{String, 2, Float64, 4}} + acv = collect(ac) + @test typeof(acv) == Vector{Cell{String, 2, Float64, 4}} + @test size(acv, 1) == 5 + + al = allleaves(root) + @test typeof(al) == Channel{Cell{String, 2, Float64, 4}} + alv = collect(al) + @test typeof(alv) == Vector{Cell{String, 2, Float64, 4}} + @test size(alv, 1) == 4 + + leaf = alv[1] + ap = allparents(leaf) + @test typeof(ap) == Channel{Cell{String, 2, Float64, 4}} + apv = collect(ap) + @test typeof(apv) == Vector{Cell{String, 2, Float64, 4}} + @test size(apv, 1) == 1 +end