Skip to content

Commit

Permalink
Refactor truncate! to return truncation error from each bond
Browse files Browse the repository at this point in the history
This commit refactors the original implemenation. `truncate!` now
expects the user to pass a pointer to a vector of floats with as
many elements as there are bonds in the MPS. It will then store the
truncation error of each bond in the vector.

The corresponding test was updated. The package was re-tested
with the same results described in the PR (75165 passing and
33 broken for total of 75198 tests).

The docstring of `truncate!` was updated to reflect the new behavior.
  • Loading branch information
NuclearPowerNerd committed Jan 2, 2025
1 parent 5c887ce commit 00f1764
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
10 changes: 5 additions & 5 deletions src/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,7 @@ provided as keyword arguments.
Keyword arguments:
* `site_range`=1:N - only truncate the MPS bonds between these sites
* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`).
* `truncation_errors` - if provided, will store the truncation error from each SVD performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_errors = Ref{Vector{Float64}}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_errors[] = zeros(nbonds)`).
"""
function truncate!(M::AbstractMPS; alg="frobenius", kwargs...)
return truncate!(Algorithm(alg), M; kwargs...)
Expand All @@ -1682,7 +1682,7 @@ function truncate!(
::Algorithm"frobenius",
M::AbstractMPS;
site_range=1:length(M),
truncation_error=nothing,
truncation_errors=nothing,
kwargs...,
)
N = length(M)
Expand All @@ -1692,12 +1692,12 @@ function truncate!(
orthogonalize!(M, last(site_range))

# Perform truncations in a right-to-left sweep
for j in reverse((first(site_range) + 1):last(site_range))
for (i,j) in enumerate(reverse((first(site_range) + 1):last(site_range)))
rinds = uniqueinds(M[j], M[j - 1])
ltags = tags(commonind(M[j], M[j - 1]))
U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...)
if !isnothing(truncation_error)
truncation_error[] += spec.truncerr
if !isnothing(truncation_errors)
truncation_errors[][i] = spec.truncerr
end
M[j] = U
M[j - 1] *= (S * V)
Expand Down
12 changes: 7 additions & 5 deletions test/base/test_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -756,12 +756,14 @@ end
@test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2]
end

@testset "truncate! with truncation_error" begin
@testset "truncate! with truncation_errors" begin
N = 10
nbonds = N - 1
M = basicRandomMPS(10; dim=10)
truncation_error = Ref{Float64}()
truncation_error[] = 0.0
truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error)
@test truncation_error[] > 0.0
truncation_errors = Ref{Vector{Float64}}()
truncation_errors[] = fill(-1.0, nbonds) # set to something other than zero for test.
truncate!(M, maxdim=3, cutoff=1E-3, truncation_errors=truncation_errors)
@test all(truncation_errors[] .>= 0.0)
end


Expand Down

0 comments on commit 00f1764

Please sign in to comment.