Skip to content

Commit

Permalink
[AlgorithmSelection] Rename BackendSelection, introduce Backend type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored May 5, 2024
1 parent fa6c706 commit 91e595d
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 63 deletions.
2 changes: 1 addition & 1 deletion NDTensors/src/blocksparse/contract.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .AlgorithmSelection: Algorithm, @Algorithm_str
using .BackendSelection: Algorithm, @Algorithm_str

function contract(
tensor1::BlockSparseTensor,
Expand Down
6 changes: 5 additions & 1 deletion NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using TimerOutputs
using TupleTools

for lib in [
:AlgorithmSelection,
:AllocateData,
:BackendSelection,
:BaseExtensions,
:UnspecifiedTypes,
:TypeParameterAccessors,
Expand Down Expand Up @@ -51,6 +51,10 @@ for lib in [
include("lib/$(lib)/src/$(lib).jl")
@eval using .$lib: $lib
end
# TODO: This is defined for backwards compatibility,
# delete this alias once downstream packages change over
# to using `BackendSelection`.
const AlgorithmSelection = BackendSelection

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand Down

This file was deleted.

39 changes: 0 additions & 39 deletions NDTensors/src/lib/AlgorithmSelection/src/algorithm.jl

This file was deleted.

12 changes: 0 additions & 12 deletions NDTensors/src/lib/AlgorithmSelection/test/runtests.jl

This file was deleted.

9 changes: 9 additions & 0 deletions NDTensors/src/lib/BackendSelection/src/BackendSelection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module BackendSelection
include("abstractbackend.jl")
include("backend_types.jl")

# TODO: This is defined for backwards compatibility,
# delete this alias once downstream packages change over
# to using `BackendSelection`.
const AlgorithmSelection = BackendSelection
end
4 changes: 4 additions & 0 deletions NDTensors/src/lib/BackendSelection/src/abstractbackend.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
abstract type AbstractBackend end

backend_string(::AbstractBackend) = error("Not implemented")
parameters(::AbstractBackend) = error("Not implemented")
57 changes: 57 additions & 0 deletions NDTensors/src/lib/BackendSelection/src/backend_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
for type in (:Algorithm, :Backend)
@eval begin
"""
$($type)
A type representing a backend for a function.
For example, a function might have multiple backends
implementations, which internally are selected with a `$($type)` type.
This allows users to extend functionality with a new implementation but
use the same interface.
"""
struct $type{Back,Kwargs<:NamedTuple} <: AbstractBackend
kwargs::Kwargs
end

$type{Back}(kwargs::NamedTuple) where {Back} = $type{Back,typeof(kwargs)}(kwargs)
$type{Back}(; kwargs...) where {Back} = $type{Back}(NamedTuple(kwargs))
$type(s; kwargs...) = $type{Symbol(s)}(NamedTuple(kwargs))

$type(backend::$type) = backend

# TODO: Use `SetParameters`.
backend_string(::$type{Back}) where {Back} = string(Back)
parameters(backend::$type) = getfield(backend, :kwargs)

function Base.show(io::IO, backend::$type)
return print(io, "$type type ", backend_string(backend), ", ", parameters(backend))
end
Base.print(io::IO, backend::$type) =
print(io, backend_string(backend), ", ", parameters(backend))
end
end

# TODO: See if these can be moved inside of `@eval`.
"""
@Algorithm_str
A convenience macro for writing [`Algorithm`](@ref) types, typically used when
adding methods to a function that supports multiple algorithm
backends.
"""
macro Algorithm_str(s)
return :(Algorithm{$(Expr(:quote, Symbol(s)))})
end

"""
@Backend_str
A convenience macro for writing [`Backend`](@ref) types, typically used when
adding methods to a function that supports multiple
backends.
"""
macro Backend_str(s)
return :(Backend{$(Expr(:quote, Symbol(s)))})
end
28 changes: 28 additions & 0 deletions NDTensors/src/lib/BackendSelection/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors: NDTensors
using NDTensors.BackendSelection:
BackendSelection, Algorithm, Backend, @Algorithm_str, @Backend_str
# TODO: This is defined for backwards compatibility,
# delete this alias once downstream packages change over
# to using `BackendSelection`.
using NDTensors.AlgorithmSelection: AlgorithmSelection
@testset "BackendSelection" begin
# TODO: This is defined for backwards compatibility,
# delete this alias once downstream packages change over
# to using `BackendSelection`.
@test AlgorithmSelection === BackendSelection
for type in (Algorithm, Backend)
@testset "$type" begin
@test type("backend") isa type{:backend}
@test type(:backend) isa type{:backend}
backend = type("backend"; x=2, y=3)
@test backend isa type{:backend}
@test BackendSelection.parameters(backend) === (; x=2, y=3)
end
end
# Macro syntax.
@test Algorithm"backend"(; x=2, y=3) === Algorithm("backend"; x=2, y=3)
@test Backend"backend"(; x=2, y=3) === Backend("backend"; x=2, y=3)
end
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module BlockSparseArrays
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using ..BackendSelection: Algorithm, @Algorithm_str
using BlockArrays:
AbstractBlockArray,
BlockArrays,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module LinearAlgebraExt
using ...AlgorithmSelection: Algorithm, @Algorithm_str
using ...BackendSelection: Algorithm, @Algorithm_str
using BlockArrays: BlockArrays, blockedrange, blocks
using ..BlockSparseArrays: nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ..AlgorithmSelection: Algorithm
using ..BackendSelection: Algorithm
using .BaseExtensions: BaseExtensions

function blockedperms(
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/lib/TensorAlgebra/src/contract/contract.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ..AlgorithmSelection: Algorithm, @Algorithm_str
using ..BackendSelection: Algorithm, @Algorithm_str

# TODO: Add `contract!!` definitions as pass-throughs to `contract!`.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ..AlgorithmSelection: @Algorithm_str
using ..BackendSelection: @Algorithm_str
using LinearAlgebra: mul!

function contract!(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ..AlgorithmSelection: Algorithm
using ..BackendSelection: Algorithm

function output_labels(
f::typeof(contract),
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/test/lib/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
using NDTensors: NDTensors
using Test: @testset
@testset "Test NDTensors lib $lib" for lib in [
"AlgorithmSelection",
"AllocateData",
"AMDGPUExtensions",
"BackendSelection",
"BaseExtensions",
"BlockSparseArrays",
"BroadcastMapConversion",
Expand Down

0 comments on commit 91e595d

Please sign in to comment.