Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement truncation_error keyword arg for truncate! #99

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

NuclearPowerNerd
Copy link

The purpose of this commit is to allow the user to access the truncation error that is internally calculated in a call to truncate!. This PR implements #96.

I have implemented this by allowing the user to pass a Ref object to the call to truncate!. The result of each SVD performed during a call to truncate! is then accumulated in that Ref.

I also added a new test for this functionality. I then re-ran all tests. There were 75165 passing and 33 broken for a total of 75198 tests. I also updated the docstring for truncate! to reflect this new keyword argument.

Here is an example of the new functionality.

A few differences with what was suggested in #96:

  • I accumulate the error rather than storing the truncation error for each bond individually.
  • I did not use the ! in my keyword argument. Was that just for stylistic reasons or is there some other convention for doing that (I'm aware of the convention in function names but not in variable names)?
truncation_error = Ref{Float64}()
truncation_error[] = 0.0
truncate!(someMPS, maxdim=4, cutoff=1E-5, truncation_error=truncation_error)
truncation_error[]

@NuclearPowerNerd
Copy link
Author

One thing I noticed while playing around with this new feature is that sometimes the truncation error exceeds the cutoff. But I thought it should be that truncation_error[] <= cutoff.

The following MWE illustrates this. I'd be curious what you think. If you want I can open a separate issue about this if/once this has been merged.

d = 2
N = 7
n = d^N
sinds = siteinds(d, N)
x = collect(LinRange(0, 1, n))
y = @. 2 * x + 3 + sin(2π * x)
χ = 8

truncation_error = Ref{Float64}()

cutoff_ = 1E-4
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)

truncation_error[]
truncation_error[] > cutoff_  # returns true!

cutoff_ = 1E-5
Y = MPS(y, sinds, maxdim=χ, cutoff=0.0)
truncation_error[] = 0.0
truncate!(Y, maxdim=χ, cutoff=cutoff_, truncation_error=truncation_error)

truncation_error[]
truncation_error[] > cutoff_ # returns false, as expected

@mtfishman
Copy link
Member

The cutoff refers to the truncation of each SVD performed, if multiple SVDs are performed the total error could add up to a value larger than the cutoff. It could make sense to store a truncation error for each bond of the MPS. That's one reason why I'm a bit hesitant about this PR, since I'd prefer to think about this more generally in terms of what kinds of other information we might want to output and how it should get output.

@NuclearPowerNerd
Copy link
Author

The cutoff refers to the truncation of each SVD performed, if multiple SVDs are performed the total error could add up to a value larger than the cutoff.

Ah, duh. Yeah so I will push an update here shortly and make it store per bond truncation error which is the more sensible thing to do. Thanks for the feedback.

This commit refactors the original implemenation. `truncate!` now
expects the user to pass a pointer to a vector of floats with as
many elements as there are bonds in the MPS. It will then store the
truncation error of each bond in the vector.

The corresponding test was updated. The package was re-tested
with the same results described in the PR (75165 passing and
33 broken for total of 75198 tests).

The docstring of `truncate!` was updated to reflect the new behavior.
@NuclearPowerNerd
Copy link
Author

Any other comments or changes you'd like me to make? Or to discuss? I don't mind if a different system is opted for but I do think it is useful to be able to inspect the truncation error.

nbonds = N - 1
M = basicRandomMPS(10; dim=10)
truncation_errors = Ref{Vector{Float64}}()
truncation_errors[] = fill(-1.0, nbonds) # set to something other than zero for test.
Copy link
Member

@mtfishman mtfishman Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to design the code so that users don't have to initialize the list ahead of time. We could add a line:

if isnothing(truncation_errors!)
  truncation_errors![] = Vector{real(scalartype(M))}(undef, length(M) - 1)
end

at the top of truncate!(...) (note the change to truncation_errors! based on my comment above).

src/abstractmps.jl Outdated Show resolved Hide resolved
src/abstractmps.jl Outdated Show resolved Hide resolved
src/abstractmps.jl Outdated Show resolved Hide resolved
Change new kwarg name to `truncation_errors!` from `truncation_errors`
Remove usage of `enumerate`
Update docstring

Co-authored-by: Matt Fishman <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants