Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/DomainBuffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,19 @@ Get the set of items stored in `db` or `dbs[domain]`
"""
getset(b::DomainBuffers, domain) = getset(b[domain])

struct DomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer
struct DomainBuffer{I,B,SV<:StateVariables,SDH<:SubDofHandler} <: AbstractDomainBuffer
set::Vector{I}
itembuffer::B
states::StateVariables{S}
states::SV
sdh::SDH
end

struct ThreadedDomainBuffer{I,B,S,SDH<:SubDofHandler} <: AbstractDomainBuffer
struct ThreadedDomainBuffer{I,B,SV<:StateVariables,SDH<:SubDofHandler} <: AbstractDomainBuffer
chunks::Vector{Vector{Vector{I}}} # I=Int (cell), I=FacetIndex (facet), or
set::Vector{I} # I=NTuple{2,FacetIndex} (interface)
num_tasks::Int
itembuffer::TaskLocals{B,B} # cell, facet, or interface buffer
states::StateVariables{S}
states::SV
sdh::SDH
end
function ThreadedDomainBuffer(set, itembuffer::AbstractItemBuffer, states::StateVariables, sdh::SubDofHandler, colors_or_chunks=nothing; num_tasks = Threads.nthreads())
Expand Down
1 change: 1 addition & 0 deletions src/FerriteAssembly.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module FerriteAssembly
using Ferrite, ForwardDiff
using Ferrite.CollectionsOfViews: ArrayOfVectorViews
using ConstructionBase: setproperties

include("Multithreading/TaskLocals.jl") # Task-local storage model
Expand Down
16 changes: 8 additions & 8 deletions src/Utils/MaterialModelsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import MaterialModelsBase as MMB

"""
FerriteAssembly.element_routine!(
Ke, re, state::Vector{<:MMB.AbstractMaterialState}, ae,
Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae,
m::MMB.AbstractMaterial, cv::AbstractCellValues, buffer)

Solve the weak form
Expand All @@ -16,13 +16,13 @@ where ``\\sigma`` is calculated with the `material_response` function from
Note that `create_cell_state` is already implemented for `<:AbstractMaterial`.
"""
function FerriteAssembly.element_routine!(
Ke, re, state::Vector{<:MMB.AbstractMaterialState},
Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
return mechanical_element_routine!(MMB.get_tensorbase(material), Ke, re, state, ae, material, cellvalues, buffer)
end

function mechanical_element_routine!(::Type{<:SymmetricTensor{2}},
Ke, re, state::Vector{<:MMB.AbstractMaterialState},
Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
cache = FerriteAssembly.get_user_cache(buffer)
Δt = FerriteAssembly.get_time_increment(buffer)
Expand All @@ -47,7 +47,7 @@ function mechanical_element_routine!(::Type{<:SymmetricTensor{2}},
end

function mechanical_element_routine!(::Type{<:Tensor{2}},
Ke, re, state::Vector{<:MMB.AbstractMaterialState},
Ke, re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
cache = FerriteAssembly.get_user_cache(buffer)
Δt = FerriteAssembly.get_time_increment(buffer)
Expand All @@ -73,20 +73,20 @@ end

"""
FerriteAssembly.element_residual!(
re, state::Vector{<:MMB.AbstractMaterialState}, ae,
re, state::AbstractVector{<:MMB.AbstractMaterialState}, ae,
m::MMB.AbstractMaterial, cv::AbstractCellValues, buffer)

The `element_residual!` implementation corresponding to the `element_routine!` implementation
for a `MaterialModelsBase.AbstractMaterial`
"""
function FerriteAssembly.element_residual!(
re, state::Vector{<:MMB.AbstractMaterialState},
re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
return mechanical_element_residual!(MMB.get_tensorbase(material), re, state, ae, material, cellvalues, buffer)
end

function mechanical_element_residual!(::Type{<:SymmetricTensor{2}},
re, state::Vector{<:MMB.AbstractMaterialState},
re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
cache = FerriteAssembly.get_user_cache(buffer)
Δt = FerriteAssembly.get_time_increment(buffer)
Expand All @@ -105,7 +105,7 @@ function mechanical_element_residual!(::Type{<:SymmetricTensor{2}},
end

function mechanical_element_residual!(::Type{<:Tensor{2}},
re, state::Vector{<:MMB.AbstractMaterialState},
re, state::AbstractVector{<:MMB.AbstractMaterialState},
ae, material::MMB.AbstractMaterial, cellvalues::AbstractCellValues, buffer)
cache = FerriteAssembly.get_user_cache(buffer)
Δt = FerriteAssembly.get_time_increment(buffer)
Expand Down
2 changes: 0 additions & 2 deletions src/Workers/QuadPointEvaluator.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using Ferrite.CollectionsOfViews: ArrayOfVectorViews

"""
QuadPointEvaluator{VT}(db::Union{DomainBuffer, DomainBuffers}, f::Function)

Expand Down
9 changes: 4 additions & 5 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function setup_domainbuffer(domain::DomainSpec; threading=Val(false), kwargs...)
end

create_states(domain::DomainSpec{Int}, a) = create_states(domain.sdh, domain.material, domain.fe_values, a, domain.set, create_dofrange(domain.sdh))
create_states(::DomainSpec{FacetIndex}, ::Any) = Dict{Int,Nothing}()
create_states(::DomainSpec{FacetIndex}, ::Any) = (s = Dict{Int,Nothing}(); StateVariables(Int[], s, s))

function setup_itembuffer(adb, domain::DomainSpec{FacetIndex}, args...)
dofrange = create_dofrange(domain.sdh)
Expand All @@ -120,10 +120,9 @@ function setup_itembuffer(adb, domain::DomainSpec{Int}, states)
end

function _setup_domainbuffer(threaded, domain; a=nothing, autodiffbuffer=Val(false), kwargs...)
new_states = create_states(domain, a)
old_states = create_states(domain, a)
itembuffer = setup_itembuffer(autodiffbuffer, domain, new_states)
return _setup_domainbuffer(threaded, domain.set, itembuffer, StateVariables(old_states, new_states), domain.sdh, domain.colors_or_chunks; kwargs...)
statevars = create_states(domain, a)
itembuffer = setup_itembuffer(autodiffbuffer, domain, statevars.old)
return _setup_domainbuffer(threaded, domain.set, itembuffer, statevars, domain.sdh, domain.colors_or_chunks; kwargs...)
end

# Type-unstable switch
Expand Down
66 changes: 56 additions & 10 deletions src/states.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,57 @@
# Minimal interface for a vector, storage format will probably be updated later.
mutable struct StateVector{SV}
vals::Dict{Int, SV}
mutable struct StateVector{SV, VV <: AbstractVector{SV}}
vals::VV
inds::Vector{Int} # Can as an optimization be shared between all `StateVectors` (also across domains)
end
Base.getindex(s::StateVector, cellnum::Int) = s.vals[cellnum]
Base.setindex!(s::StateVector, v, cellnum::Int) = setindex!(s.vals, v, cellnum)
Base.getindex(s::StateVector, cellnum::Int) = s.vals[s.inds[cellnum]]
Base.setindex!(s::StateVector, v, cellnum::Int) = setindex!(s.vals, v, s.inds[cellnum])
Base.:(==)(a::StateVector, b::StateVector) = (a.vals == b.vals)
function Base.iterate(s::StateVector, i::Int = 1)
i > length(s.vals) && return nothing
return s.vals[i], i + 1
end

struct StateVariables{SV}
old::StateVector{SV} # Rule: Referenced during assembly, not changed (ever)
new::StateVector{SV} # Rule: Updated during assembly, not referenced (before updated)
struct StateVariables{SV, VV}
old::StateVector{SV, VV} # Rule: Referenced during assembly, not changed (ever)
new::StateVector{SV, VV} # Rule: Updated during assembly, not referenced (before updated)
end
function StateVariables(inds::Vector{Int}, old::Dict{K, SV}, new::Dict{K, SV}) where {K, T, SV <: Vector{T}}
# if eltype(old) isa Vector => ArrayOfVectorViews
# else => Vector{eltype(old)}
num = length(old)
num_total = sum(length, values(old))
old_data = Vector{T}(undef, num_total)
new_data = Vector{T}(undef, num_total)
indices = Vector{Int}(undef, num + 1)
i = 1
j = 1
for key in sort(collect(keys(old)))
indices[i] = j
inds[key] = i
for (oldval, newval) in zip(old[key], new[key])
old_data[j] = oldval
new_data[j] = newval
j += 1
end
i += 1
end
indices[i] = j
oldvals = ArrayOfVectorViews(indices, old_data, LinearIndices((num,)))
newvals = ArrayOfVectorViews(indices, new_data, LinearIndices((num,)))
StateVariables(StateVector(oldvals, inds), StateVector(newvals, inds))
end
function StateVariables(inds::Vector{Int}, old::Dict{K, SV}, new::Dict{K, SV}) where {K, SV}
oldvals = Vector{SV}(undef, length(old))
newvals = Vector{SV}(undef, length(new))
i = 1
for key in sort(collect(keys(old)))
oldvals[i] = old[key]
newvals[i] = new[key]
inds[key] = i
i += 1
end
StateVariables(StateVector(oldvals, inds), StateVector(newvals, inds))
end
StateVariables(old::Dict, new::Dict) = StateVariables(StateVector(old), StateVector(new))

function update_states!(sv::StateVariables)
tmp = sv.old.vals
Expand Down Expand Up @@ -66,7 +107,12 @@ define the [`create_cell_state`](@ref) function for their `material` (and corres
"""
function create_states(sdh::SubDofHandler, material, cellvalues, a, cellset, dofrange)
ae = zeros(ndofs_per_cell(sdh))
coords = getcoordinates(_getgrid(sdh), first(cellset))
grid = _getgrid(sdh)
coords = getcoordinates(grid, first(cellset))
dofs = zeros(Int, ndofs_per_cell(sdh))
return Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset)
# Could make construction more efficient by doing this when creating the ArrayOfVectorViews
old = Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset)
new = Dict(cellnr => _create_cell_state(coords, dofs, material, cellvalues, a, ae, dofrange, sdh, cellnr) for cellnr in cellset)
inds = zeros(Int, getncells(grid)) # Could be moved out and shared between all domains...
return StateVariables(inds, old, new)
end
10 changes: 6 additions & 4 deletions test/quadpoint_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@
foo(::QEMat{4}, u, ∇u, qp_state) = 3 * qp_state[2]
qe = QuadPointEvaluator{Float64}(db, foo)
work!(qe, db)
for (i, s) in states["left"].vals # TODO: Using internals here
@test qe.data[i] ≈ 3 * s
for cellnr in FerriteAssembly.getset(db, "left")
s = states["left"][cellnr]
@test qe.data[cellnr] ≈ 3 * s
end
for (i, s) in states["right"].vals # TODO: Using internals here
@test qe.data[i] ≈ 3 * last.(s)
for cellnr in FerriteAssembly.getset(db, "right")
s = states["right"][cellnr]
@test qe.data[cellnr] ≈ 3 * last.(s)
@test all(first.(s) .≥ 0)
@test all(last.(s) .≤ 0)
end
Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import MaterialModelsBase as MMB
import MechanicalMaterialModels as MMM
using Logging

# Sometimes calling `@allocated f(args...)` allocates when called in global scope,
# calling inside a function using local variables solves this
function get_allocations(f::F, args::Vararg{N}) where {F, N}
@allocated f(args...)
end

include("replacements.jl")
include("states.jl")
include("threading_utils.jl")
Expand Down Expand Up @@ -45,8 +51,8 @@ include("errors.jl")
@test FerriteAssembly.get_dofhandler(buffer_threaded) === dh
@test FerriteAssembly.get_dofhandler(buffers) === dh

@test isa(FerriteAssembly.get_state(buffer, 1), Vector{Nothing})
@test isa(FerriteAssembly.get_old_state(buffer, 1), Vector{Nothing})
@test isa(FerriteAssembly.get_state(buffer, 1), AbstractVector{Nothing})
@test isa(FerriteAssembly.get_old_state(buffer, 1), AbstractVector{Nothing})
@test length(FerriteAssembly.get_state(buffer, 1)) == getnquadpoints(cv)
@test length(FerriteAssembly.get_old_state(buffer, 1)) == getnquadpoints(cv)

Expand Down
28 changes: 13 additions & 15 deletions test/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module TestStateModule
quadnr::Int
end
FerriteAssembly.create_cell_state(::MatA, cv, args...) = [StateA(-1, 0) for _ in 1:getnquadpoints(cv)]
function FerriteAssembly.element_residual!(re, states::Vector{StateA}, ae, ::MatA, cv, buffer)
function FerriteAssembly.element_residual!(re, states::AbstractVector{StateA}, ae, ::MatA, cv, buffer)
cellnr = cellid(buffer)
for i in 1:getnquadpoints(cv)
states[i] = StateA(cellnr, i)
Expand All @@ -33,7 +33,7 @@ module TestStateModule
counter::Int
end
FerriteAssembly.create_cell_state(::MatC, cv, args...) = [StateC(0) for _ in 1:getnquadpoints(cv)]
function FerriteAssembly.element_residual!(re, states::Vector{StateC}, ae, ::MatC, cv, buffer)
function FerriteAssembly.element_residual!(re, states::AbstractVector{StateC}, ae, ::MatC, cv, buffer)
old_states = FerriteAssembly.get_old_state(buffer)
for i in 1:getnquadpoints(cv)
states[i] = StateC(old_states[i].counter + 1)
Expand Down Expand Up @@ -69,7 +69,7 @@ end
buffer = setup_domainbuffer(DomainSpec(dh, MatA(), cv))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, FerriteAssembly.StateVector{Vector{StateA}})
@test isa(old_states, FerriteAssembly.StateVector{<:AbstractVector{StateA}})
@test old_states == states
@test old_states[1] == [StateA(-1, 0) for _ in 1:getnquadpoints(cv)]
work!(r_assembler, buffer)
Expand All @@ -83,8 +83,8 @@ end
@test old_states == states_dc # Correctly updated values
states[1][1] = StateA(0,0)
@test old_states[1][1] == StateA(1,1) # But not aliased
allocs = @allocated update_states!(container)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatA fulfills this)
# Vector{T} where isbitstype(T) should not allocate (MatA fulfills this)
@test get_allocations(update_states!, container) == 0
end

# MatB (not bitstype)
Expand All @@ -111,15 +111,15 @@ end
x_values = [spatial_coordinate(cv, i, coords) for i in 1:getnquadpoints(cv)]
states[cellnr] = StateB(0, -x_values)
@test old_states[cellnr] == StateB(cellnr, x_values) # But not aliased
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where !isbitstype(T) should no longer allocate
# Vector{T} where !isbitstype(T) should no longer allocate
@test get_allocations(update_states!, buffer) == 0

# MatC (accumulation), using threading as well
colors = create_coloring(grid)
buffer = setup_domainbuffer(DomainSpec(dh, MatC(), cv; colors=colors))
states = FerriteAssembly.get_state(buffer)
old_states = FerriteAssembly.get_old_state(buffer)
@test isa(old_states, FerriteAssembly.StateVector{Vector{StateC}})
@test isa(old_states, FerriteAssembly.StateVector{<:AbstractVector{StateC}})
@test old_states == states
@test old_states[1][1] == StateC(0)
work!(kr_assembler, buffer)
Expand All @@ -134,8 +134,8 @@ end
for cellnr in 1:getncells(grid)
@test states[cellnr][2] == StateC(2) # Check that all are updated
end
allocs = @allocated update_states!(buffer)
@test allocs == 0 # Vector{T} where isbitstype(T) should not allocate (MatC fulfills this)
# Vector{T} where isbitstype(T) should not allocate (MatC fulfills this)
@test get_allocations(update_states!, buffer) == 0
end
end

Expand All @@ -145,15 +145,13 @@ end
# Smoke-test of update_states! for nothing states (and check no allocations)
cv = CellValues(QuadratureRule{RefTriangle}(2), ip)
buffer = setup_domainbuffer(DomainSpec(dh, nothing, cv))
@test isa(FerriteAssembly.get_state(buffer), FerriteAssembly.StateVector{Vector{Nothing}})
@test isa(FerriteAssembly.get_state(buffer), FerriteAssembly.StateVector{<:AbstractVector{Nothing}})
update_states!(buffer) # Compile
allocs = @allocated update_states!(buffer)
@test allocs == 0
@test get_allocations(update_states!, buffer) == 0

gda = DomainSpec(dh, nothing, cv; set=1:getncells(dh.grid)÷2)
gdb = DomainSpec(dh, nothing, cv; set=setdiff!(Set(1:getncells(dh.grid)), gda.set))
buffers = setup_domainbuffers(Dict("a"=>gda, "b"=>gdb))
update_states!(buffers) # Compile
allocs = @allocated update_states!(buffers)
@test allocs == 0
@test get_allocations(update_states!, buffer) == 0
end
Loading