From ae84ba0a6fad8ef20915703913c7c34a1e0729ad Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 31 May 2024 12:12:37 -0400 Subject: [PATCH] Add tests for block sparse matrix multiplication --- .../wrappedabstractblocksparsearray.jl | 18 ++++++++++++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index 784b4e27c8..a98353e152 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -88,6 +88,24 @@ function Base.similar( return similar(arraytype, eltype(arraytype), axes) end +# Needed by `BlockArrays` matrix multiplication interface +# Fixes ambiguity error with `BlockArrays.jl`. +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{BlockedUnitRange,Vararg{AbstractUnitRange{Int}}}, +) + return similar(arraytype, eltype(arraytype), axes) +end + +# Needed by `BlockArrays` matrix multiplication interface +# Fixes ambiguity error with `BlockArrays.jl`. +function Base.similar( + arraytype::Type{<:BlockSparseArrayLike}, + axes::Tuple{AbstractUnitRange{Int},BlockedUnitRange,Vararg{AbstractUnitRange{Int}}}, +) + return similar(arraytype, eltype(arraytype), axes) +end + # Needed for disambiguation function Base.similar( arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}} diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 62cc23d9db..565d349c33 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -266,6 +266,18 @@ include("TestBlockSparseArraysUtils.jl") @test a_dest isa BlockSparseArray{elt} @test block_nstored(a_dest) == 1 end + @testset "Matrix multiplication" begin + a1 = BlockSparseArray{elt}([2, 3], [2, 3]) + a1[Block(1, 2)] = randn(elt, size(@view(a1[Block(1, 2)]))) + a1[Block(2, 1)] = randn(elt, size(@view(a1[Block(2, 1)]))) + a2 = BlockSparseArray{elt}([2, 3], [2, 3]) + a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)]))) + a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)]))) + @test Array(a1 * a2) ≈ Array(a1) * Array(a2) + @test Array(a1' * a2) ≈ Array(a1') * Array(a2) + @test Array(a1 * a2') ≈ Array(a1) * Array(a2') + @test Array(a1' * a2') ≈ Array(a1') * Array(a2') + end @testset "TensorAlgebra" begin a1 = BlockSparseArray{elt}([2, 3], [2, 3]) a1[Block(1, 1)] = randn(elt, size(@view(a1[Block(1, 1)])))