From 00f1764fa6ea21b7824ba66121a05fc9037d6725 Mon Sep 17 00:00:00 2001 From: Bobby Date: Wed, 1 Jan 2025 20:56:28 -0500 Subject: [PATCH] Refactor truncate! to return truncation error from each bond This commit refactors the original implemenation. `truncate!` now expects the user to pass a pointer to a vector of floats with as many elements as there are bonds in the MPS. It will then store the truncation error of each bond in the vector. The corresponding test was updated. The package was re-tested with the same results described in the PR (75165 passing and 33 broken for total of 75198 tests). The docstring of `truncate!` was updated to reflect the new behavior. --- src/abstractmps.jl | 10 +++++----- test/base/test_mps.jl | 12 +++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/abstractmps.jl b/src/abstractmps.jl index 1580efe..318e473 100644 --- a/src/abstractmps.jl +++ b/src/abstractmps.jl @@ -1672,7 +1672,7 @@ provided as keyword arguments. Keyword arguments: * `site_range`=1:N - only truncate the MPS bonds between these sites -* `truncation_error` - if provided, will store the truncation error from all SVDs performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_error = Ref{Float64}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_error[] = 0.0`). +* `truncation_errors` - if provided, will store the truncation error from each SVD performed in a single call to `truncate!`. This should be a `Ref` type, for example `truncation_errors = Ref{Vector{Float64}}()`. It should be initialized to some value (likely 0.0, e.g., `truncation_errors[] = zeros(nbonds)`). """ function truncate!(M::AbstractMPS; alg="frobenius", kwargs...) return truncate!(Algorithm(alg), M; kwargs...) @@ -1682,7 +1682,7 @@ function truncate!( ::Algorithm"frobenius", M::AbstractMPS; site_range=1:length(M), - truncation_error=nothing, + truncation_errors=nothing, kwargs..., ) N = length(M) @@ -1692,12 +1692,12 @@ function truncate!( orthogonalize!(M, last(site_range)) # Perform truncations in a right-to-left sweep - for j in reverse((first(site_range) + 1):last(site_range)) + for (i,j) in enumerate(reverse((first(site_range) + 1):last(site_range))) rinds = uniqueinds(M[j], M[j - 1]) ltags = tags(commonind(M[j], M[j - 1])) U, S, V, spec = svd(M[j], rinds; lefttags=ltags, kwargs...) - if !isnothing(truncation_error) - truncation_error[] += spec.truncerr + if !isnothing(truncation_errors) + truncation_errors[][i] = spec.truncerr end M[j] = U M[j - 1] *= (S * V) diff --git a/test/base/test_mps.jl b/test/base/test_mps.jl index ff75ef1..87c1452 100644 --- a/test/base/test_mps.jl +++ b/test/base/test_mps.jl @@ -756,12 +756,14 @@ end @test linkdims(M) == [2, 4, 2, 2, 2, 2, 8, 4, 2] end - @testset "truncate! with truncation_error" begin + @testset "truncate! with truncation_errors" begin + N = 10 + nbonds = N - 1 M = basicRandomMPS(10; dim=10) - truncation_error = Ref{Float64}() - truncation_error[] = 0.0 - truncate!(M, maxdim=3, cutoff=1E-3, truncation_error=truncation_error) - @test truncation_error[] > 0.0 + truncation_errors = Ref{Vector{Float64}}() + truncation_errors[] = fill(-1.0, nbonds) # set to something other than zero for test. + truncate!(M, maxdim=3, cutoff=1E-3, truncation_errors=truncation_errors) + @test all(truncation_errors[] .>= 0.0) end