Skip to content

Commit

Permalink
Implement basis interface proposed in qojulia#40
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Dec 9, 2024
1 parent 1672e2f commit e586db1
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 41 deletions.
71 changes: 68 additions & 3 deletions src/QuantumInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,88 @@ module QuantumInterface
"""
basis(a)
Return the basis of an object.
Return the basis of a quantum object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
If it's ambiguous, e.g. if an operator has a different
left and right basis, an [`IncompatibleBases`](@ref) error is thrown.
See [`StateVector`](@ref) and [`AbstractOperator`](@ref)
"""
function basis end

"""
basis_l(a)
Return the left basis of an operator.
"""
function basis_l end

"""
basis_r(a)
Return the right basis of an operator.
"""
function basis_r end

"""
Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end

#function bases end

function spinnumber end

function cutoff end

function offset end

##
# Standard methods
##

"""
multiplicable(a, b)
Check if any two subtypes of `StateVector` or `AbstractOperator`,
can be multiplied in the given order.
"""
function multiplicable end

"""
check_multiplicable(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects are
not multiplicable as determined by `multiplicable(a, b)`.
If the macro `@compatiblebases` is used anywhere up the call stack,
this check is disabled.
"""
function check_multiplicable end

"""
addible(a, b)
Check if any two subtypes of `StateVector` or `AbstractOperator`
can be added together.
Spcefically this checks whether the left basis of a is equal
to the left basis of b and whether the right basis of a is equal
to the right basis of b.
"""
function addible end

"""
check_addible(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects are
not addible as determined by `addible(a, b)`.
If the macro `@compatiblebases` is used anywhere up the call stack,
this check is disabled.
"""
function check_addible end

function apply! end

function dagger end
Expand Down
41 changes: 24 additions & 17 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
"""
Abstract base class for all specialized bases.
Abstract type for all specialized bases of a Hilbert space.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
The `Basis` type specifies an orthonormal basis for the Hilbert
space of the studied system. All subtypes must implement `Base.:(==)`,
and `Base.size`. `size` should return a tuple representing the total dimension
of the Hilbert space with any tensor product structure the basis has such that
`length(b::Basis) = prod(size(b))` gives the total Hilbert dimension
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
Composite systems can be defined with help of [`CompositeBasis`](@ref).
All relevant properties of subtypes of `Basis` defined in `QuantumInterface`
should be accessed using their documented functions and should not
assume anything about the internal representation of instances of these
types (i.e. don't access the struct's fields directly).
"""
abstract type Basis end

"""
Abstract base class for `Bra` and `Ket` states.
Abstract type for `Bra` and `Ket` states.
The state vector class stores the coefficients of an abstract state
in respect to a certain basis. These coefficients are stored in the
`data` field and the basis is defined in the `basis`
field.
The state vector class stores an abstract state with respect
to a certain basis. All subtypes must implement the `basis`
method which should this basis as a subtype of `Basis`.
"""
abstract type StateVector{B,T} end
abstract type AbstractKet{B,T} <: StateVector{B,T} end
abstract type AbstractBra{B,T} <: StateVector{B,T} end

"""
Abstract base class for all operators.
Abstract type for all operators and super operators.
All deriving operator classes have to define the fields
`basis_l` and `basis_r` defining the left and right side bases.
All subtypes must implement the methods `basis_l` and
`basis_r` which return subtypes of `Basis` and
represent the left and right bases that the operator
maps between and thus is compatible with a `Bra` defined
in the left basis and a `Ket` defined in the right basis.
For fast time evolution also at least the function
`mul!(result::Ket,op::AbstractOperator,x::Ket,alpha,beta)` should be
Expand All @@ -53,3 +58,5 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
```
"""
abstract type AbstractSuperOperator{B1,B2} end

const AbstractQObjType = Union{<:StateVector,<:AbstractOperator}
16 changes: 13 additions & 3 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Total dimension of the Hilbert space.
"""
Base.length(b::Basis) = prod(b.shape)
Base.length(b::Basis) = prod(b.shape) # change to prod(size(b)) when downstream Bases are updated

"""
GenericBasis(N)
Expand All @@ -24,7 +24,7 @@ end
GenericBasis(N::Integer) = GenericBasis([N])

Base.:(==)(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape)

Base.size(b::CompositeBasis) = b.shape

"""
CompositeBasis(b1, b2...)
Expand All @@ -42,8 +42,11 @@ end
CompositeBasis(bases) = CompositeBasis([length(b) for b bases], bases)
CompositeBasis(bases::Basis...) = CompositeBasis((bases...,))
CompositeBasis(bases::Vector) = CompositeBasis((bases...,))
#bases(b::CompositeBasis) = b.bases

Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape)
Base.size(b::CompositeBasis) = length.(b.bases)
Base.getindex(b::CompositeBasis, i) = getindex(b.bases, i)

##
# Common bases
Expand All @@ -69,6 +72,9 @@ struct FockBasis{T} <: Basis
end

Base.:(==)(b1::FockBasis, b2::FockBasis) = (b1.N==b2.N && b1.offset==b2.offset)
Base.size(b::FockBasis) = (b.N - b.offset + 1,)
cutoff(b::FockBasis) = b.N
offset(b::FockBasis) = b.offset


"""
Expand All @@ -88,6 +94,7 @@ struct NLevelBasis{T} <: Basis
end

Base.:(==)(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N
Base.size(b::NLevelBasis) = (b.N,)

"""
NQubitBasis(num_qubits::Int)
Expand All @@ -106,6 +113,7 @@ struct NQubitBasis{S,B} <: Basis
end

Base.:(==)(pb1::NQubitBasis, pb2::NQubitBasis) = length(pb1.bases) == length(pb2.bases)
Base.size(b::SpinBasis) = b.shape

"""
SpinBasis(n)
Expand All @@ -132,7 +140,8 @@ SpinBasis(spinnumber::Rational) = SpinBasis{spinnumber}(spinnumber)
SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber))

Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber

Base.size(b::SpinBasis) = (numerator(b.spinnumber*2 + 1),)
spinnumber(b::SpinBasis) = b.spinnumber

"""
SumBasis(b1, b2...)
Expand All @@ -151,3 +160,4 @@ SumBasis(bases::Basis...) = SumBasis((bases...,))
Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
Base.:(==)(b1::SumBasis, b2::SumBasis) = false
Base.length(b::SumBasis) = sum(b.shape)
# TODO how should `.bases` be accessed? `getindex` or a `sumbases` method?
32 changes: 17 additions & 15 deletions src/expect_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,35 @@
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number.
"""
function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
expect(op, ptrace(state, indices_))
end
expect(indices, op::AbstractOperator, state::AbstractOperator) =
expect(op, ptrace(state, complement(nsubsystems(state), indices)))

expect(index::Integer, op::AbstractOperator, state::AbstractOperator) = expect([index], op, state)

expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state)
expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states]

expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states]

expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state)
expect(op::AbstractOperator, state::AbstractOperator) =
(check_multiplicable(state, state); check_multiplicable(op,state); tr(op*state))

"""
variance(index, op, state)
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number
"""
function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
variance(op, ptrace(state, indices_))
end
variance(indices, op::AbstractOperator, state::AbstractOperator) =
variance(op, ptrace(state, complement(nsubsystems(state), indices)))

variance(index::Integer, op::AbstractOperator, state::AbstractOperator) = variance([index], op, state)

variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states]

variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states]

function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B
expect(op*op, state) - expect(op, state)^2
function variance(op::AbstractOperator, state::AbstractOperator)
check_multiplicable(op,op)
check_multiplicable(state,state)
check_multiplicable(op,state)
@compatiblebases expect(op*op, state) - expect(op, state)^2
end
3 changes: 0 additions & 3 deletions src/julia_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
*(a::StateVector, b::Number) = b*a
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12
length(a::StateVector) = length(a.basis)::Int # FIXME issue #12
basis(a::StateVector) = a.basis # FIXME issue #12
adjoint(a::StateVector) = dagger(a)


Expand All @@ -33,8 +32,6 @@ Base.broadcastable(x::StateVector) = x
##

length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)
Expand Down
44 changes: 44 additions & 0 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,50 @@

const BASES_CHECK = Ref(true)

"""
@compatiblebases
Macro to skip checks for compatible bases. Useful for `*`, `expect` and similar
functions.
"""
macro compatiblebases(ex)
return quote
BASES_CHECK.x = false
local val = $(esc(ex))
BASES_CHECK.x = true
val
end
end

function check_addible(b1, b2)
if BASES_CHECK[] && !addible(b1, b2)
throw(IncompatibleBases())
end
end

function check_multiplicable(b1, b2)
if BASES_CHECK[] && !multiplicable(b1, b2)
throw(IncompatibleBases())
end
end

addible(a::AbstractQObjType, b::AbstractQObjType) = false
addible(a::AbstractBra, b::AbstractBra) = (basis(a) == basis(b))
addible(a::AbstractKet, b::AbstractKet) = (basis(a) == basis(b))
addible(a::AbstractOperator, b::AbstractOperator) =
(basis_l(a) == basis_l(b)) && (basis_r(a) == basis_r(b))

multiplicable(a::AbstractQObjType, b::AbstractQObjType) = false
multiplicable(a::AbstractBra, b::AbstractKet) = (basis(a) == basis(b))
multiplicable(a::AbstractOperator, b::AbstractKet) = (basis_r(a) == basis(b))
multiplicable(a::AbstractBra, b::AbstractOperator) = (basis(a) == basis_l(b))
multiplicable(a::AbstractOperator, b::AbstractOperator) = (basis_r(a) == basis_l(b))

basis(a::StateVector) = throw(ArgumentError("basis() is not defined for this type of state vector: $(typeof(a))."))
basis_l(a::AbstractOperator) = throw(ArgumentError("basis_l() is not defined for this type of operator: $(typeof(a))."))
basis_r(a::AbstractOperator) = throw(ArgumentError("basis_r() is not defined for this type of operator: $(typeof(a))."))
basis(a::AbstractOperator) = (basis_l(a) == basis_r(a); basis_l(a))

##
# tensor, reduce, ptrace
##
Expand Down

0 comments on commit e586db1

Please sign in to comment.