Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Unwrap module for dispatching on unwrapped types #1220

Merged
merged 93 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
401b9d4
Create Unwraped folder and move most of the `iswrappedarray` function…
kmp5VT Oct 26, 2023
0afcc9c
Make a prototype for permutedims in NDTensors
kmp5VT Oct 26, 2023
fffd862
format
kmp5VT Oct 26, 2023
61fbe32
add permutedims! to `Unwrap`
kmp5VT Oct 26, 2023
6adee77
format
kmp5VT Oct 26, 2023
0af661d
permutedims!! calls Base.permutedim! with expose
kmp5VT Oct 26, 2023
993d512
Call base permutedims not ndtensors
kmp5VT Oct 26, 2023
7915a8e
Arrays calls permutedims without leaf_parenttype
kmp5VT Oct 26, 2023
144c57d
write a `LinearAlgebra.mul!` for `Exposed`
kmp5VT Oct 26, 2023
9b4fa78
format
kmp5VT Oct 26, 2023
4db2297
Remove file
kmp5VT Oct 26, 2023
86b9e3b
Fix metal mul wrapper
kmp5VT Oct 26, 2023
23f9ec1
Create `copyto` function for Exposed struct
kmp5VT Oct 26, 2023
72d00ac
Properly export `Exposed`
kmp5VT Oct 26, 2023
8f41fcf
format
kmp5VT Oct 26, 2023
1c2e2c5
Some fixes to metal
kmp5VT Oct 27, 2023
08e92b1
fix permutedims!! signature
kmp5VT Oct 27, 2023
4d38994
format
kmp5VT Oct 27, 2023
7728b8c
Fix permute call only call when dest is a reshapedarray
kmp5VT Oct 27, 2023
5a6d524
format
kmp5VT Oct 27, 2023
c58e034
Don't use NDTensors mtl because 64 bit precision is not supported
kmp5VT Oct 27, 2023
134909c
Allow blockspare tests to support metal
kmp5VT Oct 27, 2023
1a90927
Creat unexpose function
kmp5VT Oct 27, 2023
3f6dd75
Make parent function for Exposed
kmp5VT Oct 27, 2023
0f02339
Favor unxpose over .object and return unexposed
kmp5VT Oct 27, 2023
58473b3
Revert blocksparse changes
kmp5VT Oct 27, 2023
2606602
remove AbstractArray copyto.jl
kmp5VT Oct 27, 2023
6d0b353
Add back mul!! function
kmp5VT Oct 27, 2023
db6edd1
Include mul.jl in NDTensors.jl
kmp5VT Oct 27, 2023
39ac932
Call mul!! in case we need to mutate
kmp5VT Oct 27, 2023
8398394
format
kmp5VT Oct 27, 2023
469fad7
Merge branch 'main' into kmp5/enhancements/unwrap
kmp5VT Oct 27, 2023
57b330f
Remove NDTensors custom permutedims
kmp5VT Oct 27, 2023
e0a425c
Remove files from NDTensors
kmp5VT Oct 27, 2023
f69ece3
alphabetical order
kmp5VT Oct 27, 2023
c5ebc5d
Use linearalgebra with Expose
kmp5VT Oct 27, 2023
5cad7ae
import permutedims!
kmp5VT Oct 27, 2023
61163c6
leaf_parentype -> unwrap_type
kmp5VT Oct 30, 2023
532a519
per matts comment
kmp5VT Oct 30, 2023
aad36ed
format
kmp5VT Oct 30, 2023
a4838ca
Merge branch 'main' into kmp5/enhancements/unwrap
kmp5VT Oct 30, 2023
c81a23b
import unwrap_type
kmp5VT Oct 30, 2023
df9d194
Move metal examples test to end of test suite
kmp5VT Oct 30, 2023
6ae4816
remove `Base.` and `LinearAlgebra.`
kmp5VT Oct 30, 2023
acaa9c6
add comment
kmp5VT Oct 30, 2023
3b9817a
format
kmp5VT Oct 30, 2023
e32a674
formatting
kmp5VT Oct 30, 2023
9c8cfde
second p
kmp5VT Oct 30, 2023
475e150
Make transpose of expose
kmp5VT Oct 30, 2023
0666dc8
using not import
kmp5VT Oct 30, 2023
c4f618e
Unwrap.unwrap_type
kmp5VT Oct 30, 2023
abeb9fa
get copyto working
kmp5VT Oct 30, 2023
488bce8
Use full path name
kmp5VT Oct 30, 2023
57f8bed
Fix mul!
kmp5VT Oct 30, 2023
b477058
Add some functions to expose
kmp5VT Oct 30, 2023
d648f1e
typo
kmp5VT Oct 30, 2023
0249427
use NDTensors.mtl
kmp5VT Oct 30, 2023
c9cc991
typo
kmp5VT Oct 30, 2023
3e93d7e
Working linear algebra qr, ql, and SVD with Expose
kmp5VT Oct 30, 2023
5501ff6
Remove deleted file
kmp5VT Oct 30, 2023
e8e562a
remove leaf_parenttype and using
kmp5VT Oct 30, 2023
4eda671
If dev is metal use Float32
kmp5VT Oct 30, 2023
d8d0349
Don't do Float64 on metal
kmp5VT Oct 30, 2023
9c1fad3
use unwrap cpu
kmp5VT Oct 30, 2023
c5cdd65
format
kmp5VT Oct 30, 2023
f09689b
Rename util, abstractarray
kmp5VT Oct 31, 2023
faafbfb
Fix issue in linearalgebra
kmp5VT Oct 31, 2023
d66f907
Update get/set index for full stack
kmp5VT Oct 31, 2023
26daedf
Skip Float64 on mtl
kmp5VT Oct 31, 2023
1d383d6
format
kmp5VT Oct 31, 2023
fa039ef
format
kmp5VT Oct 31, 2023
5fcf642
Fix scalar indexing issues
kmp5VT Oct 31, 2023
d3a7e7c
When A, B and C transpose just change order to `B * A = C`. There was…
kmp5VT Oct 31, 2023
56be813
format
kmp5VT Oct 31, 2023
57f9cb9
Merge remote-tracking branch 'origin/main' into kmp5/enhancements/unwrap
kmp5VT Oct 31, 2023
b96dd4e
Fix permutedims! for arraystorage
kmp5VT Oct 31, 2023
abe78eb
format
kmp5VT Oct 31, 2023
7df11ae
Remove `Base.` and others from inline function calls
kmp5VT Oct 31, 2023
aa8e3f9
Remove leaf_parenttype
kmp5VT Oct 31, 2023
d27ec70
Add todo comments
kmp5VT Oct 31, 2023
34cbaff
Remove deleted file
kmp5VT Oct 31, 2023
13aa27d
Add back expose in SVD
kmp5VT Oct 31, 2023
5b38f13
create a mul for all transpose
kmp5VT Nov 1, 2023
f3e91a2
per some comments
kmp5VT Nov 1, 2023
76ccbfa
Merge branch 'main' into kmp5/enhancements/unwrap
mtfishman Nov 1, 2023
50a452f
Merge commit '70967cd1c1f0c9236450485446c2a10ad627bc07' into kmp5/enh…
kmp5VT Nov 1, 2023
7fb5f91
Add some scalar index work arounds for CUDA.
kmp5VT Nov 1, 2023
d82057a
Expose version wasn't properly passing kwargs. adding that back in fi…
kmp5VT Nov 1, 2023
b9f2980
format
kmp5VT Nov 1, 2023
130d4f7
Merge branch 'kmp5/enhancements/unwrap' of github.com:kmp5VT/ITensors…
kmp5VT Nov 1, 2023
a7e2ea4
Use CPU
kmp5VT Nov 1, 2023
96a8334
I do need this function
kmp5VT Nov 1, 2023
816ae64
format
kmp5VT Nov 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NDTensorsCUDAExt

using NDTensors
using NDTensors.SetParameters
using NDTensors.Unwrap
using Adapt
using Functors
using LinearAlgebra
Expand Down
26 changes: 21 additions & 5 deletions NDTensors/ext/NDTensorsCUDAExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number})
return CUDA.@allowscalar data(T)[]
function Base.getindex(E::Exposed{<:CuArray})
return CUDA.@allowscalar unexpose(E)[]
end

function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number)
CUDA.@allowscalar data(T)[] = x
return T
function setindex!(E::Exposed{<:CuArray}, x::Number)
CUDA.@allowscalar unexpose(E)[] = x
return unexpose(E)
end

function Base.getindex(E::Exposed{<:CuArray,<:Adjoint}, I...)
Ap = parent(E)
return expose(Ap)[I...]
end

function Base.copy(E::Exposed{<:CuArray,<:Base.ReshapedArray})
Ap = parent(E)
return copy(expose(Ap))
end

Base.any(f, E::Exposed{<:CuArray,<:NDTensors.Tensor}) = any(f, data(unexpose(E)))

function Base.print_array(io::IO, E::Exposed{<:CuArray})
return Base.print_array(io, expose(NDTensors.cpu(E)))
end
2 changes: 1 addition & 1 deletion NDTensors/ext/NDTensorsCUDAExt/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function NDTensors.svd_catch_error(A::CuMatrix; alg="JacobiAlgorithm")
alg = CUDA.CUSOLVER.QRAlgorithm()
end
USV = try
svd(A; alg=alg)
svd(expose(A); alg=alg)
catch
return nothing
end
Expand Down
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsMetalExt/imports.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import NDTensors: mtl, set_ndims, set_eltype, set_eltype_if_unspecified
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter

using NDTensors.Unwrap: Exposed, unwrap_type, unexpose, expose
using Metal: DefaultStorageMode
using NDTensors: adapt
10 changes: 5 additions & 5 deletions NDTensors/ext/NDTensorsMetalExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number})
return Metal.@allowscalar data(T)[]
function Base.getindex(E::Exposed{<:MtlArray})
return Metal.@allowscalar unexpose(E)[]
end

function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number)
Metal.@allowscalar data(T)[] = x
return T
function Base.setindex!(E::Exposed{<:MtlArray}, x::Number)
Metal.@allowscalar unexpose(E)[] = x
return unexpose(E)
end
36 changes: 25 additions & 11 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
Q, R = NDTensors.qr(NDTensors.cpu(A))
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R)
function LinearAlgebra.qr(A::Exposed{<:MtlMatrix})
Q, R = qr(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R)
end

function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
D, U = NDTensors.eigen(NDTensors.cpu(A))
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U)
function NDTensors.Unwrap.qr_positive(A::Exposed{<:MtlMatrix})
Q, R = qr_positive(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), R)
end

function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
U, S, V = NDTensors.svd(NDTensors.cpu(A))
return adapt(leaf_parenttype, U),
adapt(set_ndims(leaf_parenttype, ndims(S)), S),
adapt(leaf_parenttype, V)
function NDTensors.Unwrap.ql(A::Exposed{<:MtlMatrix})
Q, L = ql(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L)
end
function NDTensors.Unwrap.ql_positive(A::Exposed{<:MtlMatrix})
Q, L = ql_positive(expose(NDTensors.cpu(A)))
return adapt(unwrap_type(A), Matrix(Q)), adapt(unwrap_type(A), L)
end

function LinearAlgebra.eigen(A::Exposed{<:MtlMatrix})
D, U = eigen(expose(NDTensors.cpu(A)))
return adapt(set_ndims(unwrap_type(A), ndims(D)), D), adapt(unwrap_type(A), U)
end

function LinearAlgebra.svd(A::Exposed{<:MtlMatrix}; kwargs...)
U, S, V = svd(expose(NDTensors.cpu(A)); kwargs...)
return adapt(unwrap_type(A), U),
adapt(set_ndims(unwrap_type(A), ndims(S)), S),
adapt(unwrap_type(A), V)
end
13 changes: 5 additions & 8 deletions NDTensors/ext/NDTensorsMetalExt/mul.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# This was calling generic matrix multiplication.
# TODO: Raise an issue with `Metal.jl`.
function NDTensors.mul!!(
::Type{<:MtlArray},
CM::Transpose,
::Type{<:MtlArray},
AM::AbstractMatrix,
::Type{<:MtlArray},
BM::AbstractMatrix,
function LinearAlgebra.mul!(
CM::Exposed{<:MtlArray,<:Transpose},
AM::Exposed{<:MtlArray},
BM::Exposed{<:MtlArray},
α,
β,
)
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
return CM
return unexpose(CM)
end
15 changes: 5 additions & 10 deletions NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
function NDTensors.permutedims!(
::Type{<:MtlArray},
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray},
::Type{<:MtlArray},
A,
perm,
function permutedims!(
Edest::Exposed{<:MtlArray,<:Base.ReshapedArray}, Esrc::Exposed{<:MtlArray}, perm
)
Aperm = permutedims(A, perm)
Adest_parent = parent(Adest)
copyto!(Adest_parent, Aperm)
return Adest
Aperm = permutedims(Esrc, perm)
copyto!(expose(parent(Edest)), expose(Aperm))
return unexpose(Edest)
end
9 changes: 3 additions & 6 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ include("SortedSets/src/SortedSets.jl")
using .SortedSets
include("TagSets/src/TagSets.jl")
using .TagSets
include("Unwrap/src/Unwrap.jl")
using .Unwrap

using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo

Expand All @@ -51,16 +53,13 @@ include("algorithm.jl")
include("aliasstyle.jl")
include("abstractarray/set_types.jl")
include("abstractarray/to_shape.jl")
include("abstractarray/iswrappedarray.jl")
include("abstractarray/iscu.jl")
include("abstractarray/similar.jl")
include("abstractarray/ndims.jl")
include("abstractarray/copyto.jl")
include("abstractarray/mul.jl")
include("abstractarray/append.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("abstractarray/linearalgebra.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
Expand All @@ -75,8 +74,6 @@ include("tensor/tensor.jl")
include("dims.jl")
include("tensor/set_types.jl")
include("tensor/similar.jl")
include("tensor/permutedims.jl")
include("tensor/linearalgebra.jl")
include("adapt.jl")
include("tensoralgebra/generic_tensor_operations.jl")
include("tensoralgebra/contraction_logic.jl")
Expand Down
3 changes: 3 additions & 0 deletions NDTensors/src/Unwrap/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Unwrap

A module to unwrap complex array types to assist in the generic programming of array-type based functions.
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Replace all `leaf_parenttype` calls by wrapping the arrays in this `expose` type

Fix the issue Ryan found in MPS
Make a GPUArrays extension that has generic GPU algorithms
25 changes: 25 additions & 0 deletions NDTensors/src/Unwrap/src/Unwrap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Unwrap
using SimpleTraits
using LinearAlgebra
using Base: ReshapedArray
using Strided.StridedViews

include("expose.jl")
include("iswrappedarray.jl")

include("import.jl")
## TODO Create functions which take the `Expose` type and launch functions
## using that type
## Exposed based functions
include("functions/abstractarray.jl")
include("functions/copyto.jl")
include("functions/linearalgebra.jl")
include("functions/mul.jl")
include("functions/permutedims.jl")

export IsWrappedArray,
is_wrapped_array, parenttype, unwrap_type, expose, Exposed, unexpose, cpu

## TODO write exposed based functions in the NDTensors Extensions when necessary

end
7 changes: 7 additions & 0 deletions NDTensors/src/Unwrap/src/expose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
struct Exposed{Unwrapped,Object}
object::Object
end

expose(object) = Exposed{unwrap_type(object),typeof(object)}(object)

unexpose(E::Exposed) = E.object
22 changes: 22 additions & 0 deletions NDTensors/src/Unwrap/src/functions/abstractarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
parent(E::Exposed) = parent(unexpose(E))

transpose(E::Exposed) = transpose(unexpose(E))

cpu(E::Exposed) = cpu(unexpose(E))

getindex(E::Exposed) = unexpose(E)[]

function setindex!(E::Exposed, x::Number)
unexpose(E)[] = x
return unexpose(E)
end

getindex(E::Exposed, I...) = unexpose(E)[I...]

function copy(E::Exposed)
return copy(unexpose(E))
end

any(f, E::Exposed) = any(f, unexpose(E))

print_array(io::IO, E::Exposed) = print_array(io, unexpose(E))
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/src/functions/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function copyto!(R::Exposed, T::Exposed)
copyto!(unexpose(R), unexpose(T))
return unexpose(R)
end
29 changes: 29 additions & 0 deletions NDTensors/src/Unwrap/src/functions/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
function qr(E::Exposed)
return qr(unexpose(E))
end
## These functions do not exist in `LinearAlgebra` but were defined
## in NDTensors. Because Unwrap is imported before NDTensors,
## one cannot import a these functions from NDTensors so instead
## I define them here and extend them in NDTensors
## I have done the same thing for the function cpu
## Unwrap.qr_positive
function qr_positive(E::Exposed)
return qr_positive(unexpose(E))
end

## Unwrap.ql
function ql(E::Exposed)
return ql(unexpose(E))
end
## Unwrap.ql_positive
function ql_positive(E::Exposed)
return ql_positive(unexpose(E))
end

function eigen(E::Exposed)
return eigen(unexpose(E))
end

function svd(E::Exposed; kwargs...)
return svd(unexpose(E); kwargs...)
end
4 changes: 4 additions & 0 deletions NDTensors/src/Unwrap/src/functions/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function mul!(CM::Exposed, AM::Exposed, BM::Exposed, α, β)
mul!(unexpose(CM), unexpose(AM), unexpose(BM), α, β)
return unexpose(CM)
end
13 changes: 13 additions & 0 deletions NDTensors/src/Unwrap/src/functions/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function permutedims(E::Exposed, perm)
return permutedims(unexpose(E), perm)
end

function permutedims!(Edest::Exposed, Esrc::Exposed, perm)
permutedims!(unexpose(Edest), unexpose(Esrc), perm)
return unexpose(Edest)
end

function permutedims!(Edest::Exposed, Esrc::Exposed, perm, f)
unexpose(Edest) .= f.(unexpose(Edest), permutedims(Esrc, perm))
return unexpose(Edest)
end
13 changes: 13 additions & 0 deletions NDTensors/src/Unwrap/src/import.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Base:
permutedims,
permutedims!,
copy,
copyto!,
parent,
print_array,
transpose,
getindex,
setindex!,
any

import LinearAlgebra: mul!, qr, eigen, svd
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,20 @@ parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P
# `SimpleTraits.jl` traits dispatch.
parenttype(array::AbstractArray) = parenttype(typeof(array))

@traitfn function leaf_parenttype(
## These functions will be used in place of unwrap_type but will be
## call indirectly through the expose function.
@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return leaf_parenttype(parenttype(arraytype))
return unwrap_type(parenttype(arraytype))
end

@traitfn function leaf_parenttype(
@traitfn function unwrap_type(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

# For working with instances.
leaf_parenttype(array::AbstractArray) = leaf_parenttype(typeof(array))
unwrap_type(array::AbstractArray) = unwrap_type(typeof(array))
unwrap_type(E::Exposed) = unwrap_type(unexpose(E))
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/append.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Used to circumvent issues with some GPU backends like Metal
# not supporting `resize!`.
function append!!(collection, collections...)
return append!!(leaf_parenttype(collection), collection, collections...)
return append!!(unwrap_type(collection), collection, collections...)
end

function append!!(::Type, collection, collections...)
Expand Down
13 changes: 0 additions & 13 deletions NDTensors/src/abstractarray/copyto.jl

This file was deleted.

4 changes: 2 additions & 2 deletions NDTensors/src/abstractarray/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ function generic_randn(
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
unwrap_type(arraytype), DefaultParameters()
)
data = similar(arraytype_specified, dim)
return randn!(rng, data)
end

function generic_zeros(arraytype::Type{<:AbstractArray}, dims...)
arraytype_specified = set_unspecified_parameters(
leaf_parenttype(arraytype), DefaultParameters()
unwrap_type(arraytype), DefaultParameters()
)
ElT = eltype(arraytype_specified)
return fill!(similar(arraytype_specified, dims...), zero(ElT))
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/abstractarray/iscu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# For `isgpu`, will require a `NDTensorsGPUArrayCoreExt`.
iscu(A::AbstractArray) = iscu(typeof(A))
function iscu(A::Type{<:AbstractArray})
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
return (unwrap_type(A) == A ? false : iscu(unwrap_type(A)))
end
Loading
Loading