Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into BP_Update_Order
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Nov 6, 2023
2 parents 6df26bd + 94eb10f commit 070996a
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 83 deletions.
4 changes: 2 additions & 2 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ function svd(
)
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
U, S, V = svd(tn[src(edge)], left_inds; lefttags=u_tags, right_tags=v_tags, kwargs...)
U, S, V = svd(tn[src(edge)], left_inds; lefttags=u_tags, righttags=v_tags, kwargs...)

rem_vertex!(tn, src(edge))
add_vertex!(tn, U_vertex)
Expand Down Expand Up @@ -562,7 +562,7 @@ function _truncate_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs..
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
U, S, V = svd(tn[src(edge)], left_inds; lefttags=ltags, ortho="left", kwargs...)
U, S, V = svd(tn[src(edge)], left_inds; lefttags=ltags, kwargs...)
tn[src(edge)] = U
tn[dst(edge)] *= (S * V)
return tn
Expand Down
4 changes: 1 addition & 3 deletions src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ function vidal_gauge(
target_precision=target_canonicalness,
verbose,
)
return vidal_gauge(
ψ, mts; eigen_message_tensor_cutoff, regularization, niters, svd_kwargs...
)
return vidal_gauge(ψ, mts; eigen_message_tensor_cutoff, regularization, svd_kwargs...)
end

"""Transform from an ITensor in the Vidal Gauge (bond tensors) to the Symmetric Gauge (message tensors)"""
Expand Down
7 changes: 1 addition & 6 deletions src/treetensornetworks/solvers/applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ struct ApplyExpInfo
converged::Int
end

function applyexp(H, tau::Number, x0; kwargs...)
maxiter = get(kwargs, :maxiter, 30)
tol = get(kwargs, :tol, 1E-12)
outputlevel = get(kwargs, :outputlevel, 0)
beta_tol = get(kwargs, :normcutoff, 1E-7)

function applyexp(H, tau::Number, x0; maxiter=30, tol=1e-12, outputlevel=0, normcutoff=1e-7)
# Initialize Lanczos vectors
v1 = copy(x0)
nrm = norm(v1)
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function contract_solver(PH, psi; kwargs...)
function contract_solver(PH, psi; normalize, region, half_sweep)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
Expand Down
54 changes: 43 additions & 11 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
function solver(H, init; kws...)
function eigsolve_solver(;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
)
function solver(H, init; normalize=nothing, region=nothing, half_sweep=nothing)
howmany = 1
which = solver_which_eigenvalue
solver_kwargs = (;
ishermitian=get(kwargs, :ishermitian, true),
tol=get(kwargs, :solver_tol, 1E-14),
krylovdim=get(kwargs, :solver_krylovdim, 3),
maxiter=get(kwargs, :solver_maxiter, 1),
verbosity=get(kwargs, :solver_verbosity, 0),
vals, vecs, info = eigsolve(
H,
init,
howmany,
which;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...)
psi = vecs[1]
return psi, (; solver_info=info, energies=vals)
end
Expand All @@ -19,8 +29,30 @@ end
"""
Overload of `ITensors.dmrg`.
"""
function dmrg(H, init::AbstractTTN; kwargs...)
return alternating_update(eigsolve_solver(; kwargs...), H, init; kwargs...)
function dmrg(
H,
init::AbstractTTN;
solver_which_eigenvalue=:SR,
ishermitian=true,
solver_tol=1e-14,
solver_krylovdim=3,
solver_maxiter=1,
solver_verbosity=0,
kwargs...,
)
return alternating_update(
eigsolve_solver(;
solver_which_eigenvalue,
ishermitian,
solver_tol,
solver_krylovdim,
solver_maxiter,
solver_verbosity,
),
H,
init;
kwargs...,
)
end

"""
Expand Down
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function dmrg_x_solver(PH, init; kwargs...)
function dmrg_x_solver(
PH, init; normalize=nothing, region=nothing, half_sweep=nothing, reverse_step=nothing
)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
Expand Down
22 changes: 16 additions & 6 deletions src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ Keyword arguments:
Overload of `KrylovKit.linsolve`.
"""
function linsolve(
A::AbstractTTN, b::AbstractTTN, x₀::AbstractTTN, a₀::Number=0, a₁::Number=1; kwargs...
A::AbstractTTN,
b::AbstractTTN,
x₀::AbstractTTN,
a₀::Number=0,
a₁::Number=1;
normalize,
region,
half_sweep,
)
function linsolve_solver(
P,
Expand All @@ -33,17 +40,20 @@ function linsolve(
solver_krylovdim=30,
solver_maxiter=100,
solver_verbosity=0,
kwargs...,
)
solver_kwargs = (;
ishermitian=ishermitian,
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(
P,
b,
x₀,
a₀,
a₁;
ishermitian,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_verbosity,
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, NamedTuple()
end

Expand Down
39 changes: 24 additions & 15 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function exponentiate_solver(; kwargs...)
function exponentiate_solver()
function solver(
H,
init;
Expand All @@ -10,10 +10,13 @@ function exponentiate_solver(; kwargs...)
solver_outputlevel=0,
solver_tol=1E-12,
substep,
normalize,
time_step,
kws...,
)
solver_kwargs = (;
psi, exp_info = KrylovKit.exponentiate(
H,
time_step,
init;
ishermitian,
issymmetric,
tol=solver_tol,
Expand All @@ -22,14 +25,12 @@ function exponentiate_solver(; kwargs...)
verbosity=solver_outputlevel,
eager=true,
)

psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, (; info=exp_info)
end
return solver
end

function applyexp_solver(; kwargs...)
function applyexp_solver()
function solver(
H,
init;
Expand All @@ -39,13 +40,13 @@ function applyexp_solver(; kwargs...)
solver_tol=1E-8,
substep,
time_step,
kws...,
normalize,
)
solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
psi, exp_info = applyexp(
H, time_step, init; tol, maxiter=solver_krylovdim, outputlevel=solver_outputlevel
)
return psi, (; info=exp_info)
end
return solver
Expand Down Expand Up @@ -84,7 +85,12 @@ function sub_time_steps(order)
end

function tdvp_sweep(
order::Int, nsite::Int, time_step::Number, graph::AbstractGraph; kwargs...
order::Int,
nsite::Int,
time_step::Number,
graph::AbstractGraph;
root_vertex=default_root_vertex(graph),
reverse_step=true,
)
sweep = []
for (substep, fac) in enumerate(sub_time_steps(order))
Expand All @@ -93,10 +99,11 @@ function tdvp_sweep(
direction(substep),
graph,
make_region;
root_vertex,
nsite,
region_args=(; substep, time_step=sub_time_step),
reverse_args=(; substep, time_step=-sub_time_step),
reverse_step=true,
reverse_step,
)
append!(sweep, half)
end
Expand All @@ -113,10 +120,12 @@ function tdvp(
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
root_vertex=default_root_vertex(init),
reverse_step=true,
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; root_vertex, reverse_step)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
if outputlevel >= 1
Expand Down Expand Up @@ -156,7 +165,7 @@ Optional keyword arguments:
* `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run.
* `nsteps::Integer` - evolve by the requested total time `t` by performing `nsteps` of the TDVP algorithm. More steps can result in more accurate results but require more computational time to run. (Note that only one of the `time_step` or `nsteps` parameters can be provided, not both.)
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `observer` - object implementing the Observer interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kwargs...)
Expand All @@ -169,5 +178,5 @@ function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kw
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
return tdvp(solver(; kwargs...), H, t, init; kwargs...)
return tdvp(solver(), H, t, init; kwargs...)
end
34 changes: 18 additions & 16 deletions src/treetensornetworks/solvers/update_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,29 @@ function insert_local_tensor(
psi::AbstractTTN,
phi::ITensor,
pos::Vector;
which_decomp=nothing,
normalize=false,
# factorize kwargs
maxdim=nothing,
mindim=nothing,
cutoff=nothing,
which_decomp=nothing,
eigen_perturbation=nothing,
kwargs...,
ortho=nothing,
)
spec = nothing
for (v, vnext) in IterTools.partition(pos, 2, 1)
e = edgetype(psi)(v, vnext)
indsTe = inds(psi[v])
L, phi, spec = factorize(
phi, indsTe; which_decomp, tags=tags(psi, e), eigen_perturbation, kwargs...
phi,
indsTe;
tags=tags(psi, e),
maxdim,
mindim,
cutoff,
which_decomp,
eigen_perturbation,
ortho,
)
psi[v] = L
eigen_perturbation = nothing # TODO: fix this
Expand Down Expand Up @@ -162,16 +174,15 @@ function local_update(
sweep,
sweep_regions,
sweep_step,
kwargs...,
solver_kwargs...,
)
psi = orthogonalize(psi, current_ortho(region))
psi, phi = extract_local_tensor(psi, region)

nsites = (region isa AbstractEdge) ? 0 : length(region)
PH = set_nsite(PH, nsites)
PH = position(PH, psi, region)

phi, info = solver(PH, phi; normalize, region, step_kwargs..., kwargs...)
phi, info = solver(PH, phi; normalize, region, step_kwargs..., solver_kwargs...)
if !(phi isa ITensor && info isa NamedTuple)
println("Solver returned the following types: $(typeof(phi)), $(typeof(info))")
error("In alternating_update, solver must return an ITensor and a NamedTuple")
Expand All @@ -185,16 +196,7 @@ function local_update(
#end

psi, spec = insert_local_tensor(
psi,
phi,
region;
eigen_perturbation=drho,
ortho,
normalize,
cutoff,
maxdim,
mindim,
kwargs...,
psi, phi, region; eigen_perturbation=drho, ortho, normalize, maxdim, mindim, cutoff
)

update!(
Expand Down
4 changes: 2 additions & 2 deletions test/test_treetensornetworks/test_solvers/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using Test
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with less good initial guess MPS not equal to psi
psi_guess = truncate(psi; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess)
psi_guess = truncate(psit; maxdim=2)
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=psi_guess)
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5

# Test with nsite=1
Expand Down
Loading

0 comments on commit 070996a

Please sign in to comment.