From 35fe94f406b232475cb1930860c83c3b40e25bcd Mon Sep 17 00:00:00 2001 From: Joey Date: Fri, 10 Jan 2025 10:42:58 -0500 Subject: [PATCH] Generic naming support --- src/caches/abstractbeliefpropagationcache.jl | 8 +- src/caches/boundarympscache.jl | 169 ++++++++----------- src/caches/boundarympscacheutils.jl | 4 +- test/test_boundarymps.jl | 4 +- 4 files changed, 84 insertions(+), 101 deletions(-) diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl index 36700fcf..d010043d 100644 --- a/src/caches/abstractbeliefpropagationcache.jl +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -190,14 +190,16 @@ Compute message tensor as product of incoming mts and local state function updated_message( bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; - message_update=default_message_update, - message_update_kwargs=(;), + message_update_function=default_message_update, + message_update_function_kwargs=(;), ) vertex = src(edge) incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)]) state = factor(bpc, vertex) - return message_update(ITensor[incoming_ms; state]; message_update_kwargs...) + return message_update_function( + ITensor[incoming_ms; state]; message_update_function_kwargs... + ) end function update( diff --git a/src/caches/boundarympscache.jl b/src/caches/boundarympscache.jl index c4ab049e..fd69c5f9 100644 --- a/src/caches/boundarympscache.jl +++ b/src/caches/boundarympscache.jl @@ -9,16 +9,18 @@ using SplitApplyCombine: group using ITensors: commoninds, random_itensor using LinearAlgebra: pinv -struct BoundaryMPSCache{BPC,PG} <: AbstractBeliefPropagationCache +struct BoundaryMPSCache{BPC,PG,PP} <: AbstractBeliefPropagationCache bp_cache::BPC partitionedplanargraph::PG + partitionpair_partitionedges::PP maximum_virtual_dimension::Int64 end bp_cache(bmpsc::BoundaryMPSCache) = bmpsc.bp_cache partitionedplanargraph(bmpsc::BoundaryMPSCache) = bmpsc.partitionedplanargraph -maximum_virtual_dimension(bmpsc::BoundaryMPSCache) = bmpsc.maximum_virtual_dimension ppg(bmpsc) = partitionedplanargraph(bmpsc) +maximum_virtual_dimension(bmpsc::BoundaryMPSCache) = bmpsc.maximum_virtual_dimension +partitionpair_partitionedges(bmpsc::BoundaryMPSCache) = bmpsc.partitionpair_partitionedges planargraph(bmpsc::BoundaryMPSCache) = unpartitioned_graph(partitionedplanargraph(bmpsc)) function partitioned_tensornetwork(bmpsc::BoundaryMPSCache) @@ -38,16 +40,23 @@ end function default_message_update_kwargs(alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache) return (; niters=50, tolerance=1e-10) end -default_message_update_kwargs(alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache) = (;) +function default_message_update_kwargs( + alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache +) + return (; niters=3, tolerance=nothing) +end function Base.copy(bmpsc::BoundaryMPSCache) return BoundaryMPSCache( - copy(bp_cache(bmpsc)), copy(ppg(bmpsc)), maximum_virtual_dimension(bmpsc) + copy(bp_cache(bmpsc)), + copy(ppg(bmpsc)), + copy(partitionpair_partitionedges(bmpsc)), + maximum_virtual_dimension(bmpsc), ) end -function default_message(bmpsc::BoundaryMPSCache, args...; kwargs...) - return default_message(bp_cache(bmpsc), args...; kwargs...) +function default_message(bmpsc::BoundaryMPSCache, pe::PartitionEdge; kwargs...) + return default_message(bp_cache(bmpsc), pe::PartitionEdge; kwargs...) end function virtual_index_dimension( @@ -56,27 +65,20 @@ function virtual_index_dimension( pes = planargraph_partitionpair_partitionedges( bmpsc, planargraph_partitionpair(bmpsc, pe1) ) + if findfirst(x -> x == pe1, pes) > findfirst(x -> x == pe2, pes) - inds_above = reduce( - vcat, [linkinds(bmpsc, pe) for pe in partitionedges_above(bmpsc, pe2)] - ) - inds_below = reduce( - vcat, [linkinds(bmpsc, pe) for pe in partitionedges_below(bmpsc, pe1)] - ) + lower_pe, upper_pe = pe2, pe1 else - inds_above = reduce( - vcat, [linkinds(bmpsc, pe) for pe in partitionedges_above(bmpsc, pe1)] - ) - inds_below = reduce( - vcat, [linkinds(bmpsc, pe) for pe in partitionedges_below(bmpsc, pe2)] - ) + lower_pe, upper_pe = pe1, pe2 end + inds_above = reduce(vcat, linkinds.((bmpsc,), partitionedges_above(bmpsc, lower_pe))) + inds_below = reduce(vcat, linkinds.((bmpsc,), partitionedges_below(bmpsc, upper_pe))) return minimum(( prod(dim.(inds_above)), prod(dim.(inds_below)), maximum_virtual_dimension(bmpsc) )) end -function planargraph_vertices(bmpsc::BoundaryMPSCache, partition::Int64) +function planargraph_vertices(bmpsc::BoundaryMPSCache, partition) return vertices(ppg(bmpsc), PartitionVertex(partition)) end function planargraph_partition(bmpsc::BoundaryMPSCache, vertex) @@ -88,17 +90,26 @@ end function planargraph_partitionpair(bmpsc::BoundaryMPSCache, pe::PartitionEdge) return pair(partitionedge(ppg(bmpsc), parent(pe))) end +function planargraph_partitionpair_partitionedges( + bmpsc::BoundaryMPSCache, partition_pair::Pair +) + return partitionpair_partitionedges(bmpsc)[partition_pair] +end function BoundaryMPSCache( bpc::BeliefPropagationCache; grouping_function::Function=v -> first(v), + group_sorting_function::Function=v -> last(v), message_rank::Int64=1, ) bpc = insert_pseudo_planar_edges(bpc; grouping_function) planar_graph = partitioned_graph(bpc) vertex_groups = group(grouping_function, collect(vertices(planar_graph))) + vertex_groups = map(x -> sort(x; by=group_sorting_function), vertex_groups) ppg = PartitionedGraph(planar_graph, vertex_groups) - bmpsc = BoundaryMPSCache(bpc, ppg, message_rank) + partitionpairs = vcat(pair.(partitionedges(ppg)), reverse.(pair.(partitionedges(ppg)))) + pp_pe = Dictionary(partitionpairs, sorted_partitionedges.((ppg,), partitionpairs)) + bmpsc = BoundaryMPSCache(bpc, ppg, pp_pe, message_rank) return set_interpartition_messages(bmpsc) end @@ -106,26 +117,20 @@ function BoundaryMPSCache(tn::AbstractITensorNetwork, args...; kwargs...) return BoundaryMPSCache(BeliefPropagationCache(tn, args...); kwargs...) end -#Get all partitionedges within a column/row, ordered top to bottom -function planargraph_partitionedges(bmpsc::BoundaryMPSCache, partition::Int64) - vs = sort(planargraph_vertices(bmpsc, partition)) - return PartitionEdge.([vs[i] => vs[i + 1] for i in 1:(length(vs) - 1)]) -end - -#Get all partitionedges between the pair of neighboring partitions, sorted top to bottom -#TODO: Bring in line with NamedGraphs change -function planargraph_partitionpair_partitionedges( - bmpsc::BoundaryMPSCache, partitionpair::Pair -) - pg = planargraph(bmpsc) - src_vs, dst_vs = planargraph_vertices(bmpsc, first(partitionpair)), - planargraph_vertices(bmpsc, last(partitionpair)) - es = filter( - x -> !isempty(last(x)), - [src_v => intersect(neighbors(pg, src_v), dst_vs) for src_v in src_vs], +#Get all partitionedges between the pair of neighboring partitions, sorted +#by the position of the source in the pg +function sorted_partitionedges(pg::PartitionedGraph, partitionpair::Pair) + src_vs, dst_vs = vertices(pg, PartitionVertex(first(partitionpair))), + vertices(pg, PartitionVertex(last(partitionpair))) + es = reduce( + vcat, + [ + [src_v => dst_v for dst_v in intersect(neighbors(pg, src_v), dst_vs)] for + src_v in src_vs + ], ) - es = map(x -> first(x) => only(last(x)), es) - return sort(PartitionEdge.(NamedEdge.(es)); by=x -> src(parent(x))) + es = sort(NamedEdge.(es); by=x -> findfirst(isequal(src(x)), src_vs)) + return PartitionEdge.(es) end #Functions to get the parellel partitionedges sitting above and below a partitionedge @@ -159,8 +164,9 @@ end #Get the sequence of pairs partitionedges that need to be updated to move the MPS gauge from pe1 to pe2 function mps_gauge_update_sequence( - bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::PartitionEdge + bmpsc::BoundaryMPSCache, pe1::Union{Nothing,PartitionEdge}, pe2::PartitionEdge ) + isnothing(pe1) && return mps_gauge_update_sequence(bmpsc, pe2) ppgpe1, ppgpe2 = planargraph_partitionpair(bmpsc, pe1), planargraph_partitionpair(bmpsc, pe2) @assert ppgpe1 == ppgpe2 @@ -225,15 +231,18 @@ function switch_messages(bmpsc::BoundaryMPSCache, partitionpair::Pair) end #Update all messages tensors within a partition -function partition_update(bmpsc::BoundaryMPSCache, partition::Int64) - vs = sort(planargraph_vertices(bmpsc, partition)) - bmpsc = partition_update(bmpsc, first(vs)) - bmpsc = partition_update(bmpsc, last(vs)) +function partition_update(bmpsc::BoundaryMPSCache, partition) + vs = planargraph_vertices(bmpsc, partition) + bmpsc = partition_update(bmpsc, first(vs), last(vs)) + bmpsc = partition_update(bmpsc, last(vs), first(vs)) return bmpsc end function partition_update_sequence(bmpsc::BoundaryMPSCache, v1, v2) - return PartitionEdge.(a_star(ppg(bmpsc), v1, v2)) + isnothing(v1) && return partition_update_sequence(bmpsc, v2) + pv = planargraph_partition(bmpsc, v1) + g = subgraph(unpartitioned_graph(ppg(bmpsc)), planargraph_vertices(bmpsc, pv)) + return PartitionEdge.(a_star(g, v1, v2)) end function partition_update_sequence(bmpsc::BoundaryMPSCache, v) pv = planargraph_partition(bmpsc, v) @@ -247,7 +256,7 @@ function partition_update(bmpsc::BoundaryMPSCache, args...) Algorithm("SimpleBP"), bmpsc, partition_update_sequence(bmpsc, args...); - message_update_kwargs=(; normalize=false), + message_update_function_kwargs=(; normalize=false), ) end @@ -285,8 +294,8 @@ function gauge_step( m2, m2r = only(message(bmpsc, pe2)), only(message(bmpsc, reverse(pe2))) top_cind, bottom_cind = commonind(m1, m2), commonind(m1r, m2r) m1_siteinds, m2_siteinds = commoninds(m1, m1r), commoninds(m2, m2r) - top_ncind, bottom_ncind = setdiff(inds(m1), [m1_siteinds; top_cind]), - setdiff(inds(m1r), [m1_siteinds; bottom_cind]) + top_ncind = setdiff(inds(m1), [m1_siteinds; top_cind]) + bottom_ncind = setdiff(inds(m1r), [m1_siteinds; bottom_cind]) E = if isempty(top_ncind) m1 * m1r @@ -298,16 +307,14 @@ function gauge_step( S_sqrtinv = map_diag(x -> pinv(sqrt(x)), S) S_sqrt = map_diag(x -> sqrt(x), S) - set!(ms, pe1, ITensor[replaceind((m1 * dag(V)) * S_sqrtinv, commonind(U, S), top_cind)]) - set!( - ms, - reverse(pe1), - ITensor[replaceind((m1r * dag(U)) * S_sqrtinv, commonind(V, S), bottom_cind)], - ) - set!(ms, pe2, ITensor[replaceind((m2 * V) * S_sqrt, commonind(U, S), top_cind)]) - set!( - ms, reverse(pe2), ITensor[replaceind((m2r * U) * S_sqrt, commonind(V, S), bottom_cind)] - ) + m1 = replaceind((m1 * dag(V)) * S_sqrtinv, commonind(U, S), top_cind) + m1r = replaceind((m1r * dag(U)) * S_sqrtinv, commonind(V, S), bottom_cind) + m2 = replaceind((m2 * V) * S_sqrt, commonind(U, S), top_cind) + m2r = replaceind((m2r * U) * S_sqrt, commonind(V, S), bottom_cind) + set!(ms, pe1, ITensor[m1]) + set!(ms, reverse(pe1), ITensor[m1r]) + set!(ms, pe2, ITensor[m2]) + set!(ms, reverse(pe2), ITensor[m2r]) return bmpsc end @@ -352,19 +359,10 @@ function default_inserter( p_above, p_below = partitionedge_above(bmpsc, pe), partitionedge_below(bmpsc, pe) me = only(me) me_prev = only(message(bmpsc, pe)) - if !isnothing(p_above) - me = replaceind( - me, - commonind(me, only(message(bmpsc, reverse(p_above)))), - commonind(me_prev, only(message(bmpsc, p_above))), - ) - end - if !isnothing(p_below) - me = replaceind( - me, - commonind(me, only(message(bmpsc, reverse(p_below)))), - commonind(me_prev, only(message(bmpsc, p_below))), - ) + for pe in filter(x -> !isnothing(x), [p_above, p_below]) + ind1 = commonind(me, only(message(bmpsc, reverse(pe)))) + ind2 = commonind(me_prev, only(message(bmpsc, pe))) + me *= delta(ind1, ind2) end return set_message(bmpsc, pe, ITensor[me]) end @@ -372,32 +370,17 @@ end function default_updater( alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v ) - bmpsc = if !isnothing(prev_pe) - gauge(alg, bmpsc, reverse(prev_pe), reverse(update_pe)) - else - gauge(alg, bmpsc, reverse(update_pe)) - end - bmpsc = if !isnothing(prev_v) - partition_update(bmpsc, prev_v, cur_v) - else - partition_update(bmpsc, cur_v) - end + rev_prev_pe = isnothing(prev_pe) ? nothing : reverse(prev_pe) + bmpsc = gauge(alg, bmpsc, rev_prev_pe, reverse(update_pe)) + bmpsc = partition_update(bmpsc, prev_v, cur_v) return bmpsc end function default_updater( alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v ) - bmpsc = if !isnothing(prev_pe) - gauge(alg, bmpsc, prev_pe, update_pe) - else - gauge(alg, bmpsc, update_pe) - end - bmpsc = if !isnothing(prev_v) - partition_update(bmpsc, prev_v, cur_v) - else - partition_update(bmpsc, cur_v) - end + bmpsc = gauge(alg, bmpsc, prev_pe, update_pe) + bmpsc = partition_update(bmpsc, prev_v, cur_v) return bmpsc end @@ -460,9 +443,7 @@ function update( for update_pe in update_seq cur_v = parent(src(update_pe)) bmpsc = updater(alg, bmpsc, prev_pe, update_pe, prev_v, cur_v) - me = updated_message( - bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize) - ) + me = updated_message(bmpsc, update_pe; message_update_function_kwargs=(; normalize)) cf += costfunction(alg, bmpsc, update_pe, me) bmpsc = inserter(alg, bmpsc, update_pe, me) prev_v, prev_pe = cur_v, update_pe diff --git a/src/caches/boundarympscacheutils.jl b/src/caches/boundarympscacheutils.jl index 2c6758f5..b1b66acb 100644 --- a/src/caches/boundarympscacheutils.jl +++ b/src/caches/boundarympscacheutils.jl @@ -38,8 +38,8 @@ pair(pe::PartitionEdge) = parent(src(pe)) => parent(dst(pe)) #Return a sequence of pairs to go from item1 to item2 in an ordered_list function pair_sequence(ordered_list::Vector, item1, item2) - item1_pos, item2_pos = only(findall(x -> x == item1, ordered_list)), - only(findall(x -> x == item2, ordered_list)) + item1_pos, item2_pos = findfirst(x -> x == item1, ordered_list), + findfirst(x -> x == item2, ordered_list) item1_pos < item2_pos && return [ordered_list[i] => ordered_list[i + 1] for i in item1_pos:(item2_pos - 1)] return [ordered_list[i] => ordered_list[i - 1] for i in item1_pos:-1:(item2_pos + 1)] diff --git a/test/test_boundarymps.jl b/test/test_boundarymps.jl index 430ba1e6..0549e186 100644 --- a/test/test_boundarymps.jl +++ b/test/test_boundarymps.jl @@ -117,11 +117,11 @@ using LinearAlgebra: norm x -> abs(real(x)), A, first(inds(A)), last(inds(A)); ishermitian=true ) end - message_update_f = ms -> make_posdef.(default_message_update(ms)) + f = ms -> make_posdef.(default_message_update(ms)) ψ_vidal = VidalITensorNetwork( ψ; cache_update_kwargs=(; - maxiter=50, tol=1e-14, message_update_kwargs=(; message_update=message_update_f) + maxiter=50, tol=1e-14, message_update_kwargs=(; message_update_function=f) ), ) cache_ref = Ref{BeliefPropagationCache}()