Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Jan 16, 2025
1 parent d8a546d commit f523144
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 15 deletions.
6 changes: 6 additions & 0 deletions src/abstractblocksparsearray/abstractblocksparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value)
return a
end

# Catch zero-dimensional case to avoid scalar indexing.
function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value, ::Block{0})
blocks(a)[] = value
return a
end

function Base.setindex!(
a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Block{1},N}
) where {N}
Expand Down
3 changes: 2 additions & 1 deletion src/abstractblocksparsearray/map.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using DerivableInterfaces: @interface, AbstractArrayInterface, interface
using GPUArraysCore: @allowscalar
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase: SparseArraysBase, SparseArrayStyle

Expand Down Expand Up @@ -55,7 +56,7 @@ function map_zero_dim! end
@interface ::AbstractArrayInterface function map_zero_dim!(
f, a_dest::AbstractArray, a_srcs::AbstractArray...
)
a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
@allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...)
return a_dest
end

Expand Down
5 changes: 2 additions & 3 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ using BlockArrays:
blocks,
findblockindex
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
using GPUArraysCore: @allowscalar
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
AbstractSparseArrayInterface,
Expand Down Expand Up @@ -66,7 +65,7 @@ end
@interface ::AbstractBlockSparseArrayInterface function Base.getindex(
a::AbstractArray{<:Any,0}
)
return @allowscalar a[Block()[]]
return a[Block()[]]
end

# a[1:2, 1:2]
Expand Down Expand Up @@ -157,7 +156,7 @@ end
)
a_b = blocks(a)[]
# `value[]` handles scalars and 0-dimensional arrays.
@allowscalar a_b[] = value[]
a_b[] = value[]
# Set the block, required if it is structurally zero.
blocks(a)[] = a_b
return a
Expand Down
22 changes: 11 additions & 11 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,17 +183,17 @@ arrayts = (Array, JLArray)
@test blocksize(a) == ()
@test blocksizes(a) == fill(())
@test iszero(blockstoredlength(a))
@test iszero(a[])
@test iszero([CartesianIndex()])
@test iszero(@allowscalar(a[]))
@test iszero(@allowscalar(a[CartesianIndex()]))
@test a[Block()] == dev(fill(0))
@test iszero(a[Block()][])
@test iszero(a[Block()[]])
@test iszero(@allowscalar(a[Block()][]))
@test iszero(@allowscalar(a[Block()[]]))
@test Array(a) isa Array{elt,0}
@test Array(a) == fill(0)
for b in (
(b = copy(a); b[] = 2; b),
(b = copy(a); b[CartesianIndex()] = 2; b),
(b = copy(a); b[CartesianIndex()] = 2; b),
(b = copy(a); @allowscalar(b[] = 2); b),
(b = copy(a); @allowscalar(b[CartesianIndex()] = 2); b),
(b = copy(a); @allowscalar(b[Block()[]] = 2); b),
# Regression test for https://github.com/ITensor/BlockSparseArrays.jl/issues/27.
(b = copy(a); b[Block()] = dev(fill(2)); b),
)
Expand All @@ -202,11 +202,11 @@ arrayts = (Array, JLArray)
@test blocksize(b) == ()
@test blocksizes(b) == fill(())
@test isone(blockstoredlength(b))
@test b[] == 2
@test b[CartesianIndex()] == 2
@test @allowscalar(b[]) == 2
@test @allowscalar(b[CartesianIndex()]) == 2
@test b[Block()] == dev(fill(2))
@test b[Block()][] == 2
@test b[Block()[]] == 2
@test @allowscalar(b[Block()][]) == 2
@test @allowscalar(b[Block()[]]) == 2
@test Array(b) isa Array{elt,0}
@test Array(b) == fill(2)
end
Expand Down

0 comments on commit f523144

Please sign in to comment.