Skip to content

Commit

Permalink
Merge branch 'main' into tupletools_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
emstoudenmire committed Jun 11, 2024
2 parents 47e8d46 + de78e9f commit e4fef41
Show file tree
Hide file tree
Showing 24 changed files with 955 additions and 208 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.18"
version = "0.3.25"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
20 changes: 14 additions & 6 deletions NDTensors/ext/NDTensorscuTENSORExt/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@ using NDTensors.Expose: Exposed, unexpose
using cuTENSOR: cuTENSOR, CuArray, CuTensor

function NDTensors.contract!(
R::Exposed{<:CuArray,<:DenseTensor},
exposedR::Exposed{<:CuArray,<:DenseTensor},
labelsR,
T1::Exposed{<:CuArray,<:DenseTensor},
exposedT1::Exposed{<:CuArray,<:DenseTensor},
labelsT1,
T2::Exposed{<:CuArray,<:DenseTensor},
exposedT2::Exposed{<:CuArray,<:DenseTensor},
labelsT2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
cuR = CuTensor(array(unexpose(R)), collect(labelsR))
cuT1 = CuTensor(array(unexpose(T1)), collect(labelsT1))
cuT2 = CuTensor(array(unexpose(T2)), collect(labelsT2))
R, T1, T2 = unexpose.((exposedR, exposedT1, exposedT2))
zoffR = iszero(array(R).offset)
arrayR = zoffR ? array(R) : copy(array(R))
arrayT1 = iszero(array(T1).offset) ? array(T1) : copy(array(T1))
arrayT2 = iszero(array(T2).offset) ? array(T2) : copy(array(T2))
cuR = CuTensor(arrayR, collect(labelsR))
cuT1 = CuTensor(arrayT1, collect(labelsT1))
cuT2 = CuTensor(arrayT2, collect(labelsT2))
cuTENSOR.mul!(cuR, cuT1, cuT2, α, β)
if !zoffR
array(R) .= cuR.data
end
return R
end
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, blocksize
using BlockArrays: Block, blockedrange, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, UnitRangeDual, dual, gradedrange
using NDTensors.GradedAxes:
GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange
using NDTensors.LabelledNumbers: label
using NDTensors.SparseArrayInterface: nstored
using NDTensors.TensorAlgebra: fusedims, splitdims
Expand Down Expand Up @@ -68,30 +69,103 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = BlockSparseArray{elt}(d1, d2, d1, d2)
blockdiagonal!(randn!, a)
m = fusedims(a, (1, 2), (3, 4))
@test axes(m, 1) isa GradedUnitRange
@test axes(m, 2) isa GradedUnitRange
# TODO: Once block merging is implemented, this should
# be the real test.
for ax in axes(m)
@test ax isa GradedUnitRange
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]
@test blocklabels(ax) == [U1(0), U1(1), U1(1), U1(2)]
end
for I in CartesianIndices(m)
if I CartesianIndex.([(1, 1), (4, 4)])
@test !iszero(m[I])
else
@test iszero(m[I])
end
end
@test a[1, 1, 1, 1] == m[1, 1]
@test a[2, 2, 2, 2] == m[4, 4]
# TODO: Current `fusedims` doesn't merge
# common sectors, need to fix.
@test_broken blocksize(m) == (3, 3)
@test blocksize(m) == (4, 4)
@test a == splitdims(m, (d1, d2), (d1, d2))
end
@testset "dual axes" begin
r = gradedrange([U1(0) => 2, U1(1) => 2])
a = BlockSparseArray{elt}(dual(r), r)
a[Block(1, 1)] = randn(size(a[Block(1, 1)]))
a[Block(2, 2)] = randn(size(a[Block(2, 2)]))
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
for ax in ((r, r), (dual(r), r), (r, dual(r)), (dual(r), dual(r)))
a = BlockSparseArray{elt}(ax...)
@views for b in [Block(1, 1), Block(2, 2)]
a[b] = randn(elt, size(a[b]))
end
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(ax[dim]) === typeof(axes(a, dim))
end
@test @view(a[Block(1, 1)])[1, 1] == a[1, 1]
@test @view(a[Block(1, 1)])[2, 1] == a[2, 1]
@test @view(a[Block(1, 1)])[1, 2] == a[1, 2]
@test @view(a[Block(1, 1)])[2, 2] == a[2, 2]
@test @view(a[Block(2, 2)])[1, 1] == a[3, 3]
@test @view(a[Block(2, 2)])[2, 1] == a[4, 3]
@test @view(a[Block(2, 2)])[1, 2] == a[3, 4]
@test @view(a[Block(2, 2)])[2, 2] == a[4, 4]
@test @view(a[Block(1, 1)])[1:2, 1:2] == a[1:2, 1:2]
@test @view(a[Block(2, 2)])[1:2, 1:2] == a[3:4, 3:4]
a_dense = Array(a)
@test eachindex(a) == CartesianIndices(size(a))
for I in eachindex(a)
@test a[I] == a_dense[I]
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test typeof(axes(a', 1)) === typeof(dual(axes(a, 2)))
@test typeof(axes(a', 2)) === typeof(dual(axes(a, 1)))
@test isnothing(show(devnull, MIME("text/plain"), a))

# Check preserving dual in tensor algebra.
for b in (a + a, 2 * a, 3 * a - a)
@test Array(b) 2 * Array(a)
# TODO: Define and use `isdual` here.
for dim in 1:ndims(a)
@test typeof(axes(b, dim)) === typeof(axes(b, dim))
end
end

@test isnothing(show(devnull, MIME("text/plain"), @view(a[Block(1, 1)])))
@test @view(a[Block(1, 1)]) == a[Block(1, 1)]
end

# Test case when all axes are dual.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
a = BlockSparseArray{elt}(dual(r), dual(r))
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)
for ax in axes(b)
@test ax isa UnitRangeDual
end
end

# Test case when all axes are dual
# from taking the adjoint.
for r in (gradedrange([U1(0) => 2, U1(1) => 2]), blockedrange([2, 2]))
a = BlockSparseArray{elt}(r, r)
@views for i in [Block(1, 1), Block(2, 2)]
a[i] = randn(elt, size(a[i]))
end
b = 2 * a'
@test block_nstored(b) == 2
@test Array(b) == 2 * Array(a)'
for ax in axes(b)
@test ax isa UnitRangeDual
end
end
@test axes(a') == dual.(reverse(axes(a)))
# TODO: Define and use `isdual` here.
@test axes(a', 1) isa UnitRangeDual
@test !(axes(a', 2) isa UnitRangeDual)
@test isnothing(show(devnull, MIME("text/plain"), a))
end
@testset "Matrix multiplication" begin
r = gradedrange([U1(0) => 2, U1(1) => 3])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using BlockArrays:
BlockRange,
BlockedUnitRange,
BlockVector,
BlockSlice,
block,
blockaxes,
blockedrange,
Expand All @@ -18,6 +19,30 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
blocks::B
indices::I
end
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)

# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices)
return error("Not implemented")
Expand All @@ -29,6 +54,40 @@ function sub_axis(a::AbstractUnitRange, indices::AbstractUnitRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockRange{1}})
return sub_axis(a, indices.block)
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:Block{1}})
return sub_axis(a, Block(indices))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockSlice{<:BlockIndexRange{1}})
return sub_axis(a, indices.block)
end

function sub_axis(a::AbstractUnitRange, indices::BlockIndices)
return sub_axis(a, indices.blocks)
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::Block)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::BlockIndexRange)
return only(axes(blockedunitrange_getindices(a, indices)))
end

# TODO: Use `GradedAxes.blockedunitrange_getindices`.
# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices::AbstractVector{<:Block})
Expand Down Expand Up @@ -85,7 +144,7 @@ function cartesianindices(axes::Tuple, b::Block)
end

# Get the range within a block.
function blockindexrange(axis::AbstractUnitRange, r::UnitRange)
function blockindexrange(axis::AbstractUnitRange, r::AbstractUnitRange)
bi1 = findblockindex(axis, first(r))
bi2 = findblockindex(axis, last(r))
b = block(bi1)
Expand Down Expand Up @@ -115,8 +174,8 @@ function blockrange(axis::AbstractUnitRange, r::UnitRange)
end

function blockrange(axis::AbstractUnitRange, r::Int)
error("Slicing with integer values isn't supported.")
return findblock(axis, r)
## return findblock(axis, r)
return error("Slicing with integer values isn't supported.")
end

function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
Expand All @@ -131,10 +190,100 @@ function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
return blockrange(axis, r.block)
end

function blockrange(a::AbstractUnitRange, r::BlockIndices)
return blockrange(a, r.blocks)
end

function blockrange(axis::AbstractUnitRange, r::Block{1})
return r:r
end

function blockrange(axis::AbstractUnitRange, r::BlockIndexRange)
return Block(r):Block(r)
end

function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:BlockIndexRange{1}})
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end

function blockrange(
axis::AbstractUnitRange,
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
)
return map(b -> Block(b), blocks(r))
end

# This handles slicing with `:`/`Colon()`.
function blockrange(axis::AbstractUnitRange, r::Base.Slice)
# TODO: Maybe use `BlockRange`, but that doesn't output
# the same thing.
return only(blockaxes(axis))
end

function blockrange(axis::AbstractUnitRange, r)
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end

# This takes a range of indices `indices` of array `a`
# and maps it to the range of indices within block `block`.
function blockindices(a::AbstractArray, block::Block, indices::Tuple)
return blockindices(axes(a), block, indices)
end

function blockindices(axes::Tuple, block::Block, indices::Tuple)
return blockindices.(axes, Tuple(block), indices)
end

function blockindices(axis::AbstractUnitRange, block::Block, indices::AbstractUnitRange)
indices_within_block = intersect(indices, axis[block])
if iszero(length(indices_within_block))
# Falls outside of block
return 1:0
end
return only(blockindexrange(axis, indices_within_block).indices)
end

# This catches the case of `Vector{<:Block{1}}`.
# `BlockRange` gets wrapped in a `BlockSlice`, which is handled properly
# by the version with `indices::AbstractUnitRange`.
# TODO: This should get fixed in a better way inside of `BlockArrays`.
function blockindices(
axis::AbstractUnitRange, block::Block, indices::AbstractVector{<:Block{1}}
)
if block indices
# Falls outside of block
return 1:0
end
return Base.OneTo(length(axis[block]))
end

function blockindices(a::AbstractUnitRange, b::Block, r::BlockIndices)
return blockindices(a, b, r.blocks)
end

function blockindices(
a::AbstractUnitRange,
b::Block,
r::BlockVector{BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
)
# TODO: Change to iterate over `BlockRange(r)`
# once https://github.com/JuliaArrays/BlockArrays.jl/issues/404
# is fixed.
for bl in blocks(r)
if b == Block(bl)
return only(bl.indices)
end
end
return error("Block not found.")
end

function cartesianindices(a::AbstractArray, b::Block)
return cartesianindices(axes(a), b)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Make more generic for GPU.
a_dest = BlockSparseArray{eltype(a)}(axes)
a_dest .= a
return a_dest
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
# TODO: Make more generic for GPU.
a_dest = Array{eltype(a)}(undef, length.(axes))
a_dest .= a
return a_dest
end
Loading

0 comments on commit e4fef41

Please sign in to comment.