From 45de76bca678843a84c7fd6676a0cf39a92ce34b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Mon, 13 May 2024 19:56:17 -0400 Subject: [PATCH 01/12] [ITensorMPS] Allow customizing apply --- src/lib/ITensorMPS/src/mpo.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lib/ITensorMPS/src/mpo.jl b/src/lib/ITensorMPS/src/mpo.jl index 4758a468df..2ffb314a24 100644 --- a/src/lib/ITensorMPS/src/mpo.jl +++ b/src/lib/ITensorMPS/src/mpo.jl @@ -596,8 +596,12 @@ Equivalent to `replaceprime(contract(A, x; kwargs...), 2 => 1)`. See also [`contract`](@ref) for details about the arguments available. """ -function apply(A::MPO, ψ::MPS; kwargs...) - Aψ = contract(A, ψ; kwargs...) +function apply(A::MPO, ψ::MPS; alg=Algorithm"densitymatrix"(), kwargs...) + return apply(alg, A, ψ; kwargs...) +end + +function apply(alg::Algorithm, A::MPO, ψ::MPS; kwargs...) + Aψ = contract(alg, A, ψ; kwargs...) return replaceprime(Aψ, 1 => 0) end From af837ddf5784af817ddb234684a02baaf407b413 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 14 May 2024 08:53:03 -0400 Subject: [PATCH 02/12] Convert to Algorithm --- src/lib/ITensorMPS/src/mpo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/ITensorMPS/src/mpo.jl b/src/lib/ITensorMPS/src/mpo.jl index 2ffb314a24..a82d5ca666 100644 --- a/src/lib/ITensorMPS/src/mpo.jl +++ b/src/lib/ITensorMPS/src/mpo.jl @@ -597,7 +597,7 @@ Equivalent to `replaceprime(contract(A, x; kwargs...), 2 => 1)`. See also [`contract`](@ref) for details about the arguments available. """ function apply(A::MPO, ψ::MPS; alg=Algorithm"densitymatrix"(), kwargs...) - return apply(alg, A, ψ; kwargs...) + return apply(Algorithm(alg), A, ψ; kwargs...) end function apply(alg::Algorithm, A::MPO, ψ::MPS; kwargs...) From 3b0908b50f5996d74485e14c82ff293682181929 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 09:40:06 -0400 Subject: [PATCH 03/12] Fix test failure --- .../ext/ITensorMPSChainRulesCoreExt/abstractmps.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl index 482cfa8d68..36da20172a 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl @@ -1,7 +1,17 @@ using Adapt: adapt using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig using ITensors: - ITensor, apply, dag, hassameinds, inds, inner, itensor, mapprime, replaceprime, swapprime + ITensors, + ITensor, + apply, + dag, + hassameinds, + inds, + inner, + itensor, + mapprime, + replaceprime, + swapprime using ITensors.ITensorMPS: MPO, MPS, siteinds using NDTensors: datatype From 2ce94c94ecfdd28974917613c90233b31960fd21 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 09:46:07 -0400 Subject: [PATCH 04/12] Wrap some test files in modules --- src/lib/ITensorMPS/test/base/test_autompo.jl | 2 ++ src/lib/ITensorMPS/test/base/test_mpo.jl | 2 ++ src/lib/ITensorMPS/test/base/test_mps.jl | 2 ++ src/lib/ITensorMPS/test/base/test_readwrite.jl | 2 ++ 4 files changed, 8 insertions(+) diff --git a/src/lib/ITensorMPS/test/base/test_autompo.jl b/src/lib/ITensorMPS/test/base/test_autompo.jl index bc1e8f0d1b..e7a82a2c78 100644 --- a/src/lib/ITensorMPS/test/base/test_autompo.jl +++ b/src/lib/ITensorMPS/test/base/test_autompo.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using ITensors, Test, Random, JLD2 using NDTensors: scalartype @@ -1249,3 +1250,4 @@ end @test_nowarn H = MPO(os, sites) end end +end diff --git a/src/lib/ITensorMPS/test/base/test_mpo.jl b/src/lib/ITensorMPS/test/base/test_mpo.jl index 994ab6ebb5..07f22f969f 100644 --- a/src/lib/ITensorMPS/test/base/test_mpo.jl +++ b/src/lib/ITensorMPS/test/base/test_mpo.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using Combinatorics using ITensors using NDTensors: scalartype @@ -853,3 +854,4 @@ end @test maxlinkdim(H) ≤ maxlinkdim(H₁) + maxlinkdim(H₂) end end +end diff --git a/src/lib/ITensorMPS/test/base/test_mps.jl b/src/lib/ITensorMPS/test/base/test_mps.jl index bfe39da013..a089053086 100644 --- a/src/lib/ITensorMPS/test/base/test_mps.jl +++ b/src/lib/ITensorMPS/test/base/test_mps.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using Combinatorics using ITensors using ITensors: ITensorMPS @@ -2006,3 +2007,4 @@ end @test norm(M - Mt) ≈ 0 atol = 1e-12 end end +end diff --git a/src/lib/ITensorMPS/test/base/test_readwrite.jl b/src/lib/ITensorMPS/test/base/test_readwrite.jl index fe3bb64f61..8def21de55 100644 --- a/src/lib/ITensorMPS/test/base/test_readwrite.jl +++ b/src/lib/ITensorMPS/test/base/test_readwrite.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using ITensors, HDF5, Test include(joinpath(@__DIR__, "utils", "util.jl")) @@ -36,3 +37,4 @@ include(joinpath(@__DIR__, "utils", "util.jl")) # rm("data.h5"; force=true) end +end From 82cf24befda513398f7c20a6fa68b227a4a19613 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 09:49:20 -0400 Subject: [PATCH 05/12] Fix method rewrite issues in tests --- src/lib/ITensorMPS/test/base/test_fermions.jl | 4 ++-- src/lib/ITensorMPS/test/base/test_mps.jl | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lib/ITensorMPS/test/base/test_fermions.jl b/src/lib/ITensorMPS/test/base/test_fermions.jl index da31384d26..0468f1997c 100644 --- a/src/lib/ITensorMPS/test/base/test_fermions.jl +++ b/src/lib/ITensorMPS/test/base/test_fermions.jl @@ -260,11 +260,11 @@ import ITensors: Out, In # Reference state |110⟩ ψ110 = MPS(s, n -> n == 1 || n == 2 ? "1" : "0") - function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index) + function ITensors.op(::OpName"CdagC3", ::SiteType, s1::Index, s2::Index) return op("Cdag", s1) * op("C", s2) end - os = [("CdagC", 1, 3)] + os = [("CdagC3", 1, 3)] Os = ops(os, s) # Results in -|110⟩ diff --git a/src/lib/ITensorMPS/test/base/test_mps.jl b/src/lib/ITensorMPS/test/base/test_mps.jl index a089053086..7962a90a26 100644 --- a/src/lib/ITensorMPS/test/base/test_mps.jl +++ b/src/lib/ITensorMPS/test/base/test_mps.jl @@ -1766,11 +1766,11 @@ end # Reference state |110⟩ ψ110 = MPS(s, n -> n == 1 || n == 2 ? "1" : "0") - function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index) + function ITensors.op(::OpName"CdagC1", ::SiteType, s1::Index, s2::Index) return op("Cdag", s1) * op("C", s2) end - os = [("CdagC", 1, 3)] + os = [("CdagC1", 1, 3)] Os = ops(os, s) # Results in -|110⟩ @@ -1811,7 +1811,7 @@ end cutoff!(sweeps, 1E-12) energy, ψ0 = dmrg(H, ψ0, sweeps; outputlevel=0) - function ITensors.op(::OpName"CdagC", ::SiteType, s1::Index, s2::Index) + function ITensors.op(::OpName"CdagC2", ::SiteType, s1::Index, s2::Index) return op("Cdag", s1) * op("C", s2) end @@ -1822,7 +1822,7 @@ end end for i in 1:(N - 1), j in (i + 1):N - G1 = op("CdagC", s, i, j) + G1 = op("CdagC2", s, i, j) @disable_warn_order begin G2 = op("Cdag", s, i) From 2eeb1b8c4b9bc16f0f2584aee658dd3b8f54d334 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 11:14:07 -0400 Subject: [PATCH 06/12] Try fixing tests --- .../ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl index 36da20172a..0b03369400 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl @@ -1,5 +1,5 @@ using Adapt: adapt -using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig +using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig, rrule_via_ad using ITensors: ITensors, ITensor, From ae96276247e796b45158d8fd48a18b96f3b27e80 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 11:41:02 -0400 Subject: [PATCH 07/12] Try fixing tests --- .../ITensorMPSChainRulesCoreExt/abstractmps.jl | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl index 0b03369400..6c78a99576 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl @@ -1,18 +1,8 @@ using Adapt: adapt using ChainRulesCore: ChainRulesCore, HasReverseMode, NoTangent, RuleConfig, rrule_via_ad using ITensors: - ITensors, - ITensor, - apply, - dag, - hassameinds, - inds, - inner, - itensor, - mapprime, - replaceprime, - swapprime -using ITensors.ITensorMPS: MPO, MPS, siteinds + ITensors, ITensor, dag, hassameinds, inds, itensor, mapprime, replaceprime, swapprime +using ITensors.ITensorMPS: ITensorMPS, MPO, MPS, apply, inner, siteinds using NDTensors: datatype function ChainRulesCore.rrule( @@ -196,7 +186,7 @@ function ChainRulesCore.rrule( end y = typeof(x)(y_data) if !set_limits - y = ITensors.set_ortho_lims(y, ortho_lims(x)) + y = ITensorMPS.set_ortho_lims(y, ortho_lims(x)) end return y, map_pullback end From bfa83bba85071be6d2d71faaf0df73034bbdae27 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 11:42:53 -0400 Subject: [PATCH 08/12] Try fixing tests --- src/lib/ITensorMPS/test/Ops/runtests.jl | 2 ++ src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl | 2 ++ src/lib/ITensorMPS/test/Ops/test_trotter.jl | 2 ++ 3 files changed, 6 insertions(+) diff --git a/src/lib/ITensorMPS/test/Ops/runtests.jl b/src/lib/ITensorMPS/test/Ops/runtests.jl index c2fe5a8b6f..36cbca610b 100644 --- a/src/lib/ITensorMPS/test/Ops/runtests.jl +++ b/src/lib/ITensorMPS/test/Ops/runtests.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using ITensors using Test @@ -14,3 +15,4 @@ ITensors.disable_threaded_blocksparse() @time include(filename) end end +end diff --git a/src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl b/src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl index f7b42b30a4..1348aae7af 100644 --- a/src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl +++ b/src/lib/ITensorMPS/test/Ops/test_ops_mpo.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using Test using ITensors using ITensors.Ops @@ -95,3 +96,4 @@ end @test norm(replaceprime(H' * H, 2 => 1) - H²) ≈ 0 atol = 1e-14 @test norm(H(H) - H²) ≈ 0 atol = 1e-14 end +end diff --git a/src/lib/ITensorMPS/test/Ops/test_trotter.jl b/src/lib/ITensorMPS/test/Ops/test_trotter.jl index 4ba1f054ee..54554df9c5 100644 --- a/src/lib/ITensorMPS/test/Ops/test_trotter.jl +++ b/src/lib/ITensorMPS/test/Ops/test_trotter.jl @@ -1,3 +1,4 @@ +@eval module $(gensym()) using Test using ITensors using ITensors.Ops @@ -36,3 +37,4 @@ end end end end +end From f5cdd6ed60b7c500a8fa9306d830e22c8e00fd9d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 12:00:31 -0400 Subject: [PATCH 09/12] Try fixing tests --- .../ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl index 6c78a99576..ee6295a684 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/abstractmps.jl @@ -186,7 +186,7 @@ function ChainRulesCore.rrule( end y = typeof(x)(y_data) if !set_limits - y = ITensorMPS.set_ortho_lims(y, ortho_lims(x)) + y = ITensorMPS.set_ortho_lims(y, ITensorMPS.ortho_lims(x)) end return y, map_pullback end From 228d393f0a2e506186eff5305a2f39fb55922d41 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 12:44:49 -0400 Subject: [PATCH 10/12] Try fixing tests --- .../ext/ITensorMPSChainRulesCoreExt/mpo.jl | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl index 7662daeb37..fc6df67eec 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl @@ -3,28 +3,32 @@ using ITensors: contract, hassameinds, inner, mapprime using ITensors.ITensorMPS: MPO, MPS, firstsiteinds, siteinds using LinearAlgebra: tr -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) - y = contract(x1, x2; kwargs...) +function ChainRulesCore.rrule( + ::typeof(contract), alg::Algorithm, x1::MPO, x2::MPO; kwargs... +) + y = contract(alg, x1, x2; kwargs...) function contract_pullback(ȳ) - x̄1 = contract(ȳ, dag(x2); kwargs...) - x̄2 = contract(dag(x1), ȳ; kwargs...) + x̄1 = contract(alg, ȳ, dag(x2); kwargs...) + x̄2 = contract(alg, dag(x1), ȳ; kwargs...) return (NoTangent(), x̄1, x̄2) end return y, contract_pullback end -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPS; kwargs...) - y = contract(x1, x2; kwargs...) +function ChainRulesCore.rrule( + ::typeof(contract), alg::Algorithm, x1::MPO, x2::MPS; kwargs... +) + y = contract(alg, x1, x2; kwargs...) function contract_pullback(ȳ) x̄1 = _contract(MPO, ȳ, dag(x2); kwargs...) - x̄2 = contract(dag(x1), ȳ; kwargs...) + x̄2 = contract(alg, dag(x1), ȳ; kwargs...) return (NoTangent(), x̄1, x̄2) end return y, contract_pullback end function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) - return ChainRulesCore.rrule(contract, x1, x2; kwargs...) + return ChainRulesCore.rrule(contract, alg, x1, x2; kwargs...) end function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) From 9f7e7c4fe68953b5578fb14ff5d10650af873098 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 13:01:52 -0400 Subject: [PATCH 11/12] Load Algorithm into namespace --- src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl index fc6df67eec..c84134c6d9 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl @@ -1,5 +1,5 @@ using ChainRulesCore: ChainRulesCore, NoTangent -using ITensors: contract, hassameinds, inner, mapprime +using ITensors: Algorithm, contract, hassameinds, inner, mapprime using ITensors.ITensorMPS: MPO, MPS, firstsiteinds, siteinds using LinearAlgebra: tr From 14b7b4bf38a4f43c6d8018f336844b4996fc5f5a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 14 May 2024 13:49:34 -0400 Subject: [PATCH 12/12] Try fixing tests --- .../ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl index c84134c6d9..14cdd6c743 100644 --- a/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl +++ b/src/lib/ITensorMPS/ext/ITensorMPSChainRulesCoreExt/mpo.jl @@ -10,7 +10,7 @@ function ChainRulesCore.rrule( function contract_pullback(ȳ) x̄1 = contract(alg, ȳ, dag(x2); kwargs...) x̄2 = contract(alg, dag(x1), ȳ; kwargs...) - return (NoTangent(), x̄1, x̄2) + return (NoTangent(), NoTangent(), x̄1, x̄2) end return y, contract_pullback end @@ -22,14 +22,14 @@ function ChainRulesCore.rrule( function contract_pullback(ȳ) x̄1 = _contract(MPO, ȳ, dag(x2); kwargs...) x̄2 = contract(alg, dag(x1), ȳ; kwargs...) - return (NoTangent(), x̄1, x̄2) + return (NoTangent(), NoTangent(), x̄1, x̄2) end return y, contract_pullback end -function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) - return ChainRulesCore.rrule(contract, alg, x1, x2; kwargs...) -end +## function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; alg, kwargs...) +## return ChainRulesCore.rrule(contract, alg, x1, x2; kwargs...) +## end function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) y = +(x1, x2; kwargs...)