Skip to content

Commit

Permalink
add device array reshape (#713)
Browse files Browse the repository at this point in the history
* add device array reshape

* add device array reshape unit test

* fix error handling
  • Loading branch information
omlins authored Dec 24, 2024
1 parent 44a050e commit 2c93eda
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/device/gcn/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,20 @@ end
nothing
end
end

## reshape

function Base.reshape(a::ROCDeviceArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A}
if prod(dims) != length(a)
throw(nothing) # standing for: throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)"))
end
if N == M && dims == size(a)
return a
end
_derived_array(a, T, dims)
end

# create a derived device array (reinterpreted or reshaped) that's still a ROCDeviceArray
@inline function _derived_array(a::ROCDeviceArray{<:Any,<:Any,A}, ::Type{T}, dims::Dims{N}) where {T, N, A}
return ROCDeviceArray{T,N,A}(dims, a.ptr)
end
23 changes: 23 additions & 0 deletions test/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,26 @@ end
o2 = occursin("4×4 device array wrapper Adjoint", adj_repr)
@test o1 || o2
end

@testset "reshape" begin
function kernel(array)
i = (workgroupIdx().x-1) * workgroupDim().x + workitemIdx().x
j = (workgroupIdx().y-1) * workgroupDim().y + workitemIdx().y

_array2d = reshape(array, 10, 10)
_array2d[i,j] = i + (j-1)*size(_array2d,1)

return
end

array = zeros(Int64, 100)
array_dev = ROCArray(array)

array2d = reshape(array, 10, 10)
for i in 1:size(array2d,1), j in 1:size(array2d,2)
array2d[i,j] = i + (j-1)*size(array2d,1)
end

@roc groupsize=(10, 10) kernel(array_dev)
@test array == Array(array_dev)
end

0 comments on commit 2c93eda

Please sign in to comment.