From 72c2635c22a549489df80467a464170c8008b2aa Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Fri, 8 Dec 2023 12:33:12 -0500 Subject: [PATCH] [BlockSparseArrays] `broadcast`, `map`, `permutedims`, `block_reshape` (#1282) --- NDTensors/Project.toml | 4 ++ NDTensors/src/NDTensors.jl | 2 +- .../lib/BlockSparseArrays/examples/README.jl | 2 + .../BlockArraysExtensions.jl | 28 ++++++++- .../BlockArraysSparseArrayInterfaceExt.jl | 5 ++ .../src/BlockSparseArrays.jl | 4 ++ .../abstractblocksparsearray.jl | 33 ---------- .../abstractblocksparsearray/arraylayouts.jl | 1 + .../src/abstractblocksparsearray/broadcast.jl | 5 ++ .../src/abstractblocksparsearray/map.jl | 42 +++++++++++++ .../wrappedabstractblocksparsearray.jl | 63 +++++++++++++++++++ .../src/blocksparsearray/blocksparsearray.jl | 8 +++ .../blocksparsearrayinterface.jl | 31 +++++++-- .../blocksparsearrayinterface/blockzero.jl | 14 ++--- .../blocksparsearrayinterface/broadcast.jl | 41 ++++++++++++ .../lib/BlockSparseArrays/test/runtests.jl | 54 +++++++++++++++- .../lib/SparseArrayDOKs/src/sparsearraydok.jl | 4 ++ .../src/SparseArrayInterface.jl | 1 + .../src/abstractsparsearray/base.jl | 8 +-- .../src/abstractsparsearray/baseinterface.jl | 4 ++ .../src/abstractsparsearray/broadcast.jl | 2 +- .../src/abstractsparsearray/map.jl | 38 +++++------ .../wrappedabstractsparsearray.jl | 9 +++ .../src/sparsearrayinterface/base.jl | 3 +- .../src/sparsearrayinterface/indexing.jl | 21 ++++++- .../interface_optional.jl | 6 +- .../src/sparsearrayinterface/map.jl | 42 +++++++++++-- .../src/sparsearrayinterface/wrappers.jl | 6 +- .../lib/TensorAlgebra/src/TensorAlgebra.jl | 1 + Project.toml | 1 - 30 files changed, 398 insertions(+), 85 deletions(-) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl create mode 100644 NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/wrappedabstractsparsearray.jl diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 1fe90a90da..78878b7eac 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -4,6 +4,7 @@ authors = ["Matthew Fishman "] version = "0.2.22" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" @@ -16,6 +17,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" @@ -41,7 +43,9 @@ NDTensorsOctavianExt = "Octavian" NDTensorsTBLISExt = "TBLIS" [compat] +Accessors = "0.1.33" Adapt = "3.7" +ArrayLayouts = "1.4" BlockArrays = "0.16" Compat = "4.9" Dictionaries = "0.3.5" diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index 46f2964f67..052927677f 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -27,12 +27,12 @@ for lib in [ :BroadcastMapConversion, :Unwrap, :RankFactorization, + :GradedAxes, :TensorAlgebra, :SparseArrayInterface, :SparseArrayDOKs, :DiagonalArrays, :BlockSparseArrays, - :GradedAxes, :NamedDimsArrays, :SmallVectors, :SortedSets, diff --git a/NDTensors/src/lib/BlockSparseArrays/examples/README.jl b/NDTensors/src/lib/BlockSparseArrays/examples/README.jl index f0a9edd4bb..2bf91f512d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/examples/README.jl +++ b/NDTensors/src/lib/BlockSparseArrays/examples/README.jl @@ -64,6 +64,8 @@ function main() @test permuted_b == permutedims(Array(b), (2, 1)) @test b + b ≈ Array(b) + Array(b) + @test b + b isa BlockSparseArray + @test block_nstored(b + b) == 2 scaled_b = 2b @test scaled_b ≈ 2Array(b) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 6e84b0dd22..6094e0eb13 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -1,5 +1,6 @@ -using BlockArrays: Block +using BlockArrays: AbstractBlockArray, Block, blockedrange using Dictionaries: Dictionary, Indices +using ..SparseArrayInterface: stored_indices # TODO: Use `Tuple` conversion once # BlockArrays.jl PR is merged. @@ -12,3 +13,28 @@ end function blocks_to_cartesianindices(d::Dictionary{<:Block}) return Dictionary(blocks_to_cartesianindices(eachindex(d)), d) end + +function block_reshape(a::AbstractBlockArray, dims::Tuple{Vararg{Vector{Int}}}) + return block_reshape(a, blockedrange.(dims)) +end + +function block_reshape(a::AbstractBlockArray, dims::Vararg{Vector{Int}}) + return block_reshape(a, dims) +end + +tuple_oneto(n) = ntuple(identity, n) + +function block_reshape(a::AbstractBlockArray, axes::Tuple{Vararg{BlockedUnitRange}}) + reshaped_blocks_a = reshape(blocks(a), blocklength.(axes)) + reshaped_a = similar(a, axes) + for I in stored_indices(reshaped_blocks_a) + block_size_I = map(i -> length(axes[i][Block(I[i])]), tuple_oneto(length(axes))) + # TODO: Better converter here. + reshaped_a[Block(Tuple(I))] = reshape(reshaped_blocks_a[I], block_size_I) + end + return reshaped_a +end + +function block_reshape(a::AbstractBlockArray, axes::Vararg{BlockedUnitRange}) + return block_reshape(a, axes) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl new file mode 100644 index 0000000000..a65f47947f --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl @@ -0,0 +1,5 @@ +using ..SparseArrayInterface: SparseArrayInterface, nstored + +function SparseArrayInterface.nstored(a::AbstractBlockArray) + return sum(b -> nstored(b), blocks(a)) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 2be0e157a3..e12b68a603 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -2,12 +2,16 @@ module BlockSparseArrays include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") +include("blocksparsearrayinterface/broadcast.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") +include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") include("abstractblocksparsearray/abstractblocksparsevector.jl") include("abstractblocksparsearray/arraylayouts.jl") include("abstractblocksparsearray/linearalgebra.jl") +include("abstractblocksparsearray/broadcast.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysExtensions/BlockArraysExtensions.jl") +include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl index adc73343fd..481693bbb1 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -9,13 +9,7 @@ abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end # Base `AbstractArray` interface Base.axes(::AbstractBlockSparseArray) = error("Not implemented") -# BlockArrays `AbstractBlockArray` interface -BlockArrays.blocks(::AbstractBlockSparseArray) = error("Not implemented") - -blocktype(a::AbstractBlockSparseArray) = eltype(blocks(a)) - blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented") -blocktype(arraytype::Type{<:AbstractBlockSparseArray}) = eltype(blockstype(arraytype)) # Base `AbstractArray` interface function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N} @@ -28,30 +22,3 @@ function Base.setindex!( blocksparse_setindex!(a, value, I...) return a end - -function Base.setindex!( - a::AbstractBlockSparseArray{<:Any,N}, value, I::BlockIndex{N} -) where {N} - blocksparse_setindex!(a, value, I) - return a -end - -function Base.setindex!(a::AbstractBlockSparseArray{<:Any,N}, value, I::Block{N}) where {N} - blocksparse_setindex!(a, value, I) - return a -end - -# `BlockArrays` interface -function BlockArrays.viewblock( - a::AbstractBlockSparseArray{<:Any,N}, I::Block{N,Int} -) where {N} - return blocksparse_viewblock(a, I) -end - -# Needed by `BlockArrays` matrix multiplication interface -function Base.similar( - arraytype::Type{<:AbstractBlockSparseArray{T}}, axes::Tuple{Vararg{BlockedUnitRange}} -) where {T} - # TODO: Make generic for GPU! - return BlockSparseArray{T}(undef, axes) -end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl index e0a420404c..7e768bc73a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl @@ -3,6 +3,7 @@ using BlockArrays: BlockLayout using ..SparseArrayInterface: SparseLayout using LinearAlgebra: mul! +# TODO: Generalize to `BlockSparseArrayLike`. function ArrayLayouts.MemoryLayout(arraytype::Type{<:AbstractBlockSparseArray}) outer_layout = typeof(MemoryLayout(blockstype(arraytype))) inner_layout = typeof(MemoryLayout(blocktype(arraytype))) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl new file mode 100644 index 0000000000..0d1c942d18 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl @@ -0,0 +1,5 @@ +using Base.Broadcast: Broadcast + +function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike}) + return BlockSparseArrayStyle{ndims(arraytype)}() +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl new file mode 100644 index 0000000000..f6c129660c --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -0,0 +1,42 @@ +using ArrayLayouts: LayoutArray + +# Map +function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{BlockSparseArrayLike}) + blocksparse_map!(f, a_dest, a_srcs...) + return a_dest +end + +function Base.copy!(a_dest::AbstractArray, a_src::BlockSparseArrayLike) + blocksparse_copy!(a_dest, a_src) + return a_dest +end + +function Base.copyto!(a_dest::AbstractArray, a_src::BlockSparseArrayLike) + blocksparse_copyto!(a_dest, a_src) + return a_dest +end + +# Fix ambiguity error +function Base.copyto!(a_dest::LayoutArray, a_src::BlockSparseArrayLike) + blocksparse_copyto!(a_dest, a_src) + return a_dest +end + +function Base.permutedims!(a_dest::AbstractArray, a_src::BlockSparseArrayLike, perm) + blocksparse_permutedims!(a_dest, a_src, perm) + return a_dest +end + +function Base.mapreduce(f, op, as::Vararg{BlockSparseArrayLike}; kwargs...) + return blocksparse_mapreduce(f, op, as...; kwargs...) +end + +# TODO: Why isn't this calling `mapreduce` already? +function Base.iszero(a::BlockSparseArrayLike) + return blocksparse_iszero(a) +end + +# TODO: Why isn't this calling `mapreduce` already? +function Base.isreal(a::BlockSparseArrayLike) + return blocksparse_isreal(a) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl new file mode 100644 index 0000000000..b8b53f1342 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -0,0 +1,63 @@ +using Adapt: WrappedArray + +const WrappedAbstractBlockSparseArray{T,N,A} = WrappedArray{ + T,N,<:AbstractBlockSparseArray,<:AbstractBlockSparseArray{T,N} +} + +const BlockSparseArrayLike{T,N} = Union{ + <:AbstractBlockSparseArray{T,N},<:WrappedAbstractBlockSparseArray{T,N} +} + +# BlockArrays `AbstractBlockArray` interface +BlockArrays.blocks(a::BlockSparseArrayLike) = blocksparse_blocks(a) + +blocktype(a::BlockSparseArrayLike) = eltype(blocks(a)) + +# TODO: Use `parenttype` from `Unwrap`. +blockstype(arraytype::Type{<:WrappedAbstractBlockSparseArray}) = parenttype(arraytype) + +blocktype(arraytype::Type{<:BlockSparseArrayLike}) = eltype(blockstype(arraytype)) + +function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::BlockIndex{N}) where {N} + blocksparse_setindex!(a, value, I) + return a +end + +function Base.setindex!(a::BlockSparseArrayLike{<:Any,N}, value, I::Block{N}) where {N} + blocksparse_setindex!(a, value, I) + return a +end + +# Fix ambiguity error +function Base.setindex!(a::BlockSparseArrayLike{<:Any,1}, value, I::Block{1}) + blocksparse_setindex!(a, value, I) + return a +end + +# `BlockArrays` interface +# TODO: Is this needed if `blocks` is defined? +function BlockArrays.viewblock(a::BlockSparseArrayLike{<:Any,N}, I::Block{N,Int}) where {N} + return blocksparse_viewblock(a, I) +end + +# Needed by `BlockArrays` matrix multiplication interface +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}} +) + return similar(arraytype, eltype(arraytype), axes) +end + +# Needed by `BlockArrays` matrix multiplication interface +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} +) + # TODO: Make generic for GPU! Use `blocktype`. + return BlockSparseArray{elt}(undef, axes) +end + +function Base.similar( + a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} +) + # TODO: Make generic for GPU! Use `blocktype`. + return BlockSparseArray{eltype(a)}(undef, axes) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl index e3e9f6d3a7..ef38447c7f 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blocksparsearray.jl @@ -102,3 +102,11 @@ BlockArrays.blocks(a::BlockSparseArray) = a.blocks # TODO: Use `SetParameters`. blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B + +# Base interface +function Base.similar( + a::AbstractBlockSparseArray, elt::Type, axes::Tuple{Vararg{BlockedUnitRange}} +) + # TODO: Preserve GPU data! + return BlockSparseArray{elt}(undef, axes) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 3d8b5b4939..d61750dd4a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -1,5 +1,10 @@ using BlockArrays: Block, BlockIndex, block, blocks, blockcheckbounds, findblockindex -using ..SparseArrayInterface: nstored +using ..SparseArrayInterface: perm, iperm, nstored +using MappedArrays: mappedarray + +function blocksparse_blocks(a::AbstractArray) + return blocks(a) +end function blocksparse_getindex(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N} @boundscheck checkbounds(a, I...) @@ -24,7 +29,7 @@ function blocksparse_setindex!(a::AbstractArray{<:Any,N}, value, I::Block{N}) wh # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - blocks(a)[i...] = value + blocksparse_blocks(a)[i...] = value return a end @@ -32,9 +37,27 @@ function blocksparse_viewblock(a::AbstractArray{<:Any,N}, I::Block{N}) where {N} # TODO: Create a conversion function, say `CartesianIndex(Int.(Tuple(I)))`. i = I.n @boundscheck blockcheckbounds(a, i...) - return blocks(a)[i...] + return blocksparse_blocks(a)[i...] end function block_nstored(a::AbstractArray) - return nstored(blocks(a)) + return nstored(blocksparse_blocks(a)) +end + +# Base +function blocksparse_map!(f, a_dest::AbstractArray, a_srcs::AbstractArray...) + map!(f, blocks(a_dest), blocks.(a_srcs)...) + return a_dest +end + +# PermutedDimsArray +function blocksparse_blocks(a::PermutedDimsArray) + blocks_parent = blocksparse_blocks(parent(a)) + # Lazily permute each block + blocks_parent_mapped = mappedarray( + Base.Fix2(PermutedDimsArray, perm(a)), + Base.Fix2(PermutedDimsArray, iperm(a)), + blocks_parent, + ) + return PermutedDimsArray(blocks_parent_mapped, perm(a)) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl index 707a399240..479f78c334 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blockzero.jl @@ -18,26 +18,22 @@ struct BlockZero{Axes} axes::Axes end -function (f::BlockZero)(a::AbstractArray, I::CartesianIndex) +function (f::BlockZero)(a::AbstractArray, I) return f(eltype(a), I) end -function (f::BlockZero)( - arraytype::Type{<:AbstractArray{<:Any,N}}, I::CartesianIndex{N} -) where {N} +function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I) # TODO: Make sure this works for sparse or block sparse blocks, immutable # blocks, diagonal blocks, etc.! return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false) end # Fallback so that `SparseArray` with scalar elements works. -function (f::BlockZero)(blocktype::Type{<:Number}, I::CartesianIndex) +function (f::BlockZero)(blocktype::Type{<:Number}, I) return zero(blocktype) end # Fallback to Array if it is abstract -function (f::BlockZero)( - arraytype::Type{AbstractArray{T,N}}, I::CartesianIndex{N} -) where {T,N} - return fill!(Array{T,N}(undef, block_size(f.axes, Block(Tuple(I)))), zero(T)) +function (f::BlockZero)(arraytype::Type{AbstractArray{T,N}}, I) where {T,N} + return f(Array{T,N}, I) end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl new file mode 100644 index 0000000000..c07110cb74 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/broadcast.jl @@ -0,0 +1,41 @@ +using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted +using ..BroadcastMapConversion: map_function, map_args + +struct BlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end + +# Define for new sparse array types. +# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) +# return BlockSparseArrayStyle{ndims(arraytype)}() +# end + +BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() +BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() + +Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a +function Broadcast.BroadcastStyle( + ::BlockSparseArrayStyle{N}, a::DefaultArrayStyle +) where {N} + return BroadcastStyle(DefaultArrayStyle{N}(), a) +end +function Broadcast.BroadcastStyle( + ::BlockSparseArrayStyle{N}, ::Broadcast.Style{Tuple} +) where {N} + return DefaultArrayStyle{N}() +end + +# TODO: Use `allocate_output`, share logic with `map`. +function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type) + # TODO: Is this a good definition? Probably should check that + # they have consistent axes. + return similar(first(map_args(bc)), elt) +end + +# Broadcasting implementation +function Base.copyto!( + dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}} +) where {N} + # convert to map + # flatten and only keep the AbstractArray arguments + blocksparse_map!(map_function(bc), dest, map_args(bc)...) + return dest +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/test/runtests.jl index 1586c42e47..eb11318548 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/runtests.jl @@ -1,7 +1,8 @@ @eval module $(gensym()) using BlockArrays: Block, BlockedUnitRange, blockedrange, blocklength, blocksize using LinearAlgebra: mul! -using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored +using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored, block_reshape +using NDTensors.SparseArrayInterface: nstored using Test: @test, @testset include("TestBlockSparseArraysUtils.jl") @testset "BlockSparseArrays (eltype=$elt)" for elt in @@ -37,6 +38,47 @@ include("TestBlockSparseArraysUtils.jl") end end end + @testset "Tensor algebra" begin + a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4])) + a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) + a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) + @test eltype(a) == elt + @test block_nstored(a) == 2 + @test nstored(a) == 2 * 4 + 3 * 3 + + b = 2 * a + @test Array(b) ≈ 2 * Array(a) + @test eltype(b) == elt + @test block_nstored(b) == 2 + @test nstored(b) == 2 * 4 + 3 * 3 + + b = a + a + @test Array(b) ≈ 2 * Array(a) + @test eltype(b) == elt + @test block_nstored(b) == 2 + @test nstored(b) == 2 * 4 + 3 * 3 + + x = BlockSparseArray{elt}(undef, ([3, 4], [2, 3])) + x[Block(1, 2)] = randn(elt, size(@view(x[Block(1, 2)]))) + x[Block(2, 1)] = randn(elt, size(@view(x[Block(2, 1)]))) + b = a .+ a .+ 3 .* PermutedDimsArray(x, (2, 1)) + @test Array(b) ≈ 2 * Array(a) + 3 * permutedims(Array(x), (2, 1)) + @test eltype(b) == elt + @test block_nstored(b) == 2 + @test nstored(b) == 2 * 4 + 3 * 3 + + b = permutedims(a, (2, 1)) + @test Array(b) ≈ permutedims(Array(a), (2, 1)) + @test eltype(b) == elt + @test block_nstored(b) == 2 + @test nstored(b) == 2 * 4 + 3 * 3 + + b = map(x -> 2x, a) + @test Array(b) ≈ 2 * Array(a) + @test eltype(b) == elt + @test block_nstored(b) == 2 + @test nstored(b) == 2 * 4 + 3 * 3 + end @testset "LinearAlgebra" begin a1 = BlockSparseArray{elt}([2, 3], [2, 3]) a1[Block(1, 1)] = randn(elt, size(@view(a1[Block(1, 1)]))) @@ -47,5 +89,15 @@ include("TestBlockSparseArraysUtils.jl") @test a_dest isa BlockSparseArray{elt} @test block_nstored(a_dest) == 1 end + @testset "block_reshape" begin + a = BlockSparseArray{elt}(undef, ([3, 4], [2, 3])) + a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)]))) + a[Block(2, 1)] = randn(elt, size(@view(a[Block(2, 1)]))) + b = block_reshape(a, [6, 8, 9, 12]) + @test reshape(a[Block(1, 2)], 9) == b[Block(3)] + @test reshape(a[Block(2, 1)], 8) == b[Block(2)] + @test block_nstored(b) == 2 + @test nstored(b) == 17 + end end end diff --git a/NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl b/NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl index b72cdb8358..da90718dc4 100644 --- a/NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl +++ b/NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl @@ -1,3 +1,4 @@ +using Accessors: @set using Dictionaries: Dictionary, set! using ..SparseArrayInterface: SparseArrayInterface, AbstractSparseArray, getindex_zero_function @@ -74,6 +75,9 @@ end Base.size(a::SparseArrayDOK) = a.dims SparseArrayInterface.getindex_zero_function(a::SparseArrayDOK) = a.zero +function SparseArrayInterface.set_getindex_zero_function(a::SparseArrayDOK, f) + return @set a.zero = f +end function SparseArrayInterface.setindex_notstored!( a::SparseArrayDOK{<:Any,N}, value, I::CartesianIndex{N} diff --git a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl index c40db936ba..eeb1d3a970 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl @@ -13,6 +13,7 @@ include("sparsearrayinterface/wrappers.jl") include("sparsearrayinterface/zero.jl") include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl") include("abstractsparsearray/abstractsparsearray.jl") +include("abstractsparsearray/wrappedabstractsparsearray.jl") include("abstractsparsearray/sparsearrayinterface.jl") include("abstractsparsearray/base.jl") include("abstractsparsearray/broadcast.jl") diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/base.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/base.jl index 587571a060..8480b2fd88 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/base.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/base.jl @@ -1,18 +1,18 @@ using ..SparseArrayInterface: SparseArrayInterface # Base -function Base.:(==)(a1::AbstractSparseArray, a2::AbstractSparseArray) +function Base.:(==)(a1::SparseArrayLike, a2::SparseArrayLike) return SparseArrayInterface.sparse_isequal(a1, a2) end -function Base.reshape(a::AbstractSparseArray, dims::Tuple{Vararg{Int}}) +function Base.reshape(a::SparseArrayLike, dims::Tuple{Vararg{Int}}) return SparseArrayInterface.sparse_reshape(a, dims) end -function Base.zero(a::AbstractSparseArray) +function Base.zero(a::SparseArrayLike) return SparseArrayInterface.sparse_zero(a) end -function Base.one(a::AbstractSparseArray) +function Base.one(a::SparseArrayLike) return SparseArrayInterface.sparse_one(a) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl index deb262c827..8827b406a3 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/baseinterface.jl @@ -10,6 +10,10 @@ function Base.getindex(a::AbstractSparseArray, I...) return SparseArrayInterface.sparse_getindex(a, I...) end +function Base.isassigned(a::AbstractSparseArray, I::Integer...) + return SparseArrayInterface.sparse_isassigned(a, I...) +end + function Base.setindex!(a::AbstractSparseArray, I...) return SparseArrayInterface.sparse_setindex!(a, I...) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/broadcast.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/broadcast.jl index 152ddb2ab0..fdade0f775 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/broadcast.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/broadcast.jl @@ -1,4 +1,4 @@ # Broadcasting -function Broadcast.BroadcastStyle(arraytype::Type{<:AbstractSparseArray}) +function Broadcast.BroadcastStyle(arraytype::Type{<:SparseArrayLike}) return SparseArrayInterface.SparseArrayStyle{ndims(arraytype)}() end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/map.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/map.jl index 9546786648..bc85aa1099 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/map.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/map.jl @@ -1,42 +1,42 @@ using ArrayLayouts: LayoutArray # Map -function Base.map!(f, dest::AbstractArray, src::AbstractSparseArray) - SparseArrayInterface.sparse_map!(f, dest, src) - return dest +function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{SparseArrayLike}) + SparseArrayInterface.sparse_map!(f, a_dest, a_srcs...) + return a_dest end -function Base.copy!(dest::AbstractArray, src::AbstractSparseArray) - SparseArrayInterface.sparse_copy!(dest, src) - return dest +function Base.copy!(a_dest::AbstractArray, a_src::SparseArrayLike) + SparseArrayInterface.sparse_copy!(a_dest, a_src) + return a_dest end -function Base.copyto!(dest::AbstractArray, src::AbstractSparseArray) - SparseArrayInterface.sparse_copyto!(dest, src) - return dest +function Base.copyto!(a_dest::AbstractArray, a_src::SparseArrayLike) + SparseArrayInterface.sparse_copyto!(a_dest, a_src) + return a_dest end # Fix ambiguity error -function Base.copyto!(dest::LayoutArray, src::AbstractSparseArray) - SparseArrayInterface.sparse_copyto!(dest, src) - return dest +function Base.copyto!(a_dest::LayoutArray, a_src::SparseArrayLike) + SparseArrayInterface.sparse_copyto!(a_dest, a_src) + return a_dest end -function Base.permutedims!(dest::AbstractArray, src::AbstractSparseArray, perm) - SparseArrayInterface.sparse_permutedims!(dest, src, perm) - return dest +function Base.permutedims!(a_dest::AbstractArray, a_src::SparseArrayLike, perm) + SparseArrayInterface.sparse_permutedims!(a_dest, a_src, perm) + return a_dest end -function Base.mapreduce(f, op, a::AbstractSparseArray; kwargs...) - return SparseArrayInterface.sparse_mapreduce(f, op, a; kwargs...) +function Base.mapreduce(f, op, as::Vararg{SparseArrayLike}; kwargs...) + return SparseArrayInterface.sparse_mapreduce(f, op, as...; kwargs...) end # TODO: Why isn't this calling `mapreduce` already? -function Base.iszero(a::AbstractSparseArray) +function Base.iszero(a::SparseArrayLike) return SparseArrayInterface.sparse_iszero(a) end # TODO: Why isn't this calling `mapreduce` already? -function Base.isreal(a::AbstractSparseArray) +function Base.isreal(a::SparseArrayLike) return SparseArrayInterface.sparse_isreal(a) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/wrappedabstractsparsearray.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/wrappedabstractsparsearray.jl new file mode 100644 index 0000000000..e42ac91e8b --- /dev/null +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/wrappedabstractsparsearray.jl @@ -0,0 +1,9 @@ +using Adapt: WrappedArray + +const WrappedAbstractSparseArray{T,N,A} = WrappedArray{ + T,N,<:AbstractSparseArray,<:AbstractSparseArray{T,N} +} + +const SparseArrayLike{T,N} = Union{ + <:AbstractSparseArray{T,N},<:WrappedAbstractSparseArray{T,N} +} diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/base.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/base.jl index 5a97c8b132..474e3bd129 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/base.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/base.jl @@ -86,7 +86,6 @@ function sparse_one!(a::AbstractMatrix) return a end -# TODO: Make `sparse_one!`? function sparse_one(a::AbstractMatrix) a = sparse_zero(a) sparse_one!(a) @@ -108,7 +107,7 @@ end function sparse_reshape!(a_dest::AbstractArray, a_src::AbstractArray, dims) @assert length(a_src) == prod(dims) - sparse_fill!(a_dest, zero(eltype(a_src))) + sparse_zero!(a_dest) linear_inds = LinearIndices(a_src) dest_cartesian_inds = CartesianIndices(dims) for I in stored_indices(a_src) diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl index d267fd42f2..b66cef69cd 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl @@ -31,6 +31,10 @@ end MaybeStoredIndex(I, I_storage) = StoredIndex(I, StorageIndex(I_storage)) MaybeStoredIndex(I, I_storage::Nothing) = NotStoredIndex(I) +# Convert the index into an index into the storage. +# Return `NotStoredIndex(I)` if it isn't in the storage. +storage_index(a::AbstractArray, I...) = MaybeStoredIndex(a, I...) + function storage_indices(a::AbstractArray) return eachindex(sparse_storage(a)) end @@ -68,7 +72,7 @@ end # Implementation of element access function _sparse_getindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N} @boundscheck checkbounds(a, I) - return sparse_getindex(a, MaybeStoredIndex(a, I)) + return sparse_getindex(a, storage_index(a, I)) end # Handle trailing indices or linear indexing @@ -99,7 +103,7 @@ end # Implementation of element access function _sparse_setindex!(a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N}) where {N} @boundscheck checkbounds(a, I) - sparse_setindex!(a, value, MaybeStoredIndex(a, I)) + sparse_setindex!(a, value, storage_index(a, I)) return a end @@ -142,6 +146,19 @@ function sparse_setindex!(a::AbstractArray, value, I::NotStoredIndex) return a end +# isassigned +function sparse_isassigned(a::AbstractArray, I::Integer...) + return sparse_isassigned(a, CartesianIndex(I)) +end +sparse_isassigned(a::AbstractArray, I::NotStoredIndex) = true +sparse_isassigned(a::AbstractArray, I::StoredIndex) = sparse_isassigned(a, StorageIndex(I)) +function sparse_isassigned(a::AbstractArray, I::StorageIndex) + return isassigned(sparse_storage(a), index(I)) +end +function sparse_isassigned(a::AbstractArray, I::CartesianIndex) + return sparse_isassigned(a, storage_index(a, I)) +end + # A set of indices into the storage of the sparse array. struct StorageIndices{I} i::I diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/interface_optional.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/interface_optional.jl index 0e67b0977f..122995dd34 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/interface_optional.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/interface_optional.jl @@ -1,7 +1,11 @@ # Optional interface. -# Access a zero value. + +# Function for computing unstored zero values. getindex_zero_function(::AbstractArray) = Zero() +# Change the function for computing unstored values +set_getindex_zero_function(a::AbstractArray, f) = error("Not implemented") + function getindex_notstored(a::AbstractArray, I) return getindex_zero_function(a)(a, I) end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl index ef0b5918d0..6d74d184dc 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/map.jl @@ -1,9 +1,44 @@ using Compat: allequal +# Represents a value that isn't stored +# Used to hijack dispatch +struct NotStoredValue{Value} + value::Value +end +value(v::NotStoredValue) = v.value +nstored(::NotStoredValue) = false +Base.:*(x::Number, y::NotStoredValue) = false +Base.:+(::NotStoredValue, ::NotStoredValue...) = false +Base.:+(x::Number, ::NotStoredValue...) = x +Base.iszero(::NotStoredValue) = true +Base.isreal(::NotStoredValue) = true +Base.conj(x::NotStoredValue) = conj(value(x)) + +notstored_index(a::AbstractArray) = NotStoredIndex(first(eachindex(a))) + +# Get some not-stored value +function get_notstored(a::AbstractArray) + return sparse_getindex(a, notstored_index(a)) +end + +function apply_notstored(f, as::Vararg{AbstractArray}) + return apply(f, NotStoredValue.(get_notstored.(as))...) +end + +function apply(f, xs::Vararg{NotStoredValue}) + return f(xs...) + #return try + # return f(xs...) + #catch + # f(value(x)) + #end +end + # Test if the function preserves zero values and therefore # preserves the sparsity structure. function preserves_zero(f, as...) - return iszero(f(map(a -> sparse_getindex(a, NotStoredIndex(first(eachindex(a)))), as)...)) + # return iszero(f(map(get_notstored, as)...)) + return iszero(apply_notstored(f, as...)) end # Map a subset of indices @@ -61,8 +96,7 @@ end # TODO: Define `sparse_mapreducedim!`. function sparse_mapreduce(f, op, a::AbstractArray; kwargs...) output = mapreduce(f, op, sparse_storage(a); kwargs...) - # TODO: Use more general `zero` value. - # TODO: Better way to check that zeros don't affect the output? - @assert op(output, f(zero(eltype(a)))) == output + f_notstored = apply_notstored(f, a) + @assert op(output, f_notstored) == output return output end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl index 6dbf922c37..aa3503df76 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl @@ -1,4 +1,8 @@ perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P +iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,IP}) where {IP} = IP + +# TODO: Use `Base.PermutedDimsArrays.genperm` or +# https://github.com/jipolanco/StaticPermutations.jl? genperm(v, perm) = map(j -> v[j], perm) genperm(v::CartesianIndex, perm) = CartesianIndex(map(j -> Tuple(v)[j], perm)) @@ -14,5 +18,3 @@ function index_to_storage_index( ) where {N} return storage_index_to_index(parent(a), genperm(I, perm(a))) end - -# TODO: Add `SubArray`, `ReshapedArray`, `Diagonal`, etc. diff --git a/NDTensors/src/lib/TensorAlgebra/src/TensorAlgebra.jl b/NDTensors/src/lib/TensorAlgebra/src/TensorAlgebra.jl index 74ae00c662..3e57b7e66d 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/TensorAlgebra.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/TensorAlgebra.jl @@ -1,4 +1,5 @@ module TensorAlgebra +# TODO: Move inside `include`. using ..AlgorithmSelection: Algorithm, @Algorithm_str using LinearAlgebra: mul! diff --git a/Project.toml b/Project.toml index 3e833daa8e..6eb16ddb15 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.3.52" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"