Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic Boundary Matrix Product States Contraction Code #212

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c845947
Blah
JoeyT1994 Sep 13, 2024
90c7251
Merge remote-tracking branch 'origin/main'
JoeyT1994 Oct 17, 2024
86f3087
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Oct 17, 2024
6ff0cd5
Bug fix in current ortho. Change test
JoeyT1994 Oct 17, 2024
34e8e5e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Nov 22, 2024
d096722
Fix bug
JoeyT1994 Nov 26, 2024
70a3f7e
Merge remote-tracking branch 'upstream/main'
JoeyT1994 Dec 5, 2024
921810d
Save First Run
JoeyT1994 Dec 18, 2024
c1c1f94
Preliminary ideas
JoeyT1994 Dec 18, 2024
c7ab747
Save
JoeyT1994 Dec 19, 2024
6f9dc32
First commit
JoeyT1994 Dec 19, 2024
08f574a
Sequence sorting
JoeyT1994 Dec 19, 2024
63b18ee
Test change
JoeyT1994 Dec 19, 2024
008b981
Stuff
JoeyT1994 Dec 26, 2024
4febcc3
Gauging
JoeyT1994 Dec 26, 2024
50c53d7
Workinggit add examples/test_boundarymps.jl
JoeyT1994 Dec 27, 2024
9b9b05d
Missing edges
JoeyT1994 Dec 28, 2024
a8c04d3
Simplify
JoeyT1994 Dec 29, 2024
d8ce9cb
Code clean
JoeyT1994 Dec 29, 2024
d095c6f
Improved nomenclature
JoeyT1994 Dec 30, 2024
6e8d7b8
Adding biorthogonalization feature
JoeyT1994 Dec 30, 2024
2ca6404
Biorthogonal algorithm
JoeyT1994 Dec 31, 2024
05ae49a
File Structure
JoeyT1994 Dec 31, 2024
9c5f6c4
Utils
JoeyT1994 Dec 31, 2024
12722da
Add test
JoeyT1994 Dec 31, 2024
f7c6beb
More tests
JoeyT1994 Jan 1, 2025
7d4465b
Testing
JoeyT1994 Jan 1, 2025
1b23fbb
Working Commit
JoeyT1994 Jan 2, 2025
338a41f
BoundaryMPS
JoeyT1994 Jan 2, 2025
62c87b4
Merge branch 'BoundaryMPS' of github.com:JoeyT1994/ITensorNetworks.jl…
JoeyT1994 Jan 2, 2025
9d64029
Revert BP Cache Code
JoeyT1994 Jan 2, 2025
dc46339
Rename kwarg
JoeyT1994 Jan 2, 2025
dd22863
examples/test_boundarymps.jl
JoeyT1994 Jan 2, 2025
807739f
Updated tests
JoeyT1994 Jan 2, 2025
ee078e5
Formatting
JoeyT1994 Jan 2, 2025
a028422
AbstractCache
JoeyT1994 Jan 2, 2025
81dcea4
Revert "AbstractCache"
JoeyT1994 Jan 2, 2025
5dc9677
Abstract Cahce
JoeyT1994 Jan 2, 2025
611bf18
Working Abstract Cache
JoeyT1994 Jan 2, 2025
6c7a2a8
Simplify initialization of messges
JoeyT1994 Jan 3, 2025
531c5af
Rm redundant file
JoeyT1994 Jan 3, 2025
93d905f
Merge branch 'AbstractCache' into BoundaryMPS
JoeyT1994 Jan 3, 2025
b11c2f8
Bug fix in test_apply
JoeyT1994 Jan 3, 2025
99b4929
Generic update interface for boundary mps and simple bp
JoeyT1994 Jan 8, 2025
19630aa
Unify update function
JoeyT1994 Jan 9, 2025
35fe94f
Generic naming support
JoeyT1994 Jan 10, 2025
4ae3b53
Updated expect, inner and environment interfaces
JoeyT1994 Jan 10, 2025
df7a4c5
Fix failing tests
JoeyT1994 Jan 10, 2025
e9e0336
Two site working
JoeyT1994 Jan 13, 2025
e98a056
Cleanup
JoeyT1994 Jan 14, 2025
5b9266e
Remove @show
JoeyT1994 Jan 15, 2025
0228aa1
Fix Bug in Expect
JoeyT1994 Jan 16, 2025
1fcf947
Fix expect test
JoeyT1994 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ include("edge_sequences.jl")
include("formnetworks/abstractformnetwork.jl")
include("formnetworks/bilinearformnetwork.jl")
include("formnetworks/quadraticformnetwork.jl")
include("caches/abstractbeliefpropagationcache.jl")
include("caches/beliefpropagationcache.jl")
include("caches/boundarympscacheutils.jl")
include("caches/boundarympscache.jl")
include("contraction_tree_to_graph.jl")
include("gauging.jl")
include("utils.jl")
Expand Down
288 changes: 288 additions & 0 deletions src/caches/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
using Graphs: IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag, dot
using ITensors: dir
using ITensorMPS: ITensorMPS
using NamedGraphs.PartitionedGraphs:
PartitionedGraphs,
PartitionedGraph,
PartitionVertex,
boundary_partitionedges,
partitionvertices,
partitionedges,
unpartitioned_graph
using SimpleTraits: SimpleTraits, Not, @traitfn
using NDTensors: NDTensors

abstract type AbstractBeliefPropagationCache end

function default_message_update(contract_list::Vector{ITensor}; normalize=true, kwargs...)
sequence = optimal_contraction_sequence(contract_list)
updated_messages = contract(contract_list; sequence, kwargs...)
message_norm = norm(updated_messages)
if normalize && !iszero(message_norm)
updated_messages /= message_norm
end
return ITensor[updated_messages]
end

#TODO: Take `dot` without precontracting the messages to allow scaling to more complex messages
function message_diff(message_a::Vector{ITensor}, message_b::Vector{ITensor})
lhs, rhs = contract(message_a), contract(message_b)
f = abs2(dot(lhs / norm(lhs), rhs / norm(rhs)))
return 1 - f
end

default_message(elt, inds_e) = ITensor[denseblocks(delta(elt, i)) for i in inds_e]
default_messages(ptn::PartitionedGraph) = Dictionary()
@traitfn default_bp_maxiter(g::::(!IsDirected)) = is_tree(g) ? 1 : nothing
@traitfn function default_bp_maxiter(g::::IsDirected)
return default_bp_maxiter(undirected_graph(underlying_graph(g)))
end
default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ))
function default_partitioned_vertices(f::AbstractFormNetwork)
return group(v -> original_state_vertex(f, v), vertices(f))
end

partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented()
messages(bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_message(
bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...
)
return not_implemented()
end
default_message_update_alg(bpc::AbstractBeliefPropagationCache) = not_implemented()
Base.copy(bpc::AbstractBeliefPropagationCache) = not_implemented()
default_bp_maxiter(alg::Algorithm, bpc::AbstractBeliefPropagationCache) = not_implemented()
function default_edge_sequence(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function default_message_update_kwargs(alg::Algorithm, bpc::AbstractBeliefPropagationCache)
return not_implemented()
end
function environment(bpc::AbstractBeliefPropagationCache, verts::Vector; kwargs...)
return not_implemented()
end
function region_scalar(bpc::AbstractBeliefPropagationCache, pv::PartitionVertex; kwargs...)
return not_implemented()
end
function region_scalar(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge; kwargs...)
return not_implemented()
end
partitions(bpc::AbstractBeliefPropagationCache) = not_implemented()
partitionpairs(bpc::AbstractBeliefPropagationCache) = not_implemented()

function default_edge_sequence(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_edge_sequence(Algorithm(alg), bpc)
end
function default_bp_maxiter(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_bp_maxiter(Algorithm(alg), bpc)
end
function default_message_update_kwargs(
bpc::AbstractBeliefPropagationCache; alg=default_message_update_alg(bpc)
)
return default_message_update_kwargs(Algorithm(alg), bpc)
end

function tensornetwork(bpc::AbstractBeliefPropagationCache)
return unpartitioned_graph(partitioned_tensornetwork(bpc))
end

function factors(bpc::AbstractBeliefPropagationCache, verts::Vector)
return ITensor[tensornetwork(bpc)[v] for v in verts]
end

function factor(bpc::AbstractBeliefPropagationCache, vertex::PartitionVertex)
return factors(bpc, vertices(bpc, vertex))
end

function vertex_scalars(bpc::AbstractBeliefPropagationCache, pvs=partitions(bpc); kwargs...)
return map(pv -> region_scalar(bpc, pv; kwargs...), pvs)
end

function edge_scalars(
bpc::AbstractBeliefPropagationCache, pes=partitionpairs(bpc); kwargs...
)
return map(pe -> region_scalar(bpc, pe; kwargs...), pes)
end

function scalar_factors_quotient(bpc::AbstractBeliefPropagationCache)
return vertex_scalars(bpc), edge_scalars(bpc)
end

function incoming_messages(
bpc::AbstractBeliefPropagationCache,
partition_vertices::Vector{<:PartitionVertex};
ignore_edges=(),
)
bpes = boundary_partitionedges(bpc, partition_vertices; dir=:in)
ms = messages(bpc, setdiff(bpes, ignore_edges))
return reduce(vcat, ms; init=ITensor[])
end

function incoming_messages(
bpc::AbstractBeliefPropagationCache, partition_vertex::PartitionVertex; kwargs...
)
return incoming_messages(bpc, [partition_vertex]; kwargs...)
end

#Forward from partitioned graph
for f in [
:(PartitionedGraphs.partitioned_graph),
:(PartitionedGraphs.partitionedge),
:(PartitionedGraphs.partitionvertices),
:(PartitionedGraphs.vertices),
:(PartitionedGraphs.boundary_partitionedges),
:(ITensorMPS.linkinds),
]
@eval begin
function $f(bpc::AbstractBeliefPropagationCache, args...; kwargs...)
return $f(partitioned_tensornetwork(bpc), args...; kwargs...)
end
end
end

NDTensors.scalartype(bpc::AbstractBeliefPropagationCache) = scalartype(tensornetwork(bpc))

"""
Update the tensornetwork inside the cache
"""
function update_factors(bpc::AbstractBeliefPropagationCache, factors)
bpc = copy(bpc)
tn = tensornetwork(bpc)
for vertex in eachindex(factors)
# TODO: Add a check that this preserves the graph structure.
setindex_preserve_graph!(tn, factors[vertex], vertex)
end
return bpc
end

function update_factor(bpc, vertex, factor)
return update_factors(bpc, Dictionary([vertex], [factor]))
end

function message(bpc::AbstractBeliefPropagationCache, edge::PartitionEdge; kwargs...)
mts = messages(bpc)
return get(() -> default_message(bpc, edge; kwargs...), mts, edge)
end
function messages(bpc::AbstractBeliefPropagationCache, edges; kwargs...)
return map(edge -> message(bpc, edge; kwargs...), edges)
end
function set_message(bpc::AbstractBeliefPropagationCache, pe::PartitionEdge, message)
bpc = copy(bpc)
ms = messages(bpc)
set!(ms, pe, message)
return bpc
end

"""
Compute message tensor as product of incoming mts and local state
"""
function updated_message(
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
message_update_function=default_message_update,
message_update_function_kwargs=(;),
)
vertex = src(edge)
incoming_ms = incoming_messages(bpc, vertex; ignore_edges=PartitionEdge[reverse(edge)])
state = factor(bpc, vertex)

return message_update_function(
ITensor[incoming_ms; state]; message_update_function_kwargs...
)
end

function update(
alg::Algorithm"simplebp",
bpc::AbstractBeliefPropagationCache,
edge::PartitionEdge;
kwargs...,
)
new_m = updated_message(bpc, edge; kwargs...)
bpc = set_message(bpc, edge, new_m)
return bpc
end

"""
Do a sequential update of the message tensors on `edges`
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache,
edges::Vector;
(update_diff!)=nothing,
kwargs...,
)
bpc = copy(bpc)
for e in edges
prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing
bpc = update(alg, bpc, e; kwargs...)
if !isnothing(update_diff!)
update_diff![] += message_diff(message(bpc, e), prev_message)
end
end
return bpc
end

"""
Do parallel updates between groups of edges of all message tensors
Currently we send the full message tensor data struct to update for each edge_group. But really we only need the
mts relevant to that group.
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache,
edge_groups::Vector{<:Vector{<:PartitionEdge}};
kwargs...,
)
new_mts = copy(messages(bpc))
for edges in edge_groups
bpc_t = update(alg, bpc, edges; kwargs...)
for e in edges
new_mts[e] = message(bpc_t, e)
end
end
return set_messages(bpc, new_mts)
end

"""
More generic interface for update, with default params
"""
function update(
alg::Algorithm,
bpc::AbstractBeliefPropagationCache;
edges=default_edge_sequence(alg, bpc),
maxiter=default_bp_maxiter(alg, bpc),
message_update_kwargs=default_message_update_kwargs(alg, bpc),
tol=nothing,
verbose=false,
)
compute_error = !isnothing(tol)
if isnothing(maxiter)
error("You need to specify a number of iterations for BP!")
end
for i in 1:maxiter
diff = compute_error ? Ref(0.0) : nothing
bpc = update(alg, bpc, edges; (update_diff!)=diff, message_update_kwargs...)
if compute_error && (diff.x / length(edges)) <= tol
if verbose
println("BP converged to desired precision after $i iterations.")
end
break
end
end
return bpc
end

function update(
bpc::AbstractBeliefPropagationCache;
alg::String=default_message_update_alg(bpc),
kwargs...,
)
return update(Algorithm(alg), bpc; kwargs...)
end
Loading
Loading