diff --git a/src/ITensors.jl b/src/ITensors.jl index dbc5a868f4..2f0c4a2f6d 100644 --- a/src/ITensors.jl +++ b/src/ITensors.jl @@ -52,22 +52,24 @@ module ITensors include("usings.jl") include("utils.jl") include("lib/ContractionSequenceOptimization/src/ContractionSequenceOptimization.jl") -# TODO: `using .ContractionSequenceOptimization: ContractionSequenceOptimization`. +# TODO: `using .ContractionSequenceOptimization: ContractionSequenceOptimization, ...`. using .ContractionSequenceOptimization include("lib/LazyApply/src/LazyApply.jl") -# TODO: `using .LazyApply: LazyApply`. +# TODO: `using .LazyApply: LazyApply, ...`. using .LazyApply include("lib/Ops/src/Ops.jl") -# TODO: `using .Ops: Ops`. +# TODO: `using .Ops: Ops, ...`. using .Ops +using .Ops: Ops, Op, Trotter import .Ops: sites, name +export Trotter include("exports.jl") include("imports.jl") include("global_variables.jl") # TODO: Move to `lib/LastVals/src/LastVals.jl`. include("lastval.jl") include("lib/SmallStrings/src/SmallStrings.jl") -using .SmallStrings: SmallStrings, IntChar, Tag, isint, isnull +using .SmallStrings: SmallStrings, IntChar, SmallString, Tag, isint, isnull include("readwrite.jl") # TODO: Move to `lib/Nots/src/Nots.jl`. include("not.jl") @@ -79,13 +81,16 @@ include("index.jl") include("set_operations.jl") include("indexset.jl") include("itensor.jl") +# TODO: Move to `lib/Names/src/Names.jl`. +include("name.jl") +export val +# TODO: Move to `lib/Vals/src/Vals.jl`. include("val.jl") export val -# include("lib/QuantumNumbers/src/QuantumNumbers.jl") -# using .QuantumNumbers: QN -# export QN +include("lib/QuantumNumbers/src/QuantumNumbers.jl") +using .QuantumNumbers: QN +export QN include("qn/flux.jl") -include("qn/qn.jl") include("oneitensor.jl") include("tensor_operations/tensor_algebra.jl") include("tensor_operations/matrix_algebra.jl") @@ -124,8 +129,7 @@ export OpName, ops, state, val -# TODO: Move to `lib/ITensorsSiteTypesExt/src/ITensorsSiteTypesExt.jl`. -include("ITensorsSiteTypesExt.jl") +include("lib/ITensorsSiteTypesExt/src/ITensorsSiteTypesExt.jl") include("broadcast.jl") include("tensor_operations/matrix_decomposition.jl") include("adapt.jl") @@ -135,24 +139,18 @@ include("qn/qnindex.jl") include("qn/qnindexset.jl") include("qn/qnitensor.jl") include("nullspace.jl") - -# TODO: Move to `lib/ITensorsOpsExt/src/ITensorsOpsExt.jl`? -include("lib/Ops/ops_itensor.jl") +include("lib/ITensorsOpsExt/src/ITensorsOpsExt.jl") include("fermions/fermions.jl") - include("lib/ITensorMPS/src/ITensorMPS.jl") -# TODO: `using .ITensorMPS: ITensorMPS, ...`. +# TODO: `using .ITensorMPS: ITensorMPS, ...` and +# explicit export list. @reexport using .ITensorMPS include("lib/ITensorsNamedDimsArraysExt/src/ITensorsNamedDimsArraysExt.jl") using .ITensorsNamedDimsArraysExt: ITensorsNamedDimsArraysExt - -# TODO: Move into `Ops`. -include("lib/Ops/trotter.jl") - include("lib/ITensorChainRules/src/ITensorChainRules.jl") include("lib/ITensorNetworkMaps/src/ITensorNetworkMaps.jl") include("lib/ITensorVisualizationCore/src/ITensorVisualizationCore.jl") -# TODO: `using .ITensorVisualizationCore: ITensorVisualizationCore`. +# TODO: `using .ITensorVisualizationCore: ITensorVisualizationCore, ...`. using .ITensorVisualizationCore using .ITensorVisualizationCore: @visualize, diff --git a/src/exports.jl b/src/exports.jl index 0cb076cc13..454295dc49 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -211,13 +211,6 @@ export Prod, coefficient, - # Ops/Ops.jl - Ops, - Op, - - # Ops/trotter.jl - Trotter, - # physics/autompo.jl AutoMPO, OpSum, diff --git a/src/fermions/fermions.jl b/src/fermions/fermions.jl index a6e7a2364c..301495c51f 100644 --- a/src/fermions/fermions.jl +++ b/src/fermions/fermions.jl @@ -1,3 +1,4 @@ +using .QuantumNumbers: QN, QNVal """ parity_sign(P) diff --git a/src/lib/ITensorMPS/src/opsum_to_mpo/opsum_to_mpo_generic.jl b/src/lib/ITensorMPS/src/opsum_to_mpo/opsum_to_mpo_generic.jl index 7a855ea213..aea3a3e47c 100644 --- a/src/lib/ITensorMPS/src/opsum_to_mpo/opsum_to_mpo_generic.jl +++ b/src/lib/ITensorMPS/src/opsum_to_mpo/opsum_to_mpo_generic.jl @@ -1,4 +1,5 @@ using NDTensors: using_auto_fermion +using ..Ops: Op # TODO: Deprecate. const AutoMPO = OpSum diff --git a/src/lib/ITensorsOpsExt/src/ITensorsOpsExt.jl b/src/lib/ITensorsOpsExt/src/ITensorsOpsExt.jl new file mode 100644 index 0000000000..48bef834ae --- /dev/null +++ b/src/lib/ITensorsOpsExt/src/ITensorsOpsExt.jl @@ -0,0 +1,79 @@ +module ITensorsOpsExt +using ITensors: ITensors, ITensor, Index, apply, dag, hascommoninds, swapprime +using ..LazyApply: LazyApply, Applied, Exp, Prod, Scaled, Sum, argument, coefficient +using LinearAlgebra: UniformScaling, I +using ..Ops: Op +using ..SiteTypes: SiteTypes, op + +function SiteTypes.op(I::UniformScaling, s::Index...) + return I.λ * op("Id", s...) +end + +function ITensors.ITensor(o::Op, s::Vector{<:Index}) + return op(o.which_op, map(n -> s[n], o.sites)...; o.params...) +end + +function ITensors.ITensor(o::Scaled, s::Vector{<:Index}) + c = coefficient(o) + if isreal(c) + c = real(c) + end + return c * ITensor(argument(o), s) +end + +function ITensors.ITensor(o::Prod, s::Vector{<:Index}) + T = ITensor(true) + for a in o.args[1] + Tₙ = ITensor(a, s) + # TODO: Implement this logic inside `apply` + if hascommoninds(T, Tₙ) + T = T(Tₙ) + else + T *= Tₙ + end + end + return T +end + +function ITensors.ITensor(o::Sum, s::Vector{<:Index}) + T = ITensor() + for a in o.args[1] + T += ITensor(a, s) + end + return T +end + +function ITensors.ITensor(o::Exp, s::Vector{<:Index}) + return exp(ITensor(argument(o), s)) +end + +function ITensors.ITensor(o::LazyApply.Adjoint, s::Vector{<:Index}) + return swapprime(dag(ITensor(o', s)), 0 => 1) +end + +function LazyApply.Sum{ITensor}(o::Sum, s::Vector{<:Index}) + return Applied(sum, (map(oₙ -> ITensor(oₙ, s), only(o.args)),)) +end + +function LazyApply.Prod{ITensor}(o::Prod, s::Vector{<:Index}) + return Applied(prod, (map(oₙ -> ITensor(oₙ, s), only(o.args)),)) +end + +function LazyApply.Prod{ITensor}(o::Scaled{C,Prod{Op}}, s::Vector{<:Index}) where {C} + t = Prod{ITensor}(argument(o), s) + t1 = coefficient(o) * only(t.args)[1] + return Applied(prod, (vcat([t1], only(t.args)[2:end]),)) +end + +function ITensors.apply(o::Prod{ITensor}, v::ITensor; kwargs...) + ov = v + for oₙ in reverse(only(o.args)) + ov = apply(oₙ, ov; kwargs...) + end + return ov +end + +function (o::Prod{ITensor})(v::ITensor; kwargs...) + return apply(o, v; kwargs...) +end +end diff --git a/src/ITensorsSiteTypesExt.jl b/src/lib/ITensorsSiteTypesExt/src/ITensorsSiteTypesExt.jl similarity index 100% rename from src/ITensorsSiteTypesExt.jl rename to src/lib/ITensorsSiteTypesExt/src/ITensorsSiteTypesExt.jl diff --git a/src/lib/Ops/src/Ops.jl b/src/lib/Ops/src/Ops.jl index 64e9c360ed..fcbff091ea 100644 --- a/src/lib/Ops/src/Ops.jl +++ b/src/lib/Ops/src/Ops.jl @@ -1,339 +1,4 @@ module Ops - -using ..LazyApply - -import Base: ==, +, -, *, /, convert, exp, show, adjoint, isless, hash - -export Op, OpSum, which_op, site, sites, params, Applied, expand - -##################################################################################### -# General functionality -# - -# Helper function to split a `Tuple` according to the function `f`. -# For example: -# -# julia> t = (1, "X", 1, 2, "Y", 2, "Z", 4) -# (1, "X", 1, 2, "Y", 2, "Z", 4) -# -# julia> split(x -> x isa AbstractString, t) -# [(1,), ("X", 1, 2), ("Y", 2), ("Z", 4)] -# -function split(f, t::Tuple) - n = findall(f, t) - nsplit = length(n) + 1 - s = Vector{Any}(undef, nsplit) - s[1] = t[1:(first(n) - 1)] - for i in 2:(nsplit - 1) - s[i] = t[n[i - 1]:(n[i] - 1)] - end - s[end] = t[last(n):end] - return s -end - -## XXX: Very long compile times: -## https://github.com/JuliaLang/julia/issues/45545 -## -## julia> using ITensors -## -## julia> @time ITensors.Ops.split(x -> x isa String, ("X", 1)) -## 7.588123 seconds (2.34 M allocations: 100.919 MiB, 1.71% gc time, 100.00% compilation time) -## ((), ("X", 1)) -## -## julia> @time ITensors.Ops.split(x -> x isa String, ("X", 1)) -## 0.042590 seconds (88.59 k allocations: 4.823 MiB, 19.13% gc time, 99.84% compilation time) -## ((), ("X", 1)) -## -## function split(f, t::Tuple) -## n = findall(f, t) -## ti = t[1:(first(n) - 1)] -## ts = ntuple(i -> t[n[i]:(n[i + 1] - 1)], length(n) - 1) -## tf = t[last(n):end] -## return ti, ts..., tf -## end - -struct Op - which_op - sites::Tuple - params::NamedTuple - function Op(which_op, site...; kwargs...) - return new(which_op, site, NamedTuple(kwargs)) - end -end - -which_op(o::Op) = o.which_op -name(o::Op) = which_op(o) -sites(o::Op) = o.sites -site(o::Op) = only(sites(o)) -params(o::Op) = o.params - -function (o1::Op == o2::Op) - return o1.which_op == o2.which_op && o1.sites == o2.sites && o1.params == o2.params -end - -function hash(o::Op, h::UInt) - return hash(which_op(o), hash(sites(o), hash(params(o), hash(:Op, h)))) -end - -# Version of `isless` defined for matrices -_isless(a, b) = isless(a, b) -_isless(a::AbstractMatrix, b::AbstractMatrix) = isless(hash(a), hash(b)) -_isless(a::AbstractString, b::AbstractMatrix) = true -_isless(a::AbstractMatrix, b::AbstractString) = !_isless(b, a) - -function isless(o1::Op, o2::Op) - if sites(o1) ≠ sites(o2) - return sites(o1) < sites(o2) - end - if which_op(o1) ≠ which_op(o2) - return _isless(which_op(o1), which_op(o2)) - end - return params(o1) < params(o2) -end - -function isless(o1::Prod{Op}, o2::Prod{Op}) - if length(o1) ≠ length(o2) - return length(o1) < length(o2) - end - for n in 1:length(o1) - if o1[n] ≠ o2[n] - return (o1[n] < o2[n]) - end - end - return false -end - -function isless(o1::Scaled{C1,Prod{Op}}, o2::Scaled{C2,Prod{Op}}) where {C1,C2} - if argument(o1) == argument(o2) - if coefficient(o1) ≈ coefficient(o2) - return false - else - c1 = coefficient(o1) - c2 = coefficient(o2) - #"lexicographic" ordering on complex numbers - return real(c1) < real(c2) || (real(c1) ≈ real(c2) && imag(c1) < imag(c2)) - end - end - return argument(o1) < argument(o2) -end - -## function Op(t::Tuple) -## which_op = first(t) -## site_params = Base.tail(t) -## if last(site_params) isa NamedTuple -## site = Base.front(site_params) -## params = last(site_params) -## else -## site = site_params -## params = (;) -## end -## return Op(which_op, site; params...) -## end - -## function Op(t::Tuple{WhichOp,NamedTuple,Vararg}) where {WhichOp} -## params = t[2] -## which_op = t[1] -## sites = t[3:end] -## return Op(which_op, sites...; params...) -## end - -function sites(a::Union{Sum,Prod}) - s = [] - for n in 1:length(a) - s = s ∪ sites(a[n]) - end - return map(identity, s) -end -sites(a::Scaled{C,<:Sum}) where {C} = sites(argument(a)) -sites(a::Scaled{C,<:Prod}) where {C} = sites(argument(a)) - -params(a::Scaled{C,<:Prod}) where {C} = params(only(argument(a))) - -which_op(a::Scaled{C,Op}) where {C} = which_op(argument(a)) -sites(a::Scaled{C,Op}) where {C} = sites(argument(a)) -params(a::Scaled{C,Op}) where {C} = params(argument(a)) - -# -# Op algebra -# - -function convert(::Type{Scaled{C1,Prod{Op}}}, o::Scaled{C2,Prod{Op}}) where {C1,C2} - c = convert(C1, coefficient(o)) - return c * argument(o) -end - -""" -An `OpSum` represents a sum of operator -terms. - -Often it is used to create matrix -product operator (`MPO`) approximation -of the sum of the terms in the `OpSum` oject. -Each term is a product of local operators -specified by names such as `"Sz"` or `"N"`, -times an optional coefficient which -can be real or complex. - -Which local operator names are available -is determined by the function `op` -associated with the `TagType` defined by -special Index tags, such as `"S=1/2"`, `"S=1"`, -`"Fermion"`, and `"Electron"`. -""" -const OpSum{C} = Sum{Scaled{C,Prod{Op}}} - -# This helps with in-place operations -OpSum() = OpSum{ComplexF64}() - -(o1::Op + o2::Op) = Applied(sum, ([o1, o2],)) -(o1::Op * o2::Op) = Applied(prod, ([o1, o2],)) --(o::Op) = -one(Int) * o -(o1::Op - o2::Op) = o1 + (-o2) - -(c::Number * o::Op) = Applied(*, (c, o)) -(o::Op * c::Number) = Applied(*, (c, o)) -(o::Op / c::Number) = Applied(*, (inv(c), o)) - -(c::Number * o::Prod{Op}) = Applied(*, (c, o)) -(o::Prod{Op} * c::Number) = Applied(*, (c, o)) -(o::Prod{Op} / c::Number) = Applied(*, (inv(c), o)) - -# 1.3 * Op("X", 1) + Op("X", 2) -# 1.3 * Op("X", 1) * Op("X", 2) + Op("X", 3) -(co1::Scaled{C} + o2::Op) where {C} = co1 + one(C) * o2 - -# Op("X", 1) + 1.3 * Op("X", 2) -(o1::Op + co2::Scaled{C}) where {C} = one(C) * o1 + co2 - -(o1::Op * o2::Sum) = Applied(sum, (map(a -> o1 * a, only(o2.args)),)) -(o1::Sum * o2::Op) = Applied(sum, (map(a -> a * o2, only(o1.args)),)) - -# 1.3 * Op("X", 1) + Op("X", 2) * Op("X", 3) -# 1.3 * Op("X", 1) * Op("X", 2) + Op("X", 3) * Op("X", 4) -(co1::Scaled{C} + o2::Prod{Op}) where {C} = co1 + one(C) * o2 - -# 1.3 * Op("X", 1) * Op("X", 2) -(co1::Scaled{C} * o2::Op) where {C} = co1 * (one(C) * o2) - -exp(o::Op) = Applied(exp, (o,)) - -adjoint(o::Op) = Applied(adjoint, (o,)) -adjoint(o::LazyApply.Adjoint{Op}) = only(o.args) - -(o1::Exp{Op} * o2::Op) = Applied(prod, ([o1, o2],)) - -# -# Tuple interface -# - -const OpSumLike{C} = Union{ - Sum{Op}, - Sum{Scaled{C,Op}}, - Sum{Prod{Op}}, - Sum{Scaled{C,Prod{Op}}}, - Prod{Op}, - Scaled{C,Prod{Op}}, -} - -const WhichOp = Union{AbstractString,AbstractMatrix{<:Number}} - -# Make a `Scaled{C,Prod{Op}}` from a `Tuple` input, -# for example: -# -# (1.2, "X", 1, "Y", 2) -> 1.2 * Op("X", 1) * Op("Y", 2) -# -function op_term(a::Tuple{Number,Vararg}) - c = first(a) - return c * op_term(Base.tail(a)) -end - -function op_site(which_op, params::NamedTuple, sites...) - return Op(which_op, sites...; params...) -end - -function op_site(which_op, sites_params...) - if last(sites_params) isa NamedTuple - sites = Base.front(sites_params) - params = last(sites_params) - return Op(which_op, sites...; params...) - end - return Op(which_op, sites_params...) -end - -function op_term(a::Tuple{Vararg}) - a_split = split(x -> x isa WhichOp, a) - @assert isempty(first(a_split)) - popfirst!(a_split) - o = op_site(first(a_split)...) - popfirst!(a_split) - for aₙ in a_split - o *= op_site(aₙ...) - end - return o -end - -function (o1::OpSumLike + o2::Tuple) - return o1 + op_term(o2) -end - -function (o1::Tuple + o2::OpSumLike) - return op_term(o1) + o2 -end - -function (o1::OpSumLike - o2::Tuple) - return o1 - op_term(o2) -end - -function (o1::Tuple - o2::OpSumLike) - return op_term(o1) - o2 -end - -function (o1::OpSumLike * o2::Tuple) - return o1 * op_term(o2) -end - -function (o1::Tuple * o2::OpSumLike) - return op_term(o1) * o2 -end - -function show(io::IO, ::MIME"text/plain", o::Op) - print(io, which_op(o)) - print(io, sites(o)) - if !isempty(params(o)) - print(io, params(o)) - end - return nothing -end -show(io::IO, o::Op) = show(io, MIME("text/plain"), o) - -function show(io::IO, ::MIME"text/plain", o::Prod{Op}) - for n in 1:length(o) - print(io, o[n]) - if n < length(o) - print(io, " ") - end - end - return nothing -end -show(io::IO, o::Prod{Op}) = show(io, MIME("text/plain"), o) - -function show(io::IO, m::MIME"text/plain", o::Scaled{C,O}) where {C,O<:Union{Op,Prod{Op}}} - c = coefficient(o) - if isreal(c) - c = real(c) - end - print(io, c) - print(io, " ") - show(io, m, argument(o)) - return nothing -end -show(io::IO, o::Scaled{C,Prod{Op}}) where {C} = show(io, MIME("text/plain"), o) - -function show(io::IO, ::MIME"text/plain", o::LazyApply.Adjoint{Op}) - print(io, o') - print(io, "'") - return nothing -end -show(io::IO, o::LazyApply.Adjoint{Op}) = show(io, MIME("text/plain"), o) - +include("op.jl") +include("trotter.jl") end diff --git a/src/lib/Ops/src/op.jl b/src/lib/Ops/src/op.jl new file mode 100644 index 0000000000..886bb969b9 --- /dev/null +++ b/src/lib/Ops/src/op.jl @@ -0,0 +1,335 @@ +using ..LazyApply + +import Base: ==, +, -, *, /, convert, exp, show, adjoint, isless, hash + +export Op, OpSum, which_op, site, sites, params, Applied, expand + +##################################################################################### +# General functionality +# + +# Helper function to split a `Tuple` according to the function `f`. +# For example: +# +# julia> t = (1, "X", 1, 2, "Y", 2, "Z", 4) +# (1, "X", 1, 2, "Y", 2, "Z", 4) +# +# julia> split(x -> x isa AbstractString, t) +# [(1,), ("X", 1, 2), ("Y", 2), ("Z", 4)] +# +function split(f, t::Tuple) + n = findall(f, t) + nsplit = length(n) + 1 + s = Vector{Any}(undef, nsplit) + s[1] = t[1:(first(n) - 1)] + for i in 2:(nsplit - 1) + s[i] = t[n[i - 1]:(n[i] - 1)] + end + s[end] = t[last(n):end] + return s +end + +## XXX: Very long compile times: +## https://github.com/JuliaLang/julia/issues/45545 +## +## julia> using ITensors +## +## julia> @time ITensors.Ops.split(x -> x isa String, ("X", 1)) +## 7.588123 seconds (2.34 M allocations: 100.919 MiB, 1.71% gc time, 100.00% compilation time) +## ((), ("X", 1)) +## +## julia> @time ITensors.Ops.split(x -> x isa String, ("X", 1)) +## 0.042590 seconds (88.59 k allocations: 4.823 MiB, 19.13% gc time, 99.84% compilation time) +## ((), ("X", 1)) +## +## function split(f, t::Tuple) +## n = findall(f, t) +## ti = t[1:(first(n) - 1)] +## ts = ntuple(i -> t[n[i]:(n[i + 1] - 1)], length(n) - 1) +## tf = t[last(n):end] +## return ti, ts..., tf +## end + +struct Op + which_op + sites::Tuple + params::NamedTuple + function Op(which_op, site...; kwargs...) + return new(which_op, site, NamedTuple(kwargs)) + end +end + +which_op(o::Op) = o.which_op +name(o::Op) = which_op(o) +sites(o::Op) = o.sites +site(o::Op) = only(sites(o)) +params(o::Op) = o.params + +function (o1::Op == o2::Op) + return o1.which_op == o2.which_op && o1.sites == o2.sites && o1.params == o2.params +end + +function hash(o::Op, h::UInt) + return hash(which_op(o), hash(sites(o), hash(params(o), hash(:Op, h)))) +end + +# Version of `isless` defined for matrices +_isless(a, b) = isless(a, b) +_isless(a::AbstractMatrix, b::AbstractMatrix) = isless(hash(a), hash(b)) +_isless(a::AbstractString, b::AbstractMatrix) = true +_isless(a::AbstractMatrix, b::AbstractString) = !_isless(b, a) + +function isless(o1::Op, o2::Op) + if sites(o1) ≠ sites(o2) + return sites(o1) < sites(o2) + end + if which_op(o1) ≠ which_op(o2) + return _isless(which_op(o1), which_op(o2)) + end + return params(o1) < params(o2) +end + +function isless(o1::Prod{Op}, o2::Prod{Op}) + if length(o1) ≠ length(o2) + return length(o1) < length(o2) + end + for n in 1:length(o1) + if o1[n] ≠ o2[n] + return (o1[n] < o2[n]) + end + end + return false +end + +function isless(o1::Scaled{C1,Prod{Op}}, o2::Scaled{C2,Prod{Op}}) where {C1,C2} + if argument(o1) == argument(o2) + if coefficient(o1) ≈ coefficient(o2) + return false + else + c1 = coefficient(o1) + c2 = coefficient(o2) + #"lexicographic" ordering on complex numbers + return real(c1) < real(c2) || (real(c1) ≈ real(c2) && imag(c1) < imag(c2)) + end + end + return argument(o1) < argument(o2) +end + +## function Op(t::Tuple) +## which_op = first(t) +## site_params = Base.tail(t) +## if last(site_params) isa NamedTuple +## site = Base.front(site_params) +## params = last(site_params) +## else +## site = site_params +## params = (;) +## end +## return Op(which_op, site; params...) +## end + +## function Op(t::Tuple{WhichOp,NamedTuple,Vararg}) where {WhichOp} +## params = t[2] +## which_op = t[1] +## sites = t[3:end] +## return Op(which_op, sites...; params...) +## end + +function sites(a::Union{Sum,Prod}) + s = [] + for n in 1:length(a) + s = s ∪ sites(a[n]) + end + return map(identity, s) +end +sites(a::Scaled{C,<:Sum}) where {C} = sites(argument(a)) +sites(a::Scaled{C,<:Prod}) where {C} = sites(argument(a)) + +params(a::Scaled{C,<:Prod}) where {C} = params(only(argument(a))) + +which_op(a::Scaled{C,Op}) where {C} = which_op(argument(a)) +sites(a::Scaled{C,Op}) where {C} = sites(argument(a)) +params(a::Scaled{C,Op}) where {C} = params(argument(a)) + +# +# Op algebra +# + +function convert(::Type{Scaled{C1,Prod{Op}}}, o::Scaled{C2,Prod{Op}}) where {C1,C2} + c = convert(C1, coefficient(o)) + return c * argument(o) +end + +""" +An `OpSum` represents a sum of operator +terms. + +Often it is used to create matrix +product operator (`MPO`) approximation +of the sum of the terms in the `OpSum` oject. +Each term is a product of local operators +specified by names such as `"Sz"` or `"N"`, +times an optional coefficient which +can be real or complex. + +Which local operator names are available +is determined by the function `op` +associated with the `TagType` defined by +special Index tags, such as `"S=1/2"`, `"S=1"`, +`"Fermion"`, and `"Electron"`. +""" +const OpSum{C} = Sum{Scaled{C,Prod{Op}}} + +# This helps with in-place operations +OpSum() = OpSum{ComplexF64}() + +(o1::Op + o2::Op) = Applied(sum, ([o1, o2],)) +(o1::Op * o2::Op) = Applied(prod, ([o1, o2],)) +-(o::Op) = -one(Int) * o +(o1::Op - o2::Op) = o1 + (-o2) + +(c::Number * o::Op) = Applied(*, (c, o)) +(o::Op * c::Number) = Applied(*, (c, o)) +(o::Op / c::Number) = Applied(*, (inv(c), o)) + +(c::Number * o::Prod{Op}) = Applied(*, (c, o)) +(o::Prod{Op} * c::Number) = Applied(*, (c, o)) +(o::Prod{Op} / c::Number) = Applied(*, (inv(c), o)) + +# 1.3 * Op("X", 1) + Op("X", 2) +# 1.3 * Op("X", 1) * Op("X", 2) + Op("X", 3) +(co1::Scaled{C} + o2::Op) where {C} = co1 + one(C) * o2 + +# Op("X", 1) + 1.3 * Op("X", 2) +(o1::Op + co2::Scaled{C}) where {C} = one(C) * o1 + co2 + +(o1::Op * o2::Sum) = Applied(sum, (map(a -> o1 * a, only(o2.args)),)) +(o1::Sum * o2::Op) = Applied(sum, (map(a -> a * o2, only(o1.args)),)) + +# 1.3 * Op("X", 1) + Op("X", 2) * Op("X", 3) +# 1.3 * Op("X", 1) * Op("X", 2) + Op("X", 3) * Op("X", 4) +(co1::Scaled{C} + o2::Prod{Op}) where {C} = co1 + one(C) * o2 + +# 1.3 * Op("X", 1) * Op("X", 2) +(co1::Scaled{C} * o2::Op) where {C} = co1 * (one(C) * o2) + +exp(o::Op) = Applied(exp, (o,)) + +adjoint(o::Op) = Applied(adjoint, (o,)) +adjoint(o::LazyApply.Adjoint{Op}) = only(o.args) + +(o1::Exp{Op} * o2::Op) = Applied(prod, ([o1, o2],)) + +# +# Tuple interface +# + +const OpSumLike{C} = Union{ + Sum{Op}, + Sum{Scaled{C,Op}}, + Sum{Prod{Op}}, + Sum{Scaled{C,Prod{Op}}}, + Prod{Op}, + Scaled{C,Prod{Op}}, +} + +const WhichOp = Union{AbstractString,AbstractMatrix{<:Number}} + +# Make a `Scaled{C,Prod{Op}}` from a `Tuple` input, +# for example: +# +# (1.2, "X", 1, "Y", 2) -> 1.2 * Op("X", 1) * Op("Y", 2) +# +function op_term(a::Tuple{Number,Vararg}) + c = first(a) + return c * op_term(Base.tail(a)) +end + +function op_site(which_op, params::NamedTuple, sites...) + return Op(which_op, sites...; params...) +end + +function op_site(which_op, sites_params...) + if last(sites_params) isa NamedTuple + sites = Base.front(sites_params) + params = last(sites_params) + return Op(which_op, sites...; params...) + end + return Op(which_op, sites_params...) +end + +function op_term(a::Tuple{Vararg}) + a_split = split(x -> x isa WhichOp, a) + @assert isempty(first(a_split)) + popfirst!(a_split) + o = op_site(first(a_split)...) + popfirst!(a_split) + for aₙ in a_split + o *= op_site(aₙ...) + end + return o +end + +function (o1::OpSumLike + o2::Tuple) + return o1 + op_term(o2) +end + +function (o1::Tuple + o2::OpSumLike) + return op_term(o1) + o2 +end + +function (o1::OpSumLike - o2::Tuple) + return o1 - op_term(o2) +end + +function (o1::Tuple - o2::OpSumLike) + return op_term(o1) - o2 +end + +function (o1::OpSumLike * o2::Tuple) + return o1 * op_term(o2) +end + +function (o1::Tuple * o2::OpSumLike) + return op_term(o1) * o2 +end + +function show(io::IO, ::MIME"text/plain", o::Op) + print(io, which_op(o)) + print(io, sites(o)) + if !isempty(params(o)) + print(io, params(o)) + end + return nothing +end +show(io::IO, o::Op) = show(io, MIME("text/plain"), o) + +function show(io::IO, ::MIME"text/plain", o::Prod{Op}) + for n in 1:length(o) + print(io, o[n]) + if n < length(o) + print(io, " ") + end + end + return nothing +end +show(io::IO, o::Prod{Op}) = show(io, MIME("text/plain"), o) + +function show(io::IO, m::MIME"text/plain", o::Scaled{C,O}) where {C,O<:Union{Op,Prod{Op}}} + c = coefficient(o) + if isreal(c) + c = real(c) + end + print(io, c) + print(io, " ") + show(io, m, argument(o)) + return nothing +end +show(io::IO, o::Scaled{C,Prod{Op}}) where {C} = show(io, MIME("text/plain"), o) + +function show(io::IO, ::MIME"text/plain", o::LazyApply.Adjoint{Op}) + print(io, o') + print(io, "'") + return nothing +end +show(io::IO, o::LazyApply.Adjoint{Op}) = show(io, MIME("text/plain"), o) diff --git a/src/lib/Ops/trotter.jl b/src/lib/Ops/src/trotter.jl similarity index 74% rename from src/lib/Ops/trotter.jl rename to src/lib/Ops/src/trotter.jl index 09b4433072..9c3b1a5641 100644 --- a/src/lib/Ops/trotter.jl +++ b/src/lib/Ops/src/trotter.jl @@ -1,3 +1,5 @@ +using ..LazyApply: Applied, Sum + abstract type ExpAlgorithm end struct Exact <: ExpAlgorithm end @@ -6,13 +8,13 @@ struct Trotter{Order} <: ExpAlgorithm nsteps::Int end Trotter{Order}() where {Order} = Trotter{Order}(1) -one(::Trotter{Order}) where {Order} = Trotter{Order}(1) +Base.one(::Trotter{Order}) where {Order} = Trotter{Order}(1) -function exp(o::Sum; alg::ExpAlgorithm=Exact()) +function Base.exp(o::Sum; alg::ExpAlgorithm=Exact()) return exp(alg, o) end -function exp(::Exact, o::Sum) +function Base.exp(::Exact, o::Sum) return Applied(prod, ([Applied(exp, (o,))],)) end @@ -27,7 +29,7 @@ function exp_one_step(trotter::Trotter{2}, o::Sum) return exp_o end -function exp(trotter::Trotter, o::Sum) +function Base.exp(trotter::Trotter, o::Sum) expδo = exp_one_step(one(trotter), o / trotter.nsteps) return expδo^trotter.nsteps end diff --git a/src/qn/qn.jl b/src/lib/QuantumNumbers/src/QuantumNumbers.jl similarity index 86% rename from src/qn/qn.jl rename to src/lib/QuantumNumbers/src/QuantumNumbers.jl index 55d27a686a..3af98cbaa1 100644 --- a/src/qn/qn.jl +++ b/src/lib/QuantumNumbers/src/QuantumNumbers.jl @@ -1,10 +1,13 @@ -using .SmallStrings: SmallString +module QuantumNumbers +using ..ITensors: ITensors, Arrow, name +using NDTensors: NDTensors +using ..SmallStrings: SmallString +using StaticArrays: MVector, SVector struct QNVal name::SmallString val::Int modulus::Int - function QNVal(name, v::Int, m::Int=1) am = abs(m) if am > 1 @@ -17,11 +20,11 @@ end QNVal(v::Int, m::Int=1) = QNVal("", v, m) QNVal() = QNVal("", 0, 0) -name(qv::QNVal) = qv.name -val(qv::QNVal) = qv.val +ITensors.name(qv::QNVal) = qv.name +ITensors.val(qv::QNVal) = qv.val modulus(qv::QNVal) = qv.modulus isactive(qv::QNVal) = modulus(qv) != 0 -(qv1::QNVal < qv2::QNVal) = (name(qv1) < name(qv2)) +Base.:(<)(qv1::QNVal, qv2::QNVal) = (name(qv1) < name(qv2)) function qn_mod(val::Int, modulus::Int) amod = abs(modulus) @@ -35,11 +38,11 @@ end Base.zero(::Type{QNVal}) = QNVal() -zero(qv::QNVal) = QNVal(name(qv), 0, modulus(qv)) +Base.zero(qv::QNVal) = QNVal(name(qv), 0, modulus(qv)) -(dir::Arrow * qv::QNVal) = QNVal(name(qv), Int(dir) * val(qv), modulus(qv)) +Base.:(*)(dir::Arrow, qv::QNVal) = QNVal(name(qv), Int(dir) * val(qv), modulus(qv)) -(qv::QNVal * dir::Arrow) = (dir * qv) +Base.:(*)(qv::QNVal, dir::Arrow) = (dir * qv) function pm(qv1::QNVal, qv2::QNVal, fac::Int) if name(qv1) != name(qv2) @@ -57,8 +60,8 @@ function pm(qv1::QNVal, qv2::QNVal, fac::Int) return QNVal(name(qv1), Base.mod(val(qv1) + fac * val(qv2), abs(m1)), m1) end -(qv1::QNVal + qv2::QNVal) = pm(qv1, qv2, +1) -(qv1::QNVal - qv2::QNVal) = pm(qv1, qv2, -1) +Base.:(+)(qv1::QNVal, qv2::QNVal) = pm(qv1, qv2, +1) +Base.:(-)(qv1::QNVal, qv2::QNVal) = pm(qv1, qv2, -1) const ZeroVal = QNVal() @@ -81,12 +84,10 @@ collection, its value is treated as zero. """ struct QN data::QNStorage - function QN() s = QNStorage(ntuple(_ -> ZeroVal, Val(maxQNs))) return new(s) end - QN(s::QNStorage) = new(s) end @@ -153,11 +154,11 @@ QN(val::Int, modulus::Int=1) = QN(("", val, modulus)) data(qn::QN) = qn.data -getindex(q::QN, n::Int) = getindex(data(q), n) +Base.getindex(q::QN, n::Int) = getindex(data(q), n) -length(qn::QN) = length(data(qn)) +Base.length(qn::QN) = length(data(qn)) -lastindex(qn::QN) = length(qn) +Base.lastindex(qn::QN) = length(qn) isactive(qn::QN) = isactive(qn[1]) @@ -168,12 +169,12 @@ function nactive(q::QN) return maxQNs end -function iterate(qn::QN, state::Int=1) +function Base.iterate(qn::QN, state::Int=1) (state > length(qn)) && return nothing return (qn[state], state + 1) end -keys(qn::QN) = keys(data(qn)) +Base.keys(qn::QN) = keys(data(qn)) """ val(q::QN,name) @@ -181,7 +182,7 @@ keys(qn::QN) = keys(data(qn)) Get the value within the QN q corresponding to the string `name` """ -function val(q::QN, name_) +function ITensors.val(q::QN, name_) sname = SmallString(name_) for n in 1:maxQNs name(q[n]) == sname && return val(q[n]) @@ -210,7 +211,7 @@ Returns a QN object containing the same names as q, but with all values set to zero. """ -function zero(qn::QN) +function Base.zero(qn::QN) mqn = MQNStorage(undef) for i in 1:length(mqn) mqn[i] = zero(qn[i]) @@ -218,7 +219,7 @@ function zero(qn::QN) return QN(mqn) end -function (dir::Arrow * qn::QN) +function Base.:(*)(dir::Arrow, qn::QN) mqn = MQNStorage(undef) for i in 1:length(mqn) mqn[i] = dir * qn[i] @@ -226,9 +227,9 @@ function (dir::Arrow * qn::QN) return QN(mqn) end -(qn::QN * dir::Arrow) = (dir * qn) +Base.:(*)(qn::QN, dir::Arrow) = (dir * qn) -function -(qn::QN) +function Base.:(-)(qn::QN) mqn = MQNStorage(undef) for i in 1:length(mqn) mqn[i] = -qn[i] @@ -236,9 +237,8 @@ function -(qn::QN) return QN(mqn) end -function (a::QN + b::QN) +function Base.:(+)(a::QN, b::QN) !isactive(b[1]) && return a - ma = MQNStorage(data(a)) @inbounds for nb in 1:maxQNs !isactive(b[nb]) && break @@ -263,7 +263,7 @@ function (a::QN + b::QN) return QN(QNStorage(ma)) end -(a::QN - b::QN) = (a + (-b)) +Base.:(-)(a::QN, b::QN) = (a + (-b)) function hasname(qn::QN, qv_find::QNVal) for qv in qn @@ -319,7 +319,7 @@ function isequal_assume_filled(qn1::QN, qn2::QN) return true end -function ==(qn1::QN, qn2::QN; assume_filled=false) +function Base.:(==)(qn1::QN, qn2::QN; assume_filled=false) if !assume_filled qn1, qn2 = fillqns(qn1, qn2) end @@ -335,11 +335,11 @@ function isless_assume_filled(qn1::QN, qn2::QN) return false end -function isless(qn1::QN, qn2::QN; assume_filled=false) +function Base.isless(qn1::QN, qn2::QN; assume_filled=false) return <(qn1, qn2; assume_filled=assume_filled) end -function <(qn1::QN, qn2::QN; assume_filled=false) +function Base.:(<)(qn1::QN, qn2::QN; assume_filled=false) if !assume_filled qn1, qn2 = fillqns(qn1, qn2) end @@ -362,7 +362,6 @@ end function removeqn(qn::QN, qn_name::String) ss_qn_name = SmallString(qn_name) - # Find the location of the QNVal to remove n_qn = nothing for n in 1:length(qn) @@ -374,7 +373,6 @@ function removeqn(qn::QN, qn_name::String) if isnothing(n_qn) return qn end - qn_data = data(qn) for j in n_qn:(length(qn) - 1) qn_data = setindex(qn_data, qn_data[j + 1], j) @@ -383,7 +381,7 @@ function removeqn(qn::QN, qn_name::String) return QN(qn_data) end -function show(io::IO, q::QN) +function Base.show(io::IO, q::QN) print(io, "QN(") Na = nactive(q) for n in 1:Na @@ -401,3 +399,4 @@ function show(io::IO, q::QN) end return print(io, ")") end +end diff --git a/src/name.jl b/src/name.jl new file mode 100644 index 0000000000..bacc04d3e7 --- /dev/null +++ b/src/name.jl @@ -0,0 +1 @@ +function name end