From 7a41171ff3399887db2249e9c00eba957b46c5d5 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 29 May 2024 22:42:14 -0400 Subject: [PATCH] [NDTensors] Catch broken case of DiagBlockSparse contraction (#1475) * [NDTensors] Catch broken case of DiagBlockSparse contraction * [NDTensors] Bump to v0.3.14 --- NDTensors/src/blocksparse/diagblocksparse.jl | 5 ++++ NDTensors/test/test_diagblocksparse.jl | 29 ++++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/NDTensors/src/blocksparse/diagblocksparse.jl b/NDTensors/src/blocksparse/diagblocksparse.jl index 76cb29e566..a68b7a769c 100644 --- a/NDTensors/src/blocksparse/diagblocksparse.jl +++ b/NDTensors/src/blocksparse/diagblocksparse.jl @@ -639,6 +639,11 @@ function contract!( labelsT2, contraction_plan, ) where {ElR<:Number,NR} + if any(b -> !allequal(Tuple(b)), nzblocks(T2)) + return error( + "When contracting a BlockSparse tensor with a DiagBlockSparse tensor, the DiagBlockSparse tensor must be block diagonal for the time being.", + ) + end already_written_to = Dict{Block{NR},Bool}() indsR = inds(R) indsT1 = inds(T1) diff --git a/NDTensors/test/test_diagblocksparse.jl b/NDTensors/test/test_diagblocksparse.jl index f47f4b5da0..2aa0e98c24 100644 --- a/NDTensors/test/test_diagblocksparse.jl +++ b/NDTensors/test/test_diagblocksparse.jl @@ -1,8 +1,16 @@ @eval module $(gensym()) using Dictionaries: Dictionary -using NDTensors -using Test: @testset, @test - +using NDTensors: + NDTensors, + Block, + BlockSparseTensor, + DiagBlockSparse, + Tensor, + blockoffsets, + contract, + nzblocks +using Random: randn! +using Test: @test, @test_throws, @testset @testset "UniformDiagBlockSparseTensor basic functionality" begin NeverAlias = NDTensors.NeverAlias AllowAlias = NDTensors.AllowAlias @@ -25,4 +33,19 @@ using Test: @testset, @test @test conj(NeverAlias(), tensor)[1, 1] == conj(c) @test conj(AllowAlias(), tensor)[1, 1] == conj(c) end +@testset "DiagBlockSparse off-diagonal (eltype=$elt)" for elt in ( + Float32, Float64, Complex{Float32}, Complex{Float64} +) + inds1 = ([1, 1], [1, 1]) + inds2 = ([1, 1], [1, 1]) + blocks = [(1, 2), (2, 1)] + a1 = BlockSparseTensor{elt}(blocks, inds1...) + for b in nzblocks(a1) + randn!(a1[b]) + end + a2 = Tensor(DiagBlockSparse(one(elt), blockoffsets(a1)), inds2) + for (labels1, labels2) in (((1, -1), (-1, 2)), ((-1, -2), (-1, -2))) + @test_throws ErrorException contract(a1, labels1, a2, labels2) + end +end end