diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index 34efa04..89d3c9a 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -1,14 +1,39 @@ module QuantumInterface -import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy -import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! -import Base: show, summary -import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension +## +# Basis specific +## + +""" + basis(a) + +Return the basis of an object. + +If it's ambiguous, e.g. if an operator has a different left and right basis, +an [`IncompatibleBases`](@ref) error is thrown. +""" +function basis end + +""" +Exception that should be raised for an illegal algebraic operation. +""" +mutable struct IncompatibleBases <: Exception end + + +## +# Standard methods +## function apply! end function dagger end +""" + directsum(x, y, z...) + +Direct sum of the given objects. Alternatively, the unicode +symbol ⊕ (\\oplus) can be used. +""" function directsum end const ⊕ = directsum directsum() = GenericBasis(0) @@ -86,8 +111,9 @@ function squeeze end function wigner end -include("bases.jl") include("abstract_types.jl") +include("bases.jl") +include("show.jl") include("linalg.jl") include("tensor.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index f8667c9..0650290 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,3 +1,18 @@ +""" +Abstract base class for all specialized bases. + +The Basis class is meant to specify a basis of the Hilbert space of the +studied system. Besides basis specific information all subclasses must +implement a shape variable which indicates the dimension of the used +Hilbert space. For a spin-1/2 Hilbert space this would be the +vector `[2]`. A system composed of two spins would then have a +shape vector `[2 2]`. + +Composite systems can be defined with help of the [`CompositeBasis`](@ref) +class. +""" +abstract type Basis end + """ Abstract base class for `Bra` and `Ket` states. @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ abstract type AbstractSuperOperator{B1,B2} end - -function summary(stream::IO, x::AbstractOperator) - print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") - if samebases(x) - print(stream, " basis: ") - show(stream, basis(x)) - else - print(stream, " basis left: ") - show(stream, x.basis_l) - print(stream, "\n basis right: ") - show(stream, x.basis_r) - end -end - -show(stream::IO, x::AbstractOperator) = summary(stream, x) - -traceout!(s::StateVector, i) = ptrace(s,i) diff --git a/src/bases.jl b/src/bases.jl index 6e4b077..042f905 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -1,18 +1,3 @@ -""" -Abstract base class for all specialized bases. - -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. - -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. -""" -abstract type Basis end - """ length(b::Basis) @@ -20,17 +5,6 @@ Total dimension of the Hilbert space. """ Base.length(b::Basis) = prod(b.shape) -""" - basis(a) - -Return the basis of an object. - -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. -""" -function basis end - - """ GenericBasis(N) @@ -137,11 +111,6 @@ function equal_bases(a, b) return true end -""" -Exception that should be raised for an illegal algebraic operation. -""" -mutable struct IncompatibleBases <: Exception end - const BASES_CHECK = Ref(true) """ @@ -366,9 +335,9 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp)) SumBasis(bases::Vector) = SumBasis((bases...,)) SumBasis(bases::Basis...) = SumBasis((bases...,)) -==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) -==(b1::SumBasis, b2::SumBasis) = false -length(b::SumBasis) = sum(b.shape) +Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) +Base.:(==)(b1::SumBasis, b2::SumBasis) = false +Base.length(b::SumBasis) = sum(b.shape) """ directsum(b1::Basis, b2::Basis) @@ -393,62 +362,3 @@ function directsum(b1::SumBasis, b2::SumBasis) bases = [b1.bases...;b2.bases...] return SumBasis(shape, (bases...,)) end - -embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) - -## -# show methods -## - -function show(stream::IO, x::GenericBasis) - if length(x.shape) == 1 - write(stream, "Basis(dim=$(x.shape[1]))") - else - s = replace(string(x.shape), " " => "") - write(stream, "Basis(shape=$s)") - end -end - -function show(stream::IO, x::CompositeBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊗ ") - end - end - write(stream, "]") -end - -function show(stream::IO, x::SpinBasis) - d = denominator(x.spinnumber) - n = numerator(x.spinnumber) - if d == 1 - write(stream, "Spin($n)") - else - write(stream, "Spin($n/$d)") - end -end - -function show(stream::IO, x::FockBasis) - if iszero(x.offset) - write(stream, "Fock(cutoff=$(x.N))") - else - write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") - end -end - -function show(stream::IO, x::NLevelBasis) - write(stream, "NLevel(N=$(x.N))") -end - -function show(stream::IO, x::SumBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊕ ") - end - end - write(stream, "]") -end diff --git a/src/embed_permute.jl b/src/embed_permute.jl index c2cc4ca..ac63bc4 100644 --- a/src/embed_permute.jl +++ b/src/embed_permute.jl @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, ops_sb = [x[2] for x in idxop_sb] for (idxsb, opsb) in zip(indices_sb, ops_sb) - (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) - (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) + (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 + (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 end S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any @@ -83,6 +83,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, return embed_op end +embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) + permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) diff --git a/src/identityoperator.jl b/src/identityoperator.jl index 5959882..aa031ff 100644 --- a/src/identityoperator.jl +++ b/src/identityoperator.jl @@ -1,4 +1,4 @@ -one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) +Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) """ identityoperator(a::Basis[, b::Basis]) @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2) """Prepare the identity superoperator over a given space.""" -function identitysuperoperator end \ No newline at end of file +function identitysuperoperator end diff --git a/src/julia_base.jl b/src/julia_base.jl index 9a0532d..126694c 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -1,3 +1,5 @@ +import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy + # Common error messages arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse().")) arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse().")) @@ -8,33 +10,33 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op # States ## --(a::T) where {T<:StateVector} = T(a.basis, -a.data) +-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12 *(a::StateVector, b::Number) = b*a -copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) -length(a::StateVector) = length(a.basis)::Int -basis(a::StateVector) = a.basis +copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 +length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 +basis(a::StateVector) = a.basis # FIXME issue #12 directsum(x::StateVector...) = reduce(directsum, x) +adjoint(a::StateVector) = dagger(a) + + # Array-like functions -Base.size(x::StateVector) = size(x.data) -@inline Base.axes(x::StateVector) = axes(x.data) +Base.size(x::StateVector) = size(x.data) # FIXME issue #12 +@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12 Base.ndims(x::StateVector) = 1 Base.ndims(::Type{<:StateVector}) = 1 -Base.eltype(x::StateVector) = eltype(x.data) +Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12 # Broadcasting Base.broadcastable(x::StateVector) = x -Base.adjoint(a::StateVector) = dagger(a) - - ## # Operators ## -length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) +length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 +basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 +basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) @@ -60,14 +62,17 @@ Operator exponential. """ exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense().")) -Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) +Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12 function Base.size(op::AbstractOperator, i::Int) i < 1 && throw(ErrorException("dimension index is < 1")) i > 2 && return 1 - i==1 ? length(op.basis_l) : length(op.basis_r) + i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12 end -Base.adjoint(a::AbstractOperator) = dagger(a) +adjoint(a::AbstractOperator) = dagger(a) + +transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) + conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a) conj!(a::AbstractOperator) = conj(a::AbstractOperator) diff --git a/src/julia_linalg.jl b/src/julia_linalg.jl index d2f4d3d..3087d0a 100644 --- a/src/julia_linalg.jl +++ b/src/julia_linalg.jl @@ -1,3 +1,5 @@ +import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! + """ ishermitian(op::AbstractOperator) @@ -17,7 +19,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x) Norm of the given bra or ket state. """ -norm(x::StateVector) = norm(x.data) +norm(x::StateVector) = norm(x.data) # FIXME issue #12 """ normalize(x::StateVector) @@ -31,7 +33,7 @@ normalize(x::StateVector) = x/norm(x) In-place normalization of the given bra or ket so that `norm(x)` is one. """ -normalize!(x::StateVector) = (normalize!(x.data); x) +normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME issue #12 """ normalize(op) diff --git a/src/linalg.jl b/src/linalg.jl index 8bb47cd..fb0779d 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,10 +1,10 @@ -samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool -samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool -check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) +samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 +samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 +check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 +multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) -transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) directsum(a::AbstractOperator...) = reduce(directsum, a) ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a) _index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices) reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices)) +traceout!(s::StateVector, i) = ptrace(s,i) diff --git a/src/show.jl b/src/show.jl new file mode 100644 index 0000000..38607b0 --- /dev/null +++ b/src/show.jl @@ -0,0 +1,69 @@ +import Base: show, summary + +function summary(stream::IO, x::AbstractOperator) + print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") + if samebases(x) + print(stream, " basis: ") + show(stream, basis(x)) + else + print(stream, " basis left: ") + show(stream, x.basis_l) + print(stream, "\n basis right: ") + show(stream, x.basis_r) + end +end + +show(stream::IO, x::AbstractOperator) = summary(stream, x) + +function show(stream::IO, x::GenericBasis) + if length(x.shape) == 1 + write(stream, "Basis(dim=$(x.shape[1]))") + else + s = replace(string(x.shape), " " => "") + write(stream, "Basis(shape=$s)") + end +end + +function show(stream::IO, x::CompositeBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊗ ") + end + end + write(stream, "]") +end + +function show(stream::IO, x::SpinBasis) + d = denominator(x.spinnumber) + n = numerator(x.spinnumber) + if d == 1 + write(stream, "Spin($n)") + else + write(stream, "Spin($n/$d)") + end +end + +function show(stream::IO, x::FockBasis) + if iszero(x.offset) + write(stream, "Fock(cutoff=$(x.N))") + else + write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") + end +end + +function show(stream::IO, x::NLevelBasis) + write(stream, "NLevel(N=$(x.N))") +end + +function show(stream::IO, x::SumBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊕ ") + end + end + write(stream, "]") +end diff --git a/src/sparse.jl b/src/sparse.jl index 2ba8f5f..d6b301c 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -1,4 +1,4 @@ -# TODO make an extension? +import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension # dense(a::AbstractOperator) = arithmetic_unary_error("Conversion to dense", a) diff --git a/test/runtests.jl b/test/runtests.jl index 0bccf25..826fe33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ end println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREADS = $(Sys.CPU_THREADS)`...") @doset "sortedindices" +@doset "bases" #VERSION >= v"1.9" && @doset "doctests" get(ENV,"JET_TEST","")=="true" && @doset "jet" VERSION >= v"1.9" && @doset "aqua" diff --git a/test/test_bases.jl b/test/test_bases.jl new file mode 100644 index 0000000..1d91673 --- /dev/null +++ b/test/test_bases.jl @@ -0,0 +1,55 @@ +using Test +using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, equal_bases, multiplicable +using QuantumInterface: GenericBasis, CompositeBasis, NLevelBasis, FockBasis + +@testset "basis" begin + +shape1 = [5] +shape2 = [2, 3] +shape3 = [6] + +b1 = GenericBasis(shape1) +b2 = GenericBasis(shape2) +b3 = GenericBasis(shape3) + +@test b1.shape == shape1 +@test b2.shape == shape2 +@test b1 != b2 +@test b1 != FockBasis(2) +@test b1 == b1 + +@test tensor(b1) == b1 +comp_b1 = tensor(b1, b2) +comp_uni = b1 ⊗ b2 +comp_b2 = tensor(b1, b1, b2) +@test comp_b1.shape == [prod(shape1), prod(shape2)] +@test comp_uni.shape == [prod(shape1), prod(shape2)] +@test comp_b2.shape == [prod(shape1), prod(shape1), prod(shape2)] + +@test b1^3 == CompositeBasis(b1, b1, b1) +@test (b1⊗b2)^2 == CompositeBasis(b1, b2, b1, b2) +@test_throws ArgumentError b1^(0) + +comp_b1_b2 = tensor(comp_b1, comp_b2) +@test comp_b1_b2.shape == [prod(shape1), prod(shape2), prod(shape1), prod(shape1), prod(shape2)] +@test comp_b1_b2 == CompositeBasis(b1, b2, b1, b1, b2) + +@test_throws ArgumentError tensor() +@test comp_b2.shape == tensor(b1, comp_b1).shape +@test comp_b2 == tensor(b1, comp_b1) + +@test_throws ArgumentError ptrace(comp_b1, [1, 2]) +@test ptrace(comp_b2, [1]) == ptrace(comp_b2, [2]) == comp_b1 == ptrace(comp_b2, 1) +@test ptrace(comp_b2, [1, 2]) == ptrace(comp_b1, [1]) +@test ptrace(comp_b2, [2, 3]) == ptrace(comp_b1, [2]) +@test ptrace(comp_b2, [2, 3]) == reduced(comp_b2, [1]) +@test_throws ArgumentError reduced(comp_b1, []) + +comp1 = tensor(b1, b2, b3) +comp2 = tensor(b2, b1, b3) +@test permutesystems(comp1, [2,1,3]) == comp2 + +@test !equal_bases([b1, b2], [b1, b3]) +@test !multiplicable(comp1, b1 ⊗ b2 ⊗ NLevelBasis(prod(b3.shape))) + +end # testset