Skip to content

Commit

Permalink
Efficient implementation of expect for AbstractTTN (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-kloss authored Jan 26, 2024
1 parent a6250ac commit e1abe27
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/treetensornetworks/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,27 @@ function inner(
end
return O[]
end

function expect(
operator::String,
state::AbstractTTN;
vertices=vertices(state),
# ToDo: verify that this is a sane default
root_vertex=default_root_vertex(siteinds(state)),
)
# ToDo: for performance it may be beneficial to also implement expect! and change the orthogonality center in place
# assuming that the next algorithmic step can make use of the orthogonality center being moved to a different vertex
# ToDo: Verify that this is indeed the correct order for performance
sites = siteinds(state)
ordered_vertices = reverse(post_order_dfs_vertices(sites, root_vertex))
res = Dictionary(vertices, undef)
for v in ordered_vertices
!(v in vertices) && continue
state = orthogonalize(state, v)
@assert isone(length(sites[v]))
#ToDo: Add compatibility with more than a single index per vertex
op_v = op(operator, only(sites[v]))
res[v] = (dag(state[v]) * apply(op_v, state[v]))[]
end
return mapreduce(typeof, promote_type, res).(res)
end
32 changes: 32 additions & 0 deletions test/test_treetensornetworks/test_expect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using ITensors
using ITensorNetworks
using Test

@testset "MPS expect comparison with ITensors" begin
N = 25
s = siteinds("S=1/2", N)
a = random_mps(s; internal_inds_space=100)
b = MPS([a[v] for v in vertices(a)])
res_a = expect("Sz", a)
res_b = expect(b, "Sz")
res_a = [res_a[v] for v in vertices(a)]
@test res_a res_b rtol = 1e-6
end

@testset "TTN expect" begin
tooth_lengths = fill(5, 6)
c = named_comb_tree(tooth_lengths)
s = siteinds("S=1/2", c)
d = Dict()
magnetization = Dict()
for (i, v) in enumerate(vertices(s))
d[v] = isodd(i) ? "Up" : "Dn"
magnetization[v] = isodd(i) ? 0.5 : -0.5
end
states = v -> d[v]
state = TTN(s, states)
res = expect("Sz", state)
@test all([isapprox(res[v], magnetization[v]; atol=1e-8) for v in vertices(s)])
end

nothing

0 comments on commit e1abe27

Please sign in to comment.