Skip to content

Commit

Permalink
[NDTensors] cuTENSOR extension (#1395)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored May 6, 2024
1 parent 0e5d3f8 commit 89a428a
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 107 deletions.
6 changes: 4 additions & 2 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,22 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"

[extensions]
NDTensorsAMDGPUExt = "AMDGPU"
NDTensorsCUDAExt = "CUDA"
NDTensorscuTENSORExt = "cuTENSOR"
NDTensorsHDF5Ext = "HDF5"
NDTensorsMetalExt = "Metal"
NDTensorsOctavianExt = "Octavian"
NDTensorsTBLISExt = "TBLIS"
NDTensorsAMDGPUExt = "AMDGPU"

[compat]
Accessors = "0.1.33"
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/ext/NDTensorscuTENSORExt/NDTensorscuTENSORExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module NDTensorscuTENSORExt
include("contract.jl")
end
20 changes: 20 additions & 0 deletions NDTensors/ext/NDTensorscuTENSORExt/contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using NDTensors: NDTensors, DenseTensor, array
using NDTensors.Expose: Exposed, unexpose
using cuTENSOR: cuTENSOR, CuArray, CuTensor

function NDTensors.contract!(
R::Exposed{<:CuArray,<:DenseTensor},
labelsR,
T1::Exposed{<:CuArray,<:DenseTensor},
labelsT1,
T2::Exposed{<:CuArray,<:DenseTensor},
labelsT2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
cuR = CuTensor(array(unexpose(R)), collect(labelsR))
cuT1 = CuTensor(array(unexpose(T1)), collect(labelsT1))
cuT2 = CuTensor(array(unexpose(T2)), collect(labelsT2))
cuTENSOR.mul!(cuR, cuT1, cuT2, α, β)
return R
end
7 changes: 4 additions & 3 deletions NDTensors/src/blocksparse/contract_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ function contract!(
return R
end

using NDTensors.Expose: expose
# Function barrier to improve type stability,
# since `Folds`/`FLoops` is not type stable:
# https://discourse.julialang.org/t/type-instability-in-floop-reduction/68598
Expand Down Expand Up @@ -139,11 +140,11 @@ function _contract!(
)

contract!(
R[blockR],
expose(R[blockR]),
labelsR,
tensor1[blocktensor1],
expose(tensor1[blocktensor1]),
labelstensor1,
tensor2[blocktensor2],
expose(tensor2[blocktensor2]),
labelstensor2,
α,
β,
Expand Down
6 changes: 4 additions & 2 deletions NDTensors/src/blocksparse/contract_sequential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function contract!(
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
)
end

using NDTensors.Expose: expose
###########################################################################
# Old version
# TODO: DELETE, keeping around for now for testing/benchmarking.
Expand Down Expand Up @@ -97,7 +97,9 @@ function contract!(
# Overwrite the block of R
β = zero(ElR)
end
contract!(Rblock, labelsR, T1block, labelsT1, T2block, labelsT2, α, β)
contract!(
expose(Rblock), labelsR, expose(T1block), labelsT1, expose(T2block), labelsT2, α, β
)
end
return R
end
12 changes: 11 additions & 1 deletion NDTensors/src/blocksparse/contract_threads.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using NDTensors.Expose: expose
# TODO: This seems to be faster than the newer version using `Folds.jl`
# in `contract_folds.jl`, investigate why.
function contract_blocks!(
Expand Down Expand Up @@ -115,7 +116,16 @@ function contract!(
ElR, labelsR, blockR, indsR, labelsT1, blockT1, indsT1, labelsT2, blockT2, indsT2
)

contract!(blockR, labelsR, blockT1, labelsT1, blockT2, labelsT2, α, β)
contract!(
expose(blockR),
labelsR,
expose(blockT1),
labelsT1,
expose(blockT2),
labelsT2,
α,
β,
)
# Now keep adding to the block, since it has
# been written to
# R .= α .* (T1 * T2) .+ R
Expand Down
22 changes: 22 additions & 0 deletions NDTensors/src/diag/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ function _contract!!(
return R
end

function contract!(
output_tensor::Exposed{<:AbstractArray,<:DiagTensor},
labelsoutput_tensor,
tensor1::Exposed,
labelstensor1,
tensor2::Exposed,
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
@assert isone(α)
@assert iszero(β)
return contract!(
unexpose(output_tensor),
labelsoutput_tensor,
unexpose(tensor1),
labelstensor1,
unexpose(tensor2),
labelstensor2,
)
end

function contract!(
R::DiagTensor{ElR,NR},
labelsR,
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/lib/TensorAlgebra/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

[compat]
TensorOperations = "4.1.1"
113 changes: 59 additions & 54 deletions NDTensors/src/lib/TensorAlgebra/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using NDTensors.TensorAlgebra:
using NDTensors: NDTensors
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
using .NDTensorsTestUtils: default_rtol
using TensorOperations: TensorOperations
using Test: @test, @test_broken, @testset
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "BlockedPermutation" begin
Expand Down Expand Up @@ -111,62 +110,67 @@ end
@test eltype(a_split) === elt
@test a_split reshape(a, (2, 3, 20))
end
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
for (d1s, d2s, d_dests) in (
((1, 2), (1, 2), ()),
((1, 2), (2, 1), ()),
((1, 2), (2, 3), (1, 3)),
((1, 2), (2, 3), (3, 1)),
((2, 1), (2, 3), (3, 1)),
((1, 2, 3), (2, 3, 4), (1, 4)),
((1, 2, 3), (2, 3, 4), (4, 1)),
((3, 2, 1), (4, 2, 3), (4, 1)),
((1, 2, 3), (3, 4), (1, 2, 4)),
((1, 2, 3), (3, 4), (4, 1, 2)),
((1, 2, 3), (3, 4), (2, 4, 1)),
((3, 1, 2), (3, 4), (2, 4, 1)),
((3, 2, 1), (4, 3), (2, 4, 1)),
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)),
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)),
)
a1 = randn(elt1, map(i -> dims[i], d1s))
labels1 = map(i -> labels[i], d1s)
a2 = randn(elt2, map(i -> dims[i], d2s))
labels2 = map(i -> labels[i], d2s)
labels_dest = map(i -> labels[i], d_dests)

# Don't specify destination labels
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest′, a1, labels1, a2, labels2
## Right now TensorOperations version is downgraded when using cuTENSOR to `v0.7` we
## are waiting for TensorOperations to support the breaking changes in cuTENSOR 2.x
if !("cutensor" ARGS)
using TensorOperations: TensorOperations
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
for (d1s, d2s, d_dests) in (
((1, 2), (1, 2), ()),
((1, 2), (2, 1), ()),
((1, 2), (2, 3), (1, 3)),
((1, 2), (2, 3), (3, 1)),
((2, 1), (2, 3), (3, 1)),
((1, 2, 3), (2, 3, 4), (1, 4)),
((1, 2, 3), (2, 3, 4), (4, 1)),
((3, 2, 1), (4, 2, 3), (4, 1)),
((1, 2, 3), (3, 4), (1, 2, 4)),
((1, 2, 3), (3, 4), (4, 1, 2)),
((1, 2, 3), (3, 4), (2, 4, 1)),
((3, 1, 2), (3, 4), (2, 4, 1)),
((3, 2, 1), (4, 3), (2, 4, 1)),
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)),
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)),
)
@test a_dest a_dest_tensoroperations
a1 = randn(elt1, map(i -> dims[i], d1s))
labels1 = map(i -> labels[i], d1s)
a2 = randn(elt2, map(i -> dims[i], d2s))
labels2 = map(i -> labels[i], d2s)
labels_dest = map(i -> labels[i], d_dests)

# Specify destination labels
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
@test a_dest a_dest_tensoroperations
# Don't specify destination labels
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
@test a_dest a_dest_tensoroperations

# Specify α and β
elt_dest = promote_type(elt1, elt2)
# TODO: Using random `α`, `β` causing
# random test failures, investigate why.
α = elt_dest(1.2) # randn(elt_dest)
β = elt_dest(2.4) # randn(elt_dest)
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
## Here we loosened the tolerance because of some floating point roundoff issue.
## with Float32 numbers
@test a_dest α * a_dest_tensoroperations + β * a_dest_init rtol =
10 * default_rtol(elt_dest)
# Specify destination labels
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
@test a_dest a_dest_tensoroperations

# Specify α and β
elt_dest = promote_type(elt1, elt2)
# TODO: Using random `α`, `β` causing
# random test failures, investigate why.
α = elt_dest(1.2) # randn(elt_dest)
β = elt_dest(2.4) # randn(elt_dest)
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
a_dest = copy(a_dest_init)
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
a_dest_tensoroperations = TensorOperations.tensorcontract(
labels_dest, a1, labels1, a2, labels2
)
## Here we loosened the tolerance because of some floating point roundoff issue.
## with Float32 numbers
@test a_dest α * a_dest_tensoroperations + β * a_dest_init rtol =
10 * default_rtol(elt_dest)
end
end
end
@testset "qr (eltype=$elt)" for elt in elts
Expand All @@ -182,4 +186,5 @@ end
@test a a′
end
end

end
36 changes: 32 additions & 4 deletions NDTensors/src/tensoroperations/generic_tensor_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ function contract(
return output_tensor
end

using NDTensors.Expose: Exposed, expose, unexpose
# Overload this function for immutable storage types
function _contract!!(
output_tensor::Tensor,
Expand All @@ -129,23 +130,50 @@ function _contract!!(
)
if α 1 || β 0
contract!(
output_tensor,
expose(output_tensor),
labelsoutput_tensor,
tensor1,
expose(tensor1),
labelstensor1,
tensor2,
expose(tensor2),
labelstensor2,
α,
β,
)
else
contract!(
output_tensor, labelsoutput_tensor, tensor1, labelstensor1, tensor2, labelstensor2
expose(output_tensor),
labelsoutput_tensor,
expose(tensor1),
labelstensor1,
expose(tensor2),
labelstensor2,
)
end
return output_tensor
end

function contract!(
output_tensor::Exposed,
labelsoutput_tensor,
tensor1::Exposed,
labelstensor1,
tensor2::Exposed,
labelstensor2,
α::Number=one(Bool),
β::Number=zero(Bool),
)
return contract!(
unexpose(output_tensor),
labelsoutput_tensor,
unexpose(tensor1),
labelstensor1,
unexpose(tensor2),
labelstensor2,
α,
β,
)
end

# Is this generic for all storage types?
function contract!!(
output_tensor::Tensor,
Expand Down
5 changes: 5 additions & 0 deletions NDTensors/test/ITensors/TestITensorDMRG/TestITensorDMRG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ reference_energies = Dict([
])

is_broken(dev, elt::Type, conserve_qns::Val) = false
## Currently there is an issue in blocksparse cutensor, seems to be related to using @view, I am still working to fix this issue.
## For some reason ComplexF64 works and throws an error in `test_broken`
function is_broken(dev::typeof(cu), elt::Type, conserve_qns::Val{true})
return ("cutensor" ARGS && elt != ComplexF64)
end

include("dmrg.jl")

Expand Down
9 changes: 8 additions & 1 deletion NDTensors/test/NDTensorsTestUtils/device_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ if "metal" in ARGS || "all" in ARGS
Pkg.add("Metal")
using Metal
end
if "cutensor" in ARGS || "all" in ARGS
if in("TensorOperations", map(v -> v.name, values(Pkg.dependencies())))
Pkg.rm("TensorOperations")
end
Pkg.add("cuTENSOR")
using CUDA, cuTENSOR
end

function devices_list(test_args)
devs = Vector{Function}(undef, 0)
if isempty(test_args) || "base" in test_args
push!(devs, NDTensors.cpu)
end

if "cuda" in test_args || "all" in test_args
if "cuda" in test_args || "cutensor" in test_args || "all" in test_args
if CUDA.functional()
push!(devs, NDTensors.CUDAExtensions.cu)
else
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[compat]
Metal = "1.1.0"
cuTENSOR = "2.0"
Loading

0 comments on commit 89a428a

Please sign in to comment.