Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU functionality in NamedDimsArray #1286

Closed
wants to merge 132 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
7453e04
Add FieldType module with Unallocated and Unspecified Types
kmp5VT Oct 18, 2023
af065e7
format
kmp5VT Oct 18, 2023
bd2a078
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Oct 18, 2023
5fa5fab
Merge branch 'main' into kmp5/feature/FieldTypes
mtfishman Oct 19, 2023
86872e1
Split up fieldtypes into Unallocated and Unspecified modules
kmp5VT Oct 19, 2023
071ae92
More UnallocatedFill implementation.
kmp5VT Oct 23, 2023
1a849c9
Start simplifying the UnallocatedZeros file
kmp5VT Oct 23, 2023
297aaa2
Fix typo
kmp5VT Oct 23, 2023
3fc526c
format
kmp5VT Oct 23, 2023
c9fb989
Update to the Unallocatedfill/zero code
kmp5VT Oct 23, 2023
e0e3652
format
kmp5VT Oct 23, 2023
7e58e30
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Oct 24, 2023
ae13d96
Updates to code. Still working on all of these
kmp5VT Oct 24, 2023
3839f88
Add some promote rules. Still working here
kmp5VT Oct 24, 2023
c1b834f
format
kmp5VT Oct 24, 2023
e9d8b06
Merge commit '66f5d391a6ef01834658ec9d5a9a8ba3d4b0f827' into kmp5/fea…
kmp5VT Nov 2, 2023
83bbd37
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 2, 2023
e2c59cb
Start working on Unallocated settype functions
kmp5VT Nov 2, 2023
145882e
format
kmp5VT Nov 2, 2023
8438d5d
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 8, 2023
b5ac4ce
Fix set_types
kmp5VT Nov 8, 2023
c136685
formatting
kmp5VT Nov 8, 2023
4d1b976
format
kmp5VT Nov 8, 2023
4e265ce
Update UnallocatedArrays
kmp5VT Nov 8, 2023
2b2a6b2
format
kmp5VT Nov 8, 2023
794d6ba
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 8, 2023
305c47b
CuArray name change, fix warning
kmp5VT Nov 8, 2023
1468c69
Some fixes
kmp5VT Nov 8, 2023
a25ec05
Add comment
kmp5VT Nov 8, 2023
a824fab
I don't need getindex function and `[I...]` still works
kmp5VT Nov 8, 2023
0fdb0a4
Convert is implemented wrong
kmp5VT Nov 8, 2023
fe051c5
Import setparameters
kmp5VT Nov 9, 2023
547fcd7
Working on set_parameter
kmp5VT Nov 9, 2023
cf1dff4
format
kmp5VT Nov 9, 2023
7963a54
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 9, 2023
568fe7d
Create set_alloctype constructor for UnallocatedArrays
kmp5VT Nov 9, 2023
52895c0
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Nov 9, 2023
8d0135d
Yes we should force Alloc to be the same dimensions as the `Fill`
kmp5VT Nov 9, 2023
9056599
Add set_alloctype and complex
kmp5VT Nov 9, 2023
c646598
format
kmp5VT Nov 9, 2023
2a9bc4f
We don't need to define norm
kmp5VT Nov 9, 2023
25b1ef4
Remove unecessary imports
kmp5VT Nov 9, 2023
c0de9dc
Remove set_eltype
kmp5VT Nov 9, 2023
2ad44e9
Remove line per matts comment
kmp5VT Nov 9, 2023
c56b3ac
Move set_alloctype constructors to their own files
kmp5VT Nov 9, 2023
fcd1997
Some updates
kmp5VT Nov 9, 2023
ab7e910
Add a unittest
kmp5VT Nov 9, 2023
e007458
format
kmp5VT Nov 9, 2023
bca8e77
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 9, 2023
df752fe
Remove @eval and fix the complex function
kmp5VT Nov 9, 2023
d254db8
Add complex testing
kmp5VT Nov 9, 2023
c4dc2af
format
kmp5VT Nov 9, 2023
4e6036b
Make a UnallocFillOrZero union variable
kmp5VT Nov 9, 2023
f87c2d3
format
kmp5VT Nov 9, 2023
aef5623
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 15, 2023
d635004
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 16, 2023
a0ceca4
Move abstractUnallocatedArray functions to file
kmp5VT Nov 20, 2023
8925888
Add UnallocatedArray.jl
kmp5VT Nov 20, 2023
2a694c0
Update set_types to use AbstractUnallocatedArrays
kmp5VT Nov 20, 2023
3027afa
Make other objects subtype of AbstractUnallocatedArray
kmp5VT Nov 20, 2023
85a9837
Add norm
kmp5VT Nov 20, 2023
3cbfdbe
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Nov 30, 2023
b3ab1d6
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Nov 30, 2023
be31025
Move unallocatedArrays and UnspecifiedTypes to lib folder
kmp5VT Nov 30, 2023
7fe90f2
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 3, 2023
d945bd4
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 3, 2023
d64f3de
Working on UnallocatedArrays
kmp5VT Dec 3, 2023
ac5cb59
Remove old file
kmp5VT Dec 4, 2023
834287a
Updates to UnallocatedArrays system
kmp5VT Dec 4, 2023
4c66224
Remove import to be more explicit
kmp5VT Dec 4, 2023
19b29df
Cleanup/fixes
kmp5VT Dec 4, 2023
ba869d3
Remove import from module
kmp5VT Dec 4, 2023
864c1be
Clean up
kmp5VT Dec 4, 2023
7fe9457
Add tests to NDTensors
kmp5VT Dec 4, 2023
01a828f
format
kmp5VT Dec 4, 2023
af20c8f
additional updates
kmp5VT Dec 4, 2023
3dc4514
Add FillArrays to NDTensors test
kmp5VT Dec 4, 2023
fdc440f
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 4, 2023
a3dca2b
Remove type piracy code
kmp5VT Dec 4, 2023
41d3f75
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 4, 2023
eedc125
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 4, 2023
3fcbe11
Moved shared function to defaultunallocatedarray.jl
kmp5VT Dec 4, 2023
f92e4d9
format
kmp5VT Dec 4, 2023
71031a3
force itensors to update packages, using old FillArrays
kmp5VT Dec 4, 2023
9ea0ce5
change order
kmp5VT Dec 4, 2023
43933b9
Rename
kmp5VT Dec 4, 2023
9b963fd
rename
kmp5VT Dec 4, 2023
c4c4a74
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 6, 2023
e3ccfef
Define mult_zero for UnallocatedZeros and make unit tests
kmp5VT Dec 7, 2023
ec266f8
Define transpose and adjoint for unallocatedarrays
kmp5VT Dec 7, 2023
57d52e1
Define broadcasted_zeros
kmp5VT Dec 7, 2023
b6a51e7
Make kron_zero functions
kmp5VT Dec 7, 2023
0c4c92e
format
kmp5VT Dec 7, 2023
3c18596
fix typo
kmp5VT Dec 7, 2023
60ebf27
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
1dbf77f
Merge commit '60ebf274959a371f52588a9c500f00b873591262' into kmp5/fea…
kmp5VT Dec 8, 2023
57b9002
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
d47e50a
Remove update
kmp5VT Dec 8, 2023
b5d3609
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 8, 2023
82519ed
Merge branch 'kmp5/feature/FieldTypes' of github.com:kmp5VT/ITensors.…
kmp5VT Dec 8, 2023
7653d91
Add some functions to fix
kmp5VT Dec 8, 2023
0699e89
add fill functions
kmp5VT Dec 8, 2023
4a1cef2
overload kron_fill not kron_zeros
kmp5VT Dec 8, 2023
40af2cc
Update functions needed
kmp5VT Dec 8, 2023
1d107a6
Add more tests and sort them
kmp5VT Dec 8, 2023
e05d5ff
format
kmp5VT Dec 8, 2023
410a8df
Add more tests
kmp5VT Dec 8, 2023
5e3601e
format
kmp5VT Dec 8, 2023
3ce2ce8
Merge branch 'main' into kmp5/feature/FieldTypes
kmp5VT Dec 11, 2023
4246757
Add more tests for UnallocatedFill
kmp5VT Dec 11, 2023
417e025
Move tests to braodcast
kmp5VT Dec 11, 2023
4e70bb3
Add more tests
kmp5VT Dec 11, 2023
52e21b0
Create a trival convert foreach UnallocatedArray
kmp5VT Dec 11, 2023
d874b1b
Remove promote_rules
kmp5VT Dec 12, 2023
611420b
Move using to each file
kmp5VT Dec 12, 2023
6011fbc
Create an unallocatedArray typedef
kmp5VT Dec 12, 2023
bcf89f1
Remove comment
kmp5VT Dec 12, 2023
a6b9ec7
Definitions to fix UnallocatedZeros addition
kmp5VT Dec 12, 2023
03655d0
Some additional fixes
kmp5VT Dec 12, 2023
e9ed6c5
Remove tensor tests
kmp5VT Dec 12, 2023
0c03fab
Test similar
kmp5VT Dec 12, 2023
2a04e0a
Remove NDTensor variables
kmp5VT Dec 12, 2023
24f185c
Remove comment
kmp5VT Dec 12, 2023
b64372d
Add comment
kmp5VT Dec 12, 2023
afebe2a
format
kmp5VT Dec 12, 2023
5bd1fd3
Update constructors and fix + for UnallocatedFill
kmp5VT Dec 12, 2023
a41a70c
Make only one constructor for now that requires an abstractarray type
kmp5VT Dec 13, 2023
4bd3028
Assue alloc defines constructor for type
kmp5VT Dec 13, 2023
cd42628
Remove comment
kmp5VT Dec 13, 2023
16daf49
remove comment
kmp5VT Dec 13, 2023
2693dfc
format
kmp5VT Dec 13, 2023
2d97f18
Start testing NamedDimsArray to see if its GPU compatable.
kmp5VT Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ for lib in [
:SmallVectors,
:SortedSets,
:TagSets,
:UnspecifiedTypes,
:UnallocatedArrays,
]
include("lib/$(lib)/src/$(lib).jl")
@eval using .$lib: $lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@ using Test: @test, @testset, @test_broken
using NDTensors.NamedDimsArrays: named, unname
using NDTensors.TensorAlgebra: TensorAlgebra
using LinearAlgebra: qr
@testset "NamedDimsArraysTensorAlgebraExt contract (eltype=$(elt))" for elt in (
Float32, ComplexF32, Float64, ComplexF64
)

using NDTensors: NDTensors
using GPUArraysCore: @allowscalar
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
using .NDTensorsTestUtils: devices_list
@testset "NamedDimsArraysTensorAlgebraExt contract (eltype=$(elt))" for dev in
devices_list(ARGS),
elt in (Float32, ComplexF32, Float64, ComplexF64)

i = named(2, "i")
j = named(2, "j")
k = named(2, "k")
na1 = randn(elt, i, j)
na2 = randn(elt, j, k)
na1 = dev(randn(elt, i, j))
na2 = dev(randn(elt, j, k))
na_dest = TensorAlgebra.contract(na1, na2)
@test eltype(na_dest) === elt
@test unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2)
@allowscalar @test unname(na_dest, (i, k)) ≈ unname(na1) * unname(na2)
end
@testset "NamedDimsArraysTensorAlgebraExt QR (eltype=$(elt))" for elt in (
Float32, ComplexF32, Float64, ComplexF64
Expand Down
86 changes: 50 additions & 36 deletions NDTensors/src/lib/NamedDimsArrays/test/test_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ using NDTensors.NamedDimsArrays:
namedaxes,
namedsize,
unname
using NDTensors: NDTensors
using GPUArraysCore: @allowscalar
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
using .NDTensorsTestUtils: devices_list

@testset "NamedDimsArrays $(@__FILE__)" begin
@testset "Basic functionality" begin
a = randn(3, 4)
@testset "Basic functionality on $dev" for dev in devices_list(ARGS)
a = dev(randn(3, 4))
na = named(a, ("i", "j"))
# TODO: Just call this `size`?
i, j = namedsize(na)
Expand All @@ -22,86 +26,96 @@ using NDTensors.NamedDimsArrays:
@test !isnamed(a)
@test isnamed(na)
@test dimnames(na) == ("i", "j")
@test na[1, 1] == a[1, 1]
na[1, 1] = 11
@test na[1, 1] == 11
@allowscalar begin
@test na[1, 1] == a[1, 1]
na[1, 1] = 11
@test na[1, 1] == 11
end
# TODO: Should `size` output `namedsize`?
@test size(na) == (3, 4)
@test namedsize(na) == (named(3, "i"), named(4, "j"))
@test length(na) == 12
# TODO: Should `axes` output `namedaxes`?
@test axes(na) == (1:3, 1:4)
@test namedaxes(na) == (named(1:3, "i"), named(1:4, "j"))
@test randn(named(3, "i"), named(4, "j")) isa NamedDimsArray
@test na["i" => 1, "j" => 2] == a[1, 2]
@test na["j" => 2, "i" => 1] == a[1, 2]
na["j" => 2, "i" => 1] = 12
@test na[1, 2] == 12
@test na[j => 1, i => 2] == a[2, 1]
@test na[aj => 1, ai => 2] == a[2, 1]
na[j => 1, i => 2] = 21
@test na[2, 1] == 21
na[aj => 1, ai => 2] = 2211
@test na[2, 1] == 2211
@test dev(randn(named(3, "i"), named(4, "j"))) isa NamedDimsArray
@allowscalar begin
@test na["i" => 1, "j" => 2] == a[1, 2]
@test na["j" => 2, "i" => 1] == a[1, 2]
na["j" => 2, "i" => 1] = 12
@test na[1, 2] == 12
@test na[j => 1, i => 2] == a[2, 1]
@test na[aj => 1, ai => 2] == a[2, 1]
na[j => 1, i => 2] = 21
@test na[2, 1] == 21
na[aj => 1, ai => 2] = 2211
@test na[2, 1] == 2211
end
na′ = align(na, ("j", "i"))
@test a == permutedims(unname(na′), (2, 1))
na′ = align(na, (j, i))
@test a == permutedims(unname(na′), (2, 1))
na′ = align(na, (aj, ai))
@test a == permutedims(unname(na′), (2, 1))
end
@testset "Shorthand constructors (eltype=$elt)" for elt in (
Float32, ComplexF32, Float64, ComplexF64
)

@testset "Shorthand constructors (eltype=$elt)" for dev in devices_list(ARGS),
elt in (Float32, ComplexF32, Float64, ComplexF64)

i, j = named.((2, 2), ("i", "j"))
value = rand(elt)
for na in (zeros(elt, i, j), zeros(elt, (i, j)))
for na in (dev(zeros(elt, i, j)), dev(zeros(elt, (i, j))))
@test eltype(na) === elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test iszero(na)
@allowscalar @test iszero(na)
end
for na in (fill(value, i, j), fill(value, (i, j)))
for na in (dev(fill(value, i, j)), dev(fill(value, (i, j))))
@test eltype(na) === elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test all(isequal(value), na)
@allowscalar @test all(isequal(value), na)
end
for na in (rand(elt, i, j), rand(elt, (i, j)))
for na in (dev(rand(elt, i, j)), dev(rand(elt, (i, j))))
@test eltype(na) === elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test !iszero(na)
@test all(x -> real(x) > 0, na)
@allowscalar begin
@test !iszero(na)
@test all(x -> real(x) > 0, na)
end
end
for na in (randn(elt, i, j), randn(elt, (i, j)))
for na in (dev(randn(elt, i, j)), dev(randn(elt, (i, j))))
@test eltype(na) === elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test !iszero(na)
@allowscalar @test !iszero(na)
end
end
@testset "Shorthand constructors (eltype=unspecified)" begin

@testset "Shorthand constructors (eltype=unspecified)" for dev in devices_list(ARGS)
i, j = named.((2, 2), ("i", "j"))
default_elt = Float64
for na in (zeros(i, j), zeros((i, j)))
for na in (dev(zeros(i, j)), dev(zeros((i, j))))
@test eltype(na) === default_elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test iszero(na)
@allowscalar @test iszero(na)
end
for na in (rand(i, j), rand((i, j)))
for na in (dev(rand(i, j)), dev(rand((i, j))))
@test eltype(na) === default_elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test !iszero(na)
@test all(x -> real(x) > 0, na)
@allowscalar begin
@test !iszero(na)
@test all(x -> real(x) > 0, na)
end
end
for na in (randn(i, j), randn((i, j)))
for na in (dev(randn(i, j)), dev(randn((i, j))))
@test eltype(na) === default_elt
@test na isa NamedDimsArray
@test dimnames(na) == ("i", "j")
@test !iszero(na)
@allowscalar @test !iszero(na)
end
end
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using NDTensors.SetParameters: Position, set_parameter
function allocate_output(
::typeof(contract),
alg::Algorithm,
Expand All @@ -13,7 +14,10 @@ function allocate_output(
# TODO: Define `output_type(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
# TODO: Define `output_structure(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
# TODO: Define `allocate(type, structure)`.
return Array{promote_type(eltype(a1), eltype(a2))}(undef, length.(axes_dest))
elt = promote_type(eltype(a1), eltype(a2))
@assert typeof(a1) == typeof(a2)
array = set_parameter(typeof(a1), Position{1}(), elt)
return array(undef, length.(axes_dest))
end

# TODO: Generalize to `output_structure`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using NDTensors.Unwrap: expose
function contract!(
alg::Algorithm"matricize",
a_dest::AbstractArray,
Expand All @@ -12,7 +13,7 @@ function contract!(
a_dest_matricized = matricize(a_dest, biperm_dest)
a1_matricized = matricize(a1, biperm1)
a2_matricized = matricize(a2, biperm2)
mul!(a_dest_matricized, a1_matricized, a2_matricized, α, β)
mul!(expose(a_dest_matricized), expose(a1_matricized), expose(a2_matricized), α, β)
perm_dest = flatten(biperm_dest)
# TODO: Create a function `unmatricize` or `unfusedims`.
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest)
Expand Down
4 changes: 4 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# UnallocatedArrays

A module defining a set of unallocated immutable lazy arrays which will be used to quickly construct
tensors and allocating as little data as possible.
11 changes: 11 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/src/UnallocatedArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module UnallocatedArrays
include("abstractfill/abstractfill.jl")
include("abstractfill/set_types.jl")

include("unallocatedfill.jl")
include("unallocatedzeros.jl")
include("abstractunallocatedarray.jl")
include("set_types.jl")

export UnallocatedFill, UnallocatedZeros, alloctype, set_alloctype, allocate
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using FillArrays: AbstractFill
using NDTensors.SetParameters: Position, get_parameter, set_parameters
## Here are functions specifically defined for UnallocatedArrays
## not implemented by FillArrays
## TODO this might need a more generic name maybe like compute unit
function alloctype(A::AbstractFill)
return A.alloc
end

## TODO this fails if the parameter is a type
function alloctype(Atype::Type{<:AbstractFill})
return get_parameter(Atype, Position{4}())
end

allocate(A::AbstractFill) = alloctype(A)(parent(A))

set_eltype(T::Type{<:AbstractFill}, elt::Type) = set_parameters(T, Position{1}(), elt)
set_ndims(T::Type{<:AbstractFill}, n) = set_parameters(T, Position{2}(), n)
set_axestype(T::Type{<:AbstractFill}, ax::Type) = set_parameters(T, Position{3}(), ax)
75 changes: 75 additions & 0 deletions NDTensors/src/lib/UnallocatedArrays/src/abstractfill/set_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using FillArrays: AbstractFill
using NDTensors.SetParameters: SetParameters, Position
using NDTensors.UnspecifiedTypes: UnspecifiedZero
## TODO make unit tests for all of these functions
## TODO remove P4
# `SetParameters.jl` overloads.
SetParameters.get_parameter(::Type{<:AbstractFill{P1}}, ::Position{1}) where {P1} = P1
SetParameters.get_parameter(::Type{<:AbstractFill{<:Any,P2}}, ::Position{2}) where {P2} = P2
function SetParameters.get_parameter(
::Type{<:AbstractFill{<:Any,<:Any,P3}}, ::Position{3}
) where {P3}
return P3
end

## Setting paramaters
# right now I am just defining the necessary ones for my implementation still working on full implementation
# Set parameter 1
SetParameters.set_parameter(T::Type{<:AbstractFill}, ::Position{1}, P1) = T{P1}

# Set parameter 2
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1}}, ::Position{2}, P2
) where {P1}
return T{P2}
end

# Set parameter 3
function SetParameters.set_parameter(
T::Type{<:AbstractFill{P1,P2}}, ::Position{3}, P3
) where {P1,P2}
return T{P3}
end

## TODO define a specify_parameters function
## To quickly specify P1, P2, and P3
## default parameters
function SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{1})
return UnspecifiedTypes.UnspecifiedZero
end
SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{2}) = 0
SetParameters.default_parameter(::Type{<:AbstractFill}, ::Position{3}) = Tuple{}

SetParameters.nparameters(::Type{<:AbstractFill}) = Val(3)

# Set parameter 1
## Right now using AbstractArray
## TODO These are more difficult because T is technically defined so need some way to strip T of it {} types
# function set_parameter(
# T::Type{<:AbstractFill{<:Any,<:Any,P3}}, ::Position{1}, P1
# ) where {P3}
# return T{P1,<:Any,P3}
# end
# function set_parameter(
# T::Type{<:AbstractFill{<:Any,P2,P3}}, ::Position{1}, P1
# ) where {P2,P3}
# return T{P1,P2,P3}
# end
# function set_parameter(
# T::Type{<:AbstractFill{<:Any,<:Any,P3}}, ::Position{2}, P2
# ) where {P3}
# return T{<:Any,P2,P3}
# end
# function set_parameter(
# T::Type{<:AbstractFill{P1,<:Any,P3}}, ::Position{2}, P2
# ) where {P1,P3}
# return T
# end
# set_parameter(T::Type{<:AbstractFill}, ::Position{3}, P3) = T{<:Any,<:Any,P3}
# set_parameter(T::Type{<:AbstractFill{P1}}, ::Position{3}, P3) where {P1} = T{<:Any,P3}
# function set_parameter(T::Type{<:AbstractFill{<:Any,P2}}, ::Position{3}, P3) where {P2}
# return T{<:Any,P2,P3}
# end

# Set parameter 2
## using AbstractArray
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using FillArrays: FillArrays, getindex_value
using NDTensors.SetParameters: set_parameters
using Adapt: adapt

const UnallocatedArray{ElT,N,AxesT,AllocT} = Union{
<:UnallocatedFill{ElT,N,AxesT,AllocT},<:UnallocatedZeros{ElT,N,AxesT,AllocT}
}

@inline Base.axes(A::UnallocatedArray) = axes(parent(A))
Base.size(A::UnallocatedArray) = size(parent(A))
function FillArrays.getindex_value(A::UnallocatedArray)
return getindex_value(parent(A))
end

function Base.complex(A::UnallocatedArray)
return set_alloctype(
complex(parent(A)), set_parameters(alloctype(A), Position{1}(), complex(eltype(A)))
)
end

function Base.transpose(a::UnallocatedArray)
return set_alloctype(transpose(parent(a)), alloctype(a))
end

function Base.adjoint(a::UnallocatedArray)
return set_alloctype(adjoint(parent(a)), alloctype(a))
end

function Base.similar(a::UnallocatedArray)
return alloctype(a)(undef, size(a))
end

function set_alloctype(T::Type{<:UnallocatedArray}, alloc::Type{<:AbstractArray})
return set_parameters(T, Position{4}(), alloc)
end

## This overloads the definition defined in `FillArrays.jl`
for STYPE in (:AbstractArray, :AbstractFill)
@eval begin
@inline $STYPE{T}(F::UnallocatedArray{T}) where {T} = F
@inline $STYPE{T,N}(F::UnallocatedArray{T,N}) where {T,N} = F
end
end

## TODO fix this because reshape loses alloctype
#FillArrays.reshape(a::Union{<:UnallocatedFill, <:UnallocatedZeros}, dims) = set_alloctype(reshape(parent(a), dims), allocate(a))

# function Adapt.adapt_storage(to::Type{<:AbstractArray}, x::Union{<:UnallocatedFill, <:UnallocatedZeros})
# return set_alloctype(parent(x), to)
# end

# function Adapt.adapt_storage(to::Type{<:Number}, x::)
Loading
Loading