Skip to content

Commit

Permalink
Merge branch 'main' into kmp5/feature/JLArrays_extension
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Jun 26, 2024
2 parents d0b6a2f + 9c22961 commit d6e675d
Show file tree
Hide file tree
Showing 20 changed files with 339 additions and 319 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.35"
version = "0.3.37"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,6 @@ function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
return dual.(reverse(axes(a')))
end

# TODO: Delete this definition in favor of the one in
# GradedAxes once https://github.com/JuliaArrays/BlockArrays.jl/pull/405 is merged.
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
# to merge blocks.
function GradedAxes.blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
blocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.
# TODO: Remove `unlabel` once `BlockArray` axes
# type is generalized in BlockArrays.jl.
# TODO: Support using `BlockSparseVector`, need
# to make more `BlockSparseArray` constructors.
return BlockSparseArray(blocks, (blockedrange(length.(blocks)),))
end

# This definition is only needed since calls like
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
# returns a `BlockSparseVector` instead of a `BlockVector`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
# TODO: Fix this for `BlockedArray`.
@test_broken nstored(b) == 256
@test nstored(b) == 256
# TODO: Fix this for `BlockedArray`.
@test_broken block_nstored(b) == 16
for i in 1:ndims(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using BlockArrays:
AbstractBlockVector,
Block,
BlockRange,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
Expand All @@ -19,19 +20,6 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
Expand All @@ -42,6 +30,63 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# It seems like this isn't handling the case where `i` is a
# subslice of a block correctly (i.e. it ignores `i.indices`).
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
subblocks = S.blocks[Int.(i.block)]
subindices = mortar(
map(1:length(i.block)) do I
r = blocks(i.indices)[I]
return S.indices[first(r)]:S.indices[last(r)]
end,
)
return BlockIndices(subblocks, subindices)
end

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
# a = BlockSparseArray{elt}(undef, ([2, 3], [2, 3]))
# b = @view a[[Block(1)[1:1], Block(2)[1:2]], [Block(1)[1:1], Block(2)[1:2]]]
# b[Block(1, 1)]
# ```
# Without this change, BlockArrays has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
# while with this change it has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# Block(2)[1:2]
# ```
# i.e. it preserves the types of the blocks better. Upstream this fix to
# BlockArrays.jl. Also consider overloading `reindex` so that it calls
# a custom `getindex` function to avoid type piracy in the meantime.
# Also fix this in BlockArrays:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[Block(2)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
function Base.getindex(
a::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
I::BlockSlice{<:Block{1}},
)
# Check that the block slice corresponds to the correct block.
@assert I.indices == only(axes(a))[Block(I)]
return blocks(a)[Int(Block(I))]
end

# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices)
Expand Down Expand Up @@ -185,15 +230,30 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
return r
end

using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
# This handles changing the blocking, for example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = blockedrange([4, 4])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockedOneTo{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
return only(blockaxes(r))
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
# This handles changing the blocking, for example:
# a = BlockSparseArray{Float64}([2, 2, 2, 2], [2, 2, 2, 2])
# I = BlockedVector([Block(4), Block(3), Block(2), Block(1)], [2, 2])
# a[I, I]
# TODO: Generalize to `AbstractBlockedUnitRange` and `AbstractBlockVector`.
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
# TODO: Probably this is incorrect and should be something like:
# return findblock(axis, first(r)):findblock(axis, last(r))
return only(blockaxes(r))
end

using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
include("abstractblocksparsearray/abstractblocksparsevector.jl")
include("abstractblocksparsearray/view.jl")
include("abstractblocksparsearray/views.jl")
include("abstractblocksparsearray/arraylayouts.jl")
include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand All @@ -25,8 +25,8 @@ function Broadcast.BroadcastStyle(
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
Vararg{Any},
},
},
Expand All @@ -40,7 +40,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Any,BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{Any,BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ end

# TODO: Why isn't this calling `mapreduce` already?
function Base.iszero(a::BlockSparseArrayLike)
return sparse_iszero(a)
return sparse_iszero(blocks(a))
end

# TODO: Why isn't this calling `mapreduce` already?
function Base.isreal(a::BlockSparseArrayLike)
return sparse_isreal(a)
return sparse_isreal(blocks(a))
end
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ end
function SparseArrayInterface.sparse_storage(a::AbstractBlockSparseArray)
return BlockSparseStorage(a)
end

function SparseArrayInterface.nstored(a::BlockSparseArrayLike)
return sum(nstored, sparse_storage(blocks(a)); init=zero(Int))
end

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using BlockArrays: Block, BlockSlices

function blocksparse_view(a, I...)
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
end

# These definitions circumvent some generic definitions in BlockArrays.jl:
# https://github.com/JuliaArrays/BlockArrays.jl/blob/master/src/views.jl
# which don't handle subslices of blocks properly.
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Block{N}
) where {N}
return blocksparse_view(a, I)
end
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Vararg{Block{1},N}
) where {N}
return blocksparse_view(a, I...)
end
function Base.view(
V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlices}}, I::Block{1}
)
return blocksparse_view(a, I)
end
Loading

0 comments on commit d6e675d

Please sign in to comment.