From ea4e350a99fbadd57f4869508ec31cbb90b1f11b Mon Sep 17 00:00:00 2001 From: JordiManyer Date: Fri, 13 Dec 2024 09:29:17 +1100 Subject: [PATCH] Added MappedDofBasis --- src/ReferenceFEs/Pullbacks.jl | 98 +++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 5 deletions(-) diff --git a/src/ReferenceFEs/Pullbacks.jl b/src/ReferenceFEs/Pullbacks.jl index f9f8ad81f..5e0a47a43 100644 --- a/src/ReferenceFEs/Pullbacks.jl +++ b/src/ReferenceFEs/Pullbacks.jl @@ -57,6 +57,18 @@ Represents the inverse of a pushforward map F*, defined as where - V̂ is a function space on the reference cell K̂ and - V is a function space on the physical cell K. + +# Note: + +Given a pushforward F*, we provide a default implementation for the inverse (F*)^-1 that +is not optimal (and I think wrong for nonlinear geometries???). + +For better performance, the user should overload the following methods for each +specific pushforward type: + +- `return_cache(k::InversePushforward{PF}, v_phys::Number, args...)` +- `evaluate!(cache, k::InversePushforward{PF}, v_phys::Number, args...)` + """ struct InversePushforward{PF} <: Map pushforward::PF @@ -98,6 +110,58 @@ function evaluate!( Broadcasting(Operation(k))(f_phys,args...) end +# MappedDofBasis + +""" + struct MappedDofBasis{T<:Dof,MT,BT} <: AbstractVector{T} + F :: MT + σ :: BT + args + end + +Represents η = σ∘F, evaluated as η(φ) = σ(F(φ,args...)) + + - σ : V* -> R is a dof basis + - F : W -> V is a map between function spaces + +Intended combinations would be: + +- σ : V* -> R dof basis in the physical domain and F* : V̂ -> V is a pushforward map. +- ̂σ : V̂* -> R dof basis in the reference domain and (F*)^-1 : V -> V̂ is an inverse pushforward map. + +""" +struct MappedDofBasis{T<:Dof,MT,BT,A} <: AbstractVector{T} + F :: MT + dofs :: BT + args :: A + + function MappedDofBasis(F::Map, dofs::AbstractVector{<:Dof}, args...) + T = eltype(dofs) + MT = typeof(F) + BT = typeof(dofs) + A = typeof(args) + new{T,MT,BT,A}(F,dofs,args) + end +end + +Base.size(b::MappedDofBasis) = size(b.dofs) + +# Technically wrong, but the individual dofs are fake anyway +Base.getindex(b::MappedDofBasis, i) = getindex(b.dofs,i) + +function Arrays.return_cache(b::MappedDofBasis, fields) + f_cache = return_cache(b.F,fields,b.args...) + ffields = evaluate!(f_cache,b.F,fields,b.args...) + dofs_cache = return_cache(b.dofs,ffields) + return f_cache, dofs_cache +end + +function Arrays.evaluate!(cache, b::MappedDofBasis, fields) + f_cache, dofs_cache = cache + ffields = evaluate!(f_cache,b.F,fields,b.args...) + evaluate!(dofs_cache,b.dofs,ffields) +end + # Pullback """ @@ -131,8 +195,7 @@ end function evaluate!( cache, k::Pullback, σ_phys::AbstractVector{<:Dof}, args... ) - pf(f_ref) = evaluate(k.pushforward,f_ref,args...) # TODO: Can we avoid this? - return Arrays.OperationMap(σ_phys,pf) + return MappedDofBasis(k.pushforward,σ_phys,args...) end # InversePullback @@ -171,8 +234,7 @@ end function evaluate!( cache, k::InversePullback, σ_ref::AbstractVector{<:Dof}, args... ) - ipf(f_phys) = evaluate(inverse_map(k.pushforward),f_phys,args...) # TODO: Can we avoid this? - return Arrays.OperationMap(σ_ref,ipf) + return MappedDofBasis(inverse_map(k.pushforward),σ_ref,args...) end # ContraVariantPiolaMap @@ -184,7 +246,21 @@ function evaluate!( ) sign = (-1)^sign_flip idetJ = 1. / meas(Jt) - return (sign*v_ref)⋅(idetJ*Jt) + return sign * idetJ * (v_ref⋅Jt) +end + +function return_cache( + ::InversePushforward{ContraVariantPiolaMap}, v_phys::Number, Jt::Number, sign_flip::Bool +) + nothing +end + +function evaluate!( + cache, ::InversePushforward{ContraVariantPiolaMap}, v_phys::Number, Jt::Number, sign_flip::Bool +) + sign = (-1)^sign_flip + detJ = meas(Jt) + return sign * detJ * (inv(Jt)⋅v_phys) end # CoVariantPiolaMap @@ -197,3 +273,15 @@ function evaluate!( # we right-multiply to compute the gradient correctly return v_ref⋅transpose(inv(Jt)) end + +function return_cache( + ::InversePushforward{CoVariantPiolaMap}, v_phys::Number, Jt::Number +) + return nothing +end + +function evaluate!( + cache, ::InversePushforward{CoVariantPiolaMap}, v_phys::Number, Jt::Number +) + return transpose(Jt)⋅v_phys +end