Skip to content

Commit

Permalink
Generic naming support
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 10, 2025
1 parent 19630aa commit 35fe94f
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 101 deletions.
8 changes: 5 additions & 3 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,16 @@ Compute message tensor as product of incoming mts and local state
function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
message_update=default_message_update,
message_update_kwargs=(;),
message_update_function=default_message_update,
message_update_function_kwargs=(;),
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bpc, vertex)

return message_update(ITensor[incoming_ms; state]; message_update_kwargs...)
return message_update_function(
ITensor[incoming_ms; state]; message_update_function_kwargs...
)
end

function update(
Expand Down
169 changes: 75 additions & 94 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ using SplitApplyCombine: group
using ITensors: commoninds, random_itensor
using LinearAlgebra: pinv

struct BoundaryMPSCache{BPC,PG} <: AbstractBeliefPropagationCache
struct BoundaryMPSCache{BPC,PG,PP} <: AbstractBeliefPropagationCache
bp_cache::BPC
partitionedplanargraph::PG
partitionpair_partitionedges::PP
maximum_virtual_dimension::Int64
end

bp_cache(bmpsc::BoundaryMPSCache) = bmpsc.bp_cache
partitionedplanargraph(bmpsc::BoundaryMPSCache) = bmpsc.partitionedplanargraph
maximum_virtual_dimension(bmpsc::BoundaryMPSCache) = bmpsc.maximum_virtual_dimension
ppg(bmpsc) = partitionedplanargraph(bmpsc)
maximum_virtual_dimension(bmpsc::BoundaryMPSCache) = bmpsc.maximum_virtual_dimension
partitionpair_partitionedges(bmpsc::BoundaryMPSCache) = bmpsc.partitionpair_partitionedges
planargraph(bmpsc::BoundaryMPSCache) = unpartitioned_graph(partitionedplanargraph(bmpsc))

function partitioned_tensornetwork(bmpsc::BoundaryMPSCache)
Expand All @@ -38,16 +40,23 @@ end
function default_message_update_kwargs(alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache)
return (; niters=50, tolerance=1e-10)
end
default_message_update_kwargs(alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache) = (;)
function default_message_update_kwargs(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache
)
return (; niters=3, tolerance=nothing)
end

function Base.copy(bmpsc::BoundaryMPSCache)
return BoundaryMPSCache(
copy(bp_cache(bmpsc)), copy(ppg(bmpsc)), maximum_virtual_dimension(bmpsc)
copy(bp_cache(bmpsc)),
copy(ppg(bmpsc)),
copy(partitionpair_partitionedges(bmpsc)),
maximum_virtual_dimension(bmpsc),
)
end

function default_message(bmpsc::BoundaryMPSCache, args...; kwargs...)
return default_message(bp_cache(bmpsc), args...; kwargs...)
function default_message(bmpsc::BoundaryMPSCache, pe::PartitionEdge; kwargs...)
return default_message(bp_cache(bmpsc), pe::PartitionEdge; kwargs...)
end

function virtual_index_dimension(
Expand All @@ -56,27 +65,20 @@ function virtual_index_dimension(
pes = planargraph_partitionpair_partitionedges(
bmpsc, planargraph_partitionpair(bmpsc, pe1)
)

if findfirst(x -> x == pe1, pes) > findfirst(x -> x == pe2, pes)
inds_above = reduce(
vcat, [linkinds(bmpsc, pe) for pe in partitionedges_above(bmpsc, pe2)]
)
inds_below = reduce(
vcat, [linkinds(bmpsc, pe) for pe in partitionedges_below(bmpsc, pe1)]
)
lower_pe, upper_pe = pe2, pe1
else
inds_above = reduce(
vcat, [linkinds(bmpsc, pe) for pe in partitionedges_above(bmpsc, pe1)]
)
inds_below = reduce(
vcat, [linkinds(bmpsc, pe) for pe in partitionedges_below(bmpsc, pe2)]
)
lower_pe, upper_pe = pe1, pe2
end
inds_above = reduce(vcat, linkinds.((bmpsc,), partitionedges_above(bmpsc, lower_pe)))
inds_below = reduce(vcat, linkinds.((bmpsc,), partitionedges_below(bmpsc, upper_pe)))
return minimum((
prod(dim.(inds_above)), prod(dim.(inds_below)), maximum_virtual_dimension(bmpsc)
))
end

function planargraph_vertices(bmpsc::BoundaryMPSCache, partition::Int64)
function planargraph_vertices(bmpsc::BoundaryMPSCache, partition)
return vertices(ppg(bmpsc), PartitionVertex(partition))
end
function planargraph_partition(bmpsc::BoundaryMPSCache, vertex)
Expand All @@ -88,44 +90,47 @@ end
function planargraph_partitionpair(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
return pair(partitionedge(ppg(bmpsc), parent(pe)))
end
function planargraph_partitionpair_partitionedges(
bmpsc::BoundaryMPSCache, partition_pair::Pair
)
return partitionpair_partitionedges(bmpsc)[partition_pair]
end

function BoundaryMPSCache(
bpc::BeliefPropagationCache;
grouping_function::Function=v -> first(v),
group_sorting_function::Function=v -> last(v),
message_rank::Int64=1,
)
bpc = insert_pseudo_planar_edges(bpc; grouping_function)
planar_graph = partitioned_graph(bpc)
vertex_groups = group(grouping_function, collect(vertices(planar_graph)))
vertex_groups = map(x -> sort(x; by=group_sorting_function), vertex_groups)
ppg = PartitionedGraph(planar_graph, vertex_groups)
bmpsc = BoundaryMPSCache(bpc, ppg, message_rank)
partitionpairs = vcat(pair.(partitionedges(ppg)), reverse.(pair.(partitionedges(ppg))))
pp_pe = Dictionary(partitionpairs, sorted_partitionedges.((ppg,), partitionpairs))
bmpsc = BoundaryMPSCache(bpc, ppg, pp_pe, message_rank)
return set_interpartition_messages(bmpsc)
end

function BoundaryMPSCache(tn::AbstractITensorNetwork, args...; kwargs...)
return BoundaryMPSCache(BeliefPropagationCache(tn, args...); kwargs...)
end

#Get all partitionedges within a column/row, ordered top to bottom
function planargraph_partitionedges(bmpsc::BoundaryMPSCache, partition::Int64)
vs = sort(planargraph_vertices(bmpsc, partition))
return PartitionEdge.([vs[i] => vs[i + 1] for i in 1:(length(vs) - 1)])
end

#Get all partitionedges between the pair of neighboring partitions, sorted top to bottom
#TODO: Bring in line with NamedGraphs change
function planargraph_partitionpair_partitionedges(
bmpsc::BoundaryMPSCache, partitionpair::Pair
)
pg = planargraph(bmpsc)
src_vs, dst_vs = planargraph_vertices(bmpsc, first(partitionpair)),
planargraph_vertices(bmpsc, last(partitionpair))
es = filter(
x -> !isempty(last(x)),
[src_v => intersect(neighbors(pg, src_v), dst_vs) for src_v in src_vs],
#Get all partitionedges between the pair of neighboring partitions, sorted
#by the position of the source in the pg
function sorted_partitionedges(pg::PartitionedGraph, partitionpair::Pair)
src_vs, dst_vs = vertices(pg, PartitionVertex(first(partitionpair))),
vertices(pg, PartitionVertex(last(partitionpair)))
es = reduce(
vcat,
[
[src_v => dst_v for dst_v in intersect(neighbors(pg, src_v), dst_vs)] for
src_v in src_vs
],
)
es = map(x -> first(x) => only(last(x)), es)
return sort(PartitionEdge.(NamedEdge.(es)); by=x -> src(parent(x)))
es = sort(NamedEdge.(es); by=x -> findfirst(isequal(src(x)), src_vs))
return PartitionEdge.(es)
end

#Functions to get the parellel partitionedges sitting above and below a partitionedge
Expand Down Expand Up @@ -159,8 +164,9 @@ end

#Get the sequence of pairs partitionedges that need to be updated to move the MPS gauge from pe1 to pe2
function mps_gauge_update_sequence(
bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::PartitionEdge
bmpsc::BoundaryMPSCache, pe1::Union{Nothing,PartitionEdge}, pe2::PartitionEdge
)
isnothing(pe1) && return mps_gauge_update_sequence(bmpsc, pe2)
ppgpe1, ppgpe2 = planargraph_partitionpair(bmpsc, pe1),
planargraph_partitionpair(bmpsc, pe2)
@assert ppgpe1 == ppgpe2
Expand Down Expand Up @@ -225,15 +231,18 @@ function switch_messages(bmpsc::BoundaryMPSCache, partitionpair::Pair)
end

#Update all messages tensors within a partition
function partition_update(bmpsc::BoundaryMPSCache, partition::Int64)
vs = sort(planargraph_vertices(bmpsc, partition))
bmpsc = partition_update(bmpsc, first(vs))
bmpsc = partition_update(bmpsc, last(vs))
function partition_update(bmpsc::BoundaryMPSCache, partition)
vs = planargraph_vertices(bmpsc, partition)
bmpsc = partition_update(bmpsc, first(vs), last(vs))
bmpsc = partition_update(bmpsc, last(vs), first(vs))
return bmpsc
end

function partition_update_sequence(bmpsc::BoundaryMPSCache, v1, v2)
return PartitionEdge.(a_star(ppg(bmpsc), v1, v2))
isnothing(v1) && return partition_update_sequence(bmpsc, v2)
pv = planargraph_partition(bmpsc, v1)
g = subgraph(unpartitioned_graph(ppg(bmpsc)), planargraph_vertices(bmpsc, pv))
return PartitionEdge.(a_star(g, v1, v2))
end
function partition_update_sequence(bmpsc::BoundaryMPSCache, v)
pv = planargraph_partition(bmpsc, v)
Expand All @@ -247,7 +256,7 @@ function partition_update(bmpsc::BoundaryMPSCache, args...)
Algorithm("SimpleBP"),
bmpsc,
partition_update_sequence(bmpsc, args...);
message_update_kwargs=(; normalize=false),
message_update_function_kwargs=(; normalize=false),
)
end

Expand Down Expand Up @@ -285,8 +294,8 @@ function gauge_step(
m2, m2r = only(message(bmpsc, pe2)), only(message(bmpsc, reverse(pe2)))
top_cind, bottom_cind = commonind(m1, m2), commonind(m1r, m2r)
m1_siteinds, m2_siteinds = commoninds(m1, m1r), commoninds(m2, m2r)
top_ncind, bottom_ncind = setdiff(inds(m1), [m1_siteinds; top_cind]),
setdiff(inds(m1r), [m1_siteinds; bottom_cind])
top_ncind = setdiff(inds(m1), [m1_siteinds; top_cind])
bottom_ncind = setdiff(inds(m1r), [m1_siteinds; bottom_cind])

E = if isempty(top_ncind)
m1 * m1r
Expand All @@ -298,16 +307,14 @@ function gauge_step(
S_sqrtinv = map_diag(x -> pinv(sqrt(x)), S)
S_sqrt = map_diag(x -> sqrt(x), S)

set!(ms, pe1, ITensor[replaceind((m1 * dag(V)) * S_sqrtinv, commonind(U, S), top_cind)])
set!(
ms,
reverse(pe1),
ITensor[replaceind((m1r * dag(U)) * S_sqrtinv, commonind(V, S), bottom_cind)],
)
set!(ms, pe2, ITensor[replaceind((m2 * V) * S_sqrt, commonind(U, S), top_cind)])
set!(
ms, reverse(pe2), ITensor[replaceind((m2r * U) * S_sqrt, commonind(V, S), bottom_cind)]
)
m1 = replaceind((m1 * dag(V)) * S_sqrtinv, commonind(U, S), top_cind)
m1r = replaceind((m1r * dag(U)) * S_sqrtinv, commonind(V, S), bottom_cind)
m2 = replaceind((m2 * V) * S_sqrt, commonind(U, S), top_cind)
m2r = replaceind((m2r * U) * S_sqrt, commonind(V, S), bottom_cind)
set!(ms, pe1, ITensor[m1])
set!(ms, reverse(pe1), ITensor[m1r])
set!(ms, pe2, ITensor[m2])
set!(ms, reverse(pe2), ITensor[m2r])

return bmpsc
end
Expand Down Expand Up @@ -352,52 +359,28 @@ function default_inserter(
p_above, p_below = partitionedge_above(bmpsc, pe), partitionedge_below(bmpsc, pe)
me = only(me)
me_prev = only(message(bmpsc, pe))
if !isnothing(p_above)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_above)))),
commonind(me_prev, only(message(bmpsc, p_above))),
)
end
if !isnothing(p_below)
me = replaceind(
me,
commonind(me, only(message(bmpsc, reverse(p_below)))),
commonind(me_prev, only(message(bmpsc, p_below))),
)
for pe in filter(x -> !isnothing(x), [p_above, p_below])
ind1 = commonind(me, only(message(bmpsc, reverse(pe))))
ind2 = commonind(me_prev, only(message(bmpsc, pe)))
me *= delta(ind1, ind2)
end
return set_message(bmpsc, pe, ITensor[me])
end

function default_updater(
alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v
)
bmpsc = if !isnothing(prev_pe)
gauge(alg, bmpsc, reverse(prev_pe), reverse(update_pe))
else
gauge(alg, bmpsc, reverse(update_pe))
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end
rev_prev_pe = isnothing(prev_pe) ? nothing : reverse(prev_pe)
bmpsc = gauge(alg, bmpsc, rev_prev_pe, reverse(update_pe))
bmpsc = partition_update(bmpsc, prev_v, cur_v)
return bmpsc
end

function default_updater(
alg::Algorithm"biorthogonal", bmpsc::BoundaryMPSCache, prev_pe, update_pe, prev_v, cur_v
)
bmpsc = if !isnothing(prev_pe)
gauge(alg, bmpsc, prev_pe, update_pe)
else
gauge(alg, bmpsc, update_pe)
end
bmpsc = if !isnothing(prev_v)
partition_update(bmpsc, prev_v, cur_v)
else
partition_update(bmpsc, cur_v)
end
bmpsc = gauge(alg, bmpsc, prev_pe, update_pe)
bmpsc = partition_update(bmpsc, prev_v, cur_v)
return bmpsc
end

Expand Down Expand Up @@ -460,9 +443,7 @@ function update(
for update_pe in update_seq
cur_v = parent(src(update_pe))
bmpsc = updater(alg, bmpsc, prev_pe, update_pe, prev_v, cur_v)
me = updated_message(
bmpsc, update_pe; message_update=ms -> default_message_update(ms; normalize)
)
me = updated_message(bmpsc, update_pe; message_update_function_kwargs=(; normalize))
cf += costfunction(alg, bmpsc, update_pe, me)
bmpsc = inserter(alg, bmpsc, update_pe, me)
prev_v, prev_pe = cur_v, update_pe
Expand Down
4 changes: 2 additions & 2 deletions src/caches/boundarympscacheutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ pair(pe::PartitionEdge) = parent(src(pe)) => parent(dst(pe))

#Return a sequence of pairs to go from item1 to item2 in an ordered_list
function pair_sequence(ordered_list::Vector, item1, item2)
item1_pos, item2_pos = only(findall(x -> x == item1, ordered_list)),
only(findall(x -> x == item2, ordered_list))
item1_pos, item2_pos = findfirst(x -> x == item1, ordered_list),
findfirst(x -> x == item2, ordered_list)
item1_pos < item2_pos &&
return [ordered_list[i] => ordered_list[i + 1] for i in item1_pos:(item2_pos - 1)]
return [ordered_list[i] => ordered_list[i - 1] for i in item1_pos:-1:(item2_pos + 1)]
Expand Down
4 changes: 2 additions & 2 deletions test/test_boundarymps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ using LinearAlgebra: norm
x -> abs(real(x)), A, first(inds(A)), last(inds(A)); ishermitian=true
)
end
message_update_f = ms -> make_posdef.(default_message_update(ms))
f = ms -> make_posdef.(default_message_update(ms))
ψ_vidal = VidalITensorNetwork(
ψ;
cache_update_kwargs=(;
maxiter=50, tol=1e-14, message_update_kwargs=(; message_update=message_update_f)
maxiter=50, tol=1e-14, message_update_kwargs=(; message_update_function=f)
),
)
cache_ref = Ref{BeliefPropagationCache}()
Expand Down

0 comments on commit 35fe94f

Please sign in to comment.