Skip to content

Commit

Permalink
[NDTensors] Catch broken case of DiagBlockSparse contraction (#1475)
Browse files Browse the repository at this point in the history
* [NDTensors] Catch broken case of DiagBlockSparse contraction

* [NDTensors] Bump to v0.3.14
  • Loading branch information
mtfishman authored May 30, 2024
1 parent d2c35ef commit 7a41171
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
5 changes: 5 additions & 0 deletions NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ function contract!(
labelsT2,
contraction_plan,
) where {ElR<:Number,NR}
if any(b -> !allequal(Tuple(b)), nzblocks(T2))
return error(
"When contracting a BlockSparse tensor with a DiagBlockSparse tensor, the DiagBlockSparse tensor must be block diagonal for the time being.",
)
end
already_written_to = Dict{Block{NR},Bool}()
indsR = inds(R)
indsT1 = inds(T1)
Expand Down
29 changes: 26 additions & 3 deletions NDTensors/test/test_diagblocksparse.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
@eval module $(gensym())
using Dictionaries: Dictionary
using NDTensors
using Test: @testset, @test

using NDTensors:
NDTensors,
Block,
BlockSparseTensor,
DiagBlockSparse,
Tensor,
blockoffsets,
contract,
nzblocks
using Random: randn!
using Test: @test, @test_throws, @testset
@testset "UniformDiagBlockSparseTensor basic functionality" begin
NeverAlias = NDTensors.NeverAlias
AllowAlias = NDTensors.AllowAlias
Expand All @@ -25,4 +33,19 @@ using Test: @testset, @test
@test conj(NeverAlias(), tensor)[1, 1] == conj(c)
@test conj(AllowAlias(), tensor)[1, 1] == conj(c)
end
@testset "DiagBlockSparse off-diagonal (eltype=$elt)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
)
inds1 = ([1, 1], [1, 1])
inds2 = ([1, 1], [1, 1])
blocks = [(1, 2), (2, 1)]
a1 = BlockSparseTensor{elt}(blocks, inds1...)
for b in nzblocks(a1)
randn!(a1[b])
end
a2 = Tensor(DiagBlockSparse(one(elt), blockoffsets(a1)), inds2)
for (labels1, labels2) in (((1, -1), (-1, 2)), ((-1, -2), (-1, -2)))
@test_throws ErrorException contract(a1, labels1, a2, labels2)
end
end
end

0 comments on commit 7a41171

Please sign in to comment.