From 620da374d776f3ab1a62b924bc0e31875b85fef4 Mon Sep 17 00:00:00 2001 From: Joey Date: Tue, 10 Dec 2024 11:57:30 -0500 Subject: [PATCH] Allow rescaling flat networks with bp --- src/normalize.jl | 81 ++++++++++++++++++++++++++++++------------ test/test_normalize.jl | 24 ++++++------- 2 files changed, 71 insertions(+), 34 deletions(-) diff --git a/src/normalize.jl b/src/normalize.jl index 872a347b..7fa6b4db 100644 --- a/src/normalize.jl +++ b/src/normalize.jl @@ -1,6 +1,14 @@ using LinearAlgebra -function rescale(tn::AbstractITensorNetwork, c::Number, vs=collect(vertices(tn))) +function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return rescale(Algorithm(alg), tn; kwargs...) +end + +function rescale( + alg::Algorithm"exact", tn::AbstractITensorNetwork, vs=collect(vertices(tn)); kwargs... +) + logn = logscalar(alg, tn; kwargs...) + c = 1.0 / (exp(logn / length(vs))) tn = copy(tn) for v in vs tn[v] *= c @@ -8,25 +16,16 @@ function rescale(tn::AbstractITensorNetwork, c::Number, vs=collect(vertices(tn)) return tn end -function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...) - return normalize(Algorithm(alg), tn; kwargs...) -end - -function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork) - norm_tn = QuadraticFormNetwork(tn) - c = exp(logscalar(alg, norm_tn) / (2 * length(vertices(tn)))) - return rescale(tn, 1 / c) -end - -function LinearAlgebra.normalize( +function rescale( alg::Algorithm"bp", - tn::AbstractITensorNetwork; + tn::AbstractITensorNetwork, + vs=collect(vertices(tn)); (cache!)=nothing, update_cache=isnothing(cache!), cache_update_kwargs=default_cache_update_kwargs(cache!), ) if isnothing(cache!) - cache! = Ref(BeliefPropagationCache(QuadraticFormNetwork(tn))) + cache! = Ref(BeliefPropagationCache(tn, group(v -> v, vertices(tn)))) end if update_cache @@ -35,19 +34,57 @@ function LinearAlgebra.normalize( tn = copy(tn) cache![] = normalize_messages(cache![]) - norm_tn = tensornetwork(cache![]) - vertices_states = Dictionary() - for v in vertices(tn) - v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v) - pv = only(partitionvertices(cache![], [v_ket])) + for pv in partitionvertices(cache![]) + pv_vs = filter(v -> v ∈ vs, vertices(cache![], pv)) + + isempty(pv_vs) && continue + vn = region_scalar(cache![], pv) - norm_tn = rescale(norm_tn, 1 / sqrt(vn), [v_ket, v_bra]) - set!(vertices_states, v_ket, norm_tn[v_ket]) - set!(vertices_states, v_bra, norm_tn[v_bra]) + if isreal(vn) && vn < 0 + tn[first(pv_vs)] *= -1 + vn = abs(vn) + end + + vn = vn^(1 / length(pv_vs)) + for v in pv_vs + tn[v] /= vn + set!(vertices_states, v, tn[v]) + end end cache![] = update_factors(cache![], vertices_states) + return tn +end + +function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...) + return normalize(Algorithm(alg), tn; kwargs...) +end + +function LinearAlgebra.normalize( + alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs... +) + norm_tn = QuadraticFormNetwork(tn) + vs = filter(v -> v ∉ operator_vertices(norm_tn), collect(vertices(norm_tn))) + return ket_network(rescale(alg, norm_tn, vs; kwargs...)) +end + +function LinearAlgebra.normalize( + alg::Algorithm, + tn::AbstractITensorNetwork; + (cache!)=nothing, + cache_construction_function=tn -> + cache(alg, tn; default_cache_construction_kwargs(alg, tn)...), + update_cache=isnothing(cache!), + cache_update_kwargs=default_cache_update_kwargs(cache!), +) + norm_tn = QuadraticFormNetwork(tn) + if isnothing(cache!) + cache! = Ref(cache_construction_function(norm_tn)) + end + + vs = filter(v -> v ∉ operator_vertices(norm_tn), collect(vertices(norm_tn))) + norm_tn = rescale(alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs) return ket_network(norm_tn) end diff --git a/test/test_normalize.jl b/test/test_normalize.jl index 09bd6849..cd95d635 100644 --- a/test/test_normalize.jl +++ b/test/test_normalize.jl @@ -5,32 +5,32 @@ using ITensorNetworks: edge_scalars, norm_sqr_network, random_tensornetwork, - vertex_scalars + vertex_scalars, + rescale using ITensors: dag, inner, siteinds, scalar using Graphs: SimpleGraph, uniform_tree using LinearAlgebra: normalize using NamedGraphs: NamedGraph -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using StableRNGs: StableRNG using Test: @test, @testset @testset "Normalize" begin - #First lets do a tree - L = 6 + #First lets do a flat tree + nx, ny = 2, 3 χ = 2 rng = StableRNG(1234) - g = NamedGraph(SimpleGraph(uniform_tree(L))) - s = siteinds("S=1/2", g) - x = random_tensornetwork(rng, s; link_space=χ) + g = named_comb_tree((nx, ny)) + tn = random_tensornetwork(rng, g; link_space=χ) - ψ = normalize(x; alg="exact") - @test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 + tn_r = rescale(tn; alg="exact") + @test scalar(tn_r; alg="exact") ≈ 1.0 - ψ = normalize(x; alg="bp") - @test scalar(norm_sqr_network(ψ); alg="exact") ≈ 1.0 + tn_r = rescale(tn; alg="bp") + @test scalar(tn_r; alg="exact") ≈ 1.0 - #Now a loopy graph + #Now a state on a loopy graph Lx, Ly = 3, 2 χ = 2 rng = StableRNG(1234)