Skip to content

Commit

Permalink
Add tests for block sparse matrix multiplication
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed May 31, 2024
1 parent 80708b1 commit ae84ba0
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,24 @@ function Base.similar(
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# Fixes ambiguity error with `BlockArrays.jl`.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed by `BlockArrays` matrix multiplication interface
# Fixes ambiguity error with `BlockArrays.jl`.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
axes::Tuple{AbstractUnitRange{Int},BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
)
return similar(arraytype, eltype(arraytype), axes)
end

# Needed for disambiguation
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}}
Expand Down
12 changes: 12 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,18 @@ include("TestBlockSparseArraysUtils.jl")
@test a_dest isa BlockSparseArray{elt}
@test block_nstored(a_dest) == 1
end
@testset "Matrix multiplication" begin
a1 = BlockSparseArray{elt}([2, 3], [2, 3])
a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)])))
a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)])))
a2 = BlockSparseArray{elt}([2, 3], [2, 3])
a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)])))
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
@test Array(a1 * a2) Array(a1) * Array(a2)
@test Array(a1' * a2) Array(a1') * Array(a2)
@test Array(a1 * a2') Array(a1) * Array(a2')
@test Array(a1' * a2') Array(a1') * Array(a2')
end
@testset "TensorAlgebra" begin
a1 = BlockSparseArray{elt}([2, 3], [2, 3])
a1[Block(1, 1)] = randn(elt, size(@view(a1[Block(1, 1)])))
Expand Down

0 comments on commit ae84ba0

Please sign in to comment.