Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 14, 2025
1 parent e9e0336 commit e98a056
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 36 deletions.
81 changes: 47 additions & 34 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function default_message(bmpsc::BoundaryMPSCache, pe::PartitionEdge; kwargs...)
return default_message(bp_cache(bmpsc), pe::PartitionEdge; kwargs...)
end

#Get the dimension of the virtual index between the two message tensors on pe1 and pe2
function virtual_index_dimension(
bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::PartitionEdge
)
Expand All @@ -95,22 +96,43 @@ function virtual_index_dimension(
))
end

#Vertices of the planargraph
function planargraph_vertices(bmpsc::BoundaryMPSCache, partition)
return vertices(ppg(bmpsc), PartitionVertex(partition))
end
function planargraph_partition(bmpsc::BoundaryMPSCache, vertex)
return parent(partitionvertex(ppg(bmpsc), vertex))

#Get partition(s) of vertices of the planargraph
function planargraph_partitions(bmpsc::BoundaryMPSCache, vertices::Vector)
return parent.(partitionvertices(ppg(bmpsc), vertices))
end
function planargraph_partitions(bmpsc::BoundaryMPSCache, verts)
return parent.(partitionvertices(ppg(bmpsc), verts))

function planargraph_partition(bmpsc::BoundaryMPSCache, vertex)
return only(planargraph_partitions(bmpsc, [vertex]))
end

#Get interpartition pairs of partition edges in the underlying partitioned tensornetwork
function planargraph_partitionpair(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
return pair(partitionedge(ppg(bmpsc), parent(pe)))
end
function planargraph_sorted_partitionedges(bmpsc::BoundaryMPSCache, partition_pair::Pair)
return sorted_partitionedges(ppg(bmpsc), partition_pair)

#Sort (bottom to top) partitoonedges between pair of partitions in the planargraph
function planargraph_sorted_partitionedges(bmpsc::BoundaryMPSCache, partitionpair::Pair)
pg = ppg(bmpsc)
src_vs, dst_vs = vertices(pg, PartitionVertex(first(partitionpair))),
vertices(pg, PartitionVertex(last(partitionpair)))
es = reduce(
vcat,
[
[src_v => dst_v for dst_v in intersect(neighbors(pg, src_v), dst_vs)] for
src_v in src_vs
],
)
es = sort(NamedEdge.(es); by=x -> findfirst(isequal(src(x)), src_vs))
return PartitionEdge.(es)
end

#Constructor, inserts missing edge in the planar graph to ensure each partition is connected
#allowing the code to work for arbitrary grids and not just square grids
function BoundaryMPSCache(
bpc::BeliefPropagationCache;
grouping_function::Function=v -> first(v),
Expand All @@ -130,22 +152,6 @@ function BoundaryMPSCache(tn, args...; kwargs...)
return BoundaryMPSCache(BeliefPropagationCache(tn, args...); kwargs...)
end

#Get all partitionedges between the pair of neighboring partitions, sorted
#by the position of the source in the pg
function sorted_partitionedges(pg::PartitionedGraph, partitionpair::Pair)
src_vs, dst_vs = vertices(pg, PartitionVertex(first(partitionpair))),
vertices(pg, PartitionVertex(last(partitionpair)))
es = reduce(
vcat,
[
[src_v => dst_v for dst_v in intersect(neighbors(pg, src_v), dst_vs)] for
src_v in src_vs
],
)
es = sort(NamedEdge.(es); by=x -> findfirst(isequal(src(x)), src_vs))
return PartitionEdge.(es)
end

#Functions to get the parellel partitionedges sitting above and below a partitionedge
function partitionedges_above(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
pes = planargraph_sorted_partitionedges(bmpsc, planargraph_partitionpair(bmpsc, pe))
Expand All @@ -171,8 +177,8 @@ function partitionedge_below(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
return last(pes_below)
end

#Given a sequence of message tensor updates in a partition, get the sequence of gauge moves needed to move the MPS gauge
#from the start of the sequence to the end of the sequence
#Given a sequence of message tensor updates within a partition, get the sequence of gauge moves on the interpartition
# needed to move the MPS gauge from the start of the sequence to the end of the sequence
function mps_gauge_update_sequence(
bmpsc::BoundaryMPSCache,
partition_update_seq::Vector{<:PartitionEdge},
Expand All @@ -190,7 +196,7 @@ function mps_gauge_update_sequence(
]
end

#Get the sequence of pairs partitionedges that need to be updated to move the MPS gauge from one region of partitionedges to another
#Returns the sequence of pairs of partitionedges that need to be updated to move the MPS gauge between regions
function mps_gauge_update_sequence(
bmpsc::BoundaryMPSCache,
pe_region1::Vector{<:PartitionEdge},
Expand All @@ -216,7 +222,7 @@ function mps_gauge_update_sequence(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
return mps_gauge_update_sequence(bmpsc, [pe])
end

#Initialise all the message tensors for the pairs of neighboring partitions
#Initialise all the interpartition message tensors
function set_interpartition_messages(
bmpsc::BoundaryMPSCache, partitionpairs::Vector{<:Pair}
)
Expand All @@ -238,13 +244,12 @@ function set_interpartition_messages(
return bmpsc
end

#Initialise all the interpartition message tensors
function set_interpartition_messages(bmpsc::BoundaryMPSCache)
partitionpairs = pair.(partitionedges(ppg(bmpsc)))
return set_interpartition_messages(bmpsc, vcat(partitionpairs, reverse.(partitionpairs)))
end

#Switch the message on partition edge pe with its reverse (and dagger them)
#Switch the message tensors on partition edges with their reverse (and dagger them)
function switch_message(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
bmpsc = copy(bmpsc)
ms = messages(bmpsc)
Expand All @@ -254,7 +259,6 @@ function switch_message(bmpsc::BoundaryMPSCache, pe::PartitionEdge)
return bmpsc
end

#Switch the message tensors from partitionpair i -> i + 1 with those from i + 1 -> i
function switch_messages(bmpsc::BoundaryMPSCache, partitionpair::Pair)
for pe in planargraph_sorted_partitionedges(bmpsc, partitionpair)
bmpsc = switch_message(bmpsc, pe)
Expand All @@ -271,7 +275,7 @@ function partition_update_sequence(bmpsc::BoundaryMPSCache, partition)
)
end

#Get sequence necessary to move correct message tensors in a partition from region1 to region 2
#Get sequence necessary to move correct message tensors in a partition from region1 to region2
function partition_update_sequence(
bmpsc::BoundaryMPSCache, region1::Vector, region2::Vector
)
Expand All @@ -284,7 +288,7 @@ function partition_update_sequence(
return PartitionEdge.(path)
end

#Get sequence necessary to move correct message tensors to region
#Get sequence necessary to move correct message tensors to a region
function partition_update_sequence(bmpsc::BoundaryMPSCache, region::Vector)
pv = planargraph_partition(bmpsc, first(region))
return partition_update_sequence(bmpsc, planargraph_vertices(bmpsc, pv), region)
Expand Down Expand Up @@ -386,6 +390,7 @@ default_inserter_transform(alg::Algorithm"orthogonal") = dag
default_region_transform(alg::Algorithm"biorthogonal") = identity
default_region_transform(alg::Algorithm"orthogonal") = reverse

#Default inserter for the MPS fitting (one and two-site support)
function default_inserter(
alg::Algorithm,
bmpsc::BoundaryMPSCache,
Expand Down Expand Up @@ -419,6 +424,7 @@ function default_inserter(
return bmpsc
end

#Default updater for the MPS fitting
function default_updater(
alg::Algorithm, bmpsc::BoundaryMPSCache, prev_pe_region, update_pe_region
)
Expand All @@ -434,6 +440,7 @@ function default_updater(
return bmpsc
end

#Default extracter for the MPS fitting (1 and two-site support)
function default_extracter(
alg::Algorithm"orthogonal",
bmpsc::BoundaryMPSCache,
Expand Down Expand Up @@ -462,6 +469,8 @@ function ITensors.commonind(bmpsc::BoundaryMPSCache, pe1::PartitionEdge, pe2::Pa
return commonind(only(m1), only(m2))
end

#Transformers for switching the virtual index of message tensors on boundary of pe_region
# to those of their reverse
function virtual_index_transformers(
bmpsc::BoundaryMPSCache, pe_region::Vector{<:PartitionEdge}
)
Expand All @@ -486,6 +495,7 @@ function virtual_index_transformers(
return transformers
end

#Biorthogonal extracter for MPS fitting (needs virtual index transformer)
function default_extracter(
alg::Algorithm"biorthogonal",
bmpsc::BoundaryMPSCache,
Expand All @@ -512,6 +522,7 @@ default_niters(alg::Algorithm"biorthogonal") = 3
default_tolerance(alg::Algorithm"orthogonal") = 1e-10
default_tolerance(alg::Algorithm"biorthogonal") = nothing

#Cost functions
function default_costfunction(
alg::Algorithm"orthogonal", bmpsc::BoundaryMPSCache, pe_region::Vector{<:PartitionEdge}
)
Expand All @@ -537,6 +548,7 @@ function default_costfunction(
return region_scalar(bp_cache(bmpsc), src(pe)) / contract(ms; sequence="automatic")[]
end

#Sequences
function update_sequence(
alg::Algorithm, bmpsc::BoundaryMPSCache, partitionpair::Pair; nsites::Int64=1
)
Expand All @@ -550,8 +562,7 @@ function update_sequence(
end
end

#Update all the message tensors on an interpartition via a specified fitting procedure
#TODO: Make two-site possible
#Update all the message tensors on an interpartition via an n-site fitting procedure
function update(
alg::Algorithm,
bmpsc::BoundaryMPSCache,
Expand Down Expand Up @@ -579,6 +590,7 @@ function update(
cf += !isnothing(tolerance) ? costfunction(alg, bmpsc, update_pe_region) : 0.0
end
epsilon = abs(cf - prev_cf) / length(update_seq)
@show cf
if !isnothing(tolerance) && epsilon < tolerance
return cache_prep_function(alg, bmpsc, partitionpair)
else
Expand All @@ -588,13 +600,14 @@ function update(
return cache_prep_function(alg, bmpsc, partitionpair)
end

#Assert all vertices live in the same partition for now
#Environment support, assume all vertices live in the same partition for now
function ITensorNetworks.environment(bmpsc::BoundaryMPSCache, verts::Vector; kwargs...)
vs = parent.((partitionvertices(bp_cache(bmpsc), verts)))
bmpsc = partition_update(bmpsc, vs)
return environment(bp_cache(bmpsc), verts; kwargs...)
end

#Region scalars, allowing computation of the free energy within boundary MPS
function region_scalar(bmpsc::BoundaryMPSCache, partition)
partition_vs = planargraph_vertices(bmpsc, partition)
bmpsc = partition_update(bmpsc, [first(partition_vs)], [last(partition_vs)])
Expand Down
2 changes: 0 additions & 2 deletions src/caches/boundarympscacheutils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ITensorNetworks: ITensorNetworks, BeliefPropagationCache
using NamedGraphs.PartitionedGraphs: PartitionedGraph

#Add partition edges that may not have meaning in the underlying graph
function add_partitionedges(pg::PartitionedGraph, pes::Vector{<:PartitionEdge})
g = partitioned_graph(pg)
g = add_edges(g, parent.(pes))
Expand All @@ -10,7 +9,6 @@ function add_partitionedges(pg::PartitionedGraph, pes::Vector{<:PartitionEdge})
)
end

#Add partition edges that may not have meaning in the underlying graph
function add_partitionedges(bpc::BeliefPropagationCache, pes::Vector{<:PartitionEdge})
pg = add_partitionedges(partitioned_tensornetwork(bpc), pes)
return BeliefPropagationCache(pg, messages(bpc))
Expand Down

0 comments on commit e98a056

Please sign in to comment.