Skip to content

Commit

Permalink
Allow rescaling flat networks with bp
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Dec 10, 2024
1 parent 4f4e2e5 commit 620da37
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 34 deletions.
81 changes: 59 additions & 22 deletions src/normalize.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
using LinearAlgebra

function rescale(tn::AbstractITensorNetwork, c::Number, vs=collect(vertices(tn)))
function rescale(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return rescale(Algorithm(alg), tn; kwargs...)
end

function rescale(
alg::Algorithm"exact", tn::AbstractITensorNetwork, vs=collect(vertices(tn)); kwargs...
)
logn = logscalar(alg, tn; kwargs...)
c = 1.0 / (exp(logn / length(vs)))
tn = copy(tn)
for v in vs
tn[v] *= c
end
return tn
end

function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return normalize(Algorithm(alg), tn; kwargs...)
end

function LinearAlgebra.normalize(alg::Algorithm"exact", tn::AbstractITensorNetwork)
norm_tn = QuadraticFormNetwork(tn)
c = exp(logscalar(alg, norm_tn) / (2 * length(vertices(tn))))
return rescale(tn, 1 / c)
end

function LinearAlgebra.normalize(
function rescale(
alg::Algorithm"bp",
tn::AbstractITensorNetwork;
tn::AbstractITensorNetwork,
vs=collect(vertices(tn));
(cache!)=nothing,
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
if isnothing(cache!)
cache! = Ref(BeliefPropagationCache(QuadraticFormNetwork(tn)))
cache! = Ref(BeliefPropagationCache(tn, group(v -> v, vertices(tn))))
end

if update_cache
Expand All @@ -35,19 +34,57 @@ function LinearAlgebra.normalize(

tn = copy(tn)
cache![] = normalize_messages(cache![])
norm_tn = tensornetwork(cache![])

vertices_states = Dictionary()
for v in vertices(tn)
v_ket, v_bra = ket_vertex(norm_tn, v), bra_vertex(norm_tn, v)
pv = only(partitionvertices(cache![], [v_ket]))
for pv in partitionvertices(cache![])
pv_vs = filter(v -> v vs, vertices(cache![], pv))

isempty(pv_vs) && continue

vn = region_scalar(cache![], pv)
norm_tn = rescale(norm_tn, 1 / sqrt(vn), [v_ket, v_bra])
set!(vertices_states, v_ket, norm_tn[v_ket])
set!(vertices_states, v_bra, norm_tn[v_bra])
if isreal(vn) && vn < 0
tn[first(pv_vs)] *= -1
vn = abs(vn)
end

vn = vn^(1 / length(pv_vs))
for v in pv_vs
tn[v] /= vn
set!(vertices_states, v, tn[v])
end
end

cache![] = update_factors(cache![], vertices_states)
return tn
end

function LinearAlgebra.normalize(tn::AbstractITensorNetwork; alg="exact", kwargs...)
return normalize(Algorithm(alg), tn; kwargs...)
end

function LinearAlgebra.normalize(
alg::Algorithm"exact", tn::AbstractITensorNetwork; kwargs...
)
norm_tn = QuadraticFormNetwork(tn)
vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
return ket_network(rescale(alg, norm_tn, vs; kwargs...))
end

function LinearAlgebra.normalize(
alg::Algorithm,
tn::AbstractITensorNetwork;
(cache!)=nothing,
cache_construction_function=tn ->
cache(alg, tn; default_cache_construction_kwargs(alg, tn)...),
update_cache=isnothing(cache!),
cache_update_kwargs=default_cache_update_kwargs(cache!),
)
norm_tn = QuadraticFormNetwork(tn)
if isnothing(cache!)
cache! = Ref(cache_construction_function(norm_tn))
end

vs = filter(v -> v operator_vertices(norm_tn), collect(vertices(norm_tn)))
norm_tn = rescale(alg, norm_tn, vs; cache!, update_cache, cache_update_kwargs)

return ket_network(norm_tn)
end
24 changes: 12 additions & 12 deletions test/test_normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@ using ITensorNetworks:
edge_scalars,
norm_sqr_network,
random_tensornetwork,
vertex_scalars
vertex_scalars,
rescale
using ITensors: dag, inner, siteinds, scalar
using Graphs: SimpleGraph, uniform_tree
using LinearAlgebra: normalize
using NamedGraphs: NamedGraph
using NamedGraphs.NamedGraphGenerators: named_grid
using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree
using StableRNGs: StableRNG
using Test: @test, @testset
@testset "Normalize" begin

#First lets do a tree
L = 6
#First lets do a flat tree
nx, ny = 2, 3
χ = 2
rng = StableRNG(1234)

g = NamedGraph(SimpleGraph(uniform_tree(L)))
s = siteinds("S=1/2", g)
x = random_tensornetwork(rng, s; link_space=χ)
g = named_comb_tree((nx, ny))
tn = random_tensornetwork(rng, g; link_space=χ)

ψ = normalize(x; alg="exact")
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
tn_r = rescale(tn; alg="exact")
@test scalar(tn_r; alg="exact") 1.0

ψ = normalize(x; alg="bp")
@test scalar(norm_sqr_network(ψ); alg="exact") 1.0
tn_r = rescale(tn; alg="bp")
@test scalar(tn_r; alg="exact") 1.0

#Now a loopy graph
#Now a state on a loopy graph
Lx, Ly = 3, 2
χ = 2
rng = StableRNG(1234)
Expand Down

0 comments on commit 620da37

Please sign in to comment.