From e586db174ce60eddff3c8e23fcaae9f55e425d06 Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Thu, 5 Dec 2024 14:43:29 -0700 Subject: [PATCH] Implement basis interface proposed in #40 --- src/QuantumInterface.jl | 71 +++++++++++++++++++++++++++++++++++++++-- src/abstract_types.jl | 41 ++++++++++++++---------- src/bases.jl | 16 ++++++++-- src/expect_variance.jl | 32 ++++++++++--------- src/julia_base.jl | 3 -- src/linalg.jl | 44 +++++++++++++++++++++++++ 6 files changed, 166 insertions(+), 41 deletions(-) diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index c1bc651..80c7b1a 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -7,23 +7,88 @@ module QuantumInterface """ basis(a) -Return the basis of an object. +Return the basis of a quantum object. -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. +If it's ambiguous, e.g. if an operator has a different +left and right basis, an [`IncompatibleBases`](@ref) error is thrown. + +See [`StateVector`](@ref) and [`AbstractOperator`](@ref) """ function basis end +""" + basis_l(a) + +Return the left basis of an operator. +""" +function basis_l end + +""" + basis_r(a) + +Return the right basis of an operator. +""" +function basis_r end + """ Exception that should be raised for an illegal algebraic operation. """ mutable struct IncompatibleBases <: Exception end +#function bases end + +function spinnumber end + +function cutoff end + +function offset end ## # Standard methods ## +""" + multiplicable(a, b) + +Check if any two subtypes of `StateVector` or `AbstractOperator`, +can be multiplied in the given order. +""" +function multiplicable end + +""" + check_multiplicable(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not multiplicable as determined by `multiplicable(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_multiplicable end + +""" + addible(a, b) + +Check if any two subtypes of `StateVector` or `AbstractOperator` + can be added together. + +Spcefically this checks whether the left basis of a is equal +to the left basis of b and whether the right basis of a is equal +to the right basis of b. +""" +function addible end + +""" + check_addible(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not addible as determined by `addible(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_addible end + function apply! end function dagger end diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 0650290..7e63a53 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,35 +1,40 @@ """ -Abstract base class for all specialized bases. +Abstract type for all specialized bases of a Hilbert space. -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. +The `Basis` type specifies an orthonormal basis for the Hilbert +space of the studied system. All subtypes must implement `Base.:(==)`, +and `Base.size`. `size` should return a tuple representing the total dimension +of the Hilbert space with any tensor product structure the basis has such that +`length(b::Basis) = prod(size(b))` gives the total Hilbert dimension -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. +Composite systems can be defined with help of [`CompositeBasis`](@ref). + +All relevant properties of subtypes of `Basis` defined in `QuantumInterface` +should be accessed using their documented functions and should not +assume anything about the internal representation of instances of these +types (i.e. don't access the struct's fields directly). """ abstract type Basis end """ -Abstract base class for `Bra` and `Ket` states. +Abstract type for `Bra` and `Ket` states. -The state vector class stores the coefficients of an abstract state -in respect to a certain basis. These coefficients are stored in the -`data` field and the basis is defined in the `basis` -field. +The state vector class stores an abstract state with respect +to a certain basis. All subtypes must implement the `basis` +method which should this basis as a subtype of `Basis`. """ abstract type StateVector{B,T} end abstract type AbstractKet{B,T} <: StateVector{B,T} end abstract type AbstractBra{B,T} <: StateVector{B,T} end """ -Abstract base class for all operators. +Abstract type for all operators and super operators. -All deriving operator classes have to define the fields -`basis_l` and `basis_r` defining the left and right side bases. +All subtypes must implement the methods `basis_l` and +`basis_r` which return subtypes of `Basis` and +represent the left and right bases that the operator +maps between and thus is compatible with a `Bra` defined +in the left basis and a `Ket` defined in the right basis. For fast time evolution also at least the function `mul!(result::Ket,op::AbstractOperator,x::Ket,alpha,beta)` should be @@ -53,3 +58,5 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ abstract type AbstractSuperOperator{B1,B2} end + +const AbstractQObjType = Union{<:StateVector,<:AbstractOperator} diff --git a/src/bases.jl b/src/bases.jl index e178a66..934acd5 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -7,7 +7,7 @@ Total dimension of the Hilbert space. """ -Base.length(b::Basis) = prod(b.shape) +Base.length(b::Basis) = prod(b.shape) # change to prod(size(b)) when downstream Bases are updated """ GenericBasis(N) @@ -24,7 +24,7 @@ end GenericBasis(N::Integer) = GenericBasis([N]) Base.:(==)(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape) - +Base.size(b::CompositeBasis) = b.shape """ CompositeBasis(b1, b2...) @@ -42,8 +42,11 @@ end CompositeBasis(bases) = CompositeBasis([length(b) for b ∈ bases], bases) CompositeBasis(bases::Basis...) = CompositeBasis((bases...,)) CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) +#bases(b::CompositeBasis) = b.bases Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) +Base.size(b::CompositeBasis) = length.(b.bases) +Base.getindex(b::CompositeBasis, i) = getindex(b.bases, i) ## # Common bases @@ -69,6 +72,9 @@ struct FockBasis{T} <: Basis end Base.:(==)(b1::FockBasis, b2::FockBasis) = (b1.N==b2.N && b1.offset==b2.offset) +Base.size(b::FockBasis) = (b.N - b.offset + 1,) +cutoff(b::FockBasis) = b.N +offset(b::FockBasis) = b.offset """ @@ -88,6 +94,7 @@ struct NLevelBasis{T} <: Basis end Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N +Base.size(b::NLevelBasis) = (b.N,) """ NQubitBasis(num_qubits::Int) @@ -106,6 +113,7 @@ struct NQubitBasis{S,B} <: Basis end Base.:(==)(pb1::NQubitBasis, pb2::NQubitBasis) = length(pb1.bases) == length(pb2.bases) +Base.size(b::SpinBasis) = b.shape """ SpinBasis(n) @@ -132,7 +140,8 @@ SpinBasis(spinnumber::Rational) = SpinBasis{spinnumber}(spinnumber) SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber)) Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber - +Base.size(b::SpinBasis) = (numerator(b.spinnumber*2 + 1),) +spinnumber(b::SpinBasis) = b.spinnumber """ SumBasis(b1, b2...) @@ -151,3 +160,4 @@ SumBasis(bases::Basis...) = SumBasis((bases...,)) Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) Base.:(==)(b1::SumBasis, b2::SumBasis) = false Base.length(b::SumBasis) = sum(b.shape) +# TODO how should `.bases` be accessed? `getindex` or a `sumbases` method? diff --git a/src/expect_variance.jl b/src/expect_variance.jl index 86f5e64..678fb5f 100644 --- a/src/expect_variance.jl +++ b/src/expect_variance.jl @@ -3,33 +3,35 @@ If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number. """ -function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} - N = length(state.basis_l.shape) - indices_ = complement(N, indices) - expect(op, ptrace(state, indices_)) -end +expect(indices, op::AbstractOperator, state::AbstractOperator) = + expect(op, ptrace(state, complement(nsubsystems(state), indices))) + +expect(index::Integer, op::AbstractOperator, state::AbstractOperator) = expect([index], op, state) -expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state) expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states] + expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states] -expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state) +expect(op::AbstractOperator, state::AbstractOperator) = + (check_multiplicable(state, state); check_multiplicable(op,state); tr(op*state)) """ variance(index, op, state) If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number """ -function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} - N = length(state.basis_l.shape) - indices_ = complement(N, indices) - variance(op, ptrace(state, indices_)) -end +variance(indices, op::AbstractOperator, state::AbstractOperator) = + variance(op, ptrace(state, complement(nsubsystems(state), indices))) + +variance(index::Integer, op::AbstractOperator, state::AbstractOperator) = variance([index], op, state) -variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state) variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states] + variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states] -function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B - expect(op*op, state) - expect(op, state)^2 +function variance(op::AbstractOperator, state::AbstractOperator) + check_multiplicable(op,op) + check_multiplicable(state,state) + check_multiplicable(op,state) + @compatiblebases expect(op*op, state) - expect(op, state)^2 end diff --git a/src/julia_base.jl b/src/julia_base.jl index 2d8e085..2eb11fa 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -14,7 +14,6 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op *(a::StateVector, b::Number) = b*a copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 -basis(a::StateVector) = a.basis # FIXME issue #12 adjoint(a::StateVector) = dagger(a) @@ -33,8 +32,6 @@ Base.broadcastable(x::StateVector) = x ## length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) diff --git a/src/linalg.jl b/src/linalg.jl index b48f458..0fd8a26 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -4,6 +4,50 @@ const BASES_CHECK = Ref(true) +""" + @compatiblebases + +Macro to skip checks for compatible bases. Useful for `*`, `expect` and similar +functions. +""" +macro compatiblebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + +function check_addible(b1, b2) + if BASES_CHECK[] && !addible(b1, b2) + throw(IncompatibleBases()) + end +end + +function check_multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) + throw(IncompatibleBases()) + end +end + +addible(a::AbstractQObjType, b::AbstractQObjType) = false +addible(a::AbstractBra, b::AbstractBra) = (basis(a) == basis(b)) +addible(a::AbstractKet, b::AbstractKet) = (basis(a) == basis(b)) +addible(a::AbstractOperator, b::AbstractOperator) = + (basis_l(a) == basis_l(b)) && (basis_r(a) == basis_r(b)) + +multiplicable(a::AbstractQObjType, b::AbstractQObjType) = false +multiplicable(a::AbstractBra, b::AbstractKet) = (basis(a) == basis(b)) +multiplicable(a::AbstractOperator, b::AbstractKet) = (basis_r(a) == basis(b)) +multiplicable(a::AbstractBra, b::AbstractOperator) = (basis(a) == basis_l(b)) +multiplicable(a::AbstractOperator, b::AbstractOperator) = (basis_r(a) == basis_l(b)) + +basis(a::StateVector) = throw(ArgumentError("basis() is not defined for this type of state vector: $(typeof(a)).")) +basis_l(a::AbstractOperator) = throw(ArgumentError("basis_l() is not defined for this type of operator: $(typeof(a)).")) +basis_r(a::AbstractOperator) = throw(ArgumentError("basis_r() is not defined for this type of operator: $(typeof(a)).")) +basis(a::AbstractOperator) = (basis_l(a) == basis_r(a); basis_l(a)) + ## # tensor, reduce, ptrace ##