Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor alternating_update to use sweep_regions #77

Merged
merged 24 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
68de2e2
Generalize data of SweepStep to NamedTuple
emstoudenmire Mar 11, 2023
7843fc5
Change tree sweeping interface a bit
emstoudenmire Mar 11, 2023
ccf208c
Define simplified one and two site sweep data
emstoudenmire Mar 16, 2023
9f86b56
Change alternating_update to use sweep data and test with DMRG
emstoudenmire Mar 16, 2023
d7e4189
Remove reverse_step from code
emstoudenmire Mar 16, 2023
f319c71
Simplify sweep data structures and move inside update_step
emstoudenmire Mar 19, 2023
0391c71
Comment noise term for now
emstoudenmire Mar 19, 2023
17c731f
Fix issue when nsweeps is smaller than parameter arrays
emstoudenmire Mar 21, 2023
1998648
Custom sweep patterns for TDVP
emstoudenmire Mar 22, 2023
2b7729a
Reimplement TDVP and refactor TDVPOrder
emstoudenmire Mar 22, 2023
c8eb4cb
Make sweep_region generation more generic
emstoudenmire Mar 28, 2023
a8ac62f
Pass more args to observer
emstoudenmire Mar 28, 2023
08097d7
Formatting
emstoudenmire Mar 28, 2023
7877132
Rename TDVPOrder to TrotterOrder
emstoudenmire May 1, 2023
0698d98
Partial work on making region
emstoudenmire May 23, 2023
f82ad24
Refactor one_site_sweep and two_site_sweep functions and remove tdvpo…
emstoudenmire May 23, 2023
f8090f3
Remove TrotterOrder type
emstoudenmire May 24, 2023
9679fe2
Remove TDVPInfo type
emstoudenmire May 24, 2023
0d7dbb5
Better implementation of appending named tuple
emstoudenmire May 24, 2023
114aa0e
Improvements, mainly to get sweep regions to have a concrete element …
emstoudenmire May 24, 2023
e8e6225
Formatting
emstoudenmire May 24, 2023
2250cc0
Format tests
emstoudenmire May 24, 2023
32d7c24
Replace flatmap(...) with flatten(map(..))
emstoudenmire May 24, 2023
acf5aac
Remove redundant definition of to_tuple
emstoudenmire May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ include(joinpath("treetensornetworks", "projttns", "projttnsum.jl"))
include(joinpath("treetensornetworks", "projttns", "projttn_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvporder.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvpinfo.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
Expand Down
26 changes: 6 additions & 20 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function _extend_sweeps_param(param, nsweeps)
if param isa Number
eparam = fill(param, nsweeps)
else
length(param) == nsweeps && return param
length(param) >= nsweeps && return param[1:nsweeps]
eparam = Vector(undef, nsweeps)
eparam[1:length(param)] = param
eparam[(length(param) + 1):end] .= param[end]
Expand All @@ -23,34 +23,25 @@ function process_sweeps(
mindim = _extend_sweeps_param(mindim, nsweeps)
cutoff = _extend_sweeps_param(cutoff, nsweeps)
noise = _extend_sweeps_param(noise, nsweeps)
return maxdim, mindim, cutoff, noise
return maxdim, mindim, cutoff, noise, kwargs
end

function alternating_update(
solver,
PH,
psi0::AbstractTTN;
checkdone=nothing,
tdvp_order=TDVPOrder(2, Base.Forward),
outputlevel=0,
time_start=0.0,
time_step=0.0,
nsweeps=1,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
kwargs...,
)
maxdim, mindim, cutoff, noise = process_sweeps(nsweeps; kwargs...)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

step_observer = get(kwargs, :step_observer!, nothing)

psi = copy(psi0)

# Keep track of the start of the current time step.
# Helpful for tracking the total time, for example
# when using time-dependent solvers.
# This will be passed as a keyword argument to the
# `solver`.
current_time = time_start
info = nothing
for sw in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sw] > write_when_maxdim_exceeds
Expand All @@ -64,31 +55,26 @@ function alternating_update(

sw_time = @elapsed begin
psi, PH, info = update_step(
tdvp_order,
solver,
PH,
time_step,
psi;
kwargs...,
current_time,
outputlevel,
sweep=sw,
maxdim=maxdim[sw],
mindim=mindim[sw],
cutoff=cutoff[sw],
noise=noise[sw],
kwargs...,
)
end

current_time += time_step

update!(step_observer; psi, sweep=sw, outputlevel, current_time)
update!(step_observer; psi, sweep=sw, outputlevel)

if outputlevel >= 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
@printf(" maxerr=%.2E", info.maxtruncerr)
print(" current_time=", round(current_time; digits=3))
#print(" current_time=", round(current_time; digits=3))
print(" time=", round(sw_time; digits=3))
println()
flush(stdout)
Expand Down
3 changes: 1 addition & 2 deletions src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ function contract(
## )
## end

reverse_step = false
PH = ProjTTNApply(tn2, tn1)
psi = alternating_update(contract_solver, PH, init; nsweeps, reverse_step, kwargs...)
psi = alternating_update(contract_solver, PH, init; nsweeps, kwargs...)

return psi
end
Expand Down
8 changes: 2 additions & 6 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@ end
Overload of `ITensors.dmrg`.
"""
function dmrg(H, init::AbstractTTN; kwargs...)
reverse_step = false
psi = alternating_update(eigsolve_solver(; kwargs...), H, init; reverse_step, kwargs...)
return psi
return alternating_update(eigsolve_solver(; kwargs...), H, init; kwargs...)
end

"""
Overload of `KrylovKit.eigsolve`.
"""
function eigsolve(H, init::AbstractTTN; kwargs...)
return dmrg(H, init; kwargs...)
end
eigsolve(H, init::AbstractTTN; kwargs...) = dmrg(H, init; kwargs...)
4 changes: 2 additions & 2 deletions src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function dmrg_x_solver(PH, init; kwargs...)
return U_max, nothing
end

function dmrg_x(PH, init::AbstractTTN; reverse_step=false, kwargs...)
psi = alternating_update(dmrg_x_solver, PH, init; reverse_step, kwargs...)
function dmrg_x(PH, init::AbstractTTN; kwargs...)
psi = alternating_update(dmrg_x_solver, PH, init; kwargs...)
return psi
end
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ function linsolve(
# TODO: Define `itensornetwork_cache`
# TODO: Define `linsolve_cache`
P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b))
return alternating_update(linsolve_solver, P, x₀; reverse_step=false, kwargs...)
return alternating_update(linsolve_solver, P, x₀; kwargs...)
end
39 changes: 37 additions & 2 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ function exponentiate_solver(; kwargs...)
init;
ishermitian=true,
issymmetric=true,
region,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
Expand Down Expand Up @@ -75,19 +76,53 @@ function _compute_nsweeps(nsteps, t, time_step)
return nsweeps
end

function sub_time_steps(order)
if order == 1
return [1.0]
elseif order == 2
return [1 / 2, 1 / 2]
elseif order == 4
s = 1.0 / (2 - 2^(1 / 3))
return [s / 2, s / 2, (1 - 2 * s) / 2, (1 - 2 * s) / 2, s / 2, s / 2]
else
error("Trotter order of $order not supported")
end
end

function tdvp_sweep(
order::Int, nsite::Int, time_step::Number, graph::AbstractGraph; kwargs...
)
sweep = []
for (substep, fac) in enumerate(sub_time_steps(order))
sub_time_step = time_step * fac
half = half_sweep(
direction(substep),
graph,
make_region;
nsite,
region_args=(; substep, time_step=sub_time_step),
reverse_args=(; substep, time_step=-sub_time_step),
reverse_step=true,
)
append!(sweep, half)
end
return sweep
end

function tdvp(
solver,
H,
t::Number,
init::AbstractTTN;
time_step::Number=t,
nsite=2,
nsteps=nothing,
order::Integer=2,
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step)
tdvp_order = TDVPOrder(order, Base.Forward)
return alternating_update(solver, H, init; nsweeps, tdvp_order, time_step, kwargs...)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
return alternating_update(solver, H, init; nsweeps, sweep_regions, nsite, kwargs...)
end

"""
Expand Down
7 changes: 0 additions & 7 deletions src/treetensornetworks/solvers/tdvpinfo.jl

This file was deleted.

29 changes: 0 additions & 29 deletions src/treetensornetworks/solvers/tdvporder.jl

This file was deleted.

117 changes: 53 additions & 64 deletions src/treetensornetworks/solvers/tree_sweeping.jl
Original file line number Diff line number Diff line change
@@ -1,76 +1,65 @@
#
# Sweep step
#
direction(step_number) = isodd(step_number) ? Base.Forward : Base.Reverse

mtfishman marked this conversation as resolved.
Show resolved Hide resolved
"""
struct SweepStep{V}

Auxiliary object specifying a single local update step in a tree sweeping algorithm.
"""
struct SweepStep{V} # TODO: parametrize on position type instead of vertex type
pos::Union{Vector{<:V},NamedEdge{V}}
time_direction::Int
end

# field access
pos(st::SweepStep) = st.pos
nsite(st::SweepStep) = isa(pos(st), AbstractEdge) ? 0 : length(pos(st))
time_direction(st::SweepStep) = st.time_direction

# utility
current_ortho(st::SweepStep) = current_ortho(typeof(pos(st)), st)
current_ortho(::Type{<:Vector{<:V}}, st::SweepStep{V}) where {V} = first(pos(st)) # not very clean...
current_ortho(::Type{NamedEdge{V}}, st::SweepStep{V}) where {V} = src(pos(st))

# reverse
Base.reverse(s::SweepStep{V}) where {V} = SweepStep{V}(reverse(pos(s)), time_direction(s))

function Base.:(==)(s1::SweepStep{V}, s2::SweepStep{V}) where {V}
return pos(s1) == pos(s2) && time_direction(s1) == time_direction(s2)
end

#
# Pre-defined sweeping paradigms
#

function one_site_sweep(
direction::Base.ForwardOrdering,
graph::AbstractGraph{V},
root_vertex::V,
reverse_step;
kwargs...,
) where {V}
edges = post_order_dfs_edges(graph, root_vertex)
steps = SweepStep{V}[]
for e in edges
push!(steps, SweepStep{V}([src(e)], +1))
reverse_step && push!(steps, SweepStep{V}(e, -1))
function make_region(
edge;
last_edge=false,
nsite=1,
region_args=(;),
reverse_args=region_args,
reverse_step=false,
)
if nsite == 1
site = ([src(edge)], region_args)
bond = (edge, reverse_args)
region = reverse_step ? (site, bond) : (site,)
if last_edge
return (region..., ([dst(edge)], region_args))
else
return region
end
elseif nsite == 2
sites_two = ([src(edge), dst(edge)], region_args)
sites_one = ([dst(edge)], reverse_args)
region = reverse_step ? (sites_two, sites_one) : (sites_two,)
if last_edge
return (sites_two,)
else
return region
end
else
error("nsite=$nsite not supported in alternating_update / update_step")
end
push!(steps, SweepStep{V}([root_vertex], +1))
return steps
end

function one_site_sweep(direction::Base.ReverseOrdering, args...; kwargs...)
return reverse(reverse.(one_site_sweep(Base.Forward, args...; kwargs...)))
#
# Helper functions to take a tuple like ([1],[2])
# and append an empty named tuple to it, giving ([1],[2],(;))
#
prepend_missing_namedtuple(t::Tuple) = ((;), t...)
prepend_missing_namedtuple(t::Tuple{<:NamedTuple,Vararg}) = t
function append_missing_namedtuple(t::Tuple)
return reverse(prepend_missing_namedtuple(reverse(t)))
end

function two_site_sweep(
direction::Base.ForwardOrdering,
graph::AbstractGraph{V},
root_vertex::V,
reverse_step;
function half_sweep(
dir::Base.ForwardOrdering,
graph::AbstractGraph,
region_function;
root_vertex=default_root_vertex(graph),
kwargs...,
) where {V}
)
edges = post_order_dfs_edges(graph, root_vertex)
steps = SweepStep{V}[]
for e in edges[1:(end - 1)]
push!(steps, SweepStep{V}([src(e), dst(e)], +1))
reverse_step && push!(steps, SweepStep{V}([dst(e)], -1))
end
push!(steps, SweepStep{V}([src(edges[end]), dst(edges[end])], +1))
steps = collect(
flatten(map(e -> region_function(e; last_edge=(e == edges[end]), kwargs...), edges))
)
# Append empty namedtuple to each element if not already present
steps = append_missing_namedtuple.(to_tuple.(steps))
return steps
end

function two_site_sweep(direction::Base.ReverseOrdering, args...; kwargs...)
return reverse(reverse.(two_site_sweep(Base.Forward, args...; kwargs...)))
function half_sweep(dir::Base.ReverseOrdering, args...; kwargs...)
return map(
region -> (reverse(region[1]), region[2:end]...),
reverse(half_sweep(Base.Forward, args...; kwargs...)),
)
end
Loading