diff --git a/ext/ITensorNetworksEinExprsExt/src/ITensorNetworksEinExprsExt.jl b/ext/ITensorNetworksEinExprsExt/src/ITensorNetworksEinExprsExt.jl index 2c37b2b6..5767be8e 100644 --- a/ext/ITensorNetworksEinExprsExt/src/ITensorNetworksEinExprsExt.jl +++ b/ext/ITensorNetworksEinExprsExt/src/ITensorNetworksEinExprsExt.jl @@ -10,30 +10,37 @@ using ITensorNetworks: contraction_sequence using EinExprs: EinExprs, EinExpr, einexpr, SizedEinExpr -function to_einexpr(tn::ITensorNetwork) +function to_einexpr(tn::AbstractITensorNetwork) IndexType = Any - VertexType = vertextype(tn) tensor_exprs = EinExpr{IndexType}[] - tensor_inds_to_vertex = Dict{Set{IndexType},VertexType}() inds_dims = Dict{IndexType,Int}() for v in vertices(tn) tensor_v = tn[v] inds_v = collect(inds(tensor_v)) push!(tensor_exprs, EinExpr{IndexType}(; head=inds_v)) - tensor_inds_to_vertex[Set(inds_v)] = key merge!(inds_dims, Dict(inds_v .=> size(tensor_v))) end externalinds_tn = collect(externalinds(tn)) - expr = SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims) + return SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims) +end + +function tensor_inds_to_vertex(tn::AbstractITensorNetwork) + mapping = Dict{Set{IndexType},VertexType}() + + for v in vertices(tn) + tensor_v = tn[v] + inds_v = collect(inds(tensor_v)) + mapping[Set(inds_v)] = v + end - return expr, tensor_inds_to_vertex + return mapping end function EinExprs.einexpr(tn::ITensorNetwork; optimizer::EinExprs.Optimizer) - expr, _ = to_einexpr(tn) + expr = to_einexpr(tn) return einexpr(optimizer, expr) end @@ -46,8 +53,8 @@ end function ITensorNetworks.contraction_sequence( ::Algorithm"einexpr", tn::ITensorNetwork{T}; optimizer=EinExprs.Exhaustive() ) - expr, tensor_inds_to_vertex = to_einexpr(tn) - return to_contraction_sequence(expr, tensor_inds_to_vertex) + expr = einexpr(tn; optimizer) + return to_contraction_sequence(expr, tensor_inds_to_vertex(tn)) end function to_contraction_sequence(expr, tensor_inds_to_vertex)