Skip to content

Commit

Permalink
Updated expect, inner and environment interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jan 10, 2025
1 parent 35fe94f commit 4ae3b53
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 58 deletions.
15 changes: 5 additions & 10 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertice
function default_partitioned_vertices(f::AbstractFormNetwork)
return group(v -> original_state_vertex(f, v), vertices(f))
end
default_cache_update_kwargs(cache) = (; maxiter=25, tol=1e-8)

partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
Expand All @@ -70,6 +69,8 @@ end
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...)
return not_implemented()
end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented()

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
Expand Down Expand Up @@ -99,18 +100,12 @@ function factor(bpc::AbstractBeliefPropagationCache, vertex::PartitionVertex)
return factors(bpc, vertices(bpc, vertex))
end

function vertex_scalars(
bpc::AbstractBeliefPropagationCache,
pvs=partitionvertices(partitioned_tensornetwork(bpc));
kwargs...,
)
function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...)
return map(pv -> region_scalar(bpc, pv; kwargs...), pvs)
end

function edge_scalars(
bpc::AbstractBeliefPropagationCache,
pes=partitionedges(partitioned_tensornetwork(bpc));
kwargs...,
bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs...
)
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
end
Expand Down Expand Up @@ -203,7 +198,7 @@ function updated_message(
end

function update(
alg::Algorithm"SimpleBP",
alg::Algorithm"simplebp",
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
kwargs...,
Expand Down
16 changes: 12 additions & 4 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITens
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end

function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGraph)
return (;)
end

struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache
partitioned_tensornetwork::PTN
messages::MTS
Expand All @@ -44,6 +48,7 @@ end
function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end
default_cache_update_kwargs(alg::Algorithm"bp") = (; maxiter=25, tol=1e-8)

function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_tensornetwork
Expand All @@ -61,20 +66,23 @@ function Base.copy(bp_cache::BeliefPropagationCache)
)
end

default_message_update_alg(bp_cache::BeliefPropagationCache) = "SimpleBP"
default_message_update_alg(bp_cache::BeliefPropagationCache) = "simplebp"

function default_bp_maxiter(alg::Algorithm"SimpleBP", bp_cache::BeliefPropagationCache)
function default_bp_maxiter(alg::Algorithm"simplebp", bp_cache::BeliefPropagationCache)
return default_bp_maxiter(partitioned_graph(bp_cache))
end
function default_edge_sequence(alg::Algorithm"SimpleBP", bp_cache::BeliefPropagationCache)
function default_edge_sequence(alg::Algorithm"simplebp", bp_cache::BeliefPropagationCache)
return default_edge_sequence(partitioned_tensornetwork(bp_cache))
end
function default_message_update_kwargs(
alg::Algorithm"SimpleBP", bpc::AbstractBeliefPropagationCache
alg::Algorithm"simplebp", bpc::AbstractBeliefPropagationCache
)
return (;)
end

partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc))

function set_messages(cache::BeliefPropagationCache, messages)
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
end
Expand Down
45 changes: 42 additions & 3 deletions src/caches/boundarympscache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,30 @@ function default_message_update_kwargs(
)
return (; niters=3, tolerance=nothing)
end
default_boundarymps_message_rank(tn::AbstractITensorNetwork) = maxlinkdim(tn)^2
partitions(bmpsc::BoundaryMPSCache) = parent.(collect(partitionvertices(ppg(bmpsc))))
partitionpairs(bmpsc::BoundaryMPSCache) = pair.(partitionedges(ppg(bmpsc)))

function cache(
alg::Algorithm"boundarymps",
tn;
bp_cache_construction_kwargs=default_cache_construction_kwargs(Algorithm("bp"), tn),
kwargs...,
)
return BoundaryMPSCache(
BeliefPropagationCache(tn; bp_cache_construction_kwargs...); kwargs...
)
end

function default_cache_construction_kwargs(alg::Algorithm"boundarymps", tn)
return (;
bp_cache_construction_kwargs=default_cache_construction_kwargs(Algorithm("bp"), tn)
)
end

function default_cache_update_kwargs(alg::Algorithm"boundarymps")
return (; alg="orthogonal", message_update_kwargs=(; niters=25, tolerance=1e-10))
end

function Base.copy(bmpsc::BoundaryMPSCache)
return BoundaryMPSCache(
Expand Down Expand Up @@ -100,7 +124,7 @@ function BoundaryMPSCache(
bpc::BeliefPropagationCache;
grouping_function::Function=v -> first(v),
group_sorting_function::Function=v -> last(v),
message_rank::Int64=1,
message_rank::Int64=default_boundarymps_message_rank(tensornetwork(bpc)),
)
bpc = insert_pseudo_planar_edges(bpc; grouping_function)
planar_graph = partitioned_graph(bpc)
Expand All @@ -113,7 +137,7 @@ function BoundaryMPSCache(
return set_interpartition_messages(bmpsc)
end

function BoundaryMPSCache(tn::AbstractITensorNetwork, args...; kwargs...)
function BoundaryMPSCache(tn, args...; kwargs...)
return BoundaryMPSCache(BeliefPropagationCache(tn, args...); kwargs...)
end

Expand Down Expand Up @@ -253,7 +277,7 @@ end
#Update all messages within a partition along the path from from v1 to v2
function partition_update(bmpsc::BoundaryMPSCache, args...)
return update(
Algorithm("SimpleBP"),
Algorithm("simplebp"),
bmpsc,
partition_update_sequence(bmpsc, args...);
message_update_function_kwargs=(; normalize=false),
Expand Down Expand Up @@ -465,3 +489,18 @@ function ITensorNetworks.environment(bmpsc::BoundaryMPSCache, verts::Vector; kwa
bmpsc = partition_update(bmpsc, pv)
return environment(bp_cache(bmpsc), verts; kwargs...)
end

function region_scalar(bmpsc::BoundaryMPSCache, partition)
partition_vs = planargraph_vertices(bmpsc, partition)
bmpsc = partition_update(bmpsc, first(partition_vs), last(partition_vs))
return region_scalar(bp_cache(bmpsc), PartitionVertex(last(partition_vs)))
end

function region_scalar(bmpsc::BoundaryMPSCache, partitionpair::Pair)
pes = planargraph_partitionpair_partitionedges(bmpsc, partitionpair)
out = ITensor(1.0)
for pe in pes
out = (out * (only(message(bmpsc, pe)))) * only(message(bmpsc, reverse(pe)))
end
return out[]
end
6 changes: 3 additions & 3 deletions src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function NDTensors.contract(
return contract_approx(alg, tn, output_structure; kwargs...)
end

function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)
function ITensors.scalar(alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...)
return contract(alg, tn; kwargs...)[]
end

Expand All @@ -54,7 +54,7 @@ function logscalar(
(cache!)=nothing,
cache_construction_kwargs=default_cache_construction_kwargs(alg, tn),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
)
if isnothing(cache!)
cache! = Ref(cache(alg, tn; cache_construction_kwargs...))
Expand All @@ -77,6 +77,6 @@ function logscalar(
return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
end

function ITensors.scalar(alg::Algorithm"bp", tn::AbstractITensorNetwork; kwargs...)
function ITensors.scalar(alg::Algorithm, tn::AbstractITensorNetwork; kwargs...)
return exp(logscalar(alg, tn; kwargs...))
end
11 changes: 6 additions & 5 deletions src/environment.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using ITensors: contract
using NamedGraphs.PartitionedGraphs: PartitionedGraph

default_environment_algorithm() = "exact"
default_environment_algorithm() = "bp"

function environment(
tn::AbstractITensorNetwork,
Expand All @@ -19,15 +19,16 @@ function environment(
end

function environment(
::Algorithm"bp",
alg::Algorithm,
ptn::PartitionedGraph,
vertices::Vector;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
cache_construction_kwargs=default_cache_construction_kwargs(alg, ptn),
cache_update_kwargs=default_cache_update_kwargs(alg),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(ptn))
cache! = Ref(cache(alg, ptn; cache_construction_kwargs...))
end

if update_cache
Expand All @@ -38,7 +39,7 @@ function environment(
end

function environment(
alg::Algorithm"bp",
alg::Algorithm,
tn::AbstractITensorNetwork,
vertices::Vector;
partitioned_vertices=default_partitioned_vertices(tn),
Expand Down
2 changes: 1 addition & 1 deletion src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function ITensorMPS.expect(
ops;
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
cache_update_kwargs=default_cache_update_kwargs(alg),
cache_construction_function=tn ->
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
kwargs...,
Expand Down
8 changes: 4 additions & 4 deletions src/inner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ function ITensorMPS.loginner(
end

function ITensorMPS.loginner(
alg::Algorithm"bp",
alg::Algorithm,
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
dual_link_index_map=sim,
Expand All @@ -101,7 +101,7 @@ function ITensorMPS.loginner(
end

function ITensorMPS.loginner(
alg::Algorithm"bp",
alg::Algorithm,
ϕ::AbstractITensorNetwork,
A::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
Expand All @@ -113,7 +113,7 @@ function ITensorMPS.loginner(
end

function ITensors.inner(
alg::Algorithm"bp",
alg::Algorithm,
ϕ::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
dual_link_index_map=sim,
Expand All @@ -124,7 +124,7 @@ function ITensors.inner(
end

function ITensors.inner(
alg::Algorithm"bp",
alg::Algorithm,
ϕ::AbstractITensorNetwork,
A::AbstractITensorNetwork,
ψ::AbstractITensorNetwork;
Expand Down
47 changes: 19 additions & 28 deletions test/test_boundarymps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,35 +55,26 @@ using LinearAlgebra: norm
rng = StableRNG(1234)

#First a comb tree (which is still a planar graph) and a flat tensor network
g = named_comb_tree((4, 4))
g = named_comb_tree((3, 3))
χ = 2
tn = random_tensornetwork(rng, elt, g; link_space=χ)
vc = first(center(g))

∂tn_∂vc_bp = contract(environment(tn, [vc]; alg="bp"); sequence="automatic")
∂tn_∂vc_bp /= norm(∂tn_∂vc_bp)

∂tn_∂vc = subgraph(tn, setdiff(vertices(tn), [vc]))
∂tn_∂vc_exact = contract(∂tn_∂vc; sequence=contraction_sequence(∂tn_∂vc; alg="greedy"))
∂tn_∂vc_exact /= norm(∂tn_∂vc_exact)

#Orthogonal Boundary MPS, group by row
tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1)
tn_boundaryMPS = update(tn_boundaryMPS)
∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic")
∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS)

@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_exact) <= 10 * eps(real(elt))
@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_bp) <= 10 * eps(real(elt))

#Biorthogonal Boundary MPS, , group by row
tn_boundaryMPS = BoundaryMPSCache(tn; grouping_function=v -> last(v), message_rank=1)
tn_boundaryMPS = update(tn_boundaryMPS; alg="biorthogonal")
∂tn_∂vc_boundaryMPS = contract(environment(tn_boundaryMPS, [vc]); sequence="automatic")
∂tn_∂vc_boundaryMPS /= norm(∂tn_∂vc_boundaryMPS)
λ_bp = scalar(tn; alg="bp")
λ_exact = scalar(tn; alg="exact")
#Orthogonal boundary MPS group by column (default)
λ_ortho_bmps = scalar(tn; alg="boundarymps")
#Orthogonal boundary MPS group by column
λ_biortho_bmps = scalar(
tn;
alg="boundarymps",
cache_construction_kwargs=(; grouping_function=v -> last(v)),
cache_update_kwargs=(; maxiter=1, alg="biorthogonal"),
)

@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_exact) <= 10 * eps(real(elt))
@test norm(∂tn_∂vc_boundaryMPS - ∂tn_∂vc_bp) <= 10 * eps(real(elt))
@test abs(λ_ortho_bmps - λ_exact) <= 10 * eps(real(elt))
@test abs(λ_ortho_bmps - λ_bp) <= 10 * eps(real(elt))
@test abs(λ_ortho_bmps - λ_biortho_bmps) <= 10 * eps(real(elt))

#Now the norm tensor network of a square graph with a few vertices missing for added complexity
g = named_grid((5, 5))
Expand All @@ -98,14 +89,14 @@ using LinearAlgebra: norm
ρ_exact /= tr(ρ_exact)

#Orthogonal Boundary MPS, group by column (default)
ψIψ_boundaryMPS = BoundaryMPSCache(ψIψ; message_rank=χ * χ)
ψIψ_boundaryMPS = update(ψIψ_boundaryMPS)
ρ_boundaryMPS = contract(environment(ψIψ_boundaryMPS, [vc]); sequence="automatic")
ρ_boundaryMPS = contract(
environment(ψIψ, [vc]; alg="boundarymps"); sequence="automatic"
)
ρ_boundaryMPS /= tr(ρ_boundaryMPS)

@test norm(ρ_boundaryMPS - ρ_exact) <= 10 * eps(real(elt))

#Now we test BP and orthogonal and biorthogonl Boundary MPS are equivalent when run from in the symmetric gauge
#Now we test BP and orthogonal and biorthogonal Boundary MPS are equivalent when run from in the symmetric gauge
g = named_hexagonal_lattice_graph(3, 3)
s = siteinds("S=1/2", g)
χ = 2
Expand Down

0 comments on commit 4ae3b53

Please sign in to comment.