diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index c8f8a87c5e..0e1281595a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -6,10 +6,11 @@ using BlockArrays: BlockIndex, BlockIndexRange, BlockRange, + BlockSlice, + BlockVector, BlockedOneTo, BlockedUnitRange, - BlockVector, - BlockSlice, + BlockedVector, block, blockaxes, blockedrange, @@ -22,6 +23,26 @@ using Dictionaries: Dictionary, Indices using ..GradedAxes: blockedunitrange_getindices using ..SparseArrayInterface: SparseArrayInterface, nstored, stored_indices +# A return type for `blocks(array)` when `array` isn't blocked. +# Represents a vector with just that single block. +struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} + array::Array +end +blocks_maybe_single(a) = blocks(a) +blocks_maybe_single(a::Array) = SingleBlockView(a) +function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N} + @assert all(isone, index) + return a.array +end + +# A wrapper around a potentially blocked array that is not blocked. +struct NonBlockedArray{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} + array::Array +end +Base.size(a::NonBlockedArray) = size(a.array) +Base.getindex(a::NonBlockedArray{<:Any,N}, I::Vararg{Integer,N}) where {N} = a.array[I...] +BlockArrays.blocks(a::NonBlockedArray) = SingleBlockView(a.array) + # BlockIndices works around an issue that the indices of BlockSlice # are restricted to AbstractUnitRange{Int}. struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T} @@ -39,6 +60,21 @@ function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}}) @assert length(S.indices[Block(i)]) == length(i.indices) return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)]) end +function Base.getindex( + S::BlockIndices{<:AbstractBlockVector{<:Block{1}}}, i::BlockSlice{<:Block{1}} +) + # TODO: Check for conistency of indices. + # Calling `mortar` on the indices wraps the multiple blocks into a single + # block, since the result shouldn't be blocked. + ## return BlockIndices(S.blocks[Block(i)], NonBlockedArray(S.indices[Block(i)])) + return NonBlockedArray(S.indices[Block(i)]) +end +function Base.getindex( + S::BlockIndices{<:BlockedVector{<:Block{1},<:BlockRange{1}}}, i::BlockSlice{<:Block{1}} +) + return i +end + function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}}) # TODO: Check that `i.indices` is consistent with `S.indices`. # TODO: Turn this into a `blockedunitrange_getindices` definition. @@ -94,6 +130,13 @@ function to_blockindices(a::BlockedOneTo{<:Integer}, I::UnitRange{<:Integer}) ) end +# This handles non-blocked slices. +# For example: +# a = BlockSparseArray{Float64}([2, 2, 2, 2]) +# I = BlockedVector(Block.(1:4), [2, 2]) +# @views a[I][Block(1)] +to_blockindices(a::Base.OneTo{<:Integer}, I::UnitRange{<:Integer}) = I + # TODO: This is type piracy. This is used in `reindex` when making # views of blocks of sliced block arrays, for example: # ```julia @@ -291,14 +334,17 @@ function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer}) return only(blockaxes(r)) end -# This handles changing the blocking, for example: +# This handles block merging: # a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2]) +# I = BlockedVector(Block.(1:4), [2, 2]) +# I = BlockVector(Block.(1:4), [2, 2]) # I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) +# I = BlockVector([Block(4), Block(3), Block(2), Block(1)], [2, 2]) # a[I, I] -# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`. -function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer}) - # TODO: Probably this is incorrect and should be something like: - # return findblock(axis, first(r)):findblock(axis, last(r)) +function blockrange(axis::BlockedOneTo{<:Integer}, r::AbstractBlockVector{<:Block{1}}) + for b in r + @assert b ∈ blockaxes(axis, 1) + end return only(blockaxes(r)) end @@ -337,6 +383,10 @@ function blockrange(axis::AbstractUnitRange, r::Base.Slice) return only(blockaxes(axis)) end +function blockrange(axis::AbstractUnitRange, r::NonBlockedArray) + return Block(1):Block(1) +end + function blockrange(axis::AbstractUnitRange, r) return error("Slicing not implemented for range of type `$(typeof(r))`.") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl index 7d69e142ca..16973fcd9a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/map.jl @@ -25,16 +25,6 @@ end # This is type piracy, try to avoid this, maybe requires defining `map`. ## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) -struct SingleBlockView{T,N,Array<:AbstractArray{T,N}} <: AbstractArray{T,N} - array::Array -end -_blocks(a) = blocks(a) -_blocks(a::Array) = SingleBlockView(a) -function Base.getindex(a::SingleBlockView{<:Any,N}, index::Vararg{Int,N}) where {N} - @assert all(isone, index) - return a.array -end - reblock(a) = a # If the blocking of the slice doesn't match the blocking of the # parent array, reblock according to the blocking of the parent array. @@ -57,11 +47,11 @@ function SparseArrayInterface.sparse_map!( BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) # TODO: Investigate why this doesn't work: # block_dest = @view a_dest[_block(BI_dest)] - block_dest = _blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...] + block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...] # TODO: Investigate why this doesn't work: # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) block_srcs = ntuple(length(a_srcs)) do i - return _blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] + return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] end subblock_dest = @view block_dest[BI_dest.indices...] subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl index d2b205007c..ca19ab70d6 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/views.jl @@ -1,4 +1,5 @@ -using BlockArrays: BlockArrays, Block, BlockIndexRange, blocklength, blocksize, viewblock +using BlockArrays: + BlockArrays, Block, BlockIndexRange, BlockedVector, blocklength, blocksize, viewblock # This splits `BlockIndexRange{N}` into # `NTuple{N,BlockIndexRange{1}}`. diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index b9de79031b..a85ee61221 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -52,17 +52,13 @@ end # a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])] # Permute and merge blocks. # TODO: This isn't merging blocks yet, that needs to be implemented that. -function blocksparse_to_indices(a, inds, I::Tuple{BlockVector{<:Block{1}},Vararg{Any}}) +function blocksparse_to_indices( + a, inds, I::Tuple{AbstractBlockVector{<:Block{1}},Vararg{Any}} +) I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1])) return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) end -# TODO: Should this be combined with the version above? -function blocksparse_to_indices(a, inds, I::Tuple{BlockedVector{<:Block{1}},Vararg{Any}}) - I1 = blockedunitrange_getindices(inds[1], I[1]) - return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...) -end - # TODO: Need to implement this! function block_merge end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 586d208e18..69cff3dfbf 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -418,13 +418,7 @@ include("TestBlockSparseArraysUtils.jl") @views for b in [Block(1, 1), Block(2, 2)] a[b] = randn(elt, size(a[b])) end - for I in ( - Block.(1:2), - [Block(1), Block(2)], - BlockVector([Block(1), Block(2)], [1, 1]), - # TODO: This should merge blocks. - BlockVector([Block(1), Block(2)], [2]), - ) + for I in (Block.(1:2), [Block(1), Block(2)]) b = @view a[I, I] for I in CartesianIndices(a) @test b[I] == a[I] @@ -439,12 +433,7 @@ include("TestBlockSparseArraysUtils.jl") # TODO: Use `blocksizes(a)[Int.(Tuple(b))...]` once available. a[b] = randn(elt, size(a[b])) end - for I in ( - [Block(2), Block(1)], - BlockVector([Block(2), Block(1)], [1, 1]), - # TODO: This should merge blocks. - BlockVector([Block(2), Block(1)], [2]), - ) + for I in ([Block(2), Block(1)],) b = @view a[I, I] @test b[Block(1, 1)] == a[Block(2, 2)] @test b[Block(2, 1)] == a[Block(1, 2)] diff --git a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl index 6e978d142b..0ebcf83674 100644 --- a/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl +++ b/NDTensors/src/lib/GradedAxes/src/blockedunitrange.jl @@ -1,5 +1,6 @@ using BlockArrays: BlockArrays, + AbstractBlockVector, AbstractBlockedUnitRange, Block, BlockIndex, @@ -74,16 +75,14 @@ function blockedunitrange_getindices( end # TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. +# TODO: Make a special case for `BlockedVector{<:Block{1},<:BlockRange{1}}`? +# For example: +# ```julia +# blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) +# return blockedrange(blocklengths) +# ``` function blockedunitrange_getindices( - a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1},<:BlockRange{1}} -) - blocklengths = map(bs -> sum(b -> length(a[b]), bs), blocks(indices)) - return blockedrange(blocklengths) -end - -# TODO: Make sure this handles block labels (AbstractGradedUnitRange) correctly. -function blockedunitrange_getindices( - a::AbstractBlockedUnitRange, indices::BlockedVector{<:Block{1}} + a::AbstractBlockedUnitRange, indices::AbstractBlockVector{<:Block{1}} ) return mortar(map(bs -> mortar(map(b -> a[b], bs)), blocks(indices))) end