From d70b89ed5967988a120a4fe207f3bd95d1e352b9 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Sun, 9 Jun 2024 10:35:41 -0400 Subject: [PATCH] [BlockSparseArrays] Fix some bugs involving BlockSparseArrays with dual axes (#1488) * [BlockSparseArrays] Fix some bugs involving BlockSparseArrays with dual axes * [NDTensors] Bump to v0.3.23 --- NDTensors/Project.toml | 2 +- .../test/runtests.jl | 55 ++++++++++++++++++- .../BlockArraysExtensions.jl | 2 +- .../src/lib/GradedAxes/src/unitrangedual.jl | 18 +++++- 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index fe455e9ecd..17a14ceb90 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.22" +version = "0.3.23" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl index bed17b7928..fe177167dd 100644 --- a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl @@ -1,7 +1,7 @@ @eval module $(gensym()) using Compat: Returns using Test: @test, @testset, @test_broken -using BlockArrays: Block, blocksize +using BlockArrays: Block, blockedrange, blocksize using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange @@ -73,6 +73,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) # be the real test. for ax in axes(m) @test ax isa GradedUnitRange + # TODO: Current `fusedims` doesn't merge + # common sectors, need to fix. @test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)] @test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)] end @@ -94,8 +96,13 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "dual axes" begin r = gradedrange([U1(0) => 2, U1(1) => 2]) a = BlockSparseArray{elt}(dual(r), r) - a[Block(1, 1)] = randn(elt, size(a[Block(1, 1)])) - a[Block(2, 2)] = randn(elt, size(a[Block(2, 2)])) + @views for b in [Block(1, 1), Block(2, 2)] + a[b] = randn(elt, size(a[b])) + end + # TODO: Define and use `isdual` here. + @test axes(a, 1) isa UnitRangeDual + @test axes(a, 2) isa GradedUnitRange + @test !(axes(a, 2) isa UnitRangeDual) a_dense = Array(a) @test eachindex(a) == CartesianIndices(size(a)) for I in eachindex(a) @@ -104,8 +111,50 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test axes(a') == dual.(reverse(axes(a))) # TODO: Define and use `isdual` here. @test axes(a', 1) isa UnitRangeDual + @test axes(a', 2) isa GradedUnitRange @test !(axes(a', 2) isa UnitRangeDual) @test isnothing(show(devnull, MIME("text/plain"), a)) + + # Check preserving dual in tensor algebra. + for b in (a + a, 2 * a, 3 * a - a) + @test Array(b) ≈ 2 * Array(a) + # TODO: Define and use `isdual` here. + @test axes(b, 1) isa UnitRangeDual + @test axes(b, 2) isa GradedUnitRange + @test !(axes(b, 2) isa UnitRangeDual) + end + + @test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)]))) + @test @view(a[Block(1, 1)]) == a[Block(1, 1)] + + # Test case when all axes are dual. + for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2])) + a = BlockSparseArray{elt}(dual(r), dual(r)) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a) + for ax in axes(b) + @test ax isa UnitRangeDual + end + end + + # Test case when all axes are dual + # from taking the adjoint. + for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2])) + a = BlockSparseArray{elt}(r, r) + @views for i in [Block(1, 1), Block(2, 2)] + a[i] = randn(elt, size(a[i])) + end + b = 2 * a' + @test block_nstored(b) == 2 + @test Array(b) == 2 * Array(a)' + for ax in axes(b) + @test ax isa UnitRangeDual + end + end end @testset "Matrix multiplication" begin r = gradedrange([U1(0) => 2, U1(1) => 3]) diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 31fb71e104..43b8afab24 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -129,7 +129,7 @@ function cartesianindices(axes::Tuple, b::Block) end # Get the range within a block. -function blockindexrange(axis::AbstractUnitRange, r::UnitRange) +function blockindexrange(axis::AbstractUnitRange, r::AbstractUnitRange) bi1 = findblockindex(axis, first(r)) bi2 = findblockindex(axis, last(r)) b = block(bi1) diff --git a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl index 8f1a86f1b8..358542856c 100644 --- a/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/unitrangedual.jl @@ -68,8 +68,24 @@ end using NDTensors.LabelledNumbers: LabelledNumbers, label LabelledNumbers.label(a::UnitRangeDual) = dual(label(nondual(a))) -using BlockArrays: BlockArrays, blockaxes, blocklasts, findblock +using BlockArrays: BlockArrays, blockaxes, blocklasts, combine_blockaxes, findblock BlockArrays.blockaxes(a::UnitRangeDual) = blockaxes(nondual(a)) BlockArrays.blockfirsts(a::UnitRangeDual) = label_dual.(blockfirsts(nondual(a))) BlockArrays.blocklasts(a::UnitRangeDual) = label_dual.(blocklasts(nondual(a))) BlockArrays.findblock(a::UnitRangeDual, index::Integer) = findblock(nondual(a), index) +function BlockArrays.combine_blockaxes(a1::UnitRangeDual, a2::UnitRangeDual) + return dual(combine_blockaxes(dual(a1), dual(a2))) +end + +# This is needed when constructing `CartesianIndices` from +# a tuple of unit ranges that have this kind of dual unit range. +# TODO: See if we can find some more elegant way of constructing +# `CartesianIndices`, maybe by defining conversion of `LabelledInteger` +# to `Int`, defining a more general `convert` function, etc. +function Base.OrdinalRange{Int,Int}( + r::UnitRangeDual{<:LabelledInteger{Int},<:LabelledUnitRange{Int,UnitRange{Int}}} +) + # TODO: Implement this broadcasting operation and use it here. + # return Int.(r) + return unlabel(nondual(r)) +end