From d15b78f1071611569c99cac1ca51267d1ab8ef00 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 11:38:05 -0400 Subject: [PATCH 01/35] Minor touch ups to set --- src/internal/internal_geset.cc | 8 +++----- src/internal/internal_tzset.cc | 9 +++------ 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index c766904a2..2513de6f9 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -121,15 +121,14 @@ void set(internal::TargetType, #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) \ - shared( A, irange, jrange ) \ + #pragma omp task priority( priority ) shared( A, irange, jrange ) \ firstprivate( device, queue_index, offdiag_value, diag_value ) { // Get local tiles for writing. // convert to column major layout to simplify lda's // todo: this is in-efficient because the diagonal is independant of layout // todo: best, handle directly through the CUDA kernels - auto layout = LayoutConvert( Layout::ColMajor ); + auto layout = LayoutConvert::ColMajor; std::set A_tiles_set; for (int64_t i = 0; i < A.mt(); ++i) { @@ -229,8 +228,7 @@ void set(internal::TargetType, for (size_t g = 0; g < group_params.size(); ++g) { int64_t group_count = group_params[ g ].count; device::batch::geset( - group_params[ g ].mb, - group_params[ g ].nb, + group_params[ g ].mb, group_params[ g ].nb, offdiag_value, group_params[ g ].diag_value, a_array_dev, group_params[ g ].lda, group_count, *queue ); diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index 3bce1f5b4..1d17dc0c4 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -149,8 +149,7 @@ void set( #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) \ - shared( A, irange, jrange ) \ + #pragma omp task priority( priority ) shared( A, irange, jrange ) \ firstprivate( device, queue_index, offdiag_value, diag_value ) { // temporarily, convert both into same layout @@ -290,16 +289,14 @@ void set( if (group_params[ g ].is_diagonal) { device::batch::tzset( A.uplo(), - group_params[ g ].mb, - group_params[ g ].nb, + group_params[ g ].mb, group_params[ g ].nb, offdiag_value, diag_value, a_array_dev, group_params[ g ].lda, group_count, *queue); } else { device::batch::geset( - group_params[ g ].mb, - group_params[ g ].nb, + group_params[ g ].mb, group_params[ g ].nb, offdiag_value, offdiag_value, a_array_dev, group_params[ g ].lda, group_count, *queue ); From 003269f23aaa0526b18bb436b9e84cb6b57d14b8 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 11:56:05 -0400 Subject: [PATCH 02/35] Add regions to device scale --- include/slate/internal/device.hh | 18 +-- src/cuda/device_tzscale.cu | 4 + src/hip/device_tzscale.hip.cc | 4 + src/hip/device_tzscale.hip.cc.dep | 2 +- src/internal/internal_gescale.cc | 112 +++++++++------- src/internal/internal_tzscale.cc | 214 ++++++++++++++++++------------ src/omptarget/device_tzscale.cc | 2 + 7 files changed, 215 insertions(+), 141 deletions(-) diff --git a/include/slate/internal/device.hh b/include/slate/internal/device.hh index 04258b1c1..a7092b7f0 100644 --- a/include/slate/internal/device.hh +++ b/include/slate/internal/device.hh @@ -130,15 +130,6 @@ void gescale( scalar_t* A, int64_t lda, blas::Queue& queue); -//------------------------------------------------------------------------------ -template -void tzscale( - Uplo uplo, - int64_t m, int64_t n, - blas::real_type numer, blas::real_type denom, - scalar_t** Aarray, int64_t lda, - int64_t batch_count, blas::Queue& queue); - //------------------------------------------------------------------------------ template void gescale_row_col_batch( @@ -175,6 +166,15 @@ void gescale( scalar_t** Aarray, int64_t lda, int64_t batch_count, blas::Queue& queue); +//------------------------------------------------------------------------------ +template +void tzscale( + Uplo uplo, + int64_t m, int64_t n, + blas::real_type numer, blas::real_type denom, + scalar_t** Aarray, int64_t lda, + int64_t batch_count, blas::Queue& queue); + //------------------------------------------------------------------------------ template void geadd( diff --git a/src/cuda/device_tzscale.cu b/src/cuda/device_tzscale.cu index 88451bd86..03df511f9 100644 --- a/src/cuda/device_tzscale.cu +++ b/src/cuda/device_tzscale.cu @@ -64,6 +64,9 @@ __global__ void tzscale_kernel( } } +//============================================================================== +namespace batch { + //------------------------------------------------------------------------------ /// Batched routine for element-wise trapezoidal tile scale. /// Sets upper or lower part of @@ -169,5 +172,6 @@ void tzscale( batch_count, queue ); } +} // namespace batch } // namespace device } // namespace slate diff --git a/src/hip/device_tzscale.hip.cc b/src/hip/device_tzscale.hip.cc index 1ef1e6148..3004b51cf 100644 --- a/src/hip/device_tzscale.hip.cc +++ b/src/hip/device_tzscale.hip.cc @@ -65,6 +65,9 @@ __global__ void tzscale_kernel( } } +//============================================================================== +namespace batch { + //------------------------------------------------------------------------------ /// Batched routine for element-wise trapezoidal tile scale. /// Sets upper or lower part of @@ -170,5 +173,6 @@ void tzscale( batch_count, queue ); } +} // namespace batch } // namespace device } // namespace slate diff --git a/src/hip/device_tzscale.hip.cc.dep b/src/hip/device_tzscale.hip.cc.dep index dc08bce84..3b674b8fe 100644 --- a/src/hip/device_tzscale.hip.cc.dep +++ b/src/hip/device_tzscale.hip.cc.dep @@ -1 +1 @@ -95be14909da63e90e9d7d888f5b5e1bd src/cuda/device_tzscale.cu +2f13aaf1009fad8799225aa36791f3d8 src/cuda/device_tzscale.cu diff --git a/src/internal/internal_gescale.cc b/src/internal/internal_gescale.cc index 4c4bba3d6..123c61f57 100644 --- a/src/internal/internal_gescale.cc +++ b/src/internal/internal_gescale.cc @@ -88,32 +88,43 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A ) \ - firstprivate(device, irange, jrange, queue_index, denom, numer) priority(priority) + #pragma omp task priority( priority ) shared( A, irange, jrange ) \ + firstprivate( device, queue_index, denom, numer ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already // and possibly wrong, because an input matrix is being altered // todo: best, handle directly through the CUDA kernels - auto layout = Layout::ColMajor; + auto layout = LayoutConvert::ColMajor; std::set A_tiles_set; for (int64_t i = 0; i < A.mt(); ++i) { @@ -123,45 +134,56 @@ void scale(internal::TargetType, } } } - A.tileGetForWriting(A_tiles_set, device, LayoutConvert(layout)); + A.tileGetForWriting( A_tiles_set, device, layout ); - scalar_t** a_array_host = A.array_host(device); + scalar_t** a_array_host = A.array_host( device, queue_index ); int64_t batch_count = 0; - int64_t mb[8], nb[8], lda[8], group_count[8]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + struct Params { + int64_t count, mb, nb, lda; + }; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1 }; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; + ++batch_count; } + }} // for j, i + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj, ii - scalar_t** a_array_dev = A.array_device(device); - - blas::Queue* queue = A.compute_queue(device, queue_index); + blas::Queue* queue = A.compute_queue( device, queue_index ); + scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::gescale( - mb[q], nb[q], - numer, denom, a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - } + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + device::batch::gescale( + group_params[ g ].mb, group_params[ g ].nb, + numer, denom, a_array_dev, group_params[ g ].lda, + group_count, *queue); + a_array_dev += group_count; } queue->sync(); diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index d352fb11c..fb99497c6 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -110,32 +110,43 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A ) priority( priority ) \ - firstprivate(device, irange, jrange, queue_index, numer, denom) + #pragma omp task priority( priority ) shared( A, irange, jrange ) \ + firstprivate( device, queue_index, numer, denom ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already // and possibly wrong, because an input matrix is being altered // todo: best, handle directly through the CUDA kernels - auto layout = Layout::ColMajor; + auto layout = LayoutConvert::ColMajor; std::set A_tiles_set; if (A.uplo() == Uplo::Lower) { @@ -156,97 +167,128 @@ void scale(internal::TargetType, } } } - A.tileGetForWriting(A_tiles_set, device, LayoutConvert(layout)); + A.tileGetForWriting( A_tiles_set, device, layout ); - scalar_t** a_array_host = A.array_host(device); + scalar_t** a_array_host = A.array_host( device, queue_index ); + // Build batch groups int64_t batch_count = 0; - int64_t mb[8], nb[8], lda[8], group_count[8]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); + struct Params { + int64_t count, mb, nb, lda; + bool is_diagonal; + }; + std::vector group_params; + // Build batch groups for off-diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, false }; if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = std::max(j, irange[q][0]); i < irange[q][1]; ++i) { - if (i != j && A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal(i, j) + && device == A.tileDevice(i, j)) { + + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } - } - } - } - else { // upper - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = irange[q][0]; i < irange[q][1] && i <= j; ++i) { - if (i != j && A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); } + ++group.count; + ++batch_count; } - } + }} // for j,i } - } - for (int q = 4; q < 8; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q-4][0]); - nb[q] = A.tileNb(jrange[q-4][0]); - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = std::max(j, irange[q-4][0]); i < irange[q-4][1]; ++i) { - if (i == j && A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { // A.uplo() == Uplo::Upper + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { + if (A.tileIsLocal(i, j) + && device == A.tileDevice(i, j)) { + + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; + ++batch_count; } - } + }} // for j,i } - else { // upper - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = irange[q-4][0]; i < irange[q-4][1] && i <= j; ++i) { - if (i == j && A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } + if (group.count > 0) { + group_params.push_back( group ); + } + }} // for jj,ii + + // Build batch groups for diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, true }; + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) + { + auto Aij = A( ij, ij, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; + ++batch_count; } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); } - } - scalar_t** a_array_dev = A.array_device(device); + }} // for jj,ii - blas::Queue* queue = A.compute_queue(device, queue_index); + blas::Queue* queue = A.compute_queue( device, queue_index ); + scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::gescale(mb[q], nb[q], - numer, denom, a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + if (group_params[ g ].is_diagonal) { + device::batch::tzscale( + A.uplo(), + group_params[ g ].mb, group_params[ g ].nb, + numer, denom, a_array_dev, group_params[ g ].lda, + group_count, *queue); } - } - for (int q = 4; q < 8; ++q) { - if (group_count[q] > 0) { - device::tzscale(A.uplo(), mb[q], nb[q], - numer, denom, a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; + else { + device::batch::gescale( + group_params[ g ].mb, group_params[ g ].nb, + numer, denom, a_array_dev, group_params[ g ].lda, + group_count, *queue); } + a_array_dev += group_count; } - queue->sync(); } } diff --git a/src/omptarget/device_tzscale.cc b/src/omptarget/device_tzscale.cc index 4a762b239..28c42490c 100644 --- a/src/omptarget/device_tzscale.cc +++ b/src/omptarget/device_tzscale.cc @@ -12,6 +12,7 @@ namespace slate { namespace device { +namespace batch { //------------------------------------------------------------------------------ /// Batched routine for element-wise trapezoidal tile scale. @@ -123,5 +124,6 @@ void tzscale( std::complex** Aarray, int64_t lda, int64_t batch_count, blas::Queue& queue); +} // namespace batch } // namespace device } // namespace slate From e4354debbbead7629e481633494ce880f881b12c Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 13:14:34 -0400 Subject: [PATCH 03/35] Add regions to device copy --- src/internal/internal_gecopy.cc | 130 +++++++++++-------- src/internal/internal_tzcopy.cc | 218 ++++++++++++++++++++------------ 2 files changed, 213 insertions(+), 135 deletions(-) diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index a7c669bc8..c11a62187 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -145,27 +145,37 @@ void copy(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() }, - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() } - }; - int64_t jrange[4][2] = { - { 0, B.nt()-1 }, - { 0, B.nt()-1 }, - { B.nt()-1, B.nt() }, - { B.nt()-1, B.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A, B ) \ - firstprivate( device, irange, jrange, queue_index, call_tile_tick ) \ - priority(priority) + #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ + firstprivate( device, queue_index, call_tile_tick ) { std::set A_tiles_set; for (int64_t i = 0; i < B.mt(); ++i) { @@ -195,26 +205,40 @@ void copy(internal::TargetType, dst_scalar_t** b_array_host = B.array_host(device, queue_index); int64_t batch_count = 0; - int64_t mb[4], nb[4], lda[4], ldb[4], group_count[4]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q][0]); - nb[q] = B.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + struct Params { + int64_t count, mb, nb, lda, ldb; + }; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1 }; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; + ++batch_count; } + }} // for j, i + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj, ii // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the different types. @@ -239,25 +263,25 @@ void copy(internal::TargetType, is_conj = (A.op() == Op::ConjTrans || B.op() == Op::ConjTrans); } - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - if (is_trans) { - device::transpose_batch( - is_conj, - nb[q], mb[q], - a_array_dev, lda[q], - b_array_dev, ldb[q], - group_count[q], *queue); - } - else { - device::gecopy(mb[q], nb[q], - a_array_dev, lda[q], - b_array_dev, ldb[q], - group_count[q], *queue); - } - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + if (is_trans) { + device::transpose_batch( + is_conj, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].lda, + b_array_dev, group_params[ g ].ldb, + group_count, *queue); + } + else { + device::gecopy( + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].lda, + b_array_dev, group_params[ g ].ldb, + group_count, *queue); } + a_array_dev += group_count; + b_array_dev += group_count; } queue->sync(); diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index ce9b086c6..bde7de8d4 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -112,34 +112,37 @@ void copy(internal::TargetType, slate_error_if(A.uplo() != B.uplo()); bool lower = (B.uplo() == Uplo::Lower); - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[6][2] = { - // off-diagonal - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() }, - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() }, - // diagonal - { 0, std::min(B.mt(), B.nt())-1 }, - { std::min(B.mt(), B.nt())-1, std::min(B.mt(), B.nt()) } - }; - int64_t jrange[6][2] = { - // off-diagonal - { 0, B.nt()-1 }, - { 0, B.nt()-1 }, - { B.nt()-1, B.nt() }, - { B.nt()-1, B.nt() }, - // diagonal - { 0, std::min(B.mt(), B.nt())-1 }, - { std::min(B.mt(), B.nt())-1, std::min(B.mt(), B.nt()) } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A, B ) priority( priority ) \ - firstprivate(device, irange, jrange, lower, queue_index) + #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ + firstprivate(device, lower, queue_index) { std::set A_tiles, B_diag_tiles; for (int64_t i = 0; i < B.mt(); ++i) { @@ -172,52 +175,104 @@ void copy(internal::TargetType, src_scalar_t** a_array_host = A.array_host(device, queue_index); dst_scalar_t** b_array_host = B.array_host(device, queue_index); + // Build batch groups int64_t batch_count = 0; - int64_t mb[6], nb[6], lda[6], ldb[6], group_count[6]; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q][0]); - nb[q] = B.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (B.tileIsLocal(i, j) && - device == B.tileDevice(i, j) && - ( ( lower && i > j) || - (! lower && i < j) ) ) - { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; + struct Params { + int64_t count, mb, nb, lda, ldb; + bool is_diagonal; + }; + std::vector group_params; + // Build batch groups for off-diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1, false }; + if (A.uplo() == Uplo::Lower) { + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; ++batch_count; } - } + }} // for j,i } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q][0]); - nb[q] = B.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (B.tileIsLocal(j, j) && - device == B.tileDevice(j, j) ) - { - a_array_host[batch_count] = A(j, j, device).data(); - b_array_host[batch_count] = B(j, j, device).data(); - lda[q] = A(j, j, device).stride(); - ldb[q] = B(j, j, device).stride(); - ++group_count[q]; + else { // A.uplo() == Uplo::Upper + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { + if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; + ++batch_count; + } + }} // for j,i + } + if (group.count > 0) { + group_params.push_back( group ); + } + }} // for jj,ii + + // Build batch groups for diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1, true }; + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) && device == A.tileDevice( ij, ij )) { + auto Aij = A( ij, ij, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( ij, ij, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; ++batch_count; } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj,ii // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the differen types. @@ -236,26 +291,25 @@ void copy(internal::TargetType, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::gecopy(mb[q], nb[q], - a_array_dev, lda[q], - b_array_dev, ldb[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + if (group_params[ g ].is_diagonal) { + device::tzcopy( + B.uplo(), + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].lda, + b_array_dev, group_params[ g ].ldb, + group_count, *queue); } - } - for (int q = 4; q < 6; ++q) { - if (group_count[q] > 0) { - device::tzcopy(B.uplo(), - mb[q], nb[q], - a_array_dev, lda[q], - b_array_dev, ldb[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; + else { + device::gecopy( + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].lda, + b_array_dev, group_params[ g ].ldb, + group_count, *queue); } + a_array_dev += group_count; + b_array_dev += group_count; } queue->sync(); From b96b3902681b5a1442d7d967e5ff07a8c97a8508 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 13:39:25 -0400 Subject: [PATCH 04/35] Add regions to device scale_row_col --- src/internal/internal_gescale_row_col.cc | 131 ++++++++++++++--------- 1 file changed, 83 insertions(+), 48 deletions(-) diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index df3ec708f..3af2dbeee 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -67,26 +67,47 @@ void scale_row_col( { using ij_tuple = typename BaseMatrix::ij_tuple; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange, range_ioffset; + { + int64_t last_mb = -1; + int64_t ioffset = 0; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + range_ioffset.push_back( ioffset ); + ioffset += mb; + } + } + irange.push_back( mt ); + } + + // Find ranges of matching nb's. + std::vector< int64_t > jrange, range_joffset; + { + int last_nb = -1; + int64_t joffset = 0; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + range_joffset.push_back( joffset ); + joffset += nb; + } + } + jrange.push_back( nt ); + } #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( R, C, A ) \ - firstprivate( equed, device, irange, jrange ) + #pragma omp task shared( R, C, A, irange, jrange, range_ioffset, range_joffset ) \ + firstprivate( equed, device ) { bool want_row = equed == Equed::Both || equed == Equed::Row; bool want_col = equed == Equed::Both || equed == Equed::Col; @@ -126,37 +147,50 @@ void scale_row_col( } A.tileGetForWriting( A_tiles_set, device, layout ); - scalar_t** a_array_host = A.array_host( device ); + scalar_t** a_array_host = A.array_host( device, queue_index ); int64_t batch_count = 0; - int64_t mb[4], nb[4], lda[4], group_count[4]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb( irange[q][0] ); - nb[q] = A.tileNb( jrange[q][0] ); - int ii = A.tileMb( 0 ) * irange[q][0]; - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - int jj = A.tileNb( 0 ) * jrange[q][0]; - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { + struct Params { + int64_t count, mb, nb, lda; + }; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1 }; + int64_t joffset = range_joffset[ jj ]; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + int64_t ioffset = range_ioffset[ ii ]; + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { auto Aij = A( i, j, device ); if (want_row) - r_array_host[ batch_count ] = &dR[ ii ]; + r_array_host[ batch_count ] = &dR[ ioffset ]; if (want_col) - c_array_host[ batch_count ] = &dC[ jj ]; + c_array_host[ batch_count ] = &dC[ joffset ]; a_array_host[ batch_count ] = Aij.data(); - lda[q] = Aij.stride(); - ++group_count[q]; + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; ++batch_count; } - jj += A.tileNb( j ); - } - ii += A.tileMb( i ); + ioffset += A.tileMb( i ); + } // for i + joffset += A.tileNb( j ); + } // for j + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj, ii - scalar_t** a_array_dev = A.array_device( device ); + scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy< scalar_t* >( &a_array_dev[ 0 ], &a_array_host[ 0 ], batch_count, *queue); @@ -177,16 +211,17 @@ void scale_row_col( scalar_t2** r_array_data = r_array_dev.data(); scalar_t2** c_array_data = c_array_dev.data(); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::gescale_row_col_batch( - equed, mb[q], nb[q], - r_array_data, c_array_data, a_array_dev, lda[q], - group_count[q], *queue); - r_array_data += group_count[q]; - c_array_data += group_count[q]; - a_array_dev += group_count[q]; - } + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + device::gescale_row_col_batch( + equed, + group_params[ g ].mb, group_params[ g ].nb, + r_array_data, c_array_data, + a_array_dev, group_params[ g ].lda, + group_count, *queue); + r_array_data += group_count; + c_array_data += group_count; + a_array_dev += group_count; } // Clear the DevVectors, freeing device memory From 8adc605675d0122eaa8ca754118943ed05cbb279 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 15:17:13 -0400 Subject: [PATCH 05/35] Add regions to device add --- src/internal/internal_geadd.cc | 111 ++++++++++------ src/internal/internal_tzadd.cc | 231 +++++++++++++++++++-------------- 2 files changed, 202 insertions(+), 140 deletions(-) diff --git a/src/internal/internal_geadd.cc b/src/internal/internal_geadd.cc index 0db1b3a10..e08ed5ee5 100644 --- a/src/internal/internal_geadd.cc +++ b/src/internal/internal_geadd.cc @@ -125,26 +125,37 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() }, - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() } - }; - int64_t jrange[4][2] = { - { 0, B.nt()-1 }, - { 0, B.nt()-1 }, - { B.nt()-1, B.nt() }, - { B.nt()-1, B.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task shared(A, B) \ - firstprivate( device, irange, jrange, queue_index, beta, alpha, call_tile_tick) \ - priority(priority) + #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ + firstprivate( device, queue_index, beta, alpha, call_tile_tick ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -182,26 +193,40 @@ void add(internal::TargetType, scalar_t** b_array_host = a_array_host + batch_size; int64_t batch_count = 0; - int64_t mb[4], nb[4], lda[4], ldb[4], group_count[4]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q][0]); - nb[q] = B.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + struct Params { + int64_t count, mb, nb, lda, ldb; + }; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1 }; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); } + ++group.count; + ++batch_count; } + }} // for j, i + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj, ii slate_assert(batch_count == batch_size); scalar_t** a_array_dev = B.array_device(device, queue_index); @@ -214,15 +239,15 @@ void add(internal::TargetType, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::geadd(mb[q], nb[q], - alpha, a_array_dev, lda[q], - beta, b_array_dev, ldb[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; - } + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + device::batch::geadd( + group_params[ g ].mb, group_params[ g ].nb, + alpha, a_array_dev, group_params[ g ].lda, + beta, b_array_dev, group_params[ g ].ldb, + group_count, *queue); + a_array_dev += group_count; + b_array_dev += group_count; } queue->sync(); diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index e567d4083..e976b9fa9 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -141,25 +141,37 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() }, - { 0, B.mt()-1 }, - { B.mt()-1, B.mt() } - }; - int64_t jrange[4][2] = { - { 0, B.nt()-1 }, - { 0, B.nt()-1 }, - { B.nt()-1, B.nt() }, - { B.nt()-1, B.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task shared(A, B) priority(priority) \ - firstprivate(device, irange, jrange, queue_index, alpha, beta) + #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ + firstprivate(device, queue_index, alpha, beta) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -204,85 +216,110 @@ void add(internal::TargetType, } int64_t batch_size = A_tiles_set.size(); - scalar_t** a_array_host = B.array_host(device); + scalar_t** a_array_host = B.array_host( device, queue_index ); scalar_t** b_array_host = a_array_host + batch_size; + // Build batch groups int64_t batch_count = 0; - int64_t mb[8], nb[8], lda[8], ldb[8], group_count[8]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q][0]); - nb[q] = B.tileNb(jrange[q][0]); - if (B.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = std::max(j, irange[q][0]); i < irange[q][1]; ++i) { - if (i != j && B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + struct Params { + int64_t count, mb, nb, lda, ldb; + bool is_diagonal; + }; + std::vector group_params; + // Build batch groups for off-diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1, false }; + if (A.uplo() == Uplo::Lower) { + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); } - } - } - } - else { // upper - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = irange[q][0]; i < irange[q][1] && i <= j; ++i) { - if (i != j && B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); } + ++group.count; + ++batch_count; } - } + }} // for j,i } - } - for (int q = 4; q < 8; ++q) { - group_count[q] = 0; - lda[q] = 0; - ldb[q] = 0; - mb[q] = B.tileMb(irange[q-4][0]); - nb[q] = B.tileNb(jrange[q-4][0]); - if (B.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = std::max(j, irange[q-4][0]); i < irange[q-4][1]; ++i) { - if (i == j && B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { // A.uplo() == Uplo::Upper + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { + if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; + ++batch_count; } - } + }} // for j,i } - else { //upper - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = irange[q-4][0]; i < irange[q-4][1] && i <= j; ++i) { - if (i == j && B.tileIsLocal(i, j) && device == B.tileDevice(i, j)) { - a_array_host[batch_count] = A(i, j, device).data(); - b_array_host[batch_count] = B(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ldb[q] = B(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } + if (group.count > 0) { + group_params.push_back( group ); + } + }} // for jj,ii + + // Build batch groups for diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, -1, true }; + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) && device == A.tileDevice( ij, ij )) { + auto Aij = A( ij, ij, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( ij, ij, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); + group.ldb = Bij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + assert( group.ldb == Bij.stride() ); + } + ++group.count; + ++batch_count; } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); } - } - + }} // for jj,ii slate_assert(batch_count == batch_size); - scalar_t** a_array_dev = B.array_device(device); + scalar_t** a_array_dev = B.array_device( device, queue_index ); scalar_t** b_array_dev = a_array_dev + batch_size; blas::Queue* queue = A.compute_queue(device, queue_index); @@ -292,25 +329,25 @@ void add(internal::TargetType, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::geadd(mb[q], nb[q], - alpha, a_array_dev, lda[q], - beta, b_array_dev, ldb[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + if (group_params[ g ].is_diagonal) { + device::tzadd( + B.uplo(), + group_params[ g ].mb, group_params[ g ].nb, + alpha, a_array_dev, group_params[ g ].lda, + beta, b_array_dev, group_params[ g ].ldb, + group_count, *queue); } - } - for (int q = 4; q < 8; ++q) { - if (group_count[q] > 0) { - device::tzadd(B.uplo(), mb[q], nb[q], - alpha, a_array_dev, lda[q], - beta, b_array_dev, ldb[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - b_array_dev += group_count[q]; + else { + device::batch::geadd( + group_params[ g ].mb, group_params[ g ].nb, + alpha, a_array_dev, group_params[ g ].lda, + beta, b_array_dev, group_params[ g ].ldb, + group_count, *queue); } + a_array_dev += group_count; + b_array_dev += group_count; } queue->sync(); From 0a85dacd2207780137b81d07a81a3f9222514f67 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 20 Oct 2023 09:46:40 -0400 Subject: [PATCH 06/35] Refactor out logic to compute device ranges --- src/internal/internal_geadd.cc | 30 +++-------------- src/internal/internal_gecopy.cc | 30 +++-------------- src/internal/internal_gescale.cc | 30 +++-------------- src/internal/internal_gescale_row_col.cc | 42 ++++++++---------------- src/internal/internal_geset.cc | 30 +++-------------- src/internal/internal_tzadd.cc | 30 +++-------------- src/internal/internal_tzcopy.cc | 30 +++-------------- src/internal/internal_tzscale.cc | 30 +++-------------- src/internal/internal_tzset.cc | 30 +++-------------- src/internal/internal_util.hh | 30 +++++++++++++++++ 10 files changed, 75 insertions(+), 237 deletions(-) diff --git a/src/internal/internal_geadd.cc b/src/internal/internal_geadd.cc index e08ed5ee5..7dda6ee7a 100644 --- a/src/internal/internal_geadd.cc +++ b/src/internal/internal_geadd.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -125,32 +126,9 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index c11a62187..d9d30165f 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -11,6 +11,7 @@ #include "slate/Tile_blas.hh" #include "slate/Tile_aux.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { @@ -145,32 +146,9 @@ void copy(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { diff --git a/src/internal/internal_gescale.cc b/src/internal/internal_gescale.cc index 123c61f57..73fd59710 100644 --- a/src/internal/internal_gescale.cc +++ b/src/internal/internal_gescale.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -88,32 +89,9 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index 3af2dbeee..64c0c6270 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -11,6 +11,7 @@ #include "slate/Matrix.hh" #include "slate/types.hh" #include "tile/scale_row_col.hh" +#include "internal/internal_util.hh" namespace slate { @@ -67,41 +68,24 @@ void scale_row_col( { using ij_tuple = typename BaseMatrix::ij_tuple; - int64_t mt = A.mt(); - int64_t nt = A.nt(); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); - // Find ranges of matching mb's. - std::vector< int64_t > irange, range_ioffset; + // Compute global offsets of each block + std::vector< int64_t > range_ioffset (irange.size()-1); + std::vector< int64_t > range_joffset (jrange.size()-1); { - int64_t last_mb = -1; int64_t ioffset = 0; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - range_ioffset.push_back( ioffset ); - ioffset += mb; - } + for (size_t i = 0; i < range_ioffset.size(); ++i) { + range_ioffset[ i ] = ioffset; + ioffset += A.tileMb( irange[ i ] ) * (irange[ i+1 ] - irange[ i ]); } - irange.push_back( mt ); - } - - // Find ranges of matching nb's. - std::vector< int64_t > jrange, range_joffset; - { - int last_nb = -1; int64_t joffset = 0; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - range_joffset.push_back( joffset ); - joffset += nb; - } + for (size_t j = 0; j < range_joffset.size(); ++j) { + range_joffset[ j ] = joffset; + joffset += A.tileNb( jrange[ j ] ) * (jrange[ j+1 ] - jrange[ j ]); } - jrange.push_back( nt ); } #pragma omp taskgroup diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 2513de6f9..2cae91098 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -92,32 +93,9 @@ void set(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index e976b9fa9..a20fa88a7 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -141,32 +142,9 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index bde7de8d4..9069f42ae 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -11,6 +11,7 @@ #include "slate/Tile_blas.hh" #include "slate/Tile_aux.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -112,32 +113,9 @@ void copy(internal::TargetType, slate_error_if(A.uplo() != B.uplo()); bool lower = (B.uplo() == Uplo::Lower); - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index fb99497c6..eec8ac66e 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -110,32 +111,9 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index 1d17dc0c4..9dc12fa62 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -10,6 +10,7 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -120,32 +121,9 @@ void set( { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - int64_t mt = A.mt(); - int64_t nt = A.nt(); - - // Find ranges of matching mb's. - std::vector< int64_t > irange; - int64_t last_mb = -1; - for (int64_t i = 0; i < mt; ++i) { - int64_t mb = A.tileMb( i ); - if (mb != last_mb) { - last_mb = mb; - irange.push_back( i ); - } - } - irange.push_back( mt ); - - // Find ranges of matching nb's. - std::vector< int64_t > jrange; - int last_nb = -1; - for (int64_t j = 0; j < nt; ++j) { - int64_t nb = A.tileNb( j ); - if (nb != last_nb) { - last_nb = nb; - jrange.push_back( j ); - } - } - jrange.push_back( nt ); + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 11b080067..3da8320cd 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -110,6 +110,36 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, } +// Utilities for device batch regions + +//------------------------------------------------------------------------------ +/// Computes the range of tiles with either the same mb or the same nb +/// +/// @param[in] want_rows +/// If true, compute the row-ranges. Else, compute the column-ranges. +/// +/// @param[in] A +/// The matrix to get tile sizes from +/// +/// @return The ranges of uniform tile sizes +/// +template +std::vector device_regions_range( bool want_rows, BaseMatrix& A ) +{ + int64_t kt = want_rows ? A.mt() : A.nt(); + + std::vector< int64_t > range; + int64_t last = -1; + for (int64_t k = 0; k < kt; ++k) { + int64_t kb = want_rows ? A.tileMb( k ) : A.tileNb( k ); + if (kb != last) { + last = kb; + range.push_back( k ); + } + } + range.push_back( kt ); + return range; +} } // namespace internal } // namespace slate From 7f156c216810de66fec5c0ad5069c6e4dbd7139a Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 20 Oct 2023 14:51:37 -0400 Subject: [PATCH 07/35] Refactor duplicate logic for constructing device regions --- src/internal/internal_geadd.cc | 53 +----- src/internal/internal_gescale.cc | 48 +---- src/internal/internal_geset.cc | 102 +++-------- src/internal/internal_tzadd.cc | 121 ++----------- src/internal/internal_tzscale.cc | 111 +----------- src/internal/internal_tzset.cc | 113 +----------- src/internal/internal_util.hh | 302 +++++++++++++++++++++++++++++++ 7 files changed, 373 insertions(+), 477 deletions(-) diff --git a/src/internal/internal_geadd.cc b/src/internal/internal_geadd.cc index 7dda6ee7a..dea47b805 100644 --- a/src/internal/internal_geadd.cc +++ b/src/internal/internal_geadd.cc @@ -126,13 +126,10 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A, B ) \ firstprivate( device, queue_index, beta, alpha, call_tile_tick ) { // temporarily, convert both into same layout @@ -170,42 +167,10 @@ void add(internal::TargetType, scalar_t** a_array_host = B.array_host(device, queue_index); scalar_t** b_array_host = a_array_host + batch_size; - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda, ldb; - }; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1 }; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j, i - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj, ii - slate_assert(batch_count == batch_size); + auto group_params = device_regions_build( + {A, B}, + {a_array_host, b_array_host}, + device ); scalar_t** a_array_dev = B.array_device(device, queue_index); scalar_t** b_array_dev = a_array_dev + batch_size; @@ -213,7 +178,7 @@ void add(internal::TargetType, blas::Queue* queue = B.compute_queue(device, queue_index); blas::device_memcpy(a_array_dev, a_array_host, - batch_count*2, + batch_size*2, blas::MemcpyKind::HostToDevice, *queue); @@ -221,8 +186,8 @@ void add(internal::TargetType, int64_t group_count = group_params[ g ].count; device::batch::geadd( group_params[ g ].mb, group_params[ g ].nb, - alpha, a_array_dev, group_params[ g ].lda, - beta, b_array_dev, group_params[ g ].ldb, + alpha, a_array_dev, group_params[ g ].ld[0], + beta, b_array_dev, group_params[ g ].ld[1], group_count, *queue); a_array_dev += group_count; b_array_dev += group_count; diff --git a/src/internal/internal_gescale.cc b/src/internal/internal_gescale.cc index 73fd59710..f16f9a427 100644 --- a/src/internal/internal_gescale.cc +++ b/src/internal/internal_gescale.cc @@ -89,14 +89,10 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, irange, jrange ) \ - firstprivate( device, queue_index, denom, numer ) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A ) firstprivate( device, queue_index, denom, numer ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -114,52 +110,26 @@ void scale(internal::TargetType, } A.tileGetForWriting( A_tiles_set, device, layout ); + int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device, queue_index ); - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda; - }; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1 }; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j, i - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj, ii + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device ); blas::Queue* queue = A.compute_queue( device, queue_index ); scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( - a_array_dev, a_array_host, batch_count, + a_array_dev, a_array_host, batch_size, blas::MemcpyKind::HostToDevice, *queue); for (size_t g = 0; g < group_params.size(); ++g) { int64_t group_count = group_params[ g ].count; device::batch::gescale( group_params[ g ].mb, group_params[ g ].nb, - numer, denom, a_array_dev, group_params[ g ].lda, + numer, denom, a_array_dev, group_params[ g ].ld[0], group_count, *queue); a_array_dev += group_count; } diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 2cae91098..9cba4f683 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -99,8 +99,8 @@ void set(internal::TargetType, #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, irange, jrange ) \ - firstprivate( device, queue_index, offdiag_value, diag_value ) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A ) firstprivate( device, queue_index, offdiag_value, diag_value ) { // Get local tiles for writing. // convert to column major layout to simplify lda's @@ -118,98 +118,40 @@ void set(internal::TargetType, } A.tileGetForWriting( A_tiles_set, device, layout ); + int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device, queue_index ); // If offdiag == diag value, lump diag tiles with offdiag tiles // in one batch. bool diag_same = offdiag_value == diag_value; - // Build batch groups. - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda; - scalar_t diag_value; - }; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, offdiag_value }; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if ((diag_same || i != j) - && A.tileIsLocal( i, j ) - && device == A.tileDevice( i, j )) - { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j, i - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj, ii - - // Build batch groups for diagonal tiles, - // when offdiag_value != diag_value. - if (! diag_same) { - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, diag_value }; - // Diagonal tiles only in the intersection of irange and jrange - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) - { - auto Aij = A( ij, ij, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj, ii - } - + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device, diag_same ); blas::Queue* queue = A.compute_queue( device, queue_index ); scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( - a_array_dev, a_array_host, batch_count, + a_array_dev, a_array_host, batch_size, blas::MemcpyKind::HostToDevice, *queue); for (size_t g = 0; g < group_params.size(); ++g) { int64_t group_count = group_params[ g ].count; - device::batch::geset( - group_params[ g ].mb, group_params[ g ].nb, - offdiag_value, group_params[ g ].diag_value, - a_array_dev, group_params[ g ].lda, - group_count, *queue ); + if (group_params[ g ].is_diagonal) { + device::batch::geset( + group_params[ g ].mb, group_params[ g ].nb, + offdiag_value, diag_value, + a_array_dev, group_params[ g ].ld[0], + group_count, *queue ); + } + else { + device::batch::geset( + group_params[ g ].mb, group_params[ g ].nb, + offdiag_value, offdiag_value, + a_array_dev, group_params[ g ].ld[0], + group_count, *queue ); + } a_array_dev += group_count; } queue->sync(); diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index a20fa88a7..edbe1b00f 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -142,14 +142,10 @@ void add(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ - firstprivate(device, queue_index, alpha, beta) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A, B ) firstprivate(device, queue_index, alpha, beta, call_tile_tick) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -197,105 +193,10 @@ void add(internal::TargetType, scalar_t** a_array_host = B.array_host( device, queue_index ); scalar_t** b_array_host = a_array_host + batch_size; - // Build batch groups - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda, ldb; - bool is_diagonal; - }; - std::vector group_params; - // Build batch groups for off-diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1, false }; - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - else { // A.uplo() == Uplo::Upper - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii - - // Build batch groups for diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1, true }; - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) && device == A.tileDevice( ij, ij )) { - auto Aij = A( ij, ij, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( ij, ij, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii - slate_assert(batch_count == batch_size); + auto group_params = device_regions_build<2, scalar_t>( + {A, B}, + {a_array_host, b_array_host}, + device ); scalar_t** a_array_dev = B.array_device( device, queue_index ); scalar_t** b_array_dev = a_array_dev + batch_size; @@ -303,7 +204,7 @@ void add(internal::TargetType, blas::Queue* queue = A.compute_queue(device, queue_index); blas::device_memcpy(a_array_dev, a_array_host, - batch_count*2, + batch_size*2, blas::MemcpyKind::HostToDevice, *queue); @@ -313,15 +214,15 @@ void add(internal::TargetType, device::tzadd( B.uplo(), group_params[ g ].mb, group_params[ g ].nb, - alpha, a_array_dev, group_params[ g ].lda, - beta, b_array_dev, group_params[ g ].ldb, + alpha, a_array_dev, group_params[ g ].ld[0], + beta, b_array_dev, group_params[ g ].ld[1], group_count, *queue); } else { device::batch::geadd( group_params[ g ].mb, group_params[ g ].nb, - alpha, a_array_dev, group_params[ g ].lda, - beta, b_array_dev, group_params[ g ].ldb, + alpha, a_array_dev, group_params[ g ].ld[0], + beta, b_array_dev, group_params[ g ].ld[1], group_count, *queue); } a_array_dev += group_count; diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index eec8ac66e..234e1ce09 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -111,14 +111,10 @@ void scale(internal::TargetType, { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, irange, jrange ) \ - firstprivate( device, queue_index, numer, denom ) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A ) firstprivate( device, queue_index, numer, denom ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -147,106 +143,19 @@ void scale(internal::TargetType, } A.tileGetForWriting( A_tiles_set, device, layout ); + int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device, queue_index ); - // Build batch groups - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda; - bool is_diagonal; - }; - std::vector group_params; - // Build batch groups for off-diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, false }; - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - else { // A.uplo() == Uplo::Upper - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { - if (A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii - - // Build batch groups for diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, true }; - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) - { - auto Aij = A( ij, ij, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii + auto group_params = device_regions_build<1, scalar_t>( + {A}, + {a_array_host}, + device ); blas::Queue* queue = A.compute_queue( device, queue_index ); scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( - a_array_dev, a_array_host, batch_count, + a_array_dev, a_array_host, batch_size, blas::MemcpyKind::HostToDevice, *queue); for (size_t g = 0; g < group_params.size(); ++g) { @@ -256,13 +165,13 @@ void scale(internal::TargetType, device::batch::tzscale( A.uplo(), group_params[ g ].mb, group_params[ g ].nb, - numer, denom, a_array_dev, group_params[ g ].lda, + numer, denom, a_array_dev, group_params[ g ].ld[0], group_count, *queue); } else { device::batch::gescale( group_params[ g ].mb, group_params[ g ].nb, - numer, denom, a_array_dev, group_params[ g ].lda, + numer, denom, a_array_dev, group_params[ g ].ld[0], group_count, *queue); } a_array_dev += group_count; diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index 9dc12fa62..f80d0142d 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -121,14 +121,10 @@ void set( { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, irange, jrange ) \ - firstprivate( device, queue_index, offdiag_value, diag_value ) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A ) firstprivate( device, queue_index, offdiag_value, diag_value ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already @@ -157,108 +153,19 @@ void set( } A.tileGetForWriting( A_tiles_set, device, layout ); + int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device ); scalar_t** a_array_dev = A.array_device( device ); - // Build batch groups - int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda; - bool is_diagonal; - }; - std::vector group_params; - // Build batch groups for off-diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, false }; - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = std::max(irange[ ii ], j); i < irange[ ii+1 ]; ++i) { - if (i != j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - else { // A.uplo() == Uplo::Upper - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i <= j; ++i) { - if (i != j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii - - // Build batch groups for diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, true }; - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) - { - auto Aij = A( ij, ij, device ); - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj,ii + auto group_params = device_regions_build<1, scalar_t>( + {A}, + {a_array_host}, + device ); blas::Queue* queue = A.compute_queue(device, queue_index); blas::device_memcpy( - a_array_dev, a_array_host, batch_count, + a_array_dev, a_array_host, batch_size, blas::MemcpyKind::HostToDevice, *queue); for (size_t g = 0; g < group_params.size(); ++g) { @@ -269,14 +176,14 @@ void set( A.uplo(), group_params[ g ].mb, group_params[ g ].nb, offdiag_value, diag_value, - a_array_dev, group_params[ g ].lda, + a_array_dev, group_params[ g ].ld[0], group_count, *queue); } else { device::batch::geset( group_params[ g ].mb, group_params[ g ].nb, offdiag_value, offdiag_value, - a_array_dev, group_params[ g ].lda, + a_array_dev, group_params[ g ].ld[0], group_count, *queue ); } a_array_dev += group_count; diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 3da8320cd..4a78667e0 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -11,6 +11,7 @@ #include "slate/internal/mpi.hh" #include "slate/Matrix.hh" +#include "slate/BaseTrapezoidMatrix.hh" #include #include @@ -141,6 +142,307 @@ std::vector device_regions_range( bool want_rows, BaseMatrix& return range; } +//------------------------------------------------------------------------------ +/// Helper class to store the information on a device region +/// +/// @tparam has_diag +/// Wheather the diagonal tiles need to be special cased +/// +/// @tparam mat_count +/// The number of matrices used by the kernel +/// +template< bool has_diag, int mat_count > +struct device_regions_params { + int64_t count, mb, nb; + int64_t ld[mat_count]; + +private: + // When has_diag is false, we don't want to allocate any memory for is_diagonal + struct Empty {}; +public: + std::conditional_t< has_diag, bool, Empty > is_diagonal; + + device_regions_params() + : count(0), mb(0), nb(0) + { + for (int i = 0; i < mat_count; ++i) { + ld[i] = 0; + } + if constexpr (has_diag) { + is_diagonal = false; + } + } +}; + +//------------------------------------------------------------------------------ +/// Computes and populates the regions for the given matrices. +/// +/// @params[in] mats +/// An array of the matrices to build regions for +/// +/// @params[in] mats_array_host +/// An array of the arrays to fill with pointers to device data +/// +/// @params[in] device +/// The device to build regions for +/// +/// @params[in] diag_same +/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag +/// Ignored when has_diag is false. +template< bool has_diag, int mat_count, typename scalar_t> +std::vector< device_regions_params > device_regions_build( + std::array< std::reference_wrapper>, mat_count > mats, + std::array< scalar_t**, mat_count> mats_array_host, + int64_t device, + bool diag_same = has_diag) +{ + // The first two arguments should be valid targets for brace-initialization + // reference_wrapper works around fact that C++ doesn't allow array of references + + using Params = device_regions_params; + + auto& A = mats[0].get(); + + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); + + int64_t batch_count = 0; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if ((!has_diag || diag_same || i != j) + && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + + auto Aij = A( i, j, device ); + mats_array_host[ 0 ][ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.ld[0] = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.ld[0] == Aij.stride() ); + } + for (int m = 1; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( i, j, device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + ++group.count; + ++batch_count; + } + }} // for j, i + if (group.count > 0) { + group_params.push_back( group ); + } + + // If the diagonal needs special treatment, build the diagonal regions + if constexpr (has_diag) { + if (!diag_same) { + group = Params(); + group.is_diagonal = true; + // Diagonal tiles only in the intersection of irange and jrange + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) + { + auto Aij = A( ij, ij, device ); + mats_array_host[ 0 ][ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.ld[0] = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.ld[0] == Aij.stride() ); + } + for (int m = 1; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( ij, ij, device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + ++group.count; + ++batch_count; + } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); + } + } + } + }} // for jj, ii + return group_params; +} + +//------------------------------------------------------------------------------ +/// Computes and populates the regions for the given matrices. +/// +/// @params[in] mats +/// An array of the matrices to build regions for +/// +/// @params[in] mats_array_host +/// An array of the arrays to fill with pointers to device data +/// +/// @params[in] device +/// The device to build regions for +/// +/// @params[in] diag_same +/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag +/// Ignored when has_diag is false. +template< int mat_count, typename scalar_t> +std::vector< device_regions_params > device_regions_build( + std::array< std::reference_wrapper>, mat_count > mats, + std::array< scalar_t**, mat_count> mats_array_host, + int64_t device) +{ + // The first two arguments should be valid targets for brace-initialization + // reference_wrapper works around fact that C++ doesn't allow array of references + + using Params = device_regions_params; + + auto& A = mats[0].get(); + + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); + + int64_t batch_count = 0; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group; + if (A.uplo() == Uplo::Lower) { + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + + auto Aij = A( i, j, device ); + mats_array_host[ 0 ][ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.ld[0] = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.ld[0] == Aij.stride() ); + } + for (int m = 1; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( i, j, device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + ++group.count; + ++batch_count; + } + }} // for j,i + } + else { // A.uplo() == Uplo::Upper + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + + auto Aij = A( i, j, device ); + mats_array_host[ 0 ][ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.ld[0] = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.ld[0] == Aij.stride() ); + } + for (int m = 1; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( i, j, device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + ++group.count; + ++batch_count; + } + }} // for j, i + } + if (group.count > 0) { + group_params.push_back( group ); + } + + // Build the diagonal regions + group = Params(); + group.is_diagonal = true; + // Diagonal tiles only in the intersection of irange and jrange + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) + { + auto Aij = A( ij, ij, device ); + mats_array_host[ 0 ][ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.ld[0] = Aij.stride(); + } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.ld[0] == Aij.stride() ); + } + for (int m = 1; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( ij, ij, device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + ++group.count; + ++batch_count; + } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); + } + }} // for jj, ii + return group_params; +} + + } // namespace internal } // namespace slate From 2a3bb5823f27bc7750326b752822969a4442fb86 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 20 Oct 2023 15:45:18 -0400 Subject: [PATCH 08/35] Refactor out more code duplication --- src/internal/internal_tzadd.cc | 2 +- src/internal/internal_tzscale.cc | 2 +- src/internal/internal_tzset.cc | 2 +- src/internal/internal_util.hh | 254 +++++++------------------------ 4 files changed, 54 insertions(+), 206 deletions(-) diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index edbe1b00f..0cf703c67 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -193,7 +193,7 @@ void add(internal::TargetType, scalar_t** a_array_host = B.array_host( device, queue_index ); scalar_t** b_array_host = a_array_host + batch_size; - auto group_params = device_regions_build<2, scalar_t>( + auto group_params = device_regions_build( {A, B}, {a_array_host, b_array_host}, device ); diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index 234e1ce09..ab9747f1e 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -146,7 +146,7 @@ void scale(internal::TargetType, int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device, queue_index ); - auto group_params = device_regions_build<1, scalar_t>( + auto group_params = device_regions_build( {A}, {a_array_host}, device ); diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index f80d0142d..eee300295 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -157,7 +157,7 @@ void set( scalar_t** a_array_host = A.array_host( device ); scalar_t** a_array_dev = A.array_device( device ); - auto group_params = device_regions_build<1, scalar_t>( + auto group_params = device_regions_build( {A}, {a_array_host}, device ); diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 4a78667e0..591252d9e 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -146,7 +146,7 @@ std::vector device_regions_range( bool want_rows, BaseMatrix& /// Helper class to store the information on a device region /// /// @tparam has_diag -/// Wheather the diagonal tiles need to be special cased +/// Wheather the diagonal tiles may need to be special cased /// /// @tparam mat_count /// The number of matrices used by the kernel @@ -189,12 +189,13 @@ public: /// @params[in] diag_same /// Whether to treat the diagonal tiles as normal tiles in spite of has_diag /// Ignored when has_diag is false. +/// template< bool has_diag, int mat_count, typename scalar_t> std::vector< device_regions_params > device_regions_build( - std::array< std::reference_wrapper>, mat_count > mats, - std::array< scalar_t**, mat_count> mats_array_host, + std::array< std::reference_wrapper>, mat_count > mats, + std::array< scalar_t**, mat_count > mats_array_host, int64_t device, - bool diag_same = has_diag) + bool diag_same = true) { // The first two arguments should be valid targets for brace-initialization // reference_wrapper works around fact that C++ doesn't allow array of references @@ -207,148 +208,33 @@ std::vector< device_regions_params > device_regions_build( std::vector< int64_t > irange = device_regions_range( true, A ); std::vector< int64_t > jrange = device_regions_range( false, A ); - int64_t batch_count = 0; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if ((!has_diag || diag_same || i != j) - && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - - auto Aij = A( i, j, device ); - mats_array_host[ 0 ][ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.ld[0] = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.ld[0] == Aij.stride() ); - } - for (int m = 1; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( i, j, device ); - mats_array_host[ m ][ batch_count ] = Mij.data(); - if (group.count == 0) { - group.ld[m] = Mij.stride(); - } - else { - assert( group.ld[m] == Mij.stride() ); - } - } - ++group.count; - ++batch_count; - } - }} // for j, i - if (group.count > 0) { - group_params.push_back( group ); - } + // Trapezoidal matrices always need special treatment for diagonal tiles + diag_same &= A.uplo() == Uplo::General; - // If the diagonal needs special treatment, build the diagonal regions - if constexpr (has_diag) { - if (!diag_same) { - group = Params(); - group.is_diagonal = true; - // Diagonal tiles only in the intersection of irange and jrange - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) - { - auto Aij = A( ij, ij, device ); - mats_array_host[ 0 ][ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.ld[0] = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.ld[0] == Aij.stride() ); - } - for (int m = 1; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( ij, ij, device ); - mats_array_host[ m ][ batch_count ] = Mij.data(); - if (group.count == 0) { - group.ld[m] = Mij.stride(); - } - else { - assert( group.ld[m] == Mij.stride() ); - } - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } - } - } - }} // for jj, ii - return group_params; -} - -//------------------------------------------------------------------------------ -/// Computes and populates the regions for the given matrices. -/// -/// @params[in] mats -/// An array of the matrices to build regions for -/// -/// @params[in] mats_array_host -/// An array of the arrays to fill with pointers to device data -/// -/// @params[in] device -/// The device to build regions for -/// -/// @params[in] diag_same -/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag -/// Ignored when has_diag is false. -template< int mat_count, typename scalar_t> -std::vector< device_regions_params > device_regions_build( - std::array< std::reference_wrapper>, mat_count > mats, - std::array< scalar_t**, mat_count> mats_array_host, - int64_t device) -{ - // The first two arguments should be valid targets for brace-initialization - // reference_wrapper works around fact that C++ doesn't allow array of references - - using Params = device_regions_params; - - auto& A = mats[0].get(); - - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); + // Can't treat diagonals special when we can't store the diagonal status + assert( diag_same || has_diag ); + diag_same |= !has_diag; // Ensure the compiler can propagate this assertion int64_t batch_count = 0; + int64_t mt = A.mt(); std::vector group_params; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { Params group; - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - - auto Aij = A( i, j, device ); - mats_array_host[ 0 ][ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.ld[0] = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.ld[0] == Aij.stride() ); - } - for (int m = 1; m < mat_count; ++m) { + group.mb = A.tileMb( irange[ ii ] ); + group.nb = A.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + // Lower matrices start at j+1 + // Upper matrices end at j + // General matrices run the whole range + int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); + int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); + for (int64_t i = istart; i < iend; ++i) { + if ((!has_diag || diag_same || i != j) + && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + + // Add tiles to current group + for (int m = 0; m < mat_count; ++m) { auto Mij = mats[ m ].get()( i, j, device ); mats_array_host[ m ][ batch_count ] = Mij.data(); if (group.count == 0) { @@ -361,27 +247,29 @@ std::vector< device_regions_params > device_regions_build( ++group.count; ++batch_count; } - }} // for j,i + } // for i + } // for j + if (group.count > 0) { + group_params.push_back( group ); } - else { // A.uplo() == Uplo::Upper - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - - auto Aij = A( i, j, device ); - mats_array_host[ 0 ][ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.ld[0] = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.ld[0] == Aij.stride() ); - } - for (int m = 1; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( i, j, device ); + + // If the diagonal tiles need special treatment, build those groups + if constexpr (has_diag) if (!diag_same) { + group = Params(); + group.is_diagonal = true; + group.mb = A.tileMb( irange[ ii ] ); + group.nb = A.tileNb( jrange[ jj ] ); + // Diagonal tiles only in the intersection of irange and jrange + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) { + + // Add tiles to current group + // This logic matches that of above + for (int m = 0; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( ij, ij, device ); mats_array_host[ m ][ batch_count ] = Mij.data(); if (group.count == 0) { group.ld[m] = Mij.stride(); @@ -393,51 +281,11 @@ std::vector< device_regions_params > device_regions_build( ++group.count; ++batch_count; } - }} // for j, i - } - if (group.count > 0) { - group_params.push_back( group ); - } - - // Build the diagonal regions - group = Params(); - group.is_diagonal = true; - // Diagonal tiles only in the intersection of irange and jrange - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) - { - auto Aij = A( ij, ij, device ); - mats_array_host[ 0 ][ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.ld[0] = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.ld[0] == Aij.stride() ); - } - for (int m = 1; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( ij, ij, device ); - mats_array_host[ m ][ batch_count ] = Mij.data(); - if (group.count == 0) { - group.ld[m] = Mij.stride(); - } - else { - assert( group.ld[m] == Mij.stride() ); - } - } - ++group.count; - ++batch_count; + } // for ij + if (group.count > 0) { + group_params.push_back( group ); } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); - } + } // if has_diag && !diag_same }} // for jj, ii return group_params; } From 6561d7a1c3dba5e408d72b09a7c3a34a0ce0b5a7 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Mon, 23 Oct 2023 12:17:26 -0400 Subject: [PATCH 09/35] Add regions to device gemm --- src/internal/internal_gemm.cc | 214 ++++++---------------------------- src/internal/internal_util.hh | 13 ++- 2 files changed, 47 insertions(+), 180 deletions(-) diff --git a/src/internal/internal_gemm.cc b/src/internal/internal_gemm.cc index 9c76d512e..b348b7b9b 100644 --- a/src/internal/internal_gemm.cc +++ b/src/internal/internal_gemm.cc @@ -8,6 +8,7 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -479,136 +480,18 @@ void gemm(internal::TargetType, int64_t batch_size = C_tiles_set.size(); - // interior, excluding bottom row and right column - std::vector a_array00; - std::vector b_array00; - std::vector c_array00; - a_array00.reserve( batch_size ); - b_array00.reserve( batch_size ); - c_array00.reserve( batch_size ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb(0); - int64_t nb00 = C.tileNb(0); - int64_t kb = A.tileNb(0); - for (int64_t i = 0; i < C.mt()-1; ++i) { - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array00.push_back( A(i, 0, device).data() ); - b_array00.push_back( B(0, j, device).data() ); - c_array00.push_back( C(i, j, device).data() ); - lda00 = A(i, 0, device).stride(); - ldb00 = B(0, j, device).stride(); - ldc00 = C(i, j, device).stride(); - } - } - } - } - - // bottom row - std::vector a_array10; - std::vector b_array10; - std::vector c_array10; - a_array10.reserve( batch_size ); - b_array10.reserve( batch_size ); - c_array10.reserve( batch_size ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb(C.mt()-1); - int64_t nb10 = C.tileNb(0); - // same kb as above - { - int64_t i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array10.push_back( A(i, 0, device).data() ); - b_array10.push_back( B(0, j, device).data() ); - c_array10.push_back( C(i, j, device).data() ); - lda10 = A(i, 0, device).stride(); - ldb10 = B(0, j, device).stride(); - ldc10 = C(i, j, device).stride(); - } - } - } - } - - // right column - std::vector a_array01; - std::vector b_array01; - std::vector c_array01; - a_array01.reserve( batch_size ); - b_array01.reserve( batch_size ); - c_array01.reserve( batch_size ); - - int64_t lda01 = 0; - int64_t ldb01 = 0; - int64_t ldc01 = 0; - int64_t mb01 = C.tileMb(0); - int64_t nb01 = C.tileNb(C.nt()-1); - // same kb as above - { - int64_t j = C.nt()-1; - for (int64_t i = 0; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array01.push_back( A(i, 0, device).data() ); - b_array01.push_back( B(0, j, device).data() ); - c_array01.push_back( C(i, j, device).data() ); - lda01 = A(i, 0, device).stride(); - ldb01 = B(0, j, device).stride(); - ldc01 = C(i, j, device).stride(); - } - } - } - } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; - // bottom-right corner - std::vector a_array11; - std::vector b_array11; - std::vector c_array11; - - int64_t lda11 = 0; - int64_t ldb11 = 0; - int64_t ldc11 = 0; - int64_t mb11 = C.tileMb(C.mt()-1); - int64_t nb11 = C.tileNb(C.nt()-1); - // same kb as above - { - int i = C.mt()-1; - int j = C.nt()-1; - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array11.push_back( A(i, 0, device).data() ); - b_array11.push_back( B(0, j, device).data() ); - c_array11.push_back( C(i, j, device).data() ); - lda11 = A(i, 0, device).stride(); - ldb11 = B(0, j, device).stride(); - ldc11 = C(i, j, device).stride(); - } - } - } + // C comes first since we do computation for a local C + auto group_params = device_regions_build( + {C, A, B}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n swap(opA, opB); - swap(a_array00, b_array00); - swap(a_array10, b_array10); - swap(a_array01, b_array01); - swap(a_array11, b_array11); - swap(lda00, ldb00); - swap(lda10, ldb10); - swap(lda01, ldb01); - swap(lda11, ldb11); - swap(mb00, nb00); - swap(mb10, nb10); - swap(mb01, nb01); - swap(mb11, nb11); } { @@ -618,71 +501,44 @@ void gemm(internal::TargetType, std::vector opB_(1, opB); std::vector alpha_(1, alpha); std::vector beta_(1, beta); - std::vector k(1, kb); + std::vector k(1, A.tileNb(0)); // info size 0 disables slow checks in batched BLAS++. std::vector info; blas::Queue* queue = C.compute_queue(device, queue_index); assert(queue != nullptr); - if (c_array00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array00, ldda, - b_array00, lddb, - beta_, c_array00, lddc, - c_array00.size(), info, *queue); - } + for (size_t g = 0; g < group_params.size(); ++g) { - if (c_array10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array10, ldda, - b_array10, lddb, - beta_, c_array10, lddc, - c_array10.size(), info, *queue); - } + int64_t group_count = group_params[ g ].count; - if (c_array01.size() > 0) { - std::vector m(1, mb01); - std::vector n(1, nb01); - std::vector ldda(1, lda01); - std::vector lddb(1, ldb01); - std::vector lddc(1, ldc01); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array01, ldda, - b_array01, lddb, - beta_, c_array01, lddc, - c_array01.size(), info, *queue); - } + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddb(1, group_params[ g ].ld[2]); + std::vector lddc(1, group_params[ g ].ld[0]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); + + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(ldda, lddb); + } - if (c_array11.size() > 0) { - std::vector m(1, mb11); - std::vector n(1, nb11); - std::vector ldda(1, lda11); - std::vector lddb(1, ldb11); - std::vector lddc(1, ldc11); blas::batch::gemm( layout, opA_, opB_, m, n, k, - alpha_, a_array11, ldda, - b_array11, lddb, - beta_, c_array11, lddc, - c_array11.size(), info, *queue); + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } queue->sync(); diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 591252d9e..f2f69d077 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -215,6 +215,17 @@ std::vector< device_regions_params > device_regions_build( assert( diag_same || has_diag ); diag_same |= !has_diag; // Ensure the compiler can propagate this assertion + // Single dimensions are always indexed as 0. This allows setting up GEMM et al. + // The first matrix is always indexed normally since it determines the loops + int64_t i_step[mat_count]; + int64_t j_step[mat_count]; + i_step[0] = 1; + j_step[0] = 1; + for (int m = 1; m < mat_count; ++m) { + i_step[m] = (mats[ m ].get().mt() > 1); + j_step[m] = (mats[ m ].get().nt() > 1); + } + int64_t batch_count = 0; int64_t mt = A.mt(); std::vector group_params; @@ -235,7 +246,7 @@ std::vector< device_regions_params > device_regions_build( // Add tiles to current group for (int m = 0; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( i, j, device ); + auto Mij = mats[ m ].get()( i*i_step[m], j*j_step[m], device ); mats_array_host[ m ][ batch_count ] = Mij.data(); if (group.count == 0) { group.ld[m] = Mij.stride(); From d219186f4a0a14f305c5f498a7d6de03e788d97a Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 09:40:12 -0400 Subject: [PATCH 10/35] Refactor generic regions code to handle scale_row_col and copy --- src/internal/internal_gecopy.cc | 70 +++++------- src/internal/internal_gescale_row_col.cc | 94 ++++++---------- src/internal/internal_tzcopy.cc | 131 +++++------------------ src/internal/internal_util.hh | 45 ++++++-- 4 files changed, 122 insertions(+), 218 deletions(-) diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index d9d30165f..9bc6f34d4 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -146,14 +146,10 @@ void copy(internal::TargetType, bool call_tile_tick = tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ - firstprivate( device, queue_index, call_tile_tick ) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A, B ) firstprivate( device, queue_index, call_tile_tick ) { std::set A_tiles_set; for (int64_t i = 0; i < B.mt(); ++i) { @@ -182,41 +178,29 @@ void copy(internal::TargetType, src_scalar_t** a_array_host = A.array_host(device, queue_index); dst_scalar_t** b_array_host = B.array_host(device, queue_index); + // Because A and B may be different types and C++ doesn't easily + // support iterating over tuples. We manually handle A + std::vector lda; int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda, ldb; - }; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1 }; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j, i - if (group.count > 0) { - group_params.push_back( group ); + std::function + setup_A = [&] (int64_t group, int64_t i, int64_t j) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (lda.size() == size_t(group)) { + lda.push_back( Aij.stride() ); + } + else { + assert(lda.size() > size_t(group)); + assert(lda[group] == Aij.stride()); } - }} // for jj, ii + ++batch_count; + }; + auto group_params = device_regions_build( + {B}, + {b_array_host}, + device, + true, + setup_A ); // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the different types. @@ -247,15 +231,15 @@ void copy(internal::TargetType, device::transpose_batch( is_conj, group_params[ g ].mb, group_params[ g ].nb, - a_array_dev, group_params[ g ].lda, - b_array_dev, group_params[ g ].ldb, + a_array_dev, lda[ g ], + b_array_dev, group_params[ g ].ld[0], group_count, *queue); } else { device::gecopy( group_params[ g ].mb, group_params[ g ].nb, - a_array_dev, group_params[ g ].lda, - b_array_dev, group_params[ g ].ldb, + a_array_dev, lda[ g ], + b_array_dev, group_params[ g ].ld[0], group_count, *queue); } a_array_dev += group_count; diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index 64c0c6270..7affbd9ac 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -68,29 +68,9 @@ void scale_row_col( { using ij_tuple = typename BaseMatrix::ij_tuple; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - - // Compute global offsets of each block - std::vector< int64_t > range_ioffset (irange.size()-1); - std::vector< int64_t > range_joffset (jrange.size()-1); - { - int64_t ioffset = 0; - for (size_t i = 0; i < range_ioffset.size(); ++i) { - range_ioffset[ i ] = ioffset; - ioffset += A.tileMb( irange[ i ] ) * (irange[ i+1 ] - irange[ i ]); - } - int64_t joffset = 0; - for (size_t j = 0; j < range_joffset.size(); ++j) { - range_joffset[ j ] = joffset; - joffset += A.tileNb( jrange[ j ] ) * (jrange[ j+1 ] - jrange[ j ]); - } - } - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task shared( R, C, A, irange, jrange, range_ioffset, range_joffset ) \ + #pragma omp task shared( R, C, A ) \ firstprivate( equed, device ) { bool want_row = equed == Equed::Both || equed == Equed::Row; @@ -116,6 +96,24 @@ void scale_row_col( c_array_dev .resize( A.batchArraySize(), device, *queue ); } + std::vector< int64_t > ioffsets, joffsets; + if (want_row) { + ioffsets.reserve(A.mt()); + int64_t offset = 0; + for (int64_t i = 0; i < A.mt(); ++i) { + ioffsets.push_back( offset ); + offset += A.tileMb( i ); + } + } + if (want_col) { + joffsets.reserve(A.nt()); + int64_t offset = 0; + for (int64_t j = 0; j < A.nt(); ++j) { + joffsets.push_back( offset ); + offset += A.tileNb( j ); + } + } + // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already // and possibly wrong, because an input matrix is being altered @@ -134,45 +132,21 @@ void scale_row_col( scalar_t** a_array_host = A.array_host( device, queue_index ); int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda; + std::function + store_rc = [&](int64_t group, int64_t i, int64_t j) { + if (want_row) + r_array_host[ batch_count ] = &dR[ ioffsets[ i ] ]; + if (want_col) + c_array_host[ batch_count ] = &dC[ joffsets[ j ] ]; + ++batch_count; }; - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1 }; - int64_t joffset = range_joffset[ jj ]; - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - int64_t ioffset = range_ioffset[ ii ]; - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - auto Aij = A( i, j, device ); - if (want_row) - r_array_host[ batch_count ] = &dR[ ioffset ]; - if (want_col) - c_array_host[ batch_count ] = &dC[ joffset ]; - a_array_host[ batch_count ] = Aij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - } - ++group.count; - ++batch_count; - } - ioffset += A.tileMb( i ); - } // for i - joffset += A.tileNb( j ); - } // for j - if (group.count > 0) { - group_params.push_back( group ); - } - }} // for jj, ii + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device, + true, + store_rc ); + scalar_t** a_array_dev = A.array_device( device, queue_index ); @@ -201,7 +175,7 @@ void scale_row_col( equed, group_params[ g ].mb, group_params[ g ].nb, r_array_data, c_array_data, - a_array_dev, group_params[ g ].lda, + a_array_dev, group_params[ g ].ld[0], group_count, *queue); r_array_data += group_count; c_array_data += group_count; diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index 9069f42ae..7ea8079ba 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -113,14 +113,10 @@ void copy(internal::TargetType, slate_error_if(A.uplo() != B.uplo()); bool lower = (B.uplo() == Uplo::Lower); - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { - #pragma omp task priority( priority ) shared( A, B, irange, jrange ) \ - firstprivate(device, lower, queue_index) + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A, B ) firstprivate(device, lower, queue_index) { std::set A_tiles, B_diag_tiles; for (int64_t i = 0; i < B.mt(); ++i) { @@ -153,104 +149,29 @@ void copy(internal::TargetType, src_scalar_t** a_array_host = A.array_host(device, queue_index); dst_scalar_t** b_array_host = B.array_host(device, queue_index); - // Build batch groups + // Because A and B may be different types and C++ doesn't easily + // support iterating over tuples. We manually handle A + std::vector lda; int64_t batch_count = 0; - struct Params { - int64_t count, mb, nb, lda, ldb; - bool is_diagonal; - }; - std::vector group_params; - // Build batch groups for off-diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1, false }; - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = std::max(irange[ ii ], j+1); i < irange[ ii+1 ]; ++i) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - else { // A.uplo() == Uplo::Upper - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i < j; ++i) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - auto Aij = A( i, j, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( i, j, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - }} // for j,i - } - if (group.count > 0) { - group_params.push_back( group ); + std::function + setup_A = [&] (int64_t group, int64_t i, int64_t j) { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (lda.size() == size_t(group)) { + lda.push_back( Aij.stride() ); } - }} // for jj,ii - - // Build batch groups for diagonal tiles, - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - Params group = { 0, -1, -1, -1, -1, true }; - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) && device == A.tileDevice( ij, ij )) { - auto Aij = A( ij, ij, device ); - a_array_host[ batch_count ] = Aij.data(); - auto Bij = B( ij, ij, device ); - b_array_host[ batch_count ] = Bij.data(); - if (group.count == 0) { - group.mb = Aij.mb(); - group.nb = Aij.nb(); - group.lda = Aij.stride(); - group.ldb = Bij.stride(); - } - else { - assert( group.mb == Aij.mb() ); - assert( group.nb == Aij.nb() ); - assert( group.lda == Aij.stride() ); - assert( group.ldb == Bij.stride() ); - } - ++group.count; - ++batch_count; - } - } // for ij - if (group.count > 0) { - group_params.push_back( group ); + else { + assert(lda.size() > size_t(group)); + assert(lda[group] == Aij.stride()); } - }} // for jj,ii + ++batch_count; + }; + auto group_params = device_regions_build( + {B}, + {b_array_host}, + device, + true, + setup_A ); // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the differen types. @@ -275,15 +196,15 @@ void copy(internal::TargetType, device::tzcopy( B.uplo(), group_params[ g ].mb, group_params[ g ].nb, - a_array_dev, group_params[ g ].lda, - b_array_dev, group_params[ g ].ldb, + a_array_dev, lda[ g ], + b_array_dev, group_params[ g ].ld[0], group_count, *queue); } else { device::gecopy( group_params[ g ].mb, group_params[ g ].nb, - a_array_dev, group_params[ g ].lda, - b_array_dev, group_params[ g ].ldb, + a_array_dev, lda[ g ], + b_array_dev, group_params[ g ].ld[0], group_count, *queue); } a_array_dev += group_count; diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index f2f69d077..72d86cb90 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -143,7 +143,7 @@ std::vector device_regions_range( bool want_rows, BaseMatrix& } //------------------------------------------------------------------------------ -/// Helper class to store the information on a device region +/// Helper class to store the information on a device region. /// /// @tparam has_diag /// Wheather the diagonal tiles may need to be special cased @@ -177,25 +177,36 @@ public: //------------------------------------------------------------------------------ /// Computes and populates the regions for the given matrices. /// -/// @params[in] mats +/// @tparam has_diag +/// Wheather the diagonal tiles may need to be special cased +/// +/// @tparam mat_count +/// The number of matrices used by the kernel +/// +/// @param[in] mats /// An array of the matrices to build regions for /// -/// @params[in] mats_array_host +/// @param[in] mats_array_host /// An array of the arrays to fill with pointers to device data /// -/// @params[in] device +/// @param[in] device /// The device to build regions for /// -/// @params[in] diag_same +/// @param[in] diag_same /// Whether to treat the diagonal tiles as normal tiles in spite of has_diag /// Ignored when has_diag is false. /// +/// @param[in] extra_setup +/// Callback that is called whenever a tile is added to a group. +/// The group index and the tile indices are passed as arguments +/// template< bool has_diag, int mat_count, typename scalar_t> std::vector< device_regions_params > device_regions_build( std::array< std::reference_wrapper>, mat_count > mats, std::array< scalar_t**, mat_count > mats_array_host, int64_t device, - bool diag_same = true) + bool diag_same = true, + std::function extra_setup = {}) { // The first two arguments should be valid targets for brace-initialization // reference_wrapper works around fact that C++ doesn't allow array of references @@ -215,7 +226,8 @@ std::vector< device_regions_params > device_regions_build( assert( diag_same || has_diag ); diag_same |= !has_diag; // Ensure the compiler can propagate this assertion - // Single dimensions are always indexed as 0. This allows setting up GEMM et al. + // Size 1 dimensions get broadcast to allow setting up GEMM et al. + // i_step[m]=0 results in only accessing row 0 of matrix m (likewise for j) // The first matrix is always indexed normally since it determines the loops int64_t i_step[mat_count]; int64_t j_step[mat_count]; @@ -231,13 +243,16 @@ std::vector< device_regions_params > device_regions_build( std::vector group_params; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region. If any should be computed on this + // process & device, save them. Params group; group.mb = A.tileMb( irange[ ii ] ); group.nb = A.tileNb( jrange[ jj ] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - // Lower matrices start at j+1 - // Upper matrices end at j - // General matrices run the whole range + // This is a column major loop. So, + // * Lower matrices start at j+1 + // * Upper matrices end at j + // * General matrices run the whole range int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); for (int64_t i = istart; i < iend; ++i) { @@ -255,17 +270,23 @@ std::vector< device_regions_params > device_regions_build( assert( group.ld[m] == Mij.stride() ); } } + if (extra_setup) { + extra_setup( group_params.size(), i, j ); + } ++group.count; ++batch_count; } } // for i } // for j + // If any tiles in the region should be computed here, save the group if (group.count > 0) { group_params.push_back( group ); } // If the diagonal tiles need special treatment, build those groups if constexpr (has_diag) if (!diag_same) { + // Loop over the diagonal tiles in this region. If any should be + // computed on this process & device, save them. group = Params(); group.is_diagonal = true; group.mb = A.tileMb( irange[ ii ] ); @@ -289,10 +310,14 @@ std::vector< device_regions_params > device_regions_build( assert( group.ld[m] == Mij.stride() ); } } + if (extra_setup) { + extra_setup( group_params.size(), ij, ij ); + } ++group.count; ++batch_count; } } // for ij + // If any tiles in the region should be computed here, save the group if (group.count > 0) { group_params.push_back( group ); } From 99d29b2115699c1f0b3ecb15b8eb2130e074404f Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 09:38:26 -0400 Subject: [PATCH 11/35] Add regions to device trmm and trsm --- src/internal/internal_trmm.cc | 121 +++++++++------------------------- src/internal/internal_trsm.cc | 121 +++++++++------------------------- 2 files changed, 60 insertions(+), 182 deletions(-) diff --git a/src/internal/internal_trmm.cc b/src/internal/internal_trmm.cc index 61b4f8bd9..8f01a116e 100644 --- a/src/internal/internal_trmm.cc +++ b/src/internal/internal_trmm.cc @@ -8,6 +8,7 @@ #include "slate/types.hh" #include "slate/Tile_blas.hh" #include "internal/internal.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -225,78 +226,14 @@ void trmm(internal::TargetType, A.tileGetForReading(0, 0, device, LayoutConvert(layout)); B.tileGetForWriting(B_tiles_set, device, LayoutConvert(layout)); - // interior col or row - std::vector a_array0; - std::vector b_array0; - a_array0.reserve( batch_size ); - b_array0.reserve( batch_size ); - - // bottom-right tile - // todo: replace batch trmm with plain trmm - std::vector a_array1; - std::vector b_array1; - - int64_t lda0 = 0; - int64_t ldb0 = 0; - int64_t lda1 = 0; - int64_t ldb1 = 0; - - int64_t mb0 = B.tileMb(0); - int64_t nb0 = B.tileNb(0); - int64_t mb1 = B.tileMb(B.mt()-1); - int64_t nb1 = B.tileNb(B.nt()-1); - - if (side == Side::Right) { - for (int64_t i = 0; i < B.mt()-1; ++i) { - if (B.tileIsLocal(i, 0) - && device == B.tileDevice(i, 0)) - { - a_array0.push_back( A(0, 0, device).data() ); - b_array0.push_back( B(i, 0, device).data() ); - lda0 = A(0, 0, device).stride(); - ldb0 = B(i, 0, device).stride(); - } - } - { - int64_t i = B.mt()-1; - if (B.tileIsLocal(i, 0) - && device == B.tileDevice(i, 0)) - { - a_array1.push_back( A(0, 0, device).data() ); - b_array1.push_back( B(i, 0, device).data() ); - lda1 = A(0, 0, device).stride(); - ldb1 = B(i, 0, device).stride(); - } - } - } - else { - for (int64_t j = 0; j < B.nt()-1; ++j) { - if (B.tileIsLocal(0, j) - && device == B.tileDevice(0, j)) - { - a_array0.push_back( A(0, 0, device).data() ); - b_array0.push_back( B(0, j, device).data() ); - lda0 = A(0, 0, device).stride(); - ldb0 = B(0, j, device).stride(); - } - } - { - int64_t j = B.nt()-1; - if (B.tileIsLocal(0, j) - && device == B.tileDevice(0, j)) - { - a_array1.push_back( A(0, 0, device).data() ); - b_array1.push_back( B(0, j, device).data() ); - lda1 = A(0, 0, device).stride(); - ldb1 = B(0, j, device).stride(); - } - } - } + scalar_t** a_array_host = B.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; - if (B.op() != Op::NoTrans) { - swap(mb0, nb0); - swap(mb1, nb1); - } + // B comes first since we do computation for a local B + auto group_params = device_regions_build( + {B, A}, + {b_array_host, a_array_host}, + device ); { trace::Block trace_block("blas::batch::trmm"); @@ -306,35 +243,37 @@ void trmm(internal::TargetType, std::vector opA_(1, opA ); std::vector diag_(1, diagA); std::vector alpha_(1, alpha); + // info size 0 disables slow checks in batched BLAS++. std::vector info; blas::Queue* queue = B.compute_queue(device, queue_index); assert(queue != nullptr); - if (a_array0.size() > 0) { - std::vector m(1, mb0); - std::vector n(1, nb0); - std::vector lda(1, lda0); - std::vector ldb(1, ldb0); - blas::batch::trmm( - layout, side_, uplo_, opA_, diag_, - m, n, - alpha_, a_array0, lda, - b_array0, ldb, - a_array0.size(), info, *queue); - } + for (size_t g = 0; g < group_params.size(); ++g) { + + int64_t group_count = group_params[ g ].count; + + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddb(1, group_params[ g ].ld[0]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + + if (B.op() != Op::NoTrans) { + swap(m, n); + } - if (a_array1.size() > 0) { - std::vector m(1, mb1); - std::vector n(1, nb1); - std::vector lda(1, lda1); - std::vector ldb(1, ldb1); blas::batch::trmm( layout, side_, uplo_, opA_, diag_, m, n, - alpha_, a_array1, lda, - b_array1, ldb, - a_array1.size(), info, *queue); + alpha_, a_array, ldda, + b_array, lddb, + group_count, info, *queue); + + a_array_host += group_count; + b_array_host += group_count; } queue->sync(); diff --git a/src/internal/internal_trsm.cc b/src/internal/internal_trsm.cc index 971e3afa2..bc4dffcf4 100644 --- a/src/internal/internal_trsm.cc +++ b/src/internal/internal_trsm.cc @@ -9,6 +9,7 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -214,78 +215,14 @@ void trsm(internal::TargetType, A.tileGetForReading(0, 0, device, LayoutConvert(layout)); B.tileGetForWriting(B_tiles_set, device, LayoutConvert(layout)); - // interior col or row - std::vector a_array0; - std::vector b_array0; - a_array0.reserve( batch_size ); - b_array0.reserve( batch_size ); - - // bottom-right tile - // todo: replace batch trsm with plain trsm - std::vector a_array1; - std::vector b_array1; - - int64_t lda0 = 0; - int64_t ldb0 = 0; - int64_t lda1 = 0; - int64_t ldb1 = 0; - - int64_t mb0 = B.tileMb(0); - int64_t nb0 = B.tileNb(0); - int64_t mb1 = B.tileMb(B.mt()-1); - int64_t nb1 = B.tileNb(B.nt()-1); - - if (side == Side::Right) { - for (int64_t i = 0; i < B.mt()-1; ++i) { - if (B.tileIsLocal(i, 0) - && device == B.tileDevice(i, 0)) - { - a_array0.push_back( A(0, 0, device).data() ); - b_array0.push_back( B(i, 0, device).data() ); - lda0 = A(0, 0, device).stride(); - ldb0 = B(i, 0, device).stride(); - } - } - { - int64_t i = B.mt()-1; - if (B.tileIsLocal(i, 0) - && device == B.tileDevice(i, 0)) - { - a_array1.push_back( A(0, 0, device).data() ); - b_array1.push_back( B(i, 0, device).data() ); - lda1 = A(0, 0, device).stride(); - ldb1 = B(i, 0, device).stride(); - } - } - } - else { - for (int64_t j = 0; j < B.nt()-1; ++j) { - if (B.tileIsLocal(0, j) - && device == B.tileDevice(0, j)) - { - a_array0.push_back( A(0, 0, device).data() ); - b_array0.push_back( B(0, j, device).data() ); - lda0 = A(0, 0, device).stride(); - ldb0 = B(0, j, device).stride(); - } - } - { - int64_t j = B.nt()-1; - if (B.tileIsLocal(0, j) - && device == B.tileDevice(0, j)) - { - a_array1.push_back( A(0, 0, device).data() ); - b_array1.push_back( B(0, j, device).data() ); - lda1 = A(0, 0, device).stride(); - ldb1 = B(0, j, device).stride(); - } - } - } + scalar_t** a_array_host = B.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; - if (B.op() != Op::NoTrans) { - swap(mb0, nb0); - swap(mb1, nb1); - } + // B comes first since we do computation for a local B + auto group_params = device_regions_build( + {B, A}, + {b_array_host, a_array_host}, + device ); { trace::Block trace_block("blas::batch::trsm"); @@ -295,35 +232,37 @@ void trsm(internal::TargetType, std::vector opA_(1, opA ); std::vector diag_(1, diagA); std::vector alpha_(1, alpha); + // info size 0 disables slow checks in batched BLAS++. std::vector info; blas::Queue* queue = B.compute_queue(device, queue_index); assert(queue != nullptr); - if (a_array0.size() > 0) { - std::vector m(1, mb0); - std::vector n(1, nb0); - std::vector lda(1, lda0); - std::vector ldb(1, ldb0); - blas::batch::trsm( - layout, side_, uplo_, opA_, diag_, - m, n, - alpha_, a_array0, lda, - b_array0, ldb, - a_array0.size(), info, *queue); - } + for (size_t g = 0; g < group_params.size(); ++g) { + + int64_t group_count = group_params[ g ].count; + + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddb(1, group_params[ g ].ld[0]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + + if (B.op() != Op::NoTrans) { + swap(m, n); + } - if (a_array1.size() > 0) { - std::vector m(1, mb1); - std::vector n(1, nb1); - std::vector lda(1, lda1); - std::vector ldb(1, ldb1); blas::batch::trsm( layout, side_, uplo_, opA_, diag_, m, n, - alpha_, a_array1, lda, - b_array1, ldb, - a_array1.size(), info, *queue); + alpha_, a_array, ldda, + b_array, lddb, + group_count, info, *queue); + + a_array_host += group_count; + b_array_host += group_count; } queue->sync(); From 506fbabee7cc6010bdf770694c1e0f4dee88db03 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 12:09:46 -0400 Subject: [PATCH 12/35] FIXUP bcast in region builder --- src/internal/internal_util.hh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 72d86cb90..4491ba17b 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -301,7 +301,7 @@ std::vector< device_regions_params > device_regions_build( // Add tiles to current group // This logic matches that of above for (int m = 0; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( ij, ij, device ); + auto Mij = mats[ m ].get()( ij*i_step[m], ij*j_step[m], device ); mats_array_host[ m ][ batch_count ] = Mij.data(); if (group.count == 0) { group.ld[m] = Mij.stride(); From d1cdc79c778db7c32d99a7a16d796724c9967765 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 13:58:46 -0400 Subject: [PATCH 13/35] Add device regions to herk and her2k --- src/internal/internal_her2k.cc | 418 ++++++++------------------------- src/internal/internal_herk.cc | 273 ++++++--------------- 2 files changed, 168 insertions(+), 523 deletions(-) diff --git a/src/internal/internal_her2k.cc b/src/internal/internal_her2k.cc index 8f4d4ec2b..a8d17a17b 100644 --- a/src/internal/internal_her2k.cc +++ b/src/internal/internal_her2k.cc @@ -9,6 +9,7 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -538,23 +539,20 @@ void her2k(internal::TargetType, Op opB = (opA == Op::NoTrans ? Op::ConjTrans : Op::NoTrans); - std::set A_tiles_gemm, B_tiles_gemm, C_tiles_gemm; - std::set A_tiles_her2k, B_tiles_her2k, C_tiles_her2k; + std::set A_tiles_set, B_tiles_set, C_tiles_set; for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower if (C.tileIsLocal(i, j) && device == C.tileDevice(i, j)) { + + A_tiles_set.insert({j, 0}); + B_tiles_set.insert({j, 0}); + C_tiles_set.insert({i, j}); if (i == j) { - A_tiles_her2k.insert({j, 0}); - B_tiles_her2k.insert({j, 0}); - C_tiles_her2k.insert({i, j}); } else { - A_tiles_gemm.insert({i, 0}); - A_tiles_gemm.insert({j, 0}); - B_tiles_gemm.insert({i, 0}); - B_tiles_gemm.insert({j, 0}); - C_tiles_gemm.insert({i, j}); + A_tiles_set.insert({i, 0}); + B_tiles_set.insert({i, 0}); } } } @@ -563,351 +561,131 @@ void her2k(internal::TargetType, #pragma omp taskgroup { #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_gemm ) \ + shared( A, A_tiles_set ) \ firstprivate( device, layout ) { - A.tileGetForReading(A_tiles_gemm, device, LayoutConvert(layout)); + A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( B, B_tiles_gemm ) \ + shared( B, B_tiles_set ) \ firstprivate( device, layout ) { - B.tileGetForReading(B_tiles_gemm, device, LayoutConvert(layout)); + B.tileGetForReading(B_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_gemm ) \ + shared( C, C_tiles_set ) \ firstprivate( device, layout ) { - C.tileGetForWriting(C_tiles_gemm, device, LayoutConvert(layout)); - } - } - - int64_t batch_size_gemm = C_tiles_gemm.size(); - - //---------------------------------------- - // A * B^T - // interior - std::vector a_array_gemm00; - std::vector b_array_gemm00; - std::vector c_array_gemm00; - a_array_gemm00.reserve( batch_size_gemm ); - b_array_gemm00.reserve( batch_size_gemm ); - c_array_gemm00.reserve( batch_size_gemm ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb(0); - int64_t nb00 = C.tileNb(0); - int64_t kb = A.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm00.push_back( A(i, 0, device).data() ); - b_array_gemm00.push_back( B(j, 0, device).data() ); - c_array_gemm00.push_back( C(i, j, device).data() ); - lda00 = A(i, 0, device).stride(); - ldb00 = B(j, 0, device).stride(); - ldc00 = C(i, j, device).stride(); - } - } - } - - // bottom row - std::vector a_array_gemm10; - std::vector b_array_gemm10; - std::vector c_array_gemm10; - a_array_gemm10.reserve( batch_size_gemm ); - b_array_gemm10.reserve( batch_size_gemm ); - c_array_gemm10.reserve( batch_size_gemm ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb(C.mt()-1); - int64_t nb10 = C.tileNb(0); - // same kb as above - { - int64_t i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm10.push_back( A(i, 0, device).data() ); - b_array_gemm10.push_back( B(j, 0, device).data() ); - c_array_gemm10.push_back( C(i, j, device).data() ); - lda10 = A(i, 0, device).stride(); - ldb10 = B(j, 0, device).stride(); - ldc10 = C(i, j, device).stride(); - } + C.tileGetForWriting(C_tiles_set, device, LayoutConvert(layout)); } } - if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n - swap(opA, opB); - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - swap(mb00, nb00); - swap(mb10, nb10); - } - - std::vector opA_(1, opA); - std::vector opB_(1, opB); - std::vector k(1, kb); - std::vector info; - - blas::Queue* queue = C.compute_queue(device, queue_index); - - { - trace::Block trace_block("blas::batch::gemm"); - - std::vector alpha_(1, alpha); - std::vector beta_(1, scalar_t(beta)); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm00, ldda, - b_array_gemm00, lddb, - beta_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } + int64_t batch_size = C_tiles_set.size(); - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm10, ldda, - b_array_gemm10, lddb, - beta_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } - } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; - //---------------------------------------- - // B * A^T - // ai => bi, bj => aj, set beta = 1 + // There are only 3 batch arrays + std::vector t_array_vect( 2*batch_size ); + scalar_t** at_array_host = t_array_vect.data(); + scalar_t** bt_array_host = at_array_host + batch_size; - a_array_gemm00.clear(); - b_array_gemm00.clear(); - a_array_gemm10.clear(); - b_array_gemm10.clear(); + // Use transposed A and B to broadcast correctly + auto AT = conj_transpose(A); + auto BT = conj_transpose(B); - // interior - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm00.push_back( A(j, 0, device).data() ); - b_array_gemm00.push_back( B(i, 0, device).data() ); - lda00 = A(j, 0, device).stride(); - ldb00 = B(i, 0, device).stride(); - } - } - } + // C comes first since we do computation for a local C + auto group_params = device_regions_build( + {C, A, AT, BT, B}, + {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, + device ); - // bottom row - { - int i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm10.push_back( A(j, 0, device).data() ); - b_array_gemm10.push_back( B(i, 0, device).data() ); - lda10 = A(j, 0, device).stride(); - ldb10 = B(i, 0, device).stride(); - } - } - } if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n - //swap(opA, opB); // already done above - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - //swap(mb00, nb00); // already done above - //swap(mb10, nb10); // already done above - } - - { - trace::Block trace_block("blas::batch::gemm"); - - std::vector conj_alpha_(1, conj(alpha)); - std::vector one_( 1, one ); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - conj_alpha_, b_array_gemm00, lddb, - a_array_gemm00, ldda, - one_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } - - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - conj_alpha_, b_array_gemm10, lddb, - a_array_gemm10, ldda, - one_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } + swap(opA, opB); } - #pragma omp taskgroup { - #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_her2k ) \ - firstprivate( device, layout ) - { - A.tileGetForReading(A_tiles_her2k, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( B, B_tiles_her2k ) \ - firstprivate( device, layout ) - { - B.tileGetForReading(B_tiles_her2k, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_her2k ) \ - firstprivate( device, layout ) - { - C.tileGetForWriting(C_tiles_her2k, device, LayoutConvert(layout)); - } - } - - int64_t batch_size_her2k = C_tiles_her2k.size(); - - // diagonal - std::vector a_array_her2k_0; - std::vector b_array_her2k_0; - std::vector c_array_her2k_0; - a_array_her2k_0.reserve( batch_size_her2k ); - b_array_her2k_0.reserve( batch_size_her2k ); - c_array_her2k_0.reserve( batch_size_her2k ); - - int64_t lda_her2k_0 = 0; - int64_t ldb_her2k_0 = 0; - int64_t ldc_her2k_0 = 0; - - int64_t nb_her2k_0 = C.tileNb(0); - - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(j, j) - && device == C.tileDevice(j, j)) - { - a_array_her2k_0.push_back( A(j, 0, device).data() ); - b_array_her2k_0.push_back( B(j, 0, device).data() ); - c_array_her2k_0.push_back( C(j, j, device).data() ); - lda_her2k_0 = A(j, 0, device).stride(); - ldb_her2k_0 = B(j, 0, device).stride(); - ldc_her2k_0 = C(j, j, device).stride(); - } - } + trace::Block trace_block("blas::batch::her2k"); - // bottom-right corner - // todo: replace batch with plain call - std::vector a_array_her2k_1; - std::vector b_array_her2k_1; - std::vector c_array_her2k_1; + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector k(1, A.tileNb(0)); + std::vector info; - int64_t lda_her2k_1 = 0; - int64_t ldb_her2k_1 = 0; - int64_t ldc_her2k_1 = 0; + std::vector alpha_s(1, alpha); + std::vector conj_alpha_s(1, conj(alpha)); + auto& alpha_her2k = (C.op() == Op::NoTrans) ? alpha_s : conj_alpha_s; + std::vector beta_s(1, scalar_t(beta)); + std::vector beta_r(1, real_t(beta)); + std::vector one_( 1, one ); + std::vector uplo(1, C.uploPhysical()); - int64_t nb_her2k_1 = C.tileNb(C.nt()-1); + blas::Queue* queue = C.compute_queue(device, queue_index); - { - int j = C.nt()-1; - if (C.tileIsLocal(j, j) - && device == C.tileDevice(j, j)) - { - a_array_her2k_1.push_back( A(j, 0, device).data() ); - b_array_her2k_1.push_back( B(j, 0, device).data() ); - c_array_her2k_1.push_back( C(j, j, device).data() ); - lda_her2k_1 = A(j, 0, device).stride(); - ldb_her2k_1 = B(j, 0, device).stride(); - ldc_her2k_1 = C(j, j, device).stride(); - } - } + for (size_t g = 0; g < group_params.size(); ++g) { - { - trace::Block trace_block("blas::batch::her2k"); + int64_t group_count = group_params[ g ].count; - std::vector uplo(1, C.uploPhysical()); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddb(1, group_params[ g ].ld[3]); + std::vector lddc(1, group_params[ g ].ld[0]); + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - if (C.op() != Op::NoTrans) { - alpha = conj(alpha); - } + if (group_params[ g ].is_diagonal) { + blas::batch::her2k( + layout, uplo, opA_, + n, k, + alpha_her2k, a_array, ldda, + b_array, lddb, + beta_r, c_array, lddc, + group_count, info, *queue); + } + else { + std::vector m(1, group_params[ g ].mb); + std::vector lddat(1, group_params[ g ].ld[2]); + std::vector lddbt(1, group_params[ g ].ld[4]); + std::vector at_array(at_array_host, at_array_host+group_count); + std::vector bt_array(bt_array_host, bt_array_host+group_count); + + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(at_array, bt_array); + swap(ldda, lddb); + swap(lddat, lddbt); + } - std::vector alpha_(1, alpha); - std::vector beta_(1, beta); - - if (c_array_her2k_0.size() > 0) { - std::vector n(1, nb_her2k_0); - std::vector ldda(1, lda_her2k_0); - std::vector lddb(1, ldb_her2k_0); - std::vector lddc(1, ldc_her2k_0); - blas::batch::her2k( - layout, uplo, opA_, - n, k, - alpha_, a_array_her2k_0, ldda, - b_array_her2k_0, lddb, - beta_, c_array_her2k_0, lddc, - c_array_her2k_0.size(), info, *queue); + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + alpha_s, a_array, ldda, + b_array, lddb, + beta_s, c_array, lddc, + group_count, info, *queue); + + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + conj_alpha_s, bt_array, lddbt, + at_array, lddat, + one_, c_array, lddc, + group_count, info, *queue); + } + a_array_host += group_count; + at_array_host += group_count; + b_array_host += group_count; + bt_array_host += group_count; + c_array_host += group_count; } - if (c_array_her2k_1.size() > 0) { - std::vector n(1, nb_her2k_1); - std::vector ldda(1, lda_her2k_1); - std::vector lddb(1, ldb_her2k_1); - std::vector lddc(1, ldc_her2k_1); - blas::batch::her2k( - layout, uplo, opA_, - n, k, - alpha_, a_array_her2k_1, ldda, - b_array_her2k_1, lddb, - beta_, c_array_her2k_1, lddc, - c_array_her2k_1.size(), info, *queue); - } + queue->sync(); } - queue->sync(); - if (call_tile_tick) { for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower diff --git a/src/internal/internal_herk.cc b/src/internal/internal_herk.cc index 9c6f2c300..2ee78b5ab 100644 --- a/src/internal/internal_herk.cc +++ b/src/internal/internal_herk.cc @@ -9,6 +9,7 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" namespace slate { namespace internal { @@ -461,20 +462,15 @@ void herk(internal::TargetType, Op opB = (opA == Op::NoTrans ? Op::ConjTrans : Op::NoTrans); - std::set A_tiles_gemm, C_tiles_gemm; - std::set A_tiles_herk, C_tiles_herk; + std::set A_tiles_set, C_tiles_set; for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower if (C.tileIsLocal(i, j) && device == C.tileDevice(i, j)) { - if (i == j) { - A_tiles_herk.insert({j, 0}); - C_tiles_herk.insert({j, j}); - } - else { - A_tiles_gemm.insert({i, 0}); - A_tiles_gemm.insert({j, 0}); - C_tiles_gemm.insert({i, j}); + A_tiles_set.insert({j, 0}); + C_tiles_set.insert({i, j}); + if (i != j) { + A_tiles_set.insert({i, 0}); } } } @@ -483,228 +479,99 @@ void herk(internal::TargetType, #pragma omp taskgroup { #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_gemm ) \ + shared( A, A_tiles_set ) \ firstprivate(device, layout) { - A.tileGetForReading(A_tiles_gemm, device, LayoutConvert(layout)); + A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_gemm ) \ + shared( C, C_tiles_set ) \ firstprivate(device, layout) { - C.tileGetForWriting(C_tiles_gemm, device, LayoutConvert(layout)); + C.tileGetForWriting(C_tiles_set, device, LayoutConvert(layout)); } } - int64_t batch_size_gemm = C_tiles_gemm.size(); - - // interior - std::vector a_array_gemm00; - std::vector b_array_gemm00; - std::vector c_array_gemm00; - a_array_gemm00.reserve( batch_size_gemm ); - b_array_gemm00.reserve( batch_size_gemm ); - c_array_gemm00.reserve( batch_size_gemm ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb(0); - int64_t nb00 = C.tileNb(0); - int64_t kb = A.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_gemm00.push_back( A(i, 0, device).data() ); - b_array_gemm00.push_back( A(j, 0, device).data() ); - c_array_gemm00.push_back( C(i, j, device).data() ); - lda00 = A(i, 0, device).stride(); - ldb00 = A(j, 0, device).stride(); - ldc00 = C(i, j, device).stride(); - } - } - } - } + int64_t batch_size = C_tiles_set.size(); - // bottom row - std::vector a_array_gemm10; - std::vector b_array_gemm10; - std::vector c_array_gemm10; - a_array_gemm10.reserve( batch_size_gemm ); - b_array_gemm10.reserve( batch_size_gemm ); - c_array_gemm10.reserve( batch_size_gemm ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb(C.mt()-1); - int64_t nb10 = C.tileNb(0); - // same kb as above - { - int64_t i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_gemm10.push_back( A(i, 0, device).data() ); - b_array_gemm10.push_back( A(j, 0, device).data() ); - c_array_gemm10.push_back( C(i, j, device).data() ); - lda10 = A(i, 0, device).stride(); - ldb10 = A(j, 0, device).stride(); - ldc10 = C(i, j, device).stride(); - } - } - } - } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; + + // Use transposed A to broadcast down the rows correctly + auto AT = conj_transpose(A); + + // C comes first since we do computation for a local C + auto group_params = device_regions_build( + {C, A, AT}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n swap(opA, opB); - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - swap(mb00, nb00); - swap(mb10, nb10); } - std::vector opA_(1, opA); - std::vector opB_(1, opB); - std::vector k(1, kb); - std::vector info; - - blas::Queue* queue = C.compute_queue(device, queue_index); - { - trace::Block trace_block("blas::batch::gemm"); - - std::vector alpha_(1, scalar_t(alpha)); - std::vector beta_ (1, scalar_t(beta)); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm00, ldda, - b_array_gemm00, lddb, - beta_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } + trace::Block trace_block("blas::batch::herk"); - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm10, ldda, - b_array_gemm10, lddb, - beta_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } - } + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector k(1, A.tileNb(0)); + std::vector info; - #pragma omp taskgroup - { - #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_herk ) \ - firstprivate(device, layout) - { - A.tileGetForReading(A_tiles_herk, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_herk ) \ - firstprivate(device, layout) - { - C.tileGetForWriting(C_tiles_herk, device, LayoutConvert(layout)); - } - } + std::vector alpha_r(1, alpha); + std::vector beta_r (1, beta); + std::vector alpha_s(1, scalar_t(alpha)); + std::vector beta_s (1, scalar_t(beta)); + std::vector uplo(1, C.uploPhysical()); - int64_t batch_size_herk = C_tiles_herk.size(); + blas::Queue* queue = C.compute_queue(device, queue_index); - // diagonal - std::vector a_array_herk0; - std::vector c_array_herk0; - a_array_herk0.reserve( batch_size_herk ); - c_array_herk0.reserve( batch_size_herk ); + for (size_t g = 0; g < group_params.size(); ++g) { - int64_t lda_herk_0 = 0; - int64_t ldc_herk_0 = 0; - int64_t nb_herk_0 = C.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(j, j) - && device == C.tileDevice(j, j)) - { - a_array_herk0.push_back( A(j, 0, device).data() ); - c_array_herk0.push_back( C(j, j, device).data() ); - lda_herk_0 = A(j, 0, device).stride(); - ldc_herk_0 = C(j, j, device).stride(); - } - } + int64_t group_count = group_params[ g ].count; - // bottom-right corner - // todo: replace batch herk with plain herk - std::vector a_array_herk1; - std::vector c_array_herk1; + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddc(1, group_params[ g ].ld[0]); + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - int64_t lda_herk_1 = 0; - int64_t ldc_herk_1 = 0; - int64_t nb_herk_1 = C.tileNb(C.nt()-1); - { - int i = C.mt()-1; - int j = C.nt()-1; - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_herk1.push_back( A(j, 0, device).data() ); - c_array_herk1.push_back( C(j, j, device).data() ); - lda_herk_1 = A(j, 0, device).stride(); - ldc_herk_1 = C(j, j, device).stride(); + if (group_params[ g ].is_diagonal) { + blas::batch::herk( + layout, uplo, opA_, + n, k, + alpha_r, a_array, ldda, + beta_r, c_array, lddc, + group_count, info, *queue); } - } - } + else { + std::vector m(1, group_params[ g ].mb); + std::vector lddb(1, group_params[ g ].ld[2]); - { - trace::Block trace_block("blas::batch::herk"); + std::vector b_array(b_array_host, b_array_host+group_count); - std::vector uplo(1, C.uploPhysical()); - std::vector alpha_(1, alpha); - std::vector beta_ (1, beta); - - if (c_array_herk0.size() > 0) { - std::vector n(1, nb_herk_0); - std::vector ldda(1, lda_herk_0); - std::vector lddc(1, ldc_herk_0); - blas::batch::herk( - layout, uplo, opA_, - n, k, - alpha_, a_array_herk0, ldda, - beta_, c_array_herk0, lddc, - c_array_herk0.size(), info, *queue); - } + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(ldda, lddb); + } - if (c_array_herk1.size() > 0) { - std::vector n(1, nb_herk_1); - std::vector ldda(1, lda_herk_1); - std::vector lddc(1, ldc_herk_1); - blas::batch::herk( - layout, uplo, opA_, - n, k, - alpha_, a_array_herk1, ldda, - beta_, c_array_herk1, lddc, - c_array_herk1.size(), info, *queue); + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + alpha_s, a_array, ldda, + b_array, lddb, + beta_s, c_array, lddc, + group_count, info, *queue); + } + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } - } - queue->sync(); + queue->sync(); + } if (tile_release_strategy == TileReleaseStrategy::Internal || tile_release_strategy == TileReleaseStrategy::All) { From b942591a3c927aa396d6f9f0c158657977fa2926 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 14:09:37 -0400 Subject: [PATCH 14/35] Move device region setup code to internal_batch.hh --- src/internal/internal_batch.hh | 223 ++++++++++++++++++++++- src/internal/internal_geadd.cc | 1 - src/internal/internal_gecopy.cc | 1 - src/internal/internal_gemm.cc | 1 - src/internal/internal_gescale.cc | 1 - src/internal/internal_gescale_row_col.cc | 1 - src/internal/internal_geset.cc | 1 - src/internal/internal_her2k.cc | 1 - src/internal/internal_herk.cc | 1 - src/internal/internal_trmm.cc | 2 +- src/internal/internal_trsm.cc | 1 - src/internal/internal_tzadd.cc | 1 - src/internal/internal_tzcopy.cc | 1 - src/internal/internal_tzscale.cc | 1 - src/internal/internal_tzset.cc | 1 - src/internal/internal_util.hh | 216 ---------------------- 16 files changed, 223 insertions(+), 231 deletions(-) diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index 2a52a154c..9805a067a 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -5,12 +5,17 @@ //------------------------------------------------------------------------------ /// @file +/// Provides various helper functions for batched routines. +/// /// Provides simple precision-independent wrappers around MKL batch /// routines. Eventually to be replaced by BLAS++ batch routines. +/// +/// Provides routines to build the batch regions for device batched kernels. #ifndef SLATE_INTERNAL_BATCH_HH #define SLATE_INTERNAL_BATCH_HH #include "slate/Exception.hh" +#include "slate/BaseMatrix.hh" #include @@ -146,7 +151,223 @@ inline void cblas_gemm_batch( } #endif // BLAS_HAVE_MKL -} // namespace slate + +// Utilities for computing device batch regions + +//------------------------------------------------------------------------------ +/// Computes the range of tiles with either the same mb or the same nb +/// +/// @param[in] want_rows +/// If true, compute the row-ranges. Else, compute the column-ranges. +/// +/// @param[in] A +/// The matrix to get tile sizes from +/// +/// @return The ranges of uniform tile sizes +/// +template +std::vector device_regions_range( bool want_rows, BaseMatrix& A ) +{ + int64_t kt = want_rows ? A.mt() : A.nt(); + + std::vector< int64_t > range; + int64_t last = -1; + for (int64_t k = 0; k < kt; ++k) { + int64_t kb = want_rows ? A.tileMb( k ) : A.tileNb( k ); + if (kb != last) { + last = kb; + range.push_back( k ); + } + } + range.push_back( kt ); + return range; +} + +//------------------------------------------------------------------------------ +/// Helper class to store the information on a device region. +/// +/// @tparam has_diag +/// Wheather the diagonal tiles may need to be special cased +/// +/// @tparam mat_count +/// The number of matrices used by the kernel +/// +template< bool has_diag, int mat_count > +struct device_regions_params { + int64_t count, mb, nb; + int64_t ld[mat_count]; + +private: + // When has_diag is false, we don't want to allocate any memory for is_diagonal + struct Empty {}; +public: + std::conditional_t< has_diag, bool, Empty > is_diagonal; + + device_regions_params() + : count(0), mb(0), nb(0) + { + for (int i = 0; i < mat_count; ++i) { + ld[i] = 0; + } + if constexpr (has_diag) { + is_diagonal = false; + } + } +}; + +//------------------------------------------------------------------------------ +/// Computes and populates the regions for the given matrices. +/// +/// @tparam has_diag +/// Wheather the diagonal tiles may need to be special cased +/// +/// @tparam mat_count +/// The number of matrices used by the kernel +/// +/// @param[in] mats +/// An array of the matrices to build regions for +/// +/// @param[in] mats_array_host +/// An array of the arrays to fill with pointers to device data +/// +/// @param[in] device +/// The device to build regions for +/// +/// @param[in] diag_same +/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag +/// Ignored when has_diag is false. +/// +/// @param[in] extra_setup +/// Callback that is called whenever a tile is added to a group. +/// The group index and the tile indices are passed as arguments +/// +template< bool has_diag, int mat_count, typename scalar_t> +std::vector< device_regions_params > device_regions_build( + std::array< std::reference_wrapper>, mat_count > mats, + std::array< scalar_t**, mat_count > mats_array_host, + int64_t device, + bool diag_same = true, + std::function extra_setup = {}) +{ + // The first two arguments should be valid targets for brace-initialization + // reference_wrapper works around fact that C++ doesn't allow array of references + + using Params = device_regions_params; + + auto& A = mats[0].get(); + + // Find ranges of matching mb's and ranges of matching nb's. + std::vector< int64_t > irange = device_regions_range( true, A ); + std::vector< int64_t > jrange = device_regions_range( false, A ); + + // Trapezoidal matrices always need special treatment for diagonal tiles + diag_same &= A.uplo() == Uplo::General; + + // Can't treat diagonals special when we can't store the diagonal status + assert( diag_same || has_diag ); + diag_same |= !has_diag; // Ensure the compiler can propagate this assertion + + // Size 1 dimensions get broadcast to allow setting up GEMM et al. + // i_step[m]=0 results in only accessing row 0 of matrix m (likewise for j) + // The first matrix is always indexed normally since it determines the loops + int64_t i_step[mat_count]; + int64_t j_step[mat_count]; + i_step[0] = 1; + j_step[0] = 1; + for (int m = 1; m < mat_count; ++m) { + i_step[m] = (mats[ m ].get().mt() > 1); + j_step[m] = (mats[ m ].get().nt() > 1); + } + + int64_t batch_count = 0; + int64_t mt = A.mt(); + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region. If any should be computed on this + // process & device, save them. + Params group; + group.mb = A.tileMb( irange[ ii ] ); + group.nb = A.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + // This is a column major loop. So, + // * Lower matrices start at j+1 + // * Upper matrices end at j + // * General matrices run the whole range + int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); + int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); + for (int64_t i = istart; i < iend; ++i) { + if ((!has_diag || diag_same || i != j) + && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + + // Add tiles to current group + for (int m = 0; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( i*i_step[m], j*j_step[m], device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + if (extra_setup) { + extra_setup( group_params.size(), i, j ); + } + ++group.count; + ++batch_count; + } + } // for i + } // for j + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); + } + + // If the diagonal tiles need special treatment, build those groups + if constexpr (has_diag) if (!diag_same) { + // Loop over the diagonal tiles in this region. If any should be + // computed on this process & device, save them. + group = Params(); + group.is_diagonal = true; + group.mb = A.tileMb( irange[ ii ] ); + group.nb = A.tileNb( jrange[ jj ] ); + // Diagonal tiles only in the intersection of irange and jrange + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) { + + // Add tiles to current group + // This logic matches that of above + for (int m = 0; m < mat_count; ++m) { + auto Mij = mats[ m ].get()( ij*i_step[m], ij*j_step[m], device ); + mats_array_host[ m ][ batch_count ] = Mij.data(); + if (group.count == 0) { + group.ld[m] = Mij.stride(); + } + else { + assert( group.ld[m] == Mij.stride() ); + } + } + if (extra_setup) { + extra_setup( group_params.size(), ij, ij ); + } + ++group.count; + ++batch_count; + } + } // for ij + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); + } + } // if has_diag && !diag_same + }} // for jj, ii + return group_params; +} + } // namespace internal +} // namespace slate #endif // SLATE_INTERNAL_BATCH_HH diff --git a/src/internal/internal_geadd.cc b/src/internal/internal_geadd.cc index dea47b805..3bbd249e1 100644 --- a/src/internal/internal_geadd.cc +++ b/src/internal/internal_geadd.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index 9bc6f34d4..4738e1b5d 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -11,7 +11,6 @@ #include "slate/Tile_blas.hh" #include "slate/Tile_aux.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { diff --git a/src/internal/internal_gemm.cc b/src/internal/internal_gemm.cc index b348b7b9b..a73e25f42 100644 --- a/src/internal/internal_gemm.cc +++ b/src/internal/internal_gemm.cc @@ -8,7 +8,6 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_gescale.cc b/src/internal/internal_gescale.cc index f16f9a427..f46d92278 100644 --- a/src/internal/internal_gescale.cc +++ b/src/internal/internal_gescale.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index 7affbd9ac..75d8be0f3 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -11,7 +11,6 @@ #include "slate/Matrix.hh" #include "slate/types.hh" #include "tile/scale_row_col.hh" -#include "internal/internal_util.hh" namespace slate { diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 9cba4f683..5c264ab3f 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_her2k.cc b/src/internal/internal_her2k.cc index a8d17a17b..65bf31f1b 100644 --- a/src/internal/internal_her2k.cc +++ b/src/internal/internal_her2k.cc @@ -9,7 +9,6 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_herk.cc b/src/internal/internal_herk.cc index 2ee78b5ab..568c5354c 100644 --- a/src/internal/internal_herk.cc +++ b/src/internal/internal_herk.cc @@ -9,7 +9,6 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_trmm.cc b/src/internal/internal_trmm.cc index 8f01a116e..205a281c4 100644 --- a/src/internal/internal_trmm.cc +++ b/src/internal/internal_trmm.cc @@ -8,7 +8,7 @@ #include "slate/types.hh" #include "slate/Tile_blas.hh" #include "internal/internal.hh" -#include "internal/internal_util.hh" +#include "internal/internal_batch.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_trsm.cc b/src/internal/internal_trsm.cc index bc4dffcf4..b12359d1d 100644 --- a/src/internal/internal_trsm.cc +++ b/src/internal/internal_trsm.cc @@ -9,7 +9,6 @@ #include "slate/Tile_blas.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index 0cf703c67..328ad0632 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index 7ea8079ba..5132cf5fe 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -11,7 +11,6 @@ #include "slate/Tile_blas.hh" #include "slate/Tile_aux.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index ab9747f1e..0d836417e 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index eee300295..981c4dda1 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -10,7 +10,6 @@ #include "slate/Matrix.hh" #include "internal/Tile_lapack.hh" #include "slate/types.hh" -#include "internal/internal_util.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 4491ba17b..11b080067 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -11,7 +11,6 @@ #include "slate/internal/mpi.hh" #include "slate/Matrix.hh" -#include "slate/BaseTrapezoidMatrix.hh" #include #include @@ -111,221 +110,6 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, } -// Utilities for device batch regions - -//------------------------------------------------------------------------------ -/// Computes the range of tiles with either the same mb or the same nb -/// -/// @param[in] want_rows -/// If true, compute the row-ranges. Else, compute the column-ranges. -/// -/// @param[in] A -/// The matrix to get tile sizes from -/// -/// @return The ranges of uniform tile sizes -/// -template -std::vector device_regions_range( bool want_rows, BaseMatrix& A ) -{ - int64_t kt = want_rows ? A.mt() : A.nt(); - - std::vector< int64_t > range; - int64_t last = -1; - for (int64_t k = 0; k < kt; ++k) { - int64_t kb = want_rows ? A.tileMb( k ) : A.tileNb( k ); - if (kb != last) { - last = kb; - range.push_back( k ); - } - } - range.push_back( kt ); - return range; -} - -//------------------------------------------------------------------------------ -/// Helper class to store the information on a device region. -/// -/// @tparam has_diag -/// Wheather the diagonal tiles may need to be special cased -/// -/// @tparam mat_count -/// The number of matrices used by the kernel -/// -template< bool has_diag, int mat_count > -struct device_regions_params { - int64_t count, mb, nb; - int64_t ld[mat_count]; - -private: - // When has_diag is false, we don't want to allocate any memory for is_diagonal - struct Empty {}; -public: - std::conditional_t< has_diag, bool, Empty > is_diagonal; - - device_regions_params() - : count(0), mb(0), nb(0) - { - for (int i = 0; i < mat_count; ++i) { - ld[i] = 0; - } - if constexpr (has_diag) { - is_diagonal = false; - } - } -}; - -//------------------------------------------------------------------------------ -/// Computes and populates the regions for the given matrices. -/// -/// @tparam has_diag -/// Wheather the diagonal tiles may need to be special cased -/// -/// @tparam mat_count -/// The number of matrices used by the kernel -/// -/// @param[in] mats -/// An array of the matrices to build regions for -/// -/// @param[in] mats_array_host -/// An array of the arrays to fill with pointers to device data -/// -/// @param[in] device -/// The device to build regions for -/// -/// @param[in] diag_same -/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag -/// Ignored when has_diag is false. -/// -/// @param[in] extra_setup -/// Callback that is called whenever a tile is added to a group. -/// The group index and the tile indices are passed as arguments -/// -template< bool has_diag, int mat_count, typename scalar_t> -std::vector< device_regions_params > device_regions_build( - std::array< std::reference_wrapper>, mat_count > mats, - std::array< scalar_t**, mat_count > mats_array_host, - int64_t device, - bool diag_same = true, - std::function extra_setup = {}) -{ - // The first two arguments should be valid targets for brace-initialization - // reference_wrapper works around fact that C++ doesn't allow array of references - - using Params = device_regions_params; - - auto& A = mats[0].get(); - - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - - // Trapezoidal matrices always need special treatment for diagonal tiles - diag_same &= A.uplo() == Uplo::General; - - // Can't treat diagonals special when we can't store the diagonal status - assert( diag_same || has_diag ); - diag_same |= !has_diag; // Ensure the compiler can propagate this assertion - - // Size 1 dimensions get broadcast to allow setting up GEMM et al. - // i_step[m]=0 results in only accessing row 0 of matrix m (likewise for j) - // The first matrix is always indexed normally since it determines the loops - int64_t i_step[mat_count]; - int64_t j_step[mat_count]; - i_step[0] = 1; - j_step[0] = 1; - for (int m = 1; m < mat_count; ++m) { - i_step[m] = (mats[ m ].get().mt() > 1); - j_step[m] = (mats[ m ].get().nt() > 1); - } - - int64_t batch_count = 0; - int64_t mt = A.mt(); - std::vector group_params; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - // Loop over the tiles in this region. If any should be computed on this - // process & device, save them. - Params group; - group.mb = A.tileMb( irange[ ii ] ); - group.nb = A.tileNb( jrange[ jj ] ); - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - // This is a column major loop. So, - // * Lower matrices start at j+1 - // * Upper matrices end at j - // * General matrices run the whole range - int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); - int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); - for (int64_t i = istart; i < iend; ++i) { - if ((!has_diag || diag_same || i != j) - && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - - // Add tiles to current group - for (int m = 0; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( i*i_step[m], j*j_step[m], device ); - mats_array_host[ m ][ batch_count ] = Mij.data(); - if (group.count == 0) { - group.ld[m] = Mij.stride(); - } - else { - assert( group.ld[m] == Mij.stride() ); - } - } - if (extra_setup) { - extra_setup( group_params.size(), i, j ); - } - ++group.count; - ++batch_count; - } - } // for i - } // for j - // If any tiles in the region should be computed here, save the group - if (group.count > 0) { - group_params.push_back( group ); - } - - // If the diagonal tiles need special treatment, build those groups - if constexpr (has_diag) if (!diag_same) { - // Loop over the diagonal tiles in this region. If any should be - // computed on this process & device, save them. - group = Params(); - group.is_diagonal = true; - group.mb = A.tileMb( irange[ ii ] ); - group.nb = A.tileNb( jrange[ jj ] ); - // Diagonal tiles only in the intersection of irange and jrange - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal( ij, ij ) - && device == A.tileDevice( ij, ij )) { - - // Add tiles to current group - // This logic matches that of above - for (int m = 0; m < mat_count; ++m) { - auto Mij = mats[ m ].get()( ij*i_step[m], ij*j_step[m], device ); - mats_array_host[ m ][ batch_count ] = Mij.data(); - if (group.count == 0) { - group.ld[m] = Mij.stride(); - } - else { - assert( group.ld[m] == Mij.stride() ); - } - } - if (extra_setup) { - extra_setup( group_params.size(), ij, ij ); - } - ++group.count; - ++batch_count; - } - } // for ij - // If any tiles in the region should be computed here, save the group - if (group.count > 0) { - group_params.push_back( group ); - } - } // if has_diag && !diag_same - }} // for jj, ii - return group_params; -} - } // namespace internal } // namespace slate From 6356d1d1919be55c6898d433172c6d32ac0b6faf Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 24 Oct 2023 15:05:39 -0400 Subject: [PATCH 15/35] Add device regions to syrk and syr2k --- src/internal/internal_syr2k.cc | 411 ++++++++------------------------- src/internal/internal_syrk.cc | 265 ++++++--------------- 2 files changed, 163 insertions(+), 513 deletions(-) diff --git a/src/internal/internal_syr2k.cc b/src/internal/internal_syr2k.cc index 467334ff0..f5f90a2a9 100644 --- a/src/internal/internal_syr2k.cc +++ b/src/internal/internal_syr2k.cc @@ -518,23 +518,20 @@ void syr2k(internal::TargetType, Op opB = (opA == Op::NoTrans ? Op::Trans : Op::NoTrans); - std::set A_tiles_gemm, B_tiles_gemm, C_tiles_gemm; - std::set A_tiles_syr2k, B_tiles_syr2k, C_tiles_syr2k; + std::set A_tiles_set, B_tiles_set, C_tiles_set; for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower if (C.tileIsLocal(i, j) && device == C.tileDevice(i, j)) { + + A_tiles_set.insert({j, 0}); + B_tiles_set.insert({j, 0}); + C_tiles_set.insert({i, j}); if (i == j) { - A_tiles_syr2k.insert({j, 0}); - B_tiles_syr2k.insert({j, 0}); - C_tiles_syr2k.insert({i, j}); } else { - A_tiles_gemm.insert({i, 0}); - A_tiles_gemm.insert({j, 0}); - B_tiles_gemm.insert({i, 0}); - B_tiles_gemm.insert({j, 0}); - C_tiles_gemm.insert({i, j}); + A_tiles_set.insert({i, 0}); + B_tiles_set.insert({i, 0}); } } } @@ -543,342 +540,128 @@ void syr2k(internal::TargetType, #pragma omp taskgroup { #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_gemm ) \ - firstprivate(device, layout) + shared( A, A_tiles_set ) \ + firstprivate( device, layout ) { - A.tileGetForReading(A_tiles_gemm, device, LayoutConvert(layout)); + A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( B, B_tiles_gemm ) \ - firstprivate(device, layout) + shared( B, B_tiles_set ) \ + firstprivate( device, layout ) { - B.tileGetForReading(B_tiles_gemm, device, LayoutConvert(layout)); + B.tileGetForReading(B_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_gemm ) \ - firstprivate(device, layout) + shared( C, C_tiles_set ) \ + firstprivate( device, layout ) { - C.tileGetForWriting(C_tiles_gemm, device, LayoutConvert(layout)); - } - } - - int64_t batch_size_gemm = C_tiles_gemm.size(); - - //---------------------------------------- - // A * B^T - // interior - std::vector a_array_gemm00; - std::vector b_array_gemm00; - std::vector c_array_gemm00; - a_array_gemm00.reserve( batch_size_gemm ); - b_array_gemm00.reserve( batch_size_gemm ); - c_array_gemm00.reserve( batch_size_gemm ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb(0); - int64_t nb00 = C.tileNb(0); - int64_t kb = A.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm00.push_back( A(i, 0, device).data() ); - b_array_gemm00.push_back( B(j, 0, device).data() ); - c_array_gemm00.push_back( C(i, j, device).data() ); - lda00 = A(i, 0, device).stride(); - ldb00 = B(j, 0, device).stride(); - ldc00 = C(i, j, device).stride(); - } - } - } - - // bottom row - std::vector a_array_gemm10; - std::vector b_array_gemm10; - std::vector c_array_gemm10; - a_array_gemm10.reserve( batch_size_gemm ); - b_array_gemm10.reserve( batch_size_gemm ); - c_array_gemm10.reserve( batch_size_gemm ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb(C.mt()-1); - int64_t nb10 = C.tileNb(0); - // same kb as above - { - int64_t i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm10.push_back( A(i, 0, device).data() ); - b_array_gemm10.push_back( B(j, 0, device).data() ); - c_array_gemm10.push_back( C(i, j, device).data() ); - lda10 = A(i, 0, device).stride(); - ldb10 = B(j, 0, device).stride(); - ldc10 = C(i, j, device).stride(); - } + C.tileGetForWriting(C_tiles_set, device, LayoutConvert(layout)); } } - if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n - swap(opA, opB); - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - swap(mb00, nb00); - swap(mb10, nb10); - } - - std::vector opA_(1, opA); - std::vector opB_(1, opB); - std::vector k(1, kb); - std::vector alpha_(1, alpha); - std::vector beta_(1, beta); - std::vector info; - - blas::Queue* queue = C.compute_queue(device, queue_index); - - { - trace::Block trace_block("blas::batch::gemm"); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm00, ldda, - b_array_gemm00, lddb, - beta_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } + int64_t batch_size = C_tiles_set.size(); - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm10, ldda, - b_array_gemm10, lddb, - beta_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } - } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; - //---------------------------------------- - // B * A^T - // ai => bi, bj => aj, set beta = 1 + // There are only 3 batch arrays + std::vector t_array_vect( 2*batch_size ); + scalar_t** at_array_host = t_array_vect.data(); + scalar_t** bt_array_host = at_array_host + batch_size; - a_array_gemm00.clear(); - b_array_gemm00.clear(); - a_array_gemm10.clear(); - b_array_gemm10.clear(); + // Use transposed A and B to broadcast correctly + auto AT = transpose(A); + auto BT = transpose(B); - // interior - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm00.push_back( A(j, 0, device).data() ); - b_array_gemm00.push_back( B(i, 0, device).data() ); - lda00 = A(j, 0, device).stride(); - ldb00 = B(i, 0, device).stride(); - } - } - } + // C comes first since we do computation for a local C + auto group_params = device_regions_build( + {C, A, AT, BT, B}, + {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, + device ); - // bottom row - { - int i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_gemm10.push_back( A(j, 0, device).data() ); - b_array_gemm10.push_back( B(i, 0, device).data() ); - lda10 = A(j, 0, device).stride(); - ldb10 = B(i, 0, device).stride(); - } - } - } if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n - //swap(opA, opB); // already done above - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - //swap(mb00, nb00); // already done above - //swap(mb10, nb10); // already done above - } - - { - trace::Block trace_block("blas::batch::gemm"); - std::vector one_( 1, one ); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, b_array_gemm00, lddb, - a_array_gemm00, ldda, - one_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } - - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, b_array_gemm10, lddb, - a_array_gemm10, ldda, - one_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } + swap(opA, opB); } - #pragma omp taskgroup { - #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_syr2k ) \ - firstprivate(device, layout) - { - A.tileGetForReading(A_tiles_syr2k, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( B, B_tiles_syr2k ) \ - firstprivate(device, layout) - { - B.tileGetForReading(B_tiles_syr2k, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_syr2k ) \ - firstprivate(device, layout) - { - C.tileGetForWriting(C_tiles_syr2k, device, LayoutConvert(layout)); - } - } - - int64_t batch_size_syr2k = C_tiles_syr2k.size(); - - // diagonal - std::vector a_array_syr2k_0; - std::vector b_array_syr2k_0; - std::vector c_array_syr2k_0; - a_array_syr2k_0.reserve( batch_size_syr2k ); - b_array_syr2k_0.reserve( batch_size_syr2k ); - c_array_syr2k_0.reserve( batch_size_syr2k ); - - int64_t lda_syr2k_0 = 0; - int64_t ldb_syr2k_0 = 0; - int64_t ldc_syr2k_0 = 0; + trace::Block trace_block("blas::batch::her2k"); - int64_t nb_syr2k_0 = C.tileNb(0); + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector k(1, A.tileNb(0)); + std::vector info; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(j, j) - && device == C.tileDevice(j, j)) - { - a_array_syr2k_0.push_back( A(j, 0, device).data() ); - b_array_syr2k_0.push_back( B(j, 0, device).data() ); - c_array_syr2k_0.push_back( C(j, j, device).data() ); - lda_syr2k_0 = A(j, 0, device).stride(); - ldb_syr2k_0 = B(j, 0, device).stride(); - ldc_syr2k_0 = C(j, j, device).stride(); - } - } + std::vector alpha_(1, alpha); + std::vector beta_(1, beta); + std::vector one_( 1, one ); + std::vector uplo(1, C.uploPhysical()); - // bottom-right corner - // todo: replace batch with plain call - std::vector a_array_syr2k_1; - std::vector b_array_syr2k_1; - std::vector c_array_syr2k_1; + blas::Queue* queue = C.compute_queue(device, queue_index); - int64_t lda_syr2k_1 = 0; - int64_t ldb_syr2k_1 = 0; - int64_t ldc_syr2k_1 = 0; + for (size_t g = 0; g < group_params.size(); ++g) { - int64_t nb_syr2k_1 = C.tileNb(C.nt()-1); + int64_t group_count = group_params[ g ].count; - { - int i = C.mt()-1; - int j = C.nt()-1; - if (C.tileIsLocal(i, j) - && device == C.tileDevice(i, j)) - { - a_array_syr2k_1.push_back( A(j, 0, device).data() ); - b_array_syr2k_1.push_back( B(j, 0, device).data() ); - c_array_syr2k_1.push_back( C(j, j, device).data() ); - lda_syr2k_1 = A(j, 0, device).stride(); - ldb_syr2k_1 = B(j, 0, device).stride(); - ldc_syr2k_1 = C(j, j, device).stride(); - } - } + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddb(1, group_params[ g ].ld[3]); + std::vector lddc(1, group_params[ g ].ld[0]); + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - { - trace::Block trace_block("blas::batch::syr2k"); - - std::vector uplo(1, C.uploPhysical()); + if (group_params[ g ].is_diagonal) { + blas::batch::syr2k( + layout, uplo, opA_, + n, k, + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + } + else { + std::vector m(1, group_params[ g ].mb); + std::vector lddat(1, group_params[ g ].ld[2]); + std::vector lddbt(1, group_params[ g ].ld[4]); + std::vector at_array(at_array_host, at_array_host+group_count); + std::vector bt_array(bt_array_host, bt_array_host+group_count); + + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(at_array, bt_array); + swap(ldda, lddb); + swap(lddat, lddbt); + } - if (c_array_syr2k_0.size() > 0) { - std::vector n(1, nb_syr2k_0); - std::vector ldda(1, lda_syr2k_0); - std::vector lddb(1, ldb_syr2k_0); - std::vector lddc(1, ldc_syr2k_0); - blas::batch::syr2k( - layout, uplo, opA_, - n, k, - alpha_, a_array_syr2k_0, ldda, - b_array_syr2k_0, lddb, - beta_, c_array_syr2k_0, lddc, - c_array_syr2k_0.size(), info, *queue); + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + alpha_, bt_array, lddbt, + at_array, lddat, + one_, c_array, lddc, + group_count, info, *queue); + } + a_array_host += group_count; + at_array_host += group_count; + b_array_host += group_count; + bt_array_host += group_count; + c_array_host += group_count; } - if (c_array_syr2k_1.size() > 0) { - std::vector n(1, nb_syr2k_1); - std::vector ldda(1, lda_syr2k_1); - std::vector lddb(1, ldb_syr2k_1); - std::vector lddc(1, ldc_syr2k_1); - blas::batch::syr2k( - layout, uplo, opA_, - n, k, - alpha_, a_array_syr2k_1, ldda, - b_array_syr2k_1, lddb, - beta_, c_array_syr2k_1, lddc, - c_array_syr2k_1.size(), info, *queue); - } + queue->sync(); } - queue->sync(); - if (call_tile_tick) { for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower diff --git a/src/internal/internal_syrk.cc b/src/internal/internal_syrk.cc index e289fe31d..0cf9dc585 100644 --- a/src/internal/internal_syrk.cc +++ b/src/internal/internal_syrk.cc @@ -458,20 +458,15 @@ void syrk(internal::TargetType, Op opB = (opA == Op::NoTrans ? Op::Trans : Op::NoTrans); - std::set A_tiles_gemm, C_tiles_gemm; - std::set A_tiles_syrk, C_tiles_syrk; + std::set A_tiles_set, C_tiles_set; for (int64_t j = 0; j < C.nt(); ++j) { for (int64_t i = j; i < C.mt(); ++i) { // lower if (C.tileIsLocal(i, j) && device == C.tileDevice(i, j)) { - if (i == j) { - A_tiles_syrk.insert({j, 0}); - C_tiles_syrk.insert({j, j}); - } - else { - A_tiles_gemm.insert({i, 0}); - A_tiles_gemm.insert({j, 0}); - C_tiles_gemm.insert({i, j}); + A_tiles_set.insert({j, 0}); + C_tiles_set.insert({i, j}); + if (i != j) { + A_tiles_set.insert({i, 0}); } } } @@ -480,226 +475,98 @@ void syrk(internal::TargetType, #pragma omp taskgroup { #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_gemm ) \ + shared( A, A_tiles_set ) \ firstprivate(device, layout) { - A.tileGetForReading(A_tiles_gemm, device, LayoutConvert(layout)); + A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); } #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_gemm ) \ + shared( C, C_tiles_set ) \ firstprivate(device, layout) { - C.tileGetForWriting(C_tiles_gemm, device, LayoutConvert(layout)); + C.tileGetForWriting(C_tiles_set, device, LayoutConvert(layout)); } } - int64_t batch_size_gemm = C_tiles_gemm.size(); - - // interior - std::vector a_array_gemm00; - std::vector b_array_gemm00; - std::vector c_array_gemm00; - a_array_gemm00.reserve( batch_size_gemm ); - b_array_gemm00.reserve( batch_size_gemm ); - c_array_gemm00.reserve( batch_size_gemm ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb(0); - int64_t nb00 = C.tileNb(0); - int64_t kb = A.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - // strictly lower - for (int64_t i = j+1; i < C.mt()-1; ++i) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_gemm00.push_back( A(i, 0, device).data() ); - b_array_gemm00.push_back( A(j, 0, device).data() ); - c_array_gemm00.push_back( C(i, j, device).data() ); - lda00 = A(i, 0, device).stride(); - ldb00 = A(j, 0, device).stride(); - ldc00 = C(i, j, device).stride(); - } - } - } - } + int64_t batch_size = C_tiles_set.size(); - // bottom row - std::vector a_array_gemm10; - std::vector b_array_gemm10; - std::vector c_array_gemm10; - a_array_gemm10.reserve( batch_size_gemm ); - b_array_gemm10.reserve( batch_size_gemm ); - c_array_gemm10.reserve( batch_size_gemm ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb(C.mt()-1); - int64_t nb10 = C.tileNb(0); - // same kb as above - { - int64_t i = C.mt()-1; - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_gemm10.push_back( A(i, 0, device).data() ); - b_array_gemm10.push_back( A(j, 0, device).data() ); - c_array_gemm10.push_back( C(i, j, device).data() ); - lda10 = A(i, 0, device).stride(); - ldb10 = A(j, 0, device).stride(); - ldc10 = C(i, j, device).stride(); - } - } - } - } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; + + // Use transposed A to broadcast down the rows correctly + auto AT = transpose(A); + + // C comes first since we do computation for a local C + auto group_params = device_regions_build( + {C, A, AT}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { - // swap A <=> B; swap m <=> n swap(opA, opB); - swap(a_array_gemm00, b_array_gemm00); - swap(a_array_gemm10, b_array_gemm10); - swap(lda00, ldb00); - swap(lda10, ldb10); - swap(mb00, nb00); - swap(mb10, nb10); } - std::vector opA_(1, opA); - std::vector opB_(1, opB); - std::vector k(1, kb); - std::vector alpha_(1, scalar_t(alpha)); - std::vector beta_ (1, scalar_t(beta)); - std::vector info; - - blas::Queue* queue = C.compute_queue(device, queue_index); - { - trace::Block trace_block("blas::batch::gemm"); - - if (c_array_gemm00.size() > 0) { - std::vector m(1, mb00); - std::vector n(1, nb00); - std::vector ldda(1, lda00); - std::vector lddb(1, ldb00); - std::vector lddc(1, ldc00); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm00, ldda, - b_array_gemm00, lddb, - beta_, c_array_gemm00, lddc, - c_array_gemm00.size(), info, *queue); - } + trace::Block trace_block("blas::batch::herk"); - if (c_array_gemm10.size() > 0) { - std::vector m(1, mb10); - std::vector n(1, nb10); - std::vector ldda(1, lda10); - std::vector lddb(1, ldb10); - std::vector lddc(1, ldc10); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array_gemm10, ldda, - b_array_gemm10, lddb, - beta_, c_array_gemm10, lddc, - c_array_gemm10.size(), info, *queue); - } - } + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector k(1, A.tileNb(0)); + std::vector info; - #pragma omp taskgroup - { - #pragma omp task slate_omp_default_none \ - shared( A, A_tiles_syrk ) \ - firstprivate(device, layout) - { - A.tileGetForReading(A_tiles_syrk, device, LayoutConvert(layout)); - } - #pragma omp task slate_omp_default_none \ - shared( C, C_tiles_syrk ) \ - firstprivate(device, layout) - { - C.tileGetForWriting(C_tiles_syrk, device, LayoutConvert(layout)); - } - } + std::vector alpha_(1, scalar_t(alpha)); + std::vector beta_ (1, scalar_t(beta)); + std::vector uplo(1, C.uploPhysical()); - int64_t batch_size_syrk = C_tiles_syrk.size(); + blas::Queue* queue = C.compute_queue(device, queue_index); - // diagonal - std::vector a_array_syrk0; - std::vector c_array_syrk0; - a_array_syrk0.reserve( batch_size_syrk ); - c_array_syrk0.reserve( batch_size_syrk ); + for (size_t g = 0; g < group_params.size(); ++g) { - int64_t lda_syrk_0 = 0; - int64_t ldc_syrk_0 = 0; - int64_t nb_syrk_0 = C.tileNb(0); - for (int64_t j = 0; j < C.nt()-1; ++j) { - if (C.tileIsLocal(j, j) - && device == C.tileDevice(j, j)) - { - a_array_syrk0.push_back( A(j, 0, device).data() ); - c_array_syrk0.push_back( C(j, j, device).data() ); - lda_syrk_0 = A(j, 0, device).stride(); - ldc_syrk_0 = C(j, j, device).stride(); - } - } + int64_t group_count = group_params[ g ].count; - // bottom-right corner - // todo: replace batch syrk with plain syrk - std::vector a_array_syrk1; - std::vector c_array_syrk1; + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[1]); + std::vector lddc(1, group_params[ g ].ld[0]); + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - int64_t lda_syrk_1 = 0; - int64_t ldc_syrk_1 = 0; - int64_t nb_syrk_1 = C.tileNb(C.nt()-1); - { - int i = C.mt()-1; - int j = C.nt()-1; - if (C.tileIsLocal(i, j)) { - if (device == C.tileDevice(i, j)) { - a_array_syrk1.push_back( A(j, 0, device).data() ); - c_array_syrk1.push_back( C(j, j, device).data() ); - lda_syrk_1 = A(j, 0, device).stride(); - ldc_syrk_1 = C(j, j, device).stride(); + if (group_params[ g ].is_diagonal) { + blas::batch::syrk( + layout, uplo, opA_, + n, k, + alpha_, a_array, ldda, + beta_, c_array, lddc, + group_count, info, *queue); } - } - } + else { + std::vector m(1, group_params[ g ].mb); + std::vector lddb(1, group_params[ g ].ld[2]); - { - trace::Block trace_block("blas::batch::syrk"); + std::vector b_array(b_array_host, b_array_host+group_count); - std::vector uplo(1, C.uploPhysical()); + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(ldda, lddb); + } - if (c_array_syrk0.size() > 0) { - std::vector n(1, nb_syrk_0); - std::vector ldda(1, lda_syrk_0); - std::vector lddc(1, ldc_syrk_0); - blas::batch::syrk( - layout, uplo, opA_, - n, k, - alpha_, a_array_syrk0, ldda, - beta_, c_array_syrk0, lddc, - c_array_syrk0.size(), info, *queue); + blas::batch::gemm( + layout, opA_, opB_, + m, n, k, + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + } + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } - if (c_array_syrk1.size() > 0) { - std::vector n(1, nb_syrk_1); - std::vector ldda(1, lda_syrk_1); - std::vector lddc(1, ldc_syrk_1); - blas::batch::syrk( - layout, uplo, opA_, - n, k, - alpha_, a_array_syrk1, ldda, - beta_, c_array_syrk1, lddc, - c_array_syrk1.size(), info, *queue); - } + queue->sync(); } - queue->sync(); - if (call_tile_tick) { // both off-diagonal batch gemm and diagonal syrks are done for (int64_t j = 0; j < C.nt(); ++j) { From 6fad79fd1b7861e7d48d9ff7ce0e874141eb2588 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 19 Oct 2023 16:17:56 -0400 Subject: [PATCH 16/35] Use the batch arrays in genorm --- src/internal/internal_genorm.cc | 61 +++++++++++---------------------- src/norm.cc | 6 ++-- 2 files changed, 23 insertions(+), 44 deletions(-) diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index d2bf6aff0..0e2de7c85 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -374,11 +374,7 @@ void norm( assert(A.num_devices() > 0); - std::vector > a_host_arrays(A.num_devices()); - std::vector > vals_host_arrays(A.num_devices()); - - std::vector a_dev_arrays(A.num_devices()); - std::vector vals_dev_arrays(A.num_devices()); + std::vector> vals_host_arrays( A.num_devices() ); // devices_values used for max and Frobenius norms. std::vector devices_values; @@ -415,20 +411,6 @@ void norm( slate_not_implemented("The NormScope isn't yet supported."); } - // TODO: Why are we doing this? - // Use the batch arrays in the matrix class. - for (int device = 0; device < A.num_devices(); ++device) { - - int64_t num_tiles = A.getMaxDeviceTiles(device); - - a_host_arrays[device].resize(num_tiles); - vals_host_arrays[device].resize(num_tiles*ldv); - - blas::Queue* queue = A.compute_queue(device, queue_index); - a_dev_arrays[device] = blas::device_malloc(num_tiles, *queue); - vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); - } - // Define index ranges for regions of matrix. // Tiles in each region are all the same size. int64_t irange[4][2] = { @@ -447,8 +429,7 @@ void norm( #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ - priority( priority ) shared( A, devices_values ) \ - shared(a_host_arrays, a_dev_arrays, vals_host_arrays, vals_dev_arrays) \ + priority( priority ) shared( A, devices_values, vals_host_arrays ) \ firstprivate(device, irange, jrange, queue_index, ldv, scope, in_norm, layout) { std::set A_tiles_set; @@ -463,8 +444,7 @@ void norm( A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); // Setup batched arguments. - scalar_t** a_host_array = a_host_arrays[device].data(); - scalar_t** a_dev_array = a_dev_arrays[device]; + scalar_t** a_array_host = A.array_host( device, queue_index ); int64_t batch_count = 0; int64_t mb[4], nb[4], lda[4], group_count[4]; @@ -478,7 +458,7 @@ void norm( if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - a_host_array[batch_count] = A(i, j, device).data(); + a_array_host[batch_count] = A(i, j, device).data(); lda[q] = A(i, j, device).stride(); ++group_count[q]; ++batch_count; @@ -487,34 +467,37 @@ void norm( } } - real_t* vals_host_array = vals_host_arrays[device].data(); - real_t* vals_dev_array = vals_dev_arrays[device]; + scalar_t** a_array_dev = A.array_device(device, queue_index); + + int64_t num_tiles = A_tiles_set.size(); + vals_host_arrays[ device ].resize( num_tiles*ldv ); + real_t* vals_host_array = vals_host_arrays[ device ].data(); + blas::Queue* queue = A.compute_queue( device, queue_index ); + real_t* vals_dev_array = blas::device_malloc( num_tiles*ldv, *queue ); // Batched call to compute partial results for each tile. { trace::Block trace_block("slate::device::genorm"); - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_dev_array, a_host_array, + blas::device_memcpy(a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); + real_t* vals_dev_array_group = vals_dev_array; for (int q = 0; q < 4; ++q) { if (group_count[q] > 0) { device::genorm(in_norm, scope, mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, + a_array_dev, lda[q], + vals_dev_array_group, ldv, group_count[q], *queue); - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; + a_array_dev += group_count[q]; + vals_dev_array_group += group_count[q] * ldv; } } - vals_dev_array = vals_dev_arrays[device]; - blas::device_memcpy(vals_host_array, vals_dev_array, batch_count*ldv, blas::MemcpyKind::DeviceToHost, @@ -538,15 +521,11 @@ void norm( } } } + // Free device workspace + blas::device_free(vals_dev_array, *queue); } } - for (int device = 0; device < A.num_devices(); ++device) { - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(a_dev_arrays[device], *queue); - blas::device_free(vals_dev_arrays[device], *queue); - } - if (scope == NormScope::Matrix) { // Reduction over devices to local result. @@ -645,7 +624,7 @@ void norm( } } else { - slate_not_implemented("The NormScope isn't yet supported."); + slate_not_implemented("The Norm isn't yet supported."); } } else { diff --git a/src/norm.cc b/src/norm.cc index dbef90048..3bb4ffe15 100644 --- a/src/norm.cc +++ b/src/norm.cc @@ -43,10 +43,10 @@ norm( else if (A.op() == Op::Trans) A = transpose(A); - // TODO update internal to use these batch arrays - // They're currently just used when transposing tiles if (target == Target::Devices) { - A.allocateBatchArrays(); + const int64_t batch_size_default = 0; + const int64_t num_queues = 1; + A.allocateBatchArrays( batch_size_default, num_queues ); A.reserveDeviceWorkspace(); } From c41f9476bdd8f9550ecf710f6f338ba59c0a1800 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 20 Oct 2023 09:09:45 -0400 Subject: [PATCH 17/35] Start relaxing assumption that tile?b(0) is the largest mb/nb in norm routines --- src/internal/internal_gbnorm.cc | 12 ++++++++++-- src/internal/internal_genorm.cc | 15 +++++++++------ src/internal/internal_hbnorm.cc | 5 ++++- src/internal/internal_henorm.cc | 5 ++++- src/internal/internal_synorm.cc | 5 ++++- src/internal/internal_trnorm.cc | 12 ++++++++++-- 6 files changed, 41 insertions(+), 13 deletions(-) diff --git a/src/internal/internal_gbnorm.cc b/src/internal/internal_gbnorm.cc index 7807b23ec..50d7493c4 100644 --- a/src/internal/internal_gbnorm.cc +++ b/src/internal/internal_gbnorm.cc @@ -300,6 +300,10 @@ void norm( // todo: relax this assumption, a few cases need to be adjusted only const Layout layout = Layout::ColMajor; + if (scope != NormScope::Matrix) { + slate_not_implemented("The NormScope isn't yet supported."); + } + assert(A.num_devices() > 0); std::vector > a_host_arrays(A.num_devices()); @@ -324,10 +328,14 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One) { - ldv = A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } } else if (in_norm == Norm::Inf) { - ldv = A.tileMb(0); + for (int64_t i = 0; i < A.mt(); ++i) { + ldv = std::max( ldv, A.tileMb(i) ); + } } else if (in_norm == Norm::Fro) { ldv = 2; diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 0e2de7c85..2922240b8 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -386,12 +386,14 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One) { - // todo: this assumes all tiles with uniform nb - ldv = A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } } else if (in_norm == Norm::Inf) { - // todo: this assumes all tiles with uniform mb - ldv = A.tileMb(0); + for (int64_t i = 0; i < A.mt(); ++i) { + ldv = std::max( ldv, A.tileMb(i) ); + } } else if (in_norm == Norm::Fro) { ldv = 2; @@ -400,8 +402,9 @@ void norm( } else if (scope == NormScope::Columns) { if (in_norm == Norm::Max) { - // todo: this assumes all tiles with uniform nb - ldv = A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } } else { slate_not_implemented("The NormScope isn't yet supported."); diff --git a/src/internal/internal_hbnorm.cc b/src/internal/internal_hbnorm.cc index 4eb68fb88..157689557 100644 --- a/src/internal/internal_hbnorm.cc +++ b/src/internal/internal_hbnorm.cc @@ -386,7 +386,10 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - ldv = 2*A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } + ldv *= 2; } else if (in_norm == Norm::Fro) { ldv = 2; diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index bab0dca4f..2b2d5e6aa 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -357,7 +357,10 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - ldv = 2*A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } + ldv *= 2; } else if (in_norm == Norm::Fro) { ldv = 2; diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index 5e7bbdd66..f6d2f6b32 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -359,7 +359,10 @@ void norm(internal::TargetType, devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - ldv = 2*A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } + ldv *= 2; } else if (in_norm == Norm::Fro) { ldv = 2; diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index a6e5cdcb2..67e1cdff6 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -356,6 +356,10 @@ void norm( const Layout layout = Layout::ColMajor; using ij_tuple = typename BaseMatrix::ij_tuple; + if (scope != NormScope::Matrix) { + slate_not_implemented("The NormScope isn't yet supported."); + } + assert(A.num_devices() > 0); std::vector > a_host_arrays(A.num_devices()); @@ -373,10 +377,14 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One) { - ldv = A.tileNb(0); + for (int64_t j = 0; j < A.nt(); ++j) { + ldv = std::max( ldv, A.tileNb(j) ); + } } else if (in_norm == Norm::Inf) { - ldv = A.tileMb(0); + for (int64_t i = 0; i < A.mt(); ++i) { + ldv = std::max( ldv, A.tileMb(i) ); + } } else if (in_norm == Norm::Fro) { ldv = 2; From cafffa9a485d494ba2bf8d86531594b2f28e6e1a Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 26 Oct 2023 10:23:12 -0400 Subject: [PATCH 18/35] Add device regions to non-band norms --- src/internal/internal_genorm.cc | 164 +++++++++------------ src/internal/internal_henorm.cc | 230 +++++++++++------------------ src/internal/internal_synorm.cc | 236 ++++++++++++------------------ src/internal/internal_trnorm.cc | 250 ++++++++++++-------------------- 4 files changed, 339 insertions(+), 541 deletions(-) diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 2922240b8..002bbf494 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -414,26 +414,11 @@ void norm( slate_not_implemented("The NormScope isn't yet supported."); } - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ priority( priority ) shared( A, devices_values, vals_host_arrays ) \ - firstprivate(device, irange, jrange, queue_index, ldv, scope, in_norm, layout) + firstprivate(device, queue_index, ldv, scope, in_norm, layout) { std::set A_tiles_set; @@ -447,28 +432,13 @@ void norm( A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); // Setup batched arguments. + int64_t batch_size = A_tiles_set.size(); scalar_t** a_array_host = A.array_host( device, queue_index ); - int64_t batch_count = 0; - int64_t mb[4], nb[4], lda[4], group_count[4]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j)) - { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } - } + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device ); scalar_t** a_array_dev = A.array_device(device, queue_index); @@ -484,25 +454,26 @@ void norm( blas::device_memcpy(a_array_dev, a_array_host, - batch_count, + batch_size, blas::MemcpyKind::HostToDevice, *queue); real_t* vals_dev_array_group = vals_dev_array; - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::genorm(in_norm, scope, - mb[q], nb[q], - a_array_dev, lda[q], - vals_dev_array_group, ldv, - group_count[q], *queue); - a_array_dev += group_count[q]; - vals_dev_array_group += group_count[q] * ldv; - } + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + device::genorm( + in_norm, scope, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); + a_array_dev += group_count; + vals_dev_array_group += group_count * ldv; } blas::device_memcpy(vals_host_array, vals_dev_array, - batch_count*ldv, + batch_size*ldv, blas::MemcpyKind::DeviceToHost, *queue); @@ -513,10 +484,10 @@ void norm( // Reduction over tiles to device result. if (in_norm == Norm::Max) { devices_values[device] = - lapack::lange(in_norm, 1, batch_count, vals_host_array, 1); + lapack::lange(in_norm, 1, batch_size, vals_host_array, 1); } else if (in_norm == Norm::Fro) { - for (int64_t k = 0; k < batch_count; ++k) { + for (int64_t k = 0; k < batch_size; ++k) { combine_sumsq(devices_values[2*device + 0], devices_values[2*device + 1], vals_host_array[2*k + 0], @@ -538,53 +509,55 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); for (int device = 0; device < A.num_devices(); ++device) { real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - for (int q = 0; q < 4; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j)) - { - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[j*ldv], 1); - ++batch_count; - } + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t nb = A.tileNb( jj ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + // TODO this is broken for nonuniform block sizes + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[j*ldv], 1); + ++batch_count; } - } - } + }} // for j,i + }} // for jj,ii } } else if (in_norm == Norm::Inf) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); for (int device = 0; device < A.num_devices(); ++device) { real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - for (int q = 0; q < 4; ++q) { - int64_t mb = A.tileMb(irange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j)) - { - blas::axpy( - mb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[i*ldv], 1); - ++batch_count; - } + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t mb = A.tileMb( ii ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + // TODO this is broken for nonuniform block sizes + blas::axpy( + mb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[i*ldv], 1); + ++batch_count; } - } - } + }} // for j,i + }} // for jj,ii } } else if (in_norm == Norm::Fro) { @@ -600,6 +573,9 @@ void norm( else if (scope == NormScope::Columns) { if (in_norm == Norm::Max) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + // Reduction over devices to local result. // todo: re-arrange loops to be able to issue omp tasks for (int device = 0; device < A.num_devices(); ++device) { @@ -607,23 +583,19 @@ void norm( real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - for (int q = 0; q < 4; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j)) - { - for (int k = 0; k < nb; ++k) { - values[j*ldv + k] = - max_nan(vals_host_array[batch_count*ldv + k], - values[j*ldv + k]); - } - ++batch_count; - } + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t nb = A.tileNb( jj ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + for (int k = 0; k < nb; ++k) { + values[j*ldv + k] = + max_nan(vals_host_array[batch_count*ldv + k], + values[j*ldv + k]); } - } - } + ++batch_count; + }} // for j,i + }} // for jj,ii } } else { diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index 2b2d5e6aa..454a57b59 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -329,7 +329,7 @@ void norm( { using real_t = blas::real_type; - // norms assumes column major + // norms assume column major // todo: relax this assumption, a few cases need to be adjusted only const Layout layout = Layout::ColMajor; using ij_tuple = typename BaseMatrix::ij_tuple; @@ -342,10 +342,8 @@ void norm( assert(A.num_devices() > 0); - std::vector > a_host_arrays(A.num_devices()); std::vector > vals_host_arrays(A.num_devices()); - std::vector a_dev_arrays(A.num_devices()); std::vector vals_dev_arrays(A.num_devices()); // devices_values used for max and Frobenius norms. @@ -370,43 +368,18 @@ void norm( for (int device = 0; device < A.num_devices(); ++device) { int64_t num_tiles = A.getMaxDeviceTiles(device); - a_host_arrays[device].resize(num_tiles); vals_host_arrays[device].resize(num_tiles*ldv); blas::Queue* queue = A.comm_queue(device); - a_dev_arrays[device] = blas::device_malloc(num_tiles, *queue); vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); } - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[6][2] = { - // off-diagonal - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - int64_t jrange[6][2] = { - // off-diagonal - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, a_host_arrays, a_dev_arrays, vals_dev_arrays) \ - firstprivate(device, layout, lower, irange, jrange, queue_index, in_norm, ldv) \ + shared(vals_host_arrays, vals_dev_arrays) \ + firstprivate(device, layout, lower, queue_index, in_norm, ldv) \ priority(priority) { std::set A_tiles_set; @@ -425,49 +398,15 @@ void norm( A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); // Setup batched arguments. - scalar_t** a_host_array = a_host_arrays[device].data(); - scalar_t** a_dev_array = a_dev_arrays[device]; + int64_t batch_size = A_tiles_set.size(); + scalar_t** a_array_host = A.array_host( device, queue_index ); - int64_t batch_count = 0; - int64_t mb[6], nb[6], lda[6], group_count[6]; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( ( lower && i > j) || - (! lower && i < j) )) - { - a_host_array[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(jrange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) - { - a_host_array[batch_count] = A(j, j, device).data(); - lda[q] = A(j, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device ); + + scalar_t** a_array_dev = A.array_device(device, queue_index); real_t* vals_host_array = vals_host_arrays[device].data(); real_t* vals_dev_array = vals_dev_arrays[device]; @@ -478,50 +417,47 @@ void norm( blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_dev_array, a_host_array, - batch_count, + blas::device_memcpy(a_array_dev, a_array_host, + batch_size, blas::MemcpyKind::HostToDevice, *queue); - - // off-diagonal blocks (same as synorm) - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { + real_t* vals_dev_array_group = vals_dev_array; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + if (group_params[ g ].is_diagonal) { + device::henorm( + in_norm, A.uploPhysical(), group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); + } + else { if (in_norm == Norm::One || in_norm == Norm::Inf) { - device::synormOffdiag(in_norm, - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); + device::synormOffdiag( + in_norm, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } else { - device::genorm(in_norm, NormScope::Matrix, - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); + device::genorm( + in_norm, NormScope::Matrix, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - if (group_count[q] > 0) { - device::henorm(in_norm, A.uploPhysical(), - nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; } + a_array_dev += group_count; + vals_dev_array_group += group_count * ldv; + queue->sync(); } - vals_dev_array = vals_dev_arrays[device]; - blas::device_memcpy( vals_host_array, vals_dev_array, - batch_count*ldv, + batch_size*ldv, blas::MemcpyKind::DeviceToHost, *queue); @@ -531,14 +467,17 @@ void norm( // Reduction over tiles to device result. if (in_norm == Norm::Max) { devices_values[device] = - lapack::lange(in_norm, 1, batch_count, vals_host_array, 1); + lapack::lange(in_norm, 1, batch_size, vals_host_array, 1); } else if (in_norm == Norm::Fro) { int64_t cnt = 0; - for (int q = 0; q < 6; ++q) { + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + // double for symmetric entries in off-diagonal blocks - real_t mult = (q < 4 ? 2.0 : 1.0); - for (int64_t k = 0; k < group_count[q]; ++k) { + real_t mult = (group_params[ g ].is_diagonal) ? 1.0 : 2.0; + + for (int64_t k = 0; k < group_count; ++k) { combine_sumsq( devices_values[2*device + 0], devices_values[2*device + 1], @@ -554,7 +493,6 @@ void norm( for (int device = 0; device < A.num_devices(); ++device) { blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(a_dev_arrays[device], *queue); blas::device_free(vals_dev_arrays[device], *queue); } @@ -565,52 +503,56 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + int64_t nb0 = A.tileNb(0); + assert(A.tileNb(0) == A.tileMb(0)); + assert(A.n() == A.m()); + for (int device = 0; device < A.num_devices(); ++device) { + real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - // off-diagonal blocks - int64_t nb0 = A.tileNb(0); - for (int q = 0; q < 4; ++q) { - int64_t mb = A.tileMb(irange[q][0]); - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( ( lower && i > j) || - (! lower && i < j) )) - { - // col sums - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); - // row sums - blas::axpy( - mb, 1.0, - &vals_host_array[batch_count*ldv + nb], 1, - &values[i*nb0], 1); - ++batch_count; - } + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t mb = A.tileMb( irange[ ii ] ); + int64_t nb = A.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) + && ((A.uplo() == Uplo::Lower && i > j) || + (A.uplo() == Uplo::Upper && i < j))) { + + // TODO this is broken for nonuniform block sizes + // col sums + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[j*nb0], 1); + // row sums + blas::axpy( + mb, 1.0, + &vals_host_array[batch_count*ldv + nb], 1, + &values[i*nb0], 1); + ++batch_count; } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) + }} // for j,i + + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal(ij, ij) && + device == A.tileDevice(ij, ij)) { blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); + &values[ij*nb0], 1); ++batch_count; } } - } + }} // for jj,ii } } else if (in_norm == Norm::Fro) { diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index f6d2f6b32..4fabc4ad5 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -344,10 +344,8 @@ void norm(internal::TargetType, assert(A.num_devices() > 0); - std::vector > a_host_arrays(A.num_devices()); std::vector > vals_host_arrays(A.num_devices()); - std::vector a_dev_arrays(A.num_devices()); std::vector vals_dev_arrays(A.num_devices()); // devices_values used for max and Frobenius norms. @@ -372,43 +370,18 @@ void norm(internal::TargetType, for (int device = 0; device < A.num_devices(); ++device) { int64_t num_tiles = A.getMaxDeviceTiles(device); - a_host_arrays[device].resize(num_tiles); vals_host_arrays[device].resize(num_tiles*ldv); blas::Queue* queue = A.comm_queue(device); - a_dev_arrays[device] = blas::device_malloc(num_tiles, *queue); vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); } - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[6][2] = { - // off-diagonal - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - int64_t jrange[6][2] = { - // off-diagonal - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, vals_dev_arrays, a_host_arrays, a_dev_arrays) \ - firstprivate(device, lower, irange, jrange, queue_index, in_norm, ldv, layout) \ + shared(vals_host_arrays, vals_dev_arrays) \ + firstprivate(device, lower, queue_index, in_norm, ldv, layout) \ priority(priority) { std::set A_tiles_set; @@ -428,49 +401,15 @@ void norm(internal::TargetType, A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); // Setup batched arguments. - scalar_t** a_host_array = a_host_arrays[device].data(); - scalar_t** a_dev_array = a_dev_arrays[device]; + int64_t batch_size = A_tiles_set.size(); + scalar_t** a_array_host = A.array_host( device, queue_index ); - int64_t batch_count = 0; - int64_t mb[6], nb[6], lda[6], group_count[6]; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( ( lower && i > j) || - (! lower && i < j) )) - { - a_host_array[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(jrange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) - { - a_host_array[batch_count] = A(j, j, device).data(); - lda[q] = A(j, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device ); + + scalar_t** a_array_dev = A.array_device(device, queue_index); real_t* vals_host_array = vals_host_arrays[device].data(); real_t* vals_dev_array = vals_dev_arrays[device]; @@ -481,51 +420,50 @@ void norm(internal::TargetType, blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_dev_array, a_host_array, - batch_count, + blas::device_memcpy(a_array_dev, a_array_host, + batch_size, blas::MemcpyKind::HostToDevice, *queue); - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { + real_t* vals_dev_array_group = vals_dev_array; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + if (group_params[ g ].is_diagonal) { + device::synorm( + in_norm, A.uploPhysical(), group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); + } + else { if (in_norm == Norm::One || in_norm == Norm::Inf) { - device::synormOffdiag(in_norm, - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); + device::synormOffdiag( + in_norm, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } else { - device::genorm(in_norm, NormScope::Matrix, - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); + device::genorm( + in_norm, NormScope::Matrix, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - if (group_count[q] > 0) { - device::synorm(in_norm, A.uploPhysical(), - nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; } + a_array_dev += group_count; + vals_dev_array_group += group_count * ldv; + queue->sync(); } - vals_dev_array = vals_dev_arrays[device]; - - blas::device_memcpy(vals_host_array, vals_dev_array, - batch_count*ldv, - blas::MemcpyKind::DeviceToHost, - *queue); + blas::device_memcpy( + vals_host_array, vals_dev_array, + batch_size*ldv, + blas::MemcpyKind::DeviceToHost, + *queue); queue->sync(); } @@ -533,14 +471,17 @@ void norm(internal::TargetType, // Reduction over tiles to device result. if (in_norm == Norm::Max) { devices_values[device] = - lapack::lange(in_norm, 1, batch_count, vals_host_array, 1); + lapack::lange(in_norm, 1, batch_size, vals_host_array, 1); } else if (in_norm == Norm::Fro) { int64_t cnt = 0; - for (int q = 0; q < 6; ++q) { + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + // double for symmetric entries in off-diagonal blocks - real_t mult = (q < 4 ? 2.0 : 1.0); - for (int64_t k = 0; k < group_count[q]; ++k) { + real_t mult = (group_params[ g ].is_diagonal) ? 1.0 : 2.0; + + for (int64_t k = 0; k < group_count; ++k) { combine_sumsq( devices_values[2*device + 0], devices_values[2*device + 1], @@ -556,8 +497,7 @@ void norm(internal::TargetType, for (int device = 0; device < A.num_devices(); ++device) { blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(a_dev_arrays[device], *queue); - blas::device_free((void*)vals_dev_arrays[device], *queue); + blas::device_free(vals_dev_arrays[device], *queue); } // Reduction over devices to local result. @@ -567,52 +507,56 @@ void norm(internal::TargetType, devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + int64_t nb0 = A.tileNb(0); + assert(A.tileNb(0) == A.tileMb(0)); + assert(A.n() == A.m()); + for (int device = 0; device < A.num_devices(); ++device) { + real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - // off-diagonal blocks - int64_t nb0 = A.tileNb(0); - for (int q = 0; q < 4; ++q) { - int64_t mb = A.tileMb(irange[q][0]); - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( ( lower && i > j) || - (! lower && i < j) )) - { - // col sums - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); - // row sums - blas::axpy( - mb, 1.0, - &vals_host_array[batch_count*ldv + nb], 1, - &values[i*nb0], 1); - ++batch_count; - } + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t mb = A.tileMb( irange[ ii ] ); + int64_t nb = A.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) + && ((A.uplo() == Uplo::Lower && i > j) || + (A.uplo() == Uplo::Upper && i < j))) { + + // TODO this is broken for nonuniform block sizes + // col sums + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[j*nb0], 1); + // row sums + blas::axpy( + mb, 1.0, + &vals_host_array[batch_count*ldv + nb], 1, + &values[i*nb0], 1); + ++batch_count; } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) + }} // for j,i + + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal(ij, ij) && + device == A.tileDevice(ij, ij)) { blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); + &values[ij*nb0], 1); ++batch_count; } } - } + }} // for jj,ii } } else if (in_norm == Norm::Fro) { diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index 67e1cdff6..9e34821d2 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -362,10 +362,8 @@ void norm( assert(A.num_devices() > 0); - std::vector > a_host_arrays(A.num_devices()); std::vector > vals_host_arrays(A.num_devices()); - std::vector a_dev_arrays(A.num_devices()); std::vector vals_dev_arrays(A.num_devices()); // devices_values used for max and Frobenius norms. @@ -394,43 +392,18 @@ void norm( for (int device = 0; device < A.num_devices(); ++device) { int64_t num_tiles = A.getMaxDeviceTiles(device); - a_host_arrays[device].resize(num_tiles); vals_host_arrays[device].resize(num_tiles*ldv); blas::Queue* queue = A.comm_queue(device); - a_dev_arrays[device] = blas::device_malloc(num_tiles, *queue); vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); } - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[6][2] = { - // off-diagonal - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - int64_t jrange[6][2] = { - // off-diagonal - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() }, - // diagonal - { 0, std::min(A.mt(), A.nt())-1 }, - { std::min(A.mt(), A.nt())-1, std::min(A.mt(), A.nt()) } - }; - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, vals_dev_arrays, a_host_arrays, a_dev_arrays) \ - firstprivate(device, irange, jrange, queue_index, in_norm, ldv, layout) \ + shared(vals_host_arrays, vals_dev_arrays) \ + firstprivate(device, queue_index, in_norm, ldv, layout) \ priority(priority) { std::set A_tiles_set; @@ -449,49 +422,15 @@ void norm( A.tileGetForReading(A_tiles_set, device, LayoutConvert(layout)); // Setup batched arguments. - scalar_t** a_host_array = a_host_arrays[device].data(); - scalar_t** a_dev_array = a_dev_arrays[device]; + int64_t batch_size = A_tiles_set.size(); + scalar_t** a_array_host = A.array_host( device, queue_index ); - int64_t batch_count = 0; - int64_t mb[6], nb[6], lda[6], group_count[6]; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( (A.uplo() == Uplo::Lower && i > j) || - (A.uplo() == Uplo::Upper && i < j) )) - { - a_host_array[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(jrange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) - { - a_host_array[batch_count] = A(j, j, device).data(); - lda[q] = A(j, j, device).stride(); - ++group_count[q]; - ++batch_count; - } - } - } + auto group_params = device_regions_build( + {A}, + {a_array_host}, + device ); + + scalar_t** a_array_dev = A.array_device(device, queue_index); real_t* vals_host_array = vals_host_arrays[device].data(); real_t* vals_dev_array = vals_dev_arrays[device]; @@ -502,40 +441,39 @@ void norm( blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_dev_array, a_host_array, - batch_count, + blas::device_memcpy(a_array_dev, a_array_host, + batch_size, blas::MemcpyKind::HostToDevice, *queue); + queue->sync(); - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::genorm(in_norm, NormScope::Matrix, - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; + real_t* vals_dev_array_group = vals_dev_array; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + if (group_params[ g ].is_diagonal) { + device::trnorm( + in_norm, A.uplo(), A.diag(), + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - if (group_count[q] > 0) { - device::trnorm(in_norm, A.uplo(), A.diag(), - mb[q], nb[q], - a_dev_array, lda[q], - vals_dev_array, ldv, - group_count[q], *queue); - a_dev_array += group_count[q]; - vals_dev_array += group_count[q] * ldv; + else { + device::genorm( + in_norm, NormScope::Matrix, + group_params[ g ].mb, group_params[ g ].nb, + a_array_dev, group_params[ g ].ld[0], + vals_dev_array_group, ldv, + group_count, *queue ); } + a_array_dev += group_count; + vals_dev_array_group += group_count * ldv; + queue->sync(); } - vals_dev_array = vals_dev_arrays[device]; - blas::device_memcpy(vals_host_array, vals_dev_array, - batch_count*ldv, + batch_size*ldv, blas::MemcpyKind::DeviceToHost, *queue); @@ -545,10 +483,10 @@ void norm( // Reduction over tiles to device result. if (in_norm == Norm::Max) { devices_values[device] = - lapack::lange(in_norm, 1, batch_count, vals_host_array, 1); + lapack::lange(in_norm, 1, batch_size, vals_host_array, 1); } else if (in_norm == Norm::Fro) { - for (int64_t k = 0; k < batch_count; ++k) { + for (int64_t k = 0; k < batch_size; ++k) { combine_sumsq(devices_values[2*device + 0], devices_values[2*device + 1], vals_host_array[2*k + 0], @@ -561,7 +499,6 @@ void norm( for (int device = 0; device < A.num_devices(); ++device) { blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(a_dev_arrays[device], *queue); blas::device_free(vals_dev_arrays[device], *queue); } @@ -572,85 +509,88 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + for (int device = 0; device < A.num_devices(); ++device) { + real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( (A.uplo() == Uplo::Lower && i > j) || - (A.uplo() == Uplo::Upper && i < j) )) - { - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[j*ldv], 1); - ++batch_count; - } - } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - int64_t nb = A.tileNb(jrange[q][0]); - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(j, j) && - device == A.tileDevice(j, j)) - { + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t nb = A.tileNb( jj ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) + && ((A.uplo() == Uplo::Lower && i > j) || + (A.uplo() == Uplo::Upper && i < j))) { + + // TODO this is broken for nonuniform block sizes blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, &values[j*ldv], 1); ++batch_count; } + }} // for j,i + + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal(ij, ij) && device == A.tileDevice(ij, ij)) { + + // TODO this is broken for nonuniform block sizes + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[ij*ldv], 1); + ++batch_count; + } } - } + }} // for jj,ii } } else if (in_norm == Norm::Inf) { + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + for (int device = 0; device < A.num_devices(); ++device) { + real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - // off-diagonal blocks - for (int q = 0; q < 4; ++q) { - int64_t mb = A.tileMb(irange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && - device == A.tileDevice(i, j) && - ( (A.uplo() == Uplo::Lower && i > j) || - (A.uplo() == Uplo::Upper && i < j) )) - { - blas::axpy( - mb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[i*ldv], 1); - ++batch_count; - } - } - } - } - // diagonal blocks - for (int q = 4; q < 6; ++q) { - int64_t mb = A.tileMb(irange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - if (A.tileIsLocal(i, i) && - device == A.tileDevice(i, i)) - { + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + int64_t nb = A.tileMb( jj ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) + && ((A.uplo() == Uplo::Lower && i > j) || + (A.uplo() == Uplo::Upper && i < j))) { + blas::axpy( - mb, 1.0, + nb, 1.0, &vals_host_array[batch_count*ldv], 1, &values[i*ldv], 1); ++batch_count; } + }} // for j,i + + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal(ij, ij) && + device == A.tileDevice(ij, ij)) { + + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[ij*ldv], 1); + ++batch_count; + } } - } + }} // for jj,ii } } else if (in_norm == Norm::Fro) { From cf8d66b9ed2b0748a58640079cd4b7bc55fd5916 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 26 Oct 2023 15:14:56 -0400 Subject: [PATCH 19/35] Move an argument for device regions into a template --- src/internal/internal_batch.hh | 43 ++++++++++++------------ src/internal/internal_gecopy.cc | 1 - src/internal/internal_gescale_row_col.cc | 1 - src/internal/internal_geset.cc | 9 +++-- src/internal/internal_tzcopy.cc | 1 - 5 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index 9805a067a..7969ca894 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -186,22 +186,22 @@ std::vector device_regions_range( bool want_rows, BaseMatrix& //------------------------------------------------------------------------------ /// Helper class to store the information on a device region. /// -/// @tparam has_diag +/// @tparam store_diag /// Wheather the diagonal tiles may need to be special cased /// /// @tparam mat_count /// The number of matrices used by the kernel /// -template< bool has_diag, int mat_count > +template< bool store_diag, int mat_count > struct device_regions_params { int64_t count, mb, nb; int64_t ld[mat_count]; private: - // When has_diag is false, we don't want to allocate any memory for is_diagonal + // When store_diag is false, we don't want to allocate any memory for is_diagonal struct Empty {}; public: - std::conditional_t< has_diag, bool, Empty > is_diagonal; + std::conditional_t< store_diag, bool, Empty > is_diagonal; device_regions_params() : count(0), mb(0), nb(0) @@ -209,7 +209,7 @@ public: for (int i = 0; i < mat_count; ++i) { ld[i] = 0; } - if constexpr (has_diag) { + if constexpr (store_diag) { is_diagonal = false; } } @@ -218,12 +218,19 @@ public: //------------------------------------------------------------------------------ /// Computes and populates the regions for the given matrices. /// -/// @tparam has_diag +/// @tparam store_diag /// Wheather the diagonal tiles may need to be special cased /// /// @tparam mat_count /// The number of matrices used by the kernel /// +/// @tparam scalar_t +/// The type of the matrices +/// +/// @param[in] diag_same +/// Whether to treat the diagonal tiles as normal tiles. +/// If false, store_diag must be true +/// /// @param[in] mats /// An array of the matrices to build regions for /// @@ -233,26 +240,21 @@ public: /// @param[in] device /// The device to build regions for /// -/// @param[in] diag_same -/// Whether to treat the diagonal tiles as normal tiles in spite of has_diag -/// Ignored when has_diag is false. -/// /// @param[in] extra_setup /// Callback that is called whenever a tile is added to a group. /// The group index and the tile indices are passed as arguments /// -template< bool has_diag, int mat_count, typename scalar_t> -std::vector< device_regions_params > device_regions_build( +template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag > +std::vector< device_regions_params > device_regions_build( std::array< std::reference_wrapper>, mat_count > mats, std::array< scalar_t**, mat_count > mats_array_host, int64_t device, - bool diag_same = true, std::function extra_setup = {}) { // The first two arguments should be valid targets for brace-initialization // reference_wrapper works around fact that C++ doesn't allow array of references - using Params = device_regions_params; + using Params = device_regions_params; auto& A = mats[0].get(); @@ -261,11 +263,10 @@ std::vector< device_regions_params > device_regions_build( std::vector< int64_t > jrange = device_regions_range( false, A ); // Trapezoidal matrices always need special treatment for diagonal tiles - diag_same &= A.uplo() == Uplo::General; + assert( !diag_same || A.uplo() == Uplo::General ); - // Can't treat diagonals special when we can't store the diagonal status - assert( diag_same || has_diag ); - diag_same |= !has_diag; // Ensure the compiler can propagate this assertion + static_assert( diag_same || store_diag, + "Can't special case the diagonal when is_diagonal is not allocated" ); // Size 1 dimensions get broadcast to allow setting up GEMM et al. // i_step[m]=0 results in only accessing row 0 of matrix m (likewise for j) @@ -297,7 +298,7 @@ std::vector< device_regions_params > device_regions_build( int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); for (int64_t i = istart; i < iend; ++i) { - if ((!has_diag || diag_same || i != j) + if ((diag_same || i != j) && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { // Add tiles to current group @@ -325,7 +326,7 @@ std::vector< device_regions_params > device_regions_build( } // If the diagonal tiles need special treatment, build those groups - if constexpr (has_diag) if (!diag_same) { + if constexpr (store_diag && !diag_same) { // Loop over the diagonal tiles in this region. If any should be // computed on this process & device, save them. group = Params(); @@ -362,7 +363,7 @@ std::vector< device_regions_params > device_regions_build( if (group.count > 0) { group_params.push_back( group ); } - } // if has_diag && !diag_same + } // if store_diag && !diag_same }} // for jj, ii return group_params; } diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index 4738e1b5d..2f334c30a 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -198,7 +198,6 @@ void copy(internal::TargetType, {B}, {b_array_host}, device, - true, setup_A ); // Usually the output matrix (B) provides all the batch arrays. diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index 75d8be0f3..f092b9caa 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -143,7 +143,6 @@ void scale_row_col( {A}, {a_array_host}, device, - true, store_rc ); diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 5c264ab3f..5d044b4e9 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -124,10 +124,15 @@ void set(internal::TargetType, // in one batch. bool diag_same = offdiag_value == diag_value; - auto group_params = device_regions_build( + auto group_params = diag_same + ? device_regions_build( {A}, {a_array_host}, - device, diag_same ); + device ) + : device_regions_build( + {A}, + {a_array_host}, + device ); blas::Queue* queue = A.compute_queue( device, queue_index ); scalar_t** a_array_dev = A.array_device( device, queue_index ); diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index 5132cf5fe..a848b29fa 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -169,7 +169,6 @@ void copy(internal::TargetType, {B}, {b_array_host}, device, - true, setup_A ); // Usually the output matrix (B) provides all the batch arrays. From 264df7934eb06b29ffbec9c3d195a77287948264 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 27 Oct 2023 09:08:28 -0400 Subject: [PATCH 20/35] Fix non-band norms for variable block sizes --- src/internal/internal_genorm.cc | 12 +++++++----- src/internal/internal_henorm.cc | 11 ++++++----- src/internal/internal_synorm.cc | 10 +++++----- src/internal/internal_trnorm.cc | 20 ++++++++++---------- src/internal/internal_util.hh | 15 +++++++++++++++ 5 files changed, 43 insertions(+), 25 deletions(-) diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 002bbf494..4e87e0734 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -5,6 +5,7 @@ #include "slate/internal/device.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" #include "internal/internal.hh" #include "slate/internal/util.hh" #include "slate/Matrix.hh" @@ -511,6 +512,7 @@ void norm( else if (in_norm == Norm::One) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); + auto joffsets = tile_offsets( false, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -523,11 +525,10 @@ void norm( for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - // TODO this is broken for nonuniform block sizes blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*ldv], 1); + &values[ joffsets[j] ], 1); ++batch_count; } }} // for j,i @@ -537,6 +538,7 @@ void norm( else if (in_norm == Norm::Inf) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); + auto ioffsets = tile_offsets( true, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -549,11 +551,10 @@ void norm( for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { - // TODO this is broken for nonuniform block sizes blas::axpy( mb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[i*ldv], 1); + &values[ ioffsets[i] ], 1); ++batch_count; } }} // for j,i @@ -575,6 +576,7 @@ void norm( if (in_norm == Norm::Max) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); + auto joffsets = tile_offsets( false, A ); // Reduction over devices to local result. // todo: re-arrange loops to be able to issue omp tasks @@ -591,7 +593,7 @@ void norm( for (int k = 0; k < nb; ++k) { values[j*ldv + k] = max_nan(vals_host_array[batch_count*ldv + k], - values[j*ldv + k]); + values[ joffsets[j] + k]); } ++batch_count; }} // for j,i diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index 454a57b59..9571f5712 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -6,6 +6,7 @@ #include "slate/internal/device.hh" #include "internal/internal.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" #include "slate/internal/util.hh" #include "slate/HermitianMatrix.hh" #include "internal/Tile_lapack.hh" @@ -421,6 +422,7 @@ void norm( batch_size, blas::MemcpyKind::HostToDevice, *queue); + real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { int64_t group_count = group_params[ g ].count; @@ -505,7 +507,7 @@ void norm( else if (in_norm == Norm::One || in_norm == Norm::Inf) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); - int64_t nb0 = A.tileNb(0); + auto ioffsets = tile_offsets( true, A ); assert(A.tileNb(0) == A.tileMb(0)); assert(A.n() == A.m()); @@ -524,17 +526,16 @@ void norm( && ((A.uplo() == Uplo::Lower && i > j) || (A.uplo() == Uplo::Upper && i < j))) { - // TODO this is broken for nonuniform block sizes // col sums blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); + &values[ ioffsets[j] ], 1); // row sums blas::axpy( mb, 1.0, &vals_host_array[batch_count*ldv + nb], 1, - &values[i*nb0], 1); + &values[ ioffsets[i] ], 1); ++batch_count; } }} // for j,i @@ -548,7 +549,7 @@ void norm( blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[ij*nb0], 1); + &values[ ioffsets[ij] ], 1); ++batch_count; } } diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index 4fabc4ad5..418d99b66 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -5,6 +5,7 @@ #include "slate/internal/device.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" #include "internal/internal.hh" #include "slate/internal/util.hh" #include "slate/SymmetricMatrix.hh" @@ -509,7 +510,7 @@ void norm(internal::TargetType, else if (in_norm == Norm::One || in_norm == Norm::Inf) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); - int64_t nb0 = A.tileNb(0); + auto ioffsets = tile_offsets( true, A ); assert(A.tileNb(0) == A.tileMb(0)); assert(A.n() == A.m()); @@ -528,17 +529,16 @@ void norm(internal::TargetType, && ((A.uplo() == Uplo::Lower && i > j) || (A.uplo() == Uplo::Upper && i < j))) { - // TODO this is broken for nonuniform block sizes // col sums blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*nb0], 1); + &values[ ioffsets[j] ], 1); // row sums blas::axpy( mb, 1.0, &vals_host_array[batch_count*ldv + nb], 1, - &values[i*nb0], 1); + &values[ ioffsets[i] ], 1); ++batch_count; } }} // for j,i @@ -552,7 +552,7 @@ void norm(internal::TargetType, blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[ij*nb0], 1); + &values[ ioffsets[ij] ], 1); ++batch_count; } } diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index 9e34821d2..679c906ae 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -5,6 +5,7 @@ #include "slate/internal/device.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" #include "internal/internal.hh" #include "slate/internal/util.hh" #include "slate/TrapezoidMatrix.hh" @@ -445,7 +446,6 @@ void norm( batch_size, blas::MemcpyKind::HostToDevice, *queue); - queue->sync(); real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { @@ -511,6 +511,7 @@ void norm( else if (in_norm == Norm::One) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); + auto joffsets = tile_offsets( false, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -526,11 +527,10 @@ void norm( && ((A.uplo() == Uplo::Lower && i > j) || (A.uplo() == Uplo::Upper && i < j))) { - // TODO this is broken for nonuniform block sizes blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[j*ldv], 1); + &values[ joffsets[j] ], 1); ++batch_count; } }} // for j,i @@ -540,11 +540,10 @@ void norm( for (int64_t ij = ijstart; ij < ijend; ++ij) { if (A.tileIsLocal(ij, ij) && device == A.tileDevice(ij, ij)) { - // TODO this is broken for nonuniform block sizes blas::axpy( nb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[ij*ldv], 1); + &values[ joffsets[ij] ], 1); ++batch_count; } } @@ -554,6 +553,7 @@ void norm( else if (in_norm == Norm::Inf) { auto irange = device_regions_range( true, A ); auto jrange = device_regions_range( false, A ); + auto ioffsets = tile_offsets( true, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -562,7 +562,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t nb = A.tileMb( jj ); + int64_t mb = A.tileMb( ii ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) @@ -570,9 +570,9 @@ void norm( (A.uplo() == Uplo::Upper && i < j))) { blas::axpy( - nb, 1.0, + mb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[i*ldv], 1); + &values[ ioffsets[i] ], 1); ++batch_count; } }} // for j,i @@ -584,9 +584,9 @@ void norm( device == A.tileDevice(ij, ij)) { blas::axpy( - nb, 1.0, + mb, 1.0, &vals_host_array[batch_count*ldv], 1, - &values[ij*ldv], 1); + &values[ ioffsets[ij] ], 1); ++batch_count; } } diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 11b080067..2fd909b1a 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -109,6 +109,21 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, return V; } +template +std::vector tile_offsets(bool want_rows, slate::BaseMatrix& A) +{ + int64_t kt = want_rows ? A.mt() : A.nt(); + + std::vector< int64_t > offset_list; + offset_list.reserve( kt ); + + int64_t offset = 0; + for (int64_t k = 0; k < kt; ++k) { + offset_list.push_back( offset ); + offset += want_rows ? A.tileMb( k ) : A.tileNb( k ); + } + return offset_list; +} } // namespace internal From b7cfdcbd5d446ddb75f3e1f0e7e3f1fc4e1e7d72 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 27 Oct 2023 10:41:21 -0400 Subject: [PATCH 21/35] Improve batch regions docs and remove some unneeded includes --- src/internal/internal_batch.hh | 7 ++++--- src/internal/internal_gebr.cc | 1 - src/internal/internal_hebr.cc | 1 - src/internal/internal_hemm.cc | 1 - src/internal/internal_symm.cc | 1 - 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index 7969ca894..bf10f6b96 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -228,7 +228,7 @@ public: /// The type of the matrices /// /// @param[in] diag_same -/// Whether to treat the diagonal tiles as normal tiles. +/// Whether to include the diagonal tiles in the off-diagonal groups /// If false, store_diag must be true /// /// @param[in] mats @@ -283,10 +283,11 @@ std::vector< device_regions_params > device_regions_build int64_t batch_count = 0; int64_t mt = A.mt(); std::vector group_params; + // loop over regions for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - // Loop over the tiles in this region. If any should be computed on this - // process & device, save them. + // Loop over the tiles in this region, + // save any that should be computed on this process & device Params group; group.mb = A.tileMb( irange[ ii ] ); group.nb = A.tileNb( jrange[ jj ] ); diff --git a/src/internal/internal_gebr.cc b/src/internal/internal_gebr.cc index 04ba34835..72f72a70c 100644 --- a/src/internal/internal_gebr.cc +++ b/src/internal/internal_gebr.cc @@ -4,7 +4,6 @@ // the terms of the BSD 3-Clause license. See the accompanying LICENSE file. #include "slate/internal/device.hh" -#include "internal/internal_batch.hh" #include "internal/internal.hh" #include "slate/internal/util.hh" #include "slate/Matrix.hh" diff --git a/src/internal/internal_hebr.cc b/src/internal/internal_hebr.cc index 9d51d9521..984aa6f16 100644 --- a/src/internal/internal_hebr.cc +++ b/src/internal/internal_hebr.cc @@ -4,7 +4,6 @@ // the terms of the BSD 3-Clause license. See the accompanying LICENSE file. #include "slate/internal/device.hh" -#include "internal/internal_batch.hh" #include "internal/internal.hh" #include "slate/internal/util.hh" #include "slate/HermitianMatrix.hh" diff --git a/src/internal/internal_hemm.cc b/src/internal/internal_hemm.cc index 81dba32d2..93b5b60c5 100644 --- a/src/internal/internal_hemm.cc +++ b/src/internal/internal_hemm.cc @@ -8,7 +8,6 @@ #include "slate/types.hh" #include "slate/Tile_blas.hh" #include "internal/internal.hh" -#include "internal/internal_batch.hh" namespace slate { namespace internal { diff --git a/src/internal/internal_symm.cc b/src/internal/internal_symm.cc index 3245e67cc..24aacfc83 100644 --- a/src/internal/internal_symm.cc +++ b/src/internal/internal_symm.cc @@ -8,7 +8,6 @@ #include "slate/types.hh" #include "slate/Tile_blas.hh" #include "internal/internal.hh" -#include "internal/internal_batch.hh" namespace slate { namespace internal { From 1a35bfaac6803888728641c320458fc63672591b Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Fri, 27 Oct 2023 16:09:23 -0400 Subject: [PATCH 22/35] Add device regions to gemmA --- src/internal/internal_gemmA.cc | 355 +++++---------------------------- 1 file changed, 52 insertions(+), 303 deletions(-) diff --git a/src/internal/internal_gemmA.cc b/src/internal/internal_gemmA.cc index e32903863..1459a9d8f 100644 --- a/src/internal/internal_gemmA.cc +++ b/src/internal/internal_gemmA.cc @@ -390,7 +390,7 @@ void gemmA(internal::TargetType, } } - int64_t batch_size = C_tiles_set.size(); + int64_t batch_size = A_tiles_set.size(); if (batch_size > 0) { #pragma omp taskgroup @@ -415,329 +415,78 @@ void gemmA(internal::TargetType, } } - // interior, first column, and excluding bottom row - std::vector a_array0_; - std::vector b_array0_; - std::vector c_array0_; - a_array0_.reserve( batch_size ); - b_array0_.reserve( batch_size ); - c_array0_.reserve( batch_size ); - - int64_t lda0_ = 0; - int64_t ldb0_ = 0; - int64_t ldc0_ = 0; - int64_t mb0_ = C.tileMb( 0 ); - int64_t nb0_ = C.tileNb( 0 ); - int64_t kb = A.tileNb( 0 ); - { - if (A.nt() > 1) { - int j = 0; - for (int64_t i = 0; i < A.mt()-1; ++i) { - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_array0_.push_back( - A( i, j, device ).data() ); - b_array0_.push_back( - B( j, 0, device ).data() ); - c_array0_.push_back( - C( i, 0, device ).data() ); - lda0_ = A( i, j, device ).stride(); - ldb0_ = B( j, 0, device ).stride(); - ldc0_ = C( i, 0, device ).stride(); - } - } - } - } - } - - // bottom row, first column - scalar_t* a_array1_ = nullptr; - scalar_t* b_array1_ = nullptr; - scalar_t* c_array1_ = nullptr; - - int64_t lda1_ = 0; - int64_t ldb1_ = 0; - int64_t ldc1_ = 0; - int64_t mb1_ = C.tileMb( C.mt()-1 ); - int64_t nb1_ = C.tileNb( 0 ); - // same kb as above - { - if (A.nt() > 1) { - int64_t i = A.mt()-1; - int j = 0; - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_array1_ = A( i, j, device ).data(); - b_array1_ = B( j, 0, device ).data(); - c_array1_ = C( i, 0, device ).data(); - lda1_ = A( i, j, device ).stride(); - ldb1_ = B( j, 0, device ).stride(); - ldc1_ = C( i, 0, device ).stride(); - } - } - } - } - - // interior, excluding first column, bottom row, - // and right column - std::vector< std::vector< scalar_t* > > a_array00j; - std::vector< std::vector< scalar_t* > > b_array00j; - std::vector< std::vector< scalar_t* > > c_array00j; - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb( 0 ); - int64_t nb00 = C.tileNb( 0 ); - int64_t a00i_batch_size = A.mt() - 1; - if (A.nt() > 1) { - a_array00j.reserve( A.nt() - 2 ); - b_array00j.reserve( A.nt() - 2 ); - c_array00j.reserve( A.nt() - 2 ); - for (int64_t j = 1; j < A.nt()-1; ++j) { - std::vector a_tmp; - std::vector b_tmp; - std::vector c_tmp; - a_tmp.reserve( a00i_batch_size ); - b_tmp.reserve( a00i_batch_size ); - c_tmp.reserve( a00i_batch_size ); - for (int64_t i = 0; i < A.mt()-1; ++i) { - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_tmp.push_back( A( i, j, device ).data() ); - b_tmp.push_back( B( j, 0, device ).data() ); - c_tmp.push_back( C( i, 0, device ).data() ); - lda00 = A( i, j, device ).stride(); - ldb00 = B( j, 0, device ).stride(); - ldc00 = C( i, 0, device ).stride(); - } - } - } - if (a_tmp.size() > 0) { - a_array00j.push_back( std::move(a_tmp) ); - b_array00j.push_back( std::move(b_tmp) ); - c_array00j.push_back( std::move(c_tmp) ); - } - } - } - - // bottom row, excluding first and last columns - std::vector< scalar_t* > a_array10j; - std::vector< scalar_t* > b_array10j; - std::vector< scalar_t* > c_array10j; - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb( C.mt()-1 ); - int64_t nb10 = C.tileNb( 0 ); - // same kb as above - if (A.nt() > 1) { - a_array10j.reserve( A.nt() - 2 ); - b_array10j.reserve( A.nt() - 2 ); - c_array10j.reserve( A.nt() - 2 ); - { - int64_t i = A.mt()-1; - for (int64_t j = 1; j < A.nt()-1; ++j) { - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_array10j.push_back( - A( i, j, device ).data() ); - b_array10j.push_back( - B( j, 0, device ).data() ); - c_array10j.push_back( - C( i, 0, device ).data() ); - lda10 = A( i, j, device ).stride(); - ldb10 = B( j, 0, device ).stride(); - ldc10 = C( i, 0, device ).stride(); - } - } - } - } - } - - // right column - std::vector a_array01; - std::vector b_array01; - std::vector c_array01; - a_array01.reserve( batch_size ); - b_array01.reserve( batch_size ); - c_array01.reserve( batch_size ); - - int64_t lda01 = 0; - int64_t ldb01 = 0; - int64_t ldc01 = 0; - int64_t mb01 = C.tileMb( 0 ); - int64_t nb01 = C.tileNb( C.nt()-1 ); - int64_t kb1 = A.tileNb( A.nt()-1 ); - { - int64_t j = A.nt()-1; - for (int64_t i = 0; i < A.mt()-1; ++i) { - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_array01.push_back( A( i, j, device ).data() ); - b_array01.push_back( B( j, 0, device ).data() ); - c_array01.push_back( C( i, 0, device ).data() ); - lda01 = A( i, j, device ).stride(); - ldb01 = B( j, 0, device ).stride(); - ldc01 = C( i, 0, device ).stride(); - } - } - } - } - - // bottom-right corner - scalar_t* a_array11 = nullptr; - scalar_t* b_array11 = nullptr; - scalar_t* c_array11 = nullptr; - - int64_t lda11 = 0; - int64_t ldb11 = 0; - int64_t ldc11 = 0; - int64_t mb11 = C.tileMb( C.mt()-1 ); - int64_t nb11 = C.tileNb( C.nt()-1 ); - // same kb1 as above - { - int64_t i = A.mt()-1; - int64_t j = A.nt()-1; - if (A.tileIsLocal( i, j )) { - if (device == A.tileDevice( i, j )) { - a_array11 = A( i, j, device ).data(); - b_array11 = B( j, 0, device ).data(); - c_array11 = C( i, 0, device ).data(); - lda11 = A( i, j, device ).stride(); - ldb11 = B( j, 0, device ).stride(); - ldc11 = C( i, 0, device ).stride(); - } - } - } + // Use A's batched arrays since C's may be too small + scalar_t** a_array_host = A.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; if (C.op() != Op::NoTrans) { // swap A <=> B; swap m <=> n swap( opA, opB ); - swap( a_array0_, b_array0_ ); - swap( a_array1_, b_array1_ ); - swap( a_array00j, b_array00j ); - swap( a_array10j, b_array10j ); - swap( a_array01, b_array01 ); - swap( a_array11, b_array11 ); - swap( lda0_, ldb0_ ); - swap( lda1_, ldb1_ ); - swap( lda00, ldb00 ); - swap( lda10, ldb10 ); - swap( lda01, ldb01 ); - swap( lda11, ldb11 ); - swap( mb0_, nb0_ ); - swap( mb1_, nb1_ ); - swap( mb00, nb00 ); - swap( mb10, nb10 ); - swap( mb01, nb01 ); - swap( mb11, nb11 ); } - { + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector alpha_(1, alpha); + std::vector beta_(1, beta); + // info size 0 disables slow checks in batched BLAS++. + std::vector info; + + for (int64_t j = 0; j < A.nt(); ++j) { + auto A_j = A.sub( 0, A.mt()-1, j, j ); + auto B_j = B.sub( j, j, 0, 0 ); + // A comes first since we do computation for a local A + auto group_params = device_regions_build( + {A_j, B_j, C}, + {a_array_host, b_array_host, c_array_host}, + device ); + trace::Block trace_block("blas::batch::gemm"); - std::vector opA_( 1, opA ); - std::vector opB_( 1, opB ); - std::vector alpha_( 1, alpha ); - std::vector beta0_( 1, beta ); - std::vector beta1_( 1, beta ); - std::vector k( 1, kb ); - std::vector k1( 1, kb1 ); - // info size 0 disables slow checks in batched BLAS++. - std::vector info; - - if (c_array0_.size() > 0) { - std::vector m( 1, mb0_ ); - std::vector n( 1, nb0_ ); - std::vector ldda( 1, lda0_ ); - std::vector lddb( 1, ldb0_ ); - std::vector lddc( 1, ldc0_ ); - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array0_, ldda, - b_array0_, lddb, - beta0_, c_array0_, lddc, - c_array0_.size(), info, *queue ); + std::vector k(1, A.tileNb(j)); - beta0_[ 0 ] = one; - } + for (size_t g = 0; g < group_params.size(); ++g) { - if (c_array1_ != nullptr) { - blas::gemm( - layout, opA, opB, - mb1_, nb1_, kb, - alpha, a_array1_, lda1_, - b_array1_, ldb1_, - beta1_[0], c_array1_, ldc1_, - *queue ); + int64_t group_count = group_params[ g ].count; - beta1_[ 0 ] = one; - } + std::vector m(1, group_params[ g ].mb); + std::vector n(1, C.tileNb(0)); + std::vector ldda(1, group_params[ g ].ld[0]); + std::vector lddb(1, group_params[ g ].ld[1]); + std::vector lddc(1, group_params[ g ].ld[2]); - if (c_array00j.size() > 0) { - std::vector m( 1, mb00 ); - std::vector n( 1, nb00 ); - std::vector ldda( 1, lda00 ); - std::vector lddb( 1, ldb00 ); - std::vector lddc( 1, ldc00 ); - for (size_t j = 0; j < c_array00j.size(); ++j) { - blas::batch::gemm( - layout, opA_, opB_, - m, n, k, - alpha_, a_array00j[ j ], ldda, - b_array00j[ j ], lddb, - beta0_, c_array00j[ j ], lddc, - c_array00j[ j ].size(), info, *queue ); - beta0_[ 0 ] = one; - } - } + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - if (c_array10j.size() > 0) { - std::vector m( 1, mb10 ); - std::vector n( 1, nb10 ); - std::vector ldda( 1, lda10 ); - std::vector lddb( 1, ldb10 ); - std::vector lddc( 1, ldc10 ); - for (size_t j = 0; j < c_array10j.size(); ++j) { - blas::gemm( - layout, opA, opB, - mb10, nb10, kb, - alpha, a_array10j[ j ], lda10, - b_array10j[ j ], ldb10, - beta1_[ 0 ], c_array10j[ j ], ldc10, - *queue ); - beta1_[ 0 ] = one; + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(ldda, lddb); } - } - if (c_array01.size() > 0) { - std::vector m( 1, mb01 ); - std::vector n( 1, nb01 ); - std::vector ldda( 1, lda01 ); - std::vector lddb( 1, ldb01 ); - std::vector lddc( 1, ldc01 ); blas::batch::gemm( layout, opA_, opB_, - m, n, k1, - alpha_, a_array01, ldda, - b_array01, lddb, - beta0_, c_array01, lddc, - c_array01.size(), info, *queue ); + m, n, k, + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } - if (c_array11 != nullptr) { - blas::gemm( - layout, opA, opB, - mb11, nb11, kb1, - alpha, a_array11, lda11, - b_array11, ldb11, - beta1_[ 0 ], c_array11, ldc11, - *queue ); + // Only scale C once + // TODO relax assumption on the distribution + if (group_params.size() > 0) { + beta_[0] = one; } + } + { + trace::Block trace_block("blas::batch::gemm"); queue->sync(); } From 4e174f124a3884b59e0bbe94f5668a6258a23c69 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Mon, 30 Oct 2023 21:57:48 -0400 Subject: [PATCH 23/35] Fix column norms and reduce norm overheads --- src/internal/internal_batch.hh | 56 ++++++++++++++++++-- src/internal/internal_genorm.cc | 60 +++++++++++----------- src/internal/internal_henorm.cc | 90 +++++++++++++++------------------ src/internal/internal_synorm.cc | 85 +++++++++++++++---------------- src/internal/internal_trnorm.cc | 60 +++++++++------------- 5 files changed, 187 insertions(+), 164 deletions(-) diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index bf10f6b96..be71a04d5 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -249,7 +249,9 @@ std::vector< device_regions_params > device_regions_build std::array< std::reference_wrapper>, mat_count > mats, std::array< scalar_t**, mat_count > mats_array_host, int64_t device, - std::function extra_setup = {}) + std::function extra_setup, + std::vector& irange, + std::vector& jrange) { // The first two arguments should be valid targets for brace-initialization // reference_wrapper works around fact that C++ doesn't allow array of references @@ -258,10 +260,6 @@ std::vector< device_regions_params > device_regions_build auto& A = mats[0].get(); - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - // Trapezoidal matrices always need special treatment for diagonal tiles assert( !diag_same || A.uplo() == Uplo::General ); @@ -369,6 +367,54 @@ std::vector< device_regions_params > device_regions_build return group_params; } +//------------------------------------------------------------------------------ +/// Computes and populates the regions for the given matrices. +/// +/// irange and jrange are computed internally +/// +/// @tparam store_diag +/// Wheather the diagonal tiles may need to be special cased +/// +/// @tparam mat_count +/// The number of matrices used by the kernel +/// +/// @tparam scalar_t +/// The type of the matrices +/// +/// @param[in] diag_same +/// Whether to include the diagonal tiles in the off-diagonal groups +/// If false, store_diag must be true +/// +/// @param[in] mats +/// An array of the matrices to build regions for +/// +/// @param[in] mats_array_host +/// An array of the arrays to fill with pointers to device data +/// +/// @param[in] device +/// The device to build regions for +/// +/// @param[in] extra_setup +/// Callback that is called whenever a tile is added to a group. +/// The group index and the tile indices are passed as arguments +/// +template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag > +std::vector< device_regions_params > device_regions_build( + std::array< std::reference_wrapper>, mat_count > mats, + std::array< scalar_t**, mat_count > mats_array_host, + int64_t device, + std::function extra_setup = {}) +{ + // Find ranges of matching mb's and ranges of matching nb's. + auto irange = device_regions_range( true, mats[0].get() ); + auto jrange = device_regions_range( false, mats[0].get() ); + + return device_regions_build< store_diag, mat_count, scalar_t, diag_same >( + mats, mats_array_host, device, extra_setup, + irange, jrange ); +} + + } // namespace internal } // namespace slate diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 4e87e0734..9aa6f6377 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -380,6 +380,11 @@ void norm( // devices_values used for max and Frobenius norms. std::vector devices_values; + // Find ranges of matching mb's and ranges of matching nb's to avoid + // repeatedly recomputing them + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + int64_t ldv = 0; if (scope == NormScope::Matrix) { if (in_norm == Norm::Max) { @@ -387,13 +392,13 @@ void norm( devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One) { - for (int64_t j = 0; j < A.nt(); ++j) { - ldv = std::max( ldv, A.tileNb(j) ); + for (size_t j = 0; j < jrange.size()-1; ++j) { + ldv = std::max( ldv, A.tileNb( jrange[j] ) ); } } else if (in_norm == Norm::Inf) { - for (int64_t i = 0; i < A.mt(); ++i) { - ldv = std::max( ldv, A.tileMb(i) ); + for (size_t i = 0; i < irange.size()-1; ++i) { + ldv = std::max( ldv, A.tileMb( irange[i] ) ); } } else if (in_norm == Norm::Fro) { @@ -403,8 +408,8 @@ void norm( } else if (scope == NormScope::Columns) { if (in_norm == Norm::Max) { - for (int64_t j = 0; j < A.nt(); ++j) { - ldv = std::max( ldv, A.tileNb(j) ); + for (size_t j = 0; j < jrange.size()-1; ++j) { + ldv = std::max( ldv, A.tileNb( jrange[j] ) ); } } else { @@ -417,8 +422,8 @@ void norm( #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - priority( priority ) shared( A, devices_values, vals_host_arrays ) \ + #pragma omp task slate_omp_default_none priority( priority ) \ + shared( A, devices_values, vals_host_arrays, irange, jrange ) \ firstprivate(device, queue_index, ldv, scope, in_norm, layout) { std::set A_tiles_set; @@ -439,15 +444,16 @@ void norm( auto group_params = device_regions_build( {A}, {a_array_host}, - device ); + device, + {}, + irange, jrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); - int64_t num_tiles = A_tiles_set.size(); - vals_host_arrays[ device ].resize( num_tiles*ldv ); + vals_host_arrays[ device ].resize( batch_size*ldv ); real_t* vals_host_array = vals_host_arrays[ device ].data(); blas::Queue* queue = A.compute_queue( device, queue_index ); - real_t* vals_dev_array = blas::device_malloc( num_tiles*ldv, *queue ); + real_t* vals_dev_array = blas::device_malloc( batch_size*ldv, *queue ); // Batched call to compute partial results for each tile. { @@ -455,9 +461,9 @@ void norm( blas::device_memcpy(a_array_dev, a_array_host, - batch_size, - blas::MemcpyKind::HostToDevice, - *queue); + batch_size, + blas::MemcpyKind::HostToDevice, + *queue); real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { @@ -474,9 +480,9 @@ void norm( } blas::device_memcpy(vals_host_array, vals_dev_array, - batch_size*ldv, - blas::MemcpyKind::DeviceToHost, - *queue); + batch_size*ldv, + blas::MemcpyKind::DeviceToHost, + *queue); queue->sync(); } @@ -510,8 +516,6 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto joffsets = tile_offsets( false, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -536,8 +540,6 @@ void norm( } } else if (in_norm == Norm::Inf) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto ioffsets = tile_offsets( true, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -574,8 +576,6 @@ void norm( else if (scope == NormScope::Columns) { if (in_norm == Norm::Max) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto joffsets = tile_offsets( false, A ); // Reduction over devices to local result. @@ -590,12 +590,14 @@ void norm( int64_t nb = A.tileNb( jj ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - for (int k = 0; k < nb; ++k) { - values[j*ldv + k] = - max_nan(vals_host_array[batch_count*ldv + k], - values[ joffsets[j] + k]); + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + for (int k = 0; k < nb; ++k) { + values[j*ldv + k] = + max_nan(vals_host_array[batch_count*ldv + k], + values[ joffsets[j] + k]); + } + ++batch_count; } - ++batch_count; }} // for j,i }} // for jj,ii } diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index 9571f5712..e3895c6ca 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -350,14 +350,18 @@ void norm( // devices_values used for max and Frobenius norms. std::vector devices_values; + // Find ranges of matching mb's and ranges of matching nb's to avoid + // repeatedly recomputing them + auto ijrange = device_regions_range( true, A ); + int64_t ldv = 0; if (in_norm == Norm::Max) { ldv = 1; devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - for (int64_t j = 0; j < A.nt(); ++j) { - ldv = std::max( ldv, A.tileNb(j) ); + for (size_t j = 0; j < ijrange.size()-1; ++j) { + ldv = std::max( ldv, A.tileNb( ijrange[j] ) ); } ldv *= 2; } @@ -366,20 +370,11 @@ void norm( devices_values.resize(A.num_devices() * 2); } - for (int device = 0; device < A.num_devices(); ++device) { - int64_t num_tiles = A.getMaxDeviceTiles(device); - - vals_host_arrays[device].resize(num_tiles*ldv); - - blas::Queue* queue = A.comm_queue(device); - vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); - } - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, vals_dev_arrays) \ + shared( vals_host_arrays, vals_dev_arrays, ijrange ) \ firstprivate(device, layout, lower, queue_index, in_norm, ldv) \ priority(priority) { @@ -405,23 +400,25 @@ void norm( auto group_params = device_regions_build( {A}, {a_array_host}, - device ); + device, + {}, + ijrange, ijrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); - real_t* vals_host_array = vals_host_arrays[device].data(); - real_t* vals_dev_array = vals_dev_arrays[device]; + vals_host_arrays[ device ].resize( batch_size*ldv ); + real_t* vals_host_array = vals_host_arrays[ device ].data(); + blas::Queue* queue = A.compute_queue( device, queue_index ); + real_t* vals_dev_array = blas::device_malloc( batch_size*ldv, *queue ); // Batched call to compute partial results for each tile. { trace::Block trace_block("slate::device::henorm"); - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_array_dev, a_array_host, - batch_size, - blas::MemcpyKind::HostToDevice, - *queue); + batch_size, + blas::MemcpyKind::HostToDevice, + *queue); real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { @@ -457,11 +454,10 @@ void norm( queue->sync(); } - blas::device_memcpy( - vals_host_array, vals_dev_array, - batch_size*ldv, - blas::MemcpyKind::DeviceToHost, - *queue); + blas::device_memcpy(vals_host_array, vals_dev_array, + batch_size*ldv, + blas::MemcpyKind::DeviceToHost, + *queue); queue->sync(); } @@ -489,15 +485,12 @@ void norm( } } } + // Free device workspace + blas::device_free(vals_dev_array, *queue); } } // end omp taskgroup - for (int device = 0; device < A.num_devices(); ++device) { - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(vals_dev_arrays[device], *queue); - } - // Reduction over devices to local result. if (in_norm == Norm::Max) { *values = lapack::lange(in_norm, @@ -505,10 +498,7 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto ioffsets = tile_offsets( true, A ); - assert(A.tileNb(0) == A.tileMb(0)); assert(A.n() == A.m()); for (int device = 0; device < A.num_devices(); ++device) { @@ -516,12 +506,12 @@ void norm( real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t mb = A.tileMb( irange[ ii ] ); - int64_t nb = A.tileNb( jrange[ jj ] ); - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + for (size_t jj = 0; jj < ijrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < ijrange.size() - 1; ++ii) { + int64_t mb = A.tileMb( ijrange[ ii ] ); + int64_t nb = A.tileNb( ijrange[ jj ] ); + for (int64_t j = ijrange[ jj ]; j < ijrange[ jj+1 ]; ++j) { + for (int64_t i = ijrange[ ii ]; i < ijrange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) && ((A.uplo() == Uplo::Lower && i > j) || (A.uplo() == Uplo::Upper && i < j))) { @@ -540,17 +530,17 @@ void norm( } }} // for j,i - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal(ij, ij) && - device == A.tileDevice(ij, ij)) - { - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[ ioffsets[ij] ], 1); - ++batch_count; + if (ii == jj) { + for (int64_t ij = ijrange[ ii ]; ij < ijrange[ ii+1 ]; ++ij) { + if (A.tileIsLocal(ij, ij) && + device == A.tileDevice(ij, ij)) + { + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[ ioffsets[ij] ], 1); + ++batch_count; + } } } }} // for jj,ii diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index 418d99b66..f1fd24985 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -352,14 +352,18 @@ void norm(internal::TargetType, // devices_values used for max and Frobenius norms. std::vector devices_values; + // Find ranges of matching mb's and ranges of matching nb's to avoid + // repeatedly recomputing them + auto ijrange = device_regions_range( true, A ); + int64_t ldv = 0; if (in_norm == Norm::Max) { ldv = 1; devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - for (int64_t j = 0; j < A.nt(); ++j) { - ldv = std::max( ldv, A.tileNb(j) ); + for (size_t j = 0; j < ijrange.size()-1; ++j) { + ldv = std::max( ldv, A.tileNb( ijrange[j] ) ); } ldv *= 2; } @@ -368,20 +372,11 @@ void norm(internal::TargetType, devices_values.resize(A.num_devices() * 2); } - for (int device = 0; device < A.num_devices(); ++device) { - int64_t num_tiles = A.getMaxDeviceTiles(device); - - vals_host_arrays[device].resize(num_tiles*ldv); - - blas::Queue* queue = A.comm_queue(device); - vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); - } - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, vals_dev_arrays) \ + shared( vals_host_arrays, vals_dev_arrays, ijrange ) \ firstprivate(device, lower, queue_index, in_norm, ldv, layout) \ priority(priority) { @@ -408,23 +403,25 @@ void norm(internal::TargetType, auto group_params = device_regions_build( {A}, {a_array_host}, - device ); + device, + {}, + ijrange, ijrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); - real_t* vals_host_array = vals_host_arrays[device].data(); - real_t* vals_dev_array = vals_dev_arrays[device]; + vals_host_arrays[ device ].resize( batch_size*ldv ); + real_t* vals_host_array = vals_host_arrays[ device ].data(); + blas::Queue* queue = A.compute_queue( device, queue_index ); + real_t* vals_dev_array = blas::device_malloc( batch_size*ldv, *queue ); // Batched call to compute partial results for each tile. { trace::Block trace_block("slate::device::synorm"); - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_array_dev, a_array_host, - batch_size, - blas::MemcpyKind::HostToDevice, - *queue); + batch_size, + blas::MemcpyKind::HostToDevice, + *queue); real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { @@ -460,11 +457,10 @@ void norm(internal::TargetType, queue->sync(); } - blas::device_memcpy( - vals_host_array, vals_dev_array, - batch_size*ldv, - blas::MemcpyKind::DeviceToHost, - *queue); + blas::device_memcpy(vals_host_array, vals_dev_array, + batch_size*ldv, + blas::MemcpyKind::DeviceToHost, + *queue); queue->sync(); } @@ -492,6 +488,8 @@ void norm(internal::TargetType, } } } + // Free device workspace + blas::device_free(vals_dev_array, *queue); } } // end omp taskgroup @@ -508,10 +506,7 @@ void norm(internal::TargetType, devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto ioffsets = tile_offsets( true, A ); - assert(A.tileNb(0) == A.tileMb(0)); assert(A.n() == A.m()); for (int device = 0; device < A.num_devices(); ++device) { @@ -519,12 +514,12 @@ void norm(internal::TargetType, real_t* vals_host_array = vals_host_arrays[device].data(); int64_t batch_count = 0; - for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { - for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t mb = A.tileMb( irange[ ii ] ); - int64_t nb = A.tileNb( jrange[ jj ] ); - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + for (size_t jj = 0; jj < ijrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < ijrange.size() - 1; ++ii) { + int64_t mb = A.tileMb( ijrange[ ii ] ); + int64_t nb = A.tileNb( ijrange[ jj ] ); + for (int64_t j = ijrange[ jj ]; j < ijrange[ jj+1 ]; ++j) { + for (int64_t i = ijrange[ ii ]; i < ijrange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) && ((A.uplo() == Uplo::Lower && i > j) || (A.uplo() == Uplo::Upper && i < j))) { @@ -543,17 +538,17 @@ void norm(internal::TargetType, } }} // for j,i - int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); - int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); - for (int64_t ij = ijstart; ij < ijend; ++ij) { - if (A.tileIsLocal(ij, ij) && - device == A.tileDevice(ij, ij)) - { - blas::axpy( - nb, 1.0, - &vals_host_array[batch_count*ldv], 1, - &values[ ioffsets[ij] ], 1); - ++batch_count; + if (ii == jj) { + for (int64_t ij = ijrange[ ii ]; ij < ijrange[ ii+1 ]; ++ij) { + if (A.tileIsLocal(ij, ij) && + device == A.tileDevice(ij, ij)) + { + blas::axpy( + nb, 1.0, + &vals_host_array[batch_count*ldv], 1, + &values[ ioffsets[ij] ], 1); + ++batch_count; + } } } }} // for jj,ii diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index 679c906ae..0e6742827 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -370,19 +370,24 @@ void norm( // devices_values used for max and Frobenius norms. std::vector devices_values; + // Find ranges of matching mb's and ranges of matching nb's to avoid + // repeatedly recomputing them + auto irange = device_regions_range( true, A ); + auto jrange = device_regions_range( false, A ); + int64_t ldv = 0; if (in_norm == Norm::Max) { ldv = 1; devices_values.resize(A.num_devices()); } else if (in_norm == Norm::One) { - for (int64_t j = 0; j < A.nt(); ++j) { - ldv = std::max( ldv, A.tileNb(j) ); + for (size_t j = 0; j < jrange.size()-1; ++j) { + ldv = std::max( ldv, A.tileNb( jrange[j] ) ); } } else if (in_norm == Norm::Inf) { - for (int64_t i = 0; i < A.mt(); ++i) { - ldv = std::max( ldv, A.tileMb(i) ); + for (size_t i = 0; i < irange.size()-1; ++i) { + ldv = std::max( ldv, A.tileMb( irange[i] ) ); } } else if (in_norm == Norm::Fro) { @@ -390,20 +395,11 @@ void norm( devices_values.resize(A.num_devices() * 2); } - for (int device = 0; device < A.num_devices(); ++device) { - int64_t num_tiles = A.getMaxDeviceTiles(device); - - vals_host_arrays[device].resize(num_tiles*ldv); - - blas::Queue* queue = A.comm_queue(device); - vals_dev_arrays[device] = blas::device_malloc(num_tiles*ldv, *queue); - } - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ shared( A, devices_values ) \ - shared(vals_host_arrays, vals_dev_arrays) \ + shared( vals_host_arrays, vals_dev_arrays, irange, jrange ) \ firstprivate(device, queue_index, in_norm, ldv, layout) \ priority(priority) { @@ -429,23 +425,25 @@ void norm( auto group_params = device_regions_build( {A}, {a_array_host}, - device ); + device, + {}, + irange, jrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); - real_t* vals_host_array = vals_host_arrays[device].data(); - real_t* vals_dev_array = vals_dev_arrays[device]; + vals_host_arrays[ device ].resize( batch_size*ldv ); + real_t* vals_host_array = vals_host_arrays[ device ].data(); + blas::Queue* queue = A.compute_queue( device, queue_index ); + real_t* vals_dev_array = blas::device_malloc( batch_size*ldv, *queue ); // Batched call to compute partial results for each tile. { trace::Block trace_block("slate::device::trnorm"); - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_memcpy(a_array_dev, a_array_host, - batch_size, - blas::MemcpyKind::HostToDevice, - *queue); + batch_size, + blas::MemcpyKind::HostToDevice, + *queue); real_t* vals_dev_array_group = vals_dev_array; for (size_t g = 0; g < group_params.size(); ++g) { @@ -469,13 +467,12 @@ void norm( } a_array_dev += group_count; vals_dev_array_group += group_count * ldv; - queue->sync(); } blas::device_memcpy(vals_host_array, vals_dev_array, - batch_size*ldv, - blas::MemcpyKind::DeviceToHost, - *queue); + batch_size*ldv, + blas::MemcpyKind::DeviceToHost, + *queue); queue->sync(); } @@ -493,15 +490,12 @@ void norm( vals_host_array[2*k + 1]); } } + // Free device workspace + blas::device_free(vals_dev_array, *queue); } } // end omp taskgroup - for (int device = 0; device < A.num_devices(); ++device) { - blas::Queue* queue = A.compute_queue(device, queue_index); - blas::device_free(vals_dev_arrays[device], *queue); - } - // Reduction over devices to local result. if (in_norm == Norm::Max) { *values = lapack::lange(in_norm, @@ -509,8 +503,6 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto joffsets = tile_offsets( false, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -551,8 +543,6 @@ void norm( } } else if (in_norm == Norm::Inf) { - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); auto ioffsets = tile_offsets( true, A ); for (int device = 0; device < A.num_devices(); ++device) { From e284fedcac530b79d3c888cfac21c1ee26ff284d Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 1 Nov 2023 15:36:25 +0000 Subject: [PATCH 24/35] Fix indexing mistake in norms --- src/internal/internal_genorm.cc | 6 +++--- src/internal/internal_trnorm.cc | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 9aa6f6377..ab5626a13 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -525,7 +525,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t nb = A.tileNb( jj ); + int64_t nb = A.tileNb( jrange[jj] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { @@ -549,7 +549,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t mb = A.tileMb( ii ); + int64_t mb = A.tileMb( irange[ii] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { @@ -587,7 +587,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t nb = A.tileNb( jj ); + int64_t nb = A.tileNb( jrange[jj] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index 0e6742827..f4b7c1dbc 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -512,7 +512,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t nb = A.tileNb( jj ); + int64_t nb = A.tileNb( jrange[jj] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) @@ -552,7 +552,7 @@ void norm( int64_t batch_count = 0; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - int64_t mb = A.tileMb( ii ); + int64_t mb = A.tileMb( irange[ii] ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j ) From ecb47d92eec9e84a290488ad1edb79f3da0108f7 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 1 Nov 2023 15:50:50 -0400 Subject: [PATCH 25/35] Better follow line limts --- src/internal/internal_geadd.cc | 6 +++--- src/internal/internal_gecopy.cc | 8 ++++---- src/internal/internal_gemm.cc | 6 +++--- src/internal/internal_gemmA.cc | 6 +++--- src/internal/internal_genorm.cc | 10 +++++----- src/internal/internal_gescale.cc | 4 +--- src/internal/internal_gescale_row_col.cc | 5 +---- src/internal/internal_geset.cc | 8 ++------ src/internal/internal_henorm.cc | 10 +++++----- src/internal/internal_her2k.cc | 6 +++--- src/internal/internal_herk.cc | 6 +++--- src/internal/internal_synorm.cc | 10 +++++----- src/internal/internal_syr2k.cc | 6 +++--- src/internal/internal_syrk.cc | 6 +++--- src/internal/internal_trmm.cc | 6 +++--- src/internal/internal_trnorm.cc | 10 +++++----- src/internal/internal_trsm.cc | 6 +++--- src/internal/internal_tzadd.cc | 6 +++--- src/internal/internal_tzcopy.cc | 8 ++++---- src/internal/internal_tzscale.cc | 4 +--- src/internal/internal_tzset.cc | 4 +--- 21 files changed, 64 insertions(+), 77 deletions(-) diff --git a/src/internal/internal_geadd.cc b/src/internal/internal_geadd.cc index 3bbd249e1..09d50afb8 100644 --- a/src/internal/internal_geadd.cc +++ b/src/internal/internal_geadd.cc @@ -167,9 +167,9 @@ void add(internal::TargetType, scalar_t** b_array_host = a_array_host + batch_size; auto group_params = device_regions_build( - {A, B}, - {a_array_host, b_array_host}, - device ); + {A, B}, + {a_array_host, b_array_host}, + device ); scalar_t** a_array_dev = B.array_device(device, queue_index); scalar_t** b_array_dev = a_array_dev + batch_size; diff --git a/src/internal/internal_gecopy.cc b/src/internal/internal_gecopy.cc index 2f334c30a..204ea5f76 100644 --- a/src/internal/internal_gecopy.cc +++ b/src/internal/internal_gecopy.cc @@ -195,10 +195,10 @@ void copy(internal::TargetType, ++batch_count; }; auto group_params = device_regions_build( - {B}, - {b_array_host}, - device, - setup_A ); + {B}, + {b_array_host}, + device, + setup_A ); // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the different types. diff --git a/src/internal/internal_gemm.cc b/src/internal/internal_gemm.cc index a73e25f42..0fb97f710 100644 --- a/src/internal/internal_gemm.cc +++ b/src/internal/internal_gemm.cc @@ -485,9 +485,9 @@ void gemm(internal::TargetType, // C comes first since we do computation for a local C auto group_params = device_regions_build( - {C, A, B}, - {c_array_host, a_array_host, b_array_host}, - device ); + {C, A, B}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { swap(opA, opB); diff --git a/src/internal/internal_gemmA.cc b/src/internal/internal_gemmA.cc index 1459a9d8f..a50de745a 100644 --- a/src/internal/internal_gemmA.cc +++ b/src/internal/internal_gemmA.cc @@ -437,9 +437,9 @@ void gemmA(internal::TargetType, auto B_j = B.sub( j, j, 0, 0 ); // A comes first since we do computation for a local A auto group_params = device_regions_build( - {A_j, B_j, C}, - {a_array_host, b_array_host, c_array_host}, - device ); + {A_j, B_j, C}, + {a_array_host, b_array_host, c_array_host}, + device ); trace::Block trace_block("blas::batch::gemm"); diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index ab5626a13..55b5534d9 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -442,11 +442,11 @@ void norm( scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device, - {}, - irange, jrange ); + {A}, + {a_array_host}, + device, + {}, + irange, jrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); diff --git a/src/internal/internal_gescale.cc b/src/internal/internal_gescale.cc index f46d92278..0a0a44005 100644 --- a/src/internal/internal_gescale.cc +++ b/src/internal/internal_gescale.cc @@ -113,9 +113,7 @@ void scale(internal::TargetType, scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device ); + {A}, {a_array_host}, device ); blas::Queue* queue = A.compute_queue( device, queue_index ); diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index f092b9caa..6291f0222 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -140,10 +140,7 @@ void scale_row_col( ++batch_count; }; auto group_params = device_regions_build( - {A}, - {a_array_host}, - device, - store_rc ); + {A}, {a_array_host}, device, store_rc ); scalar_t** a_array_dev = A.array_device( device, queue_index ); diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 5d044b4e9..1671ca87b 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -126,13 +126,9 @@ void set(internal::TargetType, auto group_params = diag_same ? device_regions_build( - {A}, - {a_array_host}, - device ) + {A}, {a_array_host}, device ) : device_regions_build( - {A}, - {a_array_host}, - device ); + {A}, {a_array_host}, device ); blas::Queue* queue = A.compute_queue( device, queue_index ); scalar_t** a_array_dev = A.array_device( device, queue_index ); diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index e3895c6ca..6c955096c 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -398,11 +398,11 @@ void norm( scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device, - {}, - ijrange, ijrange ); + {A}, + {a_array_host}, + device, + {}, + ijrange, ijrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); diff --git a/src/internal/internal_her2k.cc b/src/internal/internal_her2k.cc index 65bf31f1b..1f6b195fd 100644 --- a/src/internal/internal_her2k.cc +++ b/src/internal/internal_her2k.cc @@ -596,9 +596,9 @@ void her2k(internal::TargetType, // C comes first since we do computation for a local C auto group_params = device_regions_build( - {C, A, AT, BT, B}, - {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, - device ); + {C, A, AT, BT, B}, + {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, + device ); if (C.op() != Op::NoTrans) { diff --git a/src/internal/internal_herk.cc b/src/internal/internal_herk.cc index 568c5354c..957263909 100644 --- a/src/internal/internal_herk.cc +++ b/src/internal/internal_herk.cc @@ -502,9 +502,9 @@ void herk(internal::TargetType, // C comes first since we do computation for a local C auto group_params = device_regions_build( - {C, A, AT}, - {c_array_host, a_array_host, b_array_host}, - device ); + {C, A, AT}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { swap(opA, opB); diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index f1fd24985..4e16277f0 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -401,11 +401,11 @@ void norm(internal::TargetType, scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device, - {}, - ijrange, ijrange ); + {A}, + {a_array_host}, + device, + {}, + ijrange, ijrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); diff --git a/src/internal/internal_syr2k.cc b/src/internal/internal_syr2k.cc index f5f90a2a9..28c607da1 100644 --- a/src/internal/internal_syr2k.cc +++ b/src/internal/internal_syr2k.cc @@ -576,9 +576,9 @@ void syr2k(internal::TargetType, // C comes first since we do computation for a local C auto group_params = device_regions_build( - {C, A, AT, BT, B}, - {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, - device ); + {C, A, AT, BT, B}, + {c_array_host, a_array_host, at_array_host, b_array_host, bt_array_host}, + device ); if (C.op() != Op::NoTrans) { diff --git a/src/internal/internal_syrk.cc b/src/internal/internal_syrk.cc index 0cf9dc585..f71c956b0 100644 --- a/src/internal/internal_syrk.cc +++ b/src/internal/internal_syrk.cc @@ -499,9 +499,9 @@ void syrk(internal::TargetType, // C comes first since we do computation for a local C auto group_params = device_regions_build( - {C, A, AT}, - {c_array_host, a_array_host, b_array_host}, - device ); + {C, A, AT}, + {c_array_host, a_array_host, b_array_host}, + device ); if (C.op() != Op::NoTrans) { swap(opA, opB); diff --git a/src/internal/internal_trmm.cc b/src/internal/internal_trmm.cc index 205a281c4..ea3400d41 100644 --- a/src/internal/internal_trmm.cc +++ b/src/internal/internal_trmm.cc @@ -231,9 +231,9 @@ void trmm(internal::TargetType, // B comes first since we do computation for a local B auto group_params = device_regions_build( - {B, A}, - {b_array_host, a_array_host}, - device ); + {B, A}, + {b_array_host, a_array_host}, + device ); { trace::Block trace_block("blas::batch::trmm"); diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index f4b7c1dbc..e00ecca07 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -423,11 +423,11 @@ void norm( scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device, - {}, - irange, jrange ); + {A}, + {a_array_host}, + device, + {}, + irange, jrange ); scalar_t** a_array_dev = A.array_device(device, queue_index); diff --git a/src/internal/internal_trsm.cc b/src/internal/internal_trsm.cc index b12359d1d..9019639e8 100644 --- a/src/internal/internal_trsm.cc +++ b/src/internal/internal_trsm.cc @@ -219,9 +219,9 @@ void trsm(internal::TargetType, // B comes first since we do computation for a local B auto group_params = device_regions_build( - {B, A}, - {b_array_host, a_array_host}, - device ); + {B, A}, + {b_array_host, a_array_host}, + device ); { trace::Block trace_block("blas::batch::trsm"); diff --git a/src/internal/internal_tzadd.cc b/src/internal/internal_tzadd.cc index 328ad0632..519e035d2 100644 --- a/src/internal/internal_tzadd.cc +++ b/src/internal/internal_tzadd.cc @@ -193,9 +193,9 @@ void add(internal::TargetType, scalar_t** b_array_host = a_array_host + batch_size; auto group_params = device_regions_build( - {A, B}, - {a_array_host, b_array_host}, - device ); + {A, B}, + {a_array_host, b_array_host}, + device ); scalar_t** a_array_dev = B.array_device( device, queue_index ); scalar_t** b_array_dev = a_array_dev + batch_size; diff --git a/src/internal/internal_tzcopy.cc b/src/internal/internal_tzcopy.cc index a848b29fa..e2a903aa9 100644 --- a/src/internal/internal_tzcopy.cc +++ b/src/internal/internal_tzcopy.cc @@ -166,10 +166,10 @@ void copy(internal::TargetType, ++batch_count; }; auto group_params = device_regions_build( - {B}, - {b_array_host}, - device, - setup_A ); + {B}, + {b_array_host}, + device, + setup_A ); // Usually the output matrix (B) provides all the batch arrays. // Here we are using A, because of the differen types. diff --git a/src/internal/internal_tzscale.cc b/src/internal/internal_tzscale.cc index 0d836417e..c06424521 100644 --- a/src/internal/internal_tzscale.cc +++ b/src/internal/internal_tzscale.cc @@ -146,9 +146,7 @@ void scale(internal::TargetType, scalar_t** a_array_host = A.array_host( device, queue_index ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device ); + {A}, {a_array_host}, device ); blas::Queue* queue = A.compute_queue( device, queue_index ); diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index 981c4dda1..2c73a1ae8 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -157,9 +157,7 @@ void set( scalar_t** a_array_dev = A.array_device( device ); auto group_params = device_regions_build( - {A}, - {a_array_host}, - device ); + {A}, {a_array_host}, device ); blas::Queue* queue = A.compute_queue(device, queue_index); From 1aa706483e86c96f9fa110af65101790f993f367 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 2 Nov 2023 08:48:31 -0400 Subject: [PATCH 26/35] Add device regions to trsmA --- src/internal/internal_trsmA.cc | 183 +++++++++++++++++---------------- 1 file changed, 96 insertions(+), 87 deletions(-) diff --git a/src/internal/internal_trsmA.cc b/src/internal/internal_trsmA.cc index 07b3d84c4..35f71f024 100644 --- a/src/internal/internal_trsmA.cc +++ b/src/internal/internal_trsmA.cc @@ -225,78 +225,87 @@ void trsmA(internal::TargetType, B.tileGetForWriting( B_tiles_set, device, LayoutConvert( layout ) ); - // interior col or row - std::vector a_array0; - std::vector b_array0; - a_array0.reserve( batch_size ); - b_array0.reserve( batch_size ); - - // bottom-right tile - // todo: replace batch trsm with plain trsm - std::vector a_array1; - std::vector b_array1; - - int64_t lda0 = 0; - int64_t ldb0 = 0; - int64_t lda1 = 0; - int64_t ldb1 = 0; - - int64_t mb0 = B.tileMb(0); - int64_t nb0 = B.tileNb(0); - int64_t mb1 = B.tileMb(B.mt()-1); - int64_t nb1 = B.tileNb(B.nt()-1); - - auto A00d = A( 0, 0, device ); - auto dAdata = A00d.data(); - lda1 = lda0 = A00d.stride(); + scalar_t** a_array_host = A.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + // Varient of device_regions_build to handle trsmA + using Params = device_regions_params; + + int64_t batch_count = 0; + std::vector group_params; if (side == Side::Right) { - // TODO loop over B_tiles_set instead of looking for again. - for (int64_t i = 0; i < B.mt()-1; ++i) { - if (B.tileExists( i, 0, device )) - { - auto Bi0d = B( i, 0, device ); - a_array0.push_back( dAdata ); - b_array0.push_back( Bi0d.data() ); - ldb0 = Bi0d.stride(); - } - } - { - int64_t i = B.mt()-1; - if (B.tileExists( i, 0, device )) - { - auto Bi0d = B( i, 0, device ); - a_array1.push_back( dAdata ); - b_array1.push_back( Bi0d.data() ); - ldb1 = Bi0d.stride(); + // Find ranges of matching mb's and ranges of matching nb's. + auto irange = device_regions_range( true, B ); + + // loop over regions + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = B.tileMb( irange[ ii ] ); + group.nb = B.tileNb( 0 ); + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (B.tileExists( i, 0, device )) { + + // Add tiles to current group + auto Aij = A( 0, 0, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, 0, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + } + else { + assert( group.ld[0] == Aij.stride() ); + assert( group.ld[1] == Bij.stride() ); + } + ++group.count; + ++batch_count; + } + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } + } // for ii } else { - for (int64_t j = 0; j < B.nt()-1; ++j) { - if (B.tileExists( 0, j, device )) - { - auto B0jd = B( 0, j, device ); - a_array0.push_back( dAdata ); - b_array0.push_back( B0jd.data() ); - ldb0 = B0jd.stride(); + // Find ranges of matching mb's and ranges of matching nb's. + auto jrange = device_regions_range( false, B ); + + // loop over regions + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = B.tileMb( 0 ); + group.nb = B.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + if (B.tileExists( 0, j, device )) { + + // Add tiles to current group + auto Aij = A( 0, 0, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( 0, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + } + else { + assert( group.ld[0] == Aij.stride() ); + assert( group.ld[1] == Bij.stride() ); + } + ++group.count; + ++batch_count; + } + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } - { - int64_t j = B.nt()-1; - if (B.tileExists( 0, j, device )) - { - auto B0jd = B( 0, j, device ); - a_array1.push_back( dAdata ); - b_array1.push_back( B0jd.data() ); - ldb1 = B0jd.stride(); - } - } - } - - if (B.op() != Op::NoTrans) { - swap( mb0, nb0 ); - swap( mb1, nb1 ); + } // for ii } { @@ -311,35 +320,35 @@ void trsmA(internal::TargetType, blas::Queue* queue = A.compute_queue( device, queue_index ); assert( queue != nullptr ); + queue->sync(); - if (a_array0.size() > 0) { - std::vector m( 1, mb0 ); - std::vector n( 1, nb0 ); - std::vector lda( 1, lda0 ); - std::vector ldb( 1, ldb0 ); + for (size_t g = 0; g < group_params.size(); ++g) { - blas::batch::trsm( - layout, side_, uplo_, opA_, diag_, - m, n, - alpha_, a_array0, lda, - b_array0, ldb, - a_array0.size(), info, *queue); - } + int64_t group_count = group_params[ g ].count; - if (a_array1.size() > 0) { - std::vector m(1, mb1); - std::vector n(1, nb1); - std::vector lda(1, lda1); - std::vector ldb(1, ldb1); + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[0]); + std::vector lddb(1, group_params[ g ].ld[1]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + + if (B.op() != Op::NoTrans) { + swap(m, n); + } blas::batch::trsm( layout, side_, uplo_, opA_, diag_, m, n, - alpha_, a_array1, lda, - b_array1, ldb, - a_array1.size(), info, *queue); - } + alpha_, a_array, ldda, + b_array, lddb, + group_count, info, *queue); + queue->sync(); + a_array_host += group_count; + b_array_host += group_count; + } queue->sync(); } From 201e4b8c123a334e90f29f64db2a7747a2b9f2dd Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 2 Nov 2023 11:09:08 -0400 Subject: [PATCH 27/35] Add device regions to he2hb_gemm --- src/internal/internal_he2hb_gemm.cc | 182 +++++++++++++--------------- 1 file changed, 83 insertions(+), 99 deletions(-) diff --git a/src/internal/internal_he2hb_gemm.cc b/src/internal/internal_he2hb_gemm.cc index 0d3fd596d..46c769832 100644 --- a/src/internal/internal_he2hb_gemm.cc +++ b/src/internal/internal_he2hb_gemm.cc @@ -120,6 +120,8 @@ void he2hb_gemm( Op opA = A.op(); Op opB = B.op(); + scalar_t** host_work = A.array_host(device, queue_index); + for (int64_t k = 0; k < B.mt(); ++k) { std::set A_tiles_set, B_tiles_set, C_tiles_set; for (int64_t i = 0; i < A.mt(); ++i) { @@ -154,124 +156,106 @@ void he2hb_gemm( int64_t batch_size = C_tiles_set.size(); - int64_t i_interior = A.mt(); - int64_t i_last = 0; - int64_t mt = C.mt(); - - // check if there are cleanup tiles - if (C.tileMb( mt-1 ) != C.tileMb( 0 )) { - i_interior = A.mt() - 1; - i_last = 1; - } - - // interior - std::vector a_array00; - std::vector b_array00; - std::vector c_array00; - a_array00.reserve( batch_size ); - b_array00.reserve( batch_size ); - c_array00.reserve( batch_size ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb( 0 ); - int64_t nb00 = C.tileNb( 0 ); - int64_t kb = A.tileNb( 0 ); - for (int64_t i = 0; i < i_interior; ++i) { - if (A.tileRank( i, k ) == panel_rank - && device == C.tileDevice( i, 0 )) { - a_array00.push_back( A( i, k, device ).data() ); - b_array00.push_back( B( k, 0, device ).data() ); - c_array00.push_back( C( i, 0, device ).data() ); - lda00 = A( i, k, device ).stride(); - ldb00 = B( k, 0, device ).stride(); - ldc00 = C( i, 0, device ).stride(); - } - } - - // if mod( n, nb ) != 0, this is for the last tile - std::vector a_array11; - std::vector b_array11; - std::vector c_array11; - a_array11.reserve( batch_size ); - b_array11.reserve( batch_size ); - c_array11.reserve( batch_size ); - - int64_t lda11 = 0; - int64_t ldb11 = 0; - int64_t ldc11 = 0; - int64_t mb11 = C.tileMb( C.mt()-1 ); - int64_t nb11 = C.tileNb( C.nt()-1 ); - // same kb as above - { - int i = C.mt()-1; - if ((A.tileRank( i, k ) == panel_rank) && (i_last == 1)) { - if (device == C.tileDevice( i, 0 )) { - a_array11.push_back( A( i, k, device ).data() ); - b_array11.push_back( B( k, 0, device ).data() ); - c_array11.push_back( C( i, 0, device ).data() ); - lda11 = A( i, k, device ).stride(); - ldb11 = B( k, 0, device ).stride(); - ldc11 = C( i, 0, device ).stride(); + scalar_t** a_array_host = host_work; + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; + + // Varient of device_regions_build to handle trsmA + using Params = device_regions_params; + + // Find ranges of matching mb's and ranges of matching nb's. + auto irange = device_regions_range( true, C ); + + // loop over regions + int64_t batch_count = 0; + std::vector group_params; + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = C.tileMb( irange[ ii ] ); + group.nb = C.tileNb( 0 ); + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (A.tileRank( i, k ) == panel_rank + && device == C.tileDevice( i, 0 )) { + + // Add tiles to current group + auto Aij = A( i, k, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( k, 0, device ); + b_array_host[ batch_count ] = Bij.data(); + auto Cij = C( i, 0, device ); + c_array_host[ batch_count ] = Cij.data(); + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + group.ld[2] = Cij.stride(); + } + else { + // default(none) doesn't allow asserts + //assert( group.ld[0] == Aij.stride() ); + //assert( group.ld[1] == Bij.stride() ); + //assert( group.ld[2] == Bij.stride() ); + } + ++group.count; + ++batch_count; } + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } + } // for ii if (C.op() != Op::NoTrans) { // swap A <=> B; swap m <=> n swap( opA, opB ); - swap( a_array00, b_array00 ); - swap( lda00, ldb00 ); - swap( mb00, nb00 ); } { trace::Block trace_block( "blas::batch::he2hb_gemm" ); - std::vector opA_( 1, opA ); - std::vector opB_( 1, opB ); - std::vector alpha_( 1, alpha ); - std::vector beta_( 1, beta ); - std::vector kb_( 1, kb ); + std::vector opA_(1, opA); + std::vector opB_(1, opB); + std::vector alpha_(1, alpha); + std::vector beta_(1, beta); + std::vector kb(1, A.tileNb(k)); // info size 0 disables slow checks in batched BLAS++. std::vector info; - blas::Queue* queue = C.compute_queue( device, queue_index ); - // assert conflicts with default(none) in old gcc. - //assert( queue != nullptr ); + blas::Queue* queue = C.compute_queue(device, queue_index); - if (c_array00.size() > 0) { - std::vector m( 1, mb00 ); - std::vector n( 1, nb00 ); - std::vector ldda( 1, lda00 ); - std::vector lddb( 1, ldb00 ); - std::vector lddc( 1, ldc00 ); - blas::batch::gemm( - layout, opA_, opB_, - //m, n, kb_, - m, n, lddb, - alpha_, a_array00, ldda, - b_array00, lddb, - beta_, c_array00, lddc, - c_array00.size(), info, *queue ); - } + for (size_t g = 0; g < group_params.size(); ++g) { - if (c_array11.size() > 0) { - std::vector m( 1, mb11 ); - std::vector n( 1, nb11 ); - std::vector ldda( 1, lda11 ); - std::vector lddb( 1, ldb11 ); - std::vector lddc( 1, ldc11 ); + int64_t group_count = group_params[ g ].count; + + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[0]); + std::vector lddb(1, group_params[ g ].ld[1]); + std::vector lddc(1, group_params[ g ].ld[2]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); + + if (C.op() != Op::NoTrans) { + swap(m, n); + swap(a_array, b_array); + swap(ldda, lddb); + } blas::batch::gemm( layout, opA_, opB_, - //m, n, kb_, - m, n, lddb, - alpha_, a_array11, ldda, - b_array11, lddb, - beta_, c_array11, lddc, - c_array11.size(), info, *queue ); + m, n, kb, + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue); + + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } queue->sync(); } From 0d2305d82c55972132b39024067a0859101d36c4 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 2 Nov 2023 12:37:42 -0400 Subject: [PATCH 28/35] Clean up some he2hb helper functions --- .../internal_he2hb_her2k_offdiag_ranks.cc | 39 ++-- src/internal/internal_he2hb_trmm.cc | 198 ++++++++---------- 2 files changed, 107 insertions(+), 130 deletions(-) diff --git a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc index 3c7f457a4..982edf55d 100644 --- a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc +++ b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc @@ -161,7 +161,7 @@ void he2hb_her2k_offdiag_ranks( } } else { // i == j - // Diagonal tiles dealt with above. + // Diagonal tiles dealt with elsewhere in he2hb. // assert conflicts with default(none) in old gcc. //assert( ! C.tileIsLocal( i, j ) ); } @@ -169,23 +169,6 @@ void he2hb_her2k_offdiag_ranks( } int64_t batch_size = C_tiles_set.size(); - int64_t j_interior = nt; - int64_t j_last = 0; - if (C.tileNb( nt-1 ) != C.tileNb( 0 )) { - j_interior = C.nt() - 1; - j_last = 1; - } - - int64_t m_indices = panel_rank_rows.size(); - int64_t i_interior = m_indices; - int64_t i_last = 0; - int64_t i0 = panel_rank_rows[ 0 ]; - int64_t i1 = panel_rank_rows[ m_indices-1 ]; - if (C.tileMb( i0 ) != C.tileMb( i1 )) { - i_interior = m_indices - 1; - i_last = 1; - } - #pragma omp taskgroup { #pragma omp task slate_omp_default_none \ @@ -208,6 +191,24 @@ void he2hb_her2k_offdiag_ranks( } } + + int64_t j_interior = nt; + int64_t j_last = 0; + if (C.tileNb( nt-1 ) != C.tileNb( 0 )) { + j_interior = C.nt() - 1; + j_last = 1; + } + + int64_t m_indices = panel_rank_rows.size(); + int64_t i_interior = m_indices; + int64_t i_last = 0; + int64_t i0 = panel_rank_rows[ 0 ]; + int64_t i1 = panel_rank_rows[ m_indices-1 ]; + if (C.tileMb( i0 ) != C.tileMb( i1 )) { + i_interior = m_indices - 1; + i_last = 1; + } + // interior std::vector a_array00; std::vector b_array00; @@ -249,7 +250,7 @@ void he2hb_her2k_offdiag_ranks( } } else { // i == j - // Diagonal tiles dealt with above. + // Diagonal tiles dealt with elsewhere in he2hb. // assert conflicts with default(none) in old gcc. //assert( ! C.tileIsLocal( i, j ) ); } diff --git a/src/internal/internal_he2hb_trmm.cc b/src/internal/internal_he2hb_trmm.cc index 91862b5f6..fe9c1807f 100644 --- a/src/internal/internal_he2hb_trmm.cc +++ b/src/internal/internal_he2hb_trmm.cc @@ -11,6 +11,27 @@ namespace slate { namespace internal { +template +bool need_Bi0(HermitianMatrix AH, + int mpi_rank, + int64_t i, + std::vector& panel_rank_rows) +{ + for (int64_t j : panel_rank_rows) { + if (i >= j) { // lower + if (AH.tileRank( i, j ) == mpi_rank) { + return true; + } + } + else { + if (AH.tileRank( j, i ) == mpi_rank) { + return true; + } + } + } + return false; +} + //------------------------------------------------------------------------------ /// Triangular matrix multiply. Compute B = B A /// AH is a Hermitian matrix. It needed here just to check if the rank is an @@ -23,6 +44,9 @@ namespace internal { /// T = A[ 0:A.mb(), 0:A.mb() ] is upper triangular, /// Bi = Bi[ 0:B.mb(), 0:A.mb() ]. Call trmm Bi = Bi T. /// Dispatches to target implementations. +/// +/// panel_rank_rows contains the local row-indices of B +/// /// @ingroup heev_internal /// template @@ -58,39 +82,34 @@ void he2hb_trmm( const Layout layout = Layout::ColMajor; const LayoutConvert layoutc = LayoutConvert( layout ); + if (panel_rank_rows.size() == 0) { + return; + } + auto A0 = A.sub( 0, 0, 0, 0 ); + int64_t mb = A0.tileMb( 0 ); + int64_t nb = A0.tileNb( 0 ); + bool trapezoid = (mb < nb); + if (trapezoid) { + A0 = A0.slice( 0, mb-1, 0, mb-1 ); // first mb-by-mb part + } #pragma omp taskgroup for (int64_t i = 0; i < B.mt(); ++i) { #pragma omp task slate_omp_default_none \ shared( A0, AH, B, panel_rank_rows ) \ - firstprivate( one, i, mpi_rank, layoutc ) \ + firstprivate( one, i, mpi_rank, layoutc, mb, trapezoid ) \ priority( priority ) { - int rank_lower = -1; - int rank_upper = -1; - for (int64_t j : panel_rank_rows) { - if (i >= j) { // lower - rank_lower = AH.tileRank( i, j ); - } - else { // upper - rank_upper = AH.tileRank( j, i ); - } - } // If I contributed to Bi, multiply by A. - if (rank_upper == mpi_rank || rank_lower == mpi_rank) { + if (need_Bi0( AH, mpi_rank, i, panel_rank_rows )) { // Bi = Bi * A auto Bi = B.sub( i, i, 0, 0 ); - int64_t mb = A0.tileMb( 0 ); - int64_t nb = A0.tileNb( 0 ); - bool trapezoid = (mb < nb); - B.tileGetForWriting( i, 0, layoutc ); if (trapezoid) { auto B00 = Bi( 0, 0 ); int64_t mb1 = B00.mb(); - A0 = A0.slice( 0, mb-1, 0, mb-1 ); // first mb-by-mb part Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part } @@ -126,6 +145,10 @@ void he2hb_trmm( const Layout layout = Layout::ColMajor; const LayoutConvert layoutc = LayoutConvert( layout ); + if (panel_rank_rows.size() == 0) { + return; + } + #pragma omp taskgroup for (int device = 0; device < B.num_devices(); ++device) { #pragma omp task slate_omp_default_none \ @@ -134,23 +157,11 @@ void he2hb_trmm( priority( priority ) { std::set B_tiles_set, A0_tiles_set; - int rank_lower = -1; - int rank_upper = -1; for (int64_t i = 0; i < B.mt(); ++i) { - for (int64_t j : panel_rank_rows) { - if (i >= j) { // lower - rank_lower = AH.tileRank( i, j ); - } - else { // upper - rank_upper = AH.tileRank( j, i ); - } - } - - if (rank_upper == mpi_rank || rank_lower == mpi_rank) { - if (device == B.tileDevice( i, 0 )) { - B_tiles_set.insert( { i, 0 } ); - } + if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) + && device == B.tileDevice( i, 0 )) { + B_tiles_set.insert( { i, 0 } ); } } @@ -189,79 +200,55 @@ void he2hb_trmm( int64_t mb1 = B.tileMb( B.mt()-1 ); int64_t nb1 = B.tileNb( B.mt()-1 ); - rank_lower = -1; - rank_upper = -1; + int64_t mb = A0.tileMb( 0 ); + int64_t nb = A0.tileNb( 0 ); + bool trapezoid = (mb < nb); + if (trapezoid) { + A0 = A0.slice( 0, mb-1, 0, mb-1 ); // first mb-by-mb part + } + auto T = TriangularMatrix( Uplo::Upper, Diag::NonUnit, A0 ); for (int64_t i = 0; i < i_interior; ++i) { - for (int64_t j : panel_rank_rows) { - if (i >= j) { // lower - rank_lower = AH.tileRank( i, j ); - } - else { // upper - rank_upper = AH.tileRank( j, i ); - } - } - A0 = A.sub( 0, 0, 0, 0 ); - int64_t mb = A0.tileMb( 0 ); - int64_t nb = A0.tileNb( 0 ); - auto Bi = B.sub( i, i, 0, 0 ); - bool trapezoid = (mb < nb); - - if (trapezoid) { - auto B00 = Bi( 0, 0 ); - mb1 = B00.mb(); - A0 = A0.slice( 0, mb-1, 0, mb-1 ); // first mb-by-mb part - Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part - } - auto T = TriangularMatrix( Uplo::Upper, Diag::NonUnit, A0 ); - - if (rank_upper == mpi_rank || rank_lower == mpi_rank) { - if (device == B.tileDevice( i, 0 )) { - a_array0.push_back( T( 0, 0, device ).data() ); - b_array0.push_back( Bi( 0, 0, device ).data() ); - //b_array0.push_back( B( i, 0, device ).data() ); - lda0 = A0( 0, 0, device ).stride(); - ldb0 = Bi( 0, 0, device ).stride(); - mb0 = Bi.tileMb( 0 ); - nb0 = Bi.tileNb( 0 ); + if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) + && device == B.tileDevice( i, 0 )) { + + auto Bi = B.sub( i, i, 0, 0 ); + + if (trapezoid) { + auto B00 = Bi( 0, 0 ); + mb1 = B00.mb(); + Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part } + + a_array0.push_back( T( 0, 0, device ).data() ); + b_array0.push_back( Bi( 0, 0, device ).data() ); + //b_array0.push_back( B( i, 0, device ).data() ); + lda0 = A0( 0, 0, device ).stride(); + ldb0 = Bi( 0, 0, device ).stride(); + mb0 = Bi.tileMb( 0 ); + nb0 = Bi.tileNb( 0 ); } } if (i_last == 1) { int64_t i = B.mt()-1; - rank_lower = -1; - rank_upper = -1; - for (int64_t j : panel_rank_rows) { - if (i >= j) { // lower - rank_lower = AH.tileRank( i, j ); - } - else { // upper - rank_upper = AH.tileRank( j, i ); - } - } - A0 = A.sub( 0, 0, 0, 0 ); - int64_t mb = A0.tileMb( 0 ); - int64_t nb = A0.tileNb( 0 ); - auto Bi = B.sub( i, i, 0, 0 ); - bool trapezoid = (mb < nb); - - if (trapezoid) { - auto B00 = Bi( 0, 0 ); - mb1 = B00.mb(); - A0 = A0.slice( 0, mb-1, 0, mb-1 ); // first mb-by-mb part - Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part - } - auto T = TriangularMatrix( Uplo::Upper, Diag::NonUnit, A0 ); - if (rank_upper == mpi_rank || rank_lower == mpi_rank) { - if (device == B.tileDevice( i, 0 )) { - a_array1.push_back( T( 0, 0, device ).data() ); - b_array1.push_back( Bi( 0, 0, device ).data() ); - lda1 = T( 0, 0, device ).stride(); - ldb1 = Bi( 0, 0, device ).stride(); - mb1 = Bi.tileMb( 0 ); - nb1 = Bi.tileNb( 0 ); + if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) + && device == B.tileDevice( i, 0 )) { + + auto Bi = B.sub( i, i, 0, 0 ); + + if (trapezoid) { + auto B00 = Bi( 0, 0 ); + mb1 = B00.mb(); + Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part } + + a_array1.push_back( T( 0, 0, device ).data() ); + b_array1.push_back( Bi( 0, 0, device ).data() ); + lda1 = T( 0, 0, device ).stride(); + ldb1 = Bi( 0, 0, device ).stride(); + mb1 = Bi.tileMb( 0 ); + nb1 = Bi.tileNb( 0 ); } } @@ -314,23 +301,12 @@ void he2hb_trmm( } // todo: release tiles in top-level routine. - // rank_lower = -1; - // rank_upper = -1; // for (int64_t i = 0; i < B.mt(); ++i) { - // for (int64_t j : panel_rank_rows) { - // if (i >= j) { // lower - // rank_lower = AH.tileRank( i, j ); - // } - // else { // upper - // rank_upper = AH.tileRank( j, i ); - // } - // } + // if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) + // && device == B.tileDevice( i, 0 )) { // - // if (rank_upper == mpi_rank || rank_lower == mpi_rank) { - // if (device == B.tileDevice( i, 0 )) { - // B.tileRelease( i, 0, device ); - // B.tileTick( i, 0 ); - // } + // B.tileRelease( i, 0, device ); + // B.tileTick( i, 0 ); // } // } } From 646d46a3664d5fc98f0a4de43a94ee1b66ac6823 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 2 Nov 2023 13:03:08 -0400 Subject: [PATCH 29/35] Add device regions to he2hb trmm --- src/internal/internal_he2hb_trmm.cc | 146 +++++++++++++--------------- 1 file changed, 66 insertions(+), 80 deletions(-) diff --git a/src/internal/internal_he2hb_trmm.cc b/src/internal/internal_he2hb_trmm.cc index fe9c1807f..0f7b210ca 100644 --- a/src/internal/internal_he2hb_trmm.cc +++ b/src/internal/internal_he2hb_trmm.cc @@ -7,6 +7,7 @@ #include "slate/HermitianMatrix.hh" #include "slate/types.hh" #include "internal/internal.hh" +#include "internal/internal_batch.hh" namespace slate { namespace internal { @@ -165,14 +166,6 @@ void he2hb_trmm( } } - int64_t i_interior = B.mt(); - int64_t i_last = 0; - int64_t mt = B.mt(); - if (B.tileMb( mt-1 ) != B.tileMb( 0 )) { - i_interior = B.mt() - 1; - i_last = 1; - } - int64_t batch_size = B_tiles_set.size(); if (batch_size > 0) { @@ -190,16 +183,6 @@ void he2hb_trmm( std::vector a_array1; std::vector b_array1; - int64_t lda0 = 0; - int64_t ldb0 = 0; - int64_t lda1 = 0; - int64_t ldb1 = 0; - - int64_t mb0 = B.tileMb( 0 ); - int64_t nb0 = B.tileNb( 0 ); - int64_t mb1 = B.tileMb( B.mt()-1 ); - int64_t nb1 = B.tileNb( B.mt()-1 ); - int64_t mb = A0.tileMb( 0 ); int64_t nb = A0.tileNb( 0 ); bool trapezoid = (mb < nb); @@ -208,49 +191,57 @@ void he2hb_trmm( } auto T = TriangularMatrix( Uplo::Upper, Diag::NonUnit, A0 ); - for (int64_t i = 0; i < i_interior; ++i) { - if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) - && device == B.tileDevice( i, 0 )) { - - auto Bi = B.sub( i, i, 0, 0 ); - - if (trapezoid) { - auto B00 = Bi( 0, 0 ); - mb1 = B00.mb(); - Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part - } - - a_array0.push_back( T( 0, 0, device ).data() ); - b_array0.push_back( Bi( 0, 0, device ).data() ); - //b_array0.push_back( B( i, 0, device ).data() ); - lda0 = A0( 0, 0, device ).stride(); - ldb0 = Bi( 0, 0, device ).stride(); - mb0 = Bi.tileMb( 0 ); - nb0 = Bi.tileNb( 0 ); - } - } - - if (i_last == 1) { - int64_t i = B.mt()-1; - if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) - && device == B.tileDevice( i, 0 )) { - - auto Bi = B.sub( i, i, 0, 0 ); - - if (trapezoid) { - auto B00 = Bi( 0, 0 ); - mb1 = B00.mb(); - Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part + scalar_t** t_array_host = B.array_host(device, queue_index); + scalar_t** b_array_host = t_array_host + batch_size; + + // Varient of device_regions_build to handle trsmA + using Params = device_regions_params; + + // Find ranges of matching mb's and ranges of matching nb's. + auto irange = device_regions_range( true, B ); + + // loop over regions + int64_t batch_count = 0; + std::vector group_params; + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = B.tileMb( irange[ ii ] ); + group.nb = T.tileMb( 0 ); + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (need_Bi0( AH, mpi_rank, i, panel_rank_rows ) + && device == B.tileDevice( i, 0 )) { + + // Add tiles to current group + auto Bi = B.sub( i, i, 0, 0 ); + if (trapezoid) { + auto B00 = Bi( 0, 0 ); + int64_t mb1 = B00.mb(); + Bi = Bi.slice( 0, mb1-1, 0, mb-1 ); // first mb1-by-mb part + } + + auto Tij = T( 0, 0, device ); + t_array_host[ batch_count ] = Tij.data(); + auto Bij = Bi( 0, 0, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.ld[0] = Tij.stride(); + group.ld[1] = Bij.stride(); + } + else { + //assert( group.ld[0] == Tij.stride() ); + //assert( group.ld[1] == Bij.stride() ); + } + ++group.count; + ++batch_count; } - - a_array1.push_back( T( 0, 0, device ).data() ); - b_array1.push_back( Bi( 0, 0, device ).data() ); - lda1 = T( 0, 0, device ).stride(); - ldb1 = Bi( 0, 0, device ).stride(); - mb1 = Bi.tileMb( 0 ); - nb1 = Bi.tileNb( 0 ); + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } + } // for ii { trace::Block trace_block( "blas::batch::he2hb_trmm" ); @@ -270,31 +261,26 @@ void he2hb_trmm( std::vector alpha_( 1, alpha ); std::vector info; - if (b_array0.size() > 0) { - std::vector m( 1, mb0 ); - std::vector n( 1, nb0 ); - std::vector ldb( 1, ldb0 ); - std::vector lda( 1, lda0 ); - blas::batch::trmm( - layout, side_, uplo_, opA_, diag_, - m, n, - alpha_, a_array0, lda, - b_array0, ldb, - a_array0.size(), info, *queue ); - } + for (size_t g = 0; g < group_params.size(); ++g) { + + int64_t group_count = group_params[ g ].count; + std::vector m( 1, group_params[g].mb ); + std::vector n( 1, group_params[g].nb ); + std::vector ldda( 1, group_params[g].ld[0] ); + std::vector lddb( 1, group_params[g].ld[1] ); + + std::vector t_array(t_array_host, t_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); - if (b_array1.size() > 0) { - std::vector m( 1, mb1 ); - std::vector n( 1, nb1 ); - std::vector lda( 1, lda1 ); - std::vector ldb( 1, ldb1 ); blas::batch::trmm( layout, side_, uplo_, opA_, diag_, m, n, - //m, lda, - alpha_, a_array1, lda, - b_array1, ldb, - a_array1.size(), info, *queue ); + alpha_, t_array, ldda, + b_array, lddb, + group_count, info, *queue ); + + t_array_host += group_count; + b_array_host += group_count; } queue->sync(); From c96a391940efecc529163955949eaae175e7ebec Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 2 Nov 2023 14:19:20 -0400 Subject: [PATCH 30/35] Add device regions to he2hb_her2k_offdiag --- .../internal_he2hb_her2k_offdiag_ranks.cc | 364 ++++++------------ src/internal/internal_her2k.cc | 2 +- 2 files changed, 121 insertions(+), 245 deletions(-) diff --git a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc index 982edf55d..506ee9e8c 100644 --- a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc +++ b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc @@ -6,6 +6,7 @@ #include "slate/Matrix.hh" #include "slate/types.hh" #include "internal/internal.hh" +#include "internal/internal_batch.hh" namespace slate { namespace internal { @@ -191,214 +192,123 @@ void he2hb_her2k_offdiag_ranks( } } + scalar_t** a_array_host = C.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + scalar_t** c_array_host = b_array_host + batch_size; - int64_t j_interior = nt; - int64_t j_last = 0; - if (C.tileNb( nt-1 ) != C.tileNb( 0 )) { - j_interior = C.nt() - 1; - j_last = 1; - } + using Params = device_regions_params; - int64_t m_indices = panel_rank_rows.size(); - int64_t i_interior = m_indices; - int64_t i_last = 0; - int64_t i0 = panel_rank_rows[ 0 ]; - int64_t i1 = panel_rank_rows[ m_indices-1 ]; - if (C.tileMb( i0 ) != C.tileMb( i1 )) { - i_interior = m_indices - 1; - i_last = 1; - } + // Find ranges of matching mb's and ranges of matching nb's. + auto jrange = device_regions_range( false, C ); - // interior - std::vector a_array00; - std::vector b_array00; - std::vector c_array00; - a_array00.reserve( batch_size ); - b_array00.reserve( batch_size ); - c_array00.reserve( batch_size ); - - int64_t lda00 = 0; - int64_t ldb00 = 0; - int64_t ldc00 = 0; - int64_t mb00 = C.tileMb( i0 ); - int64_t nb00 = C.tileNb( 0 ); - int64_t kb = A.tileNb( 0 ); - - for (int64_t j = 0; j < j_interior; ++j) { - for (int64_t i_ = 0; i_ < i_interior; ++i_) { - int64_t i = panel_rank_rows[ i_ ]; - if (i > j) { - if (C.tileIsLocal( i, j ) - && device == C.tileDevice( i, j )) { - a_array00.push_back( A( i, 0, device ).data() ); - b_array00.push_back( B( j, 0, device ).data() ); - c_array00.push_back( C( i, j, device ).data() ); - lda00 = A( i, 0, device ).stride(); - ldb00 = B( j, 0, device ).stride(); - ldc00 = C( i, j, device ).stride(); - } - } - else if (i < j) { - if (C.tileIsLocal( j, i ) - && device == C.tileDevice( j, i )) { - a_array00.push_back( B( j, 0, device ).data() ); - b_array00.push_back( A( i, 0, device ).data() ); - c_array00.push_back( C( j, i, device ).data() ); - lda00 = B( j, 0, device ).stride(); - ldb00 = A( i, 0, device ).stride(); - ldc00 = C( j, i, device ).stride(); - } - } - else { // i == j - // Diagonal tiles dealt with elsewhere in he2hb. - // assert conflicts with default(none) in old gcc. - //assert( ! C.tileIsLocal( i, j ) ); - } + std::vector< int64_t > irange; + int64_t last_ij = -1; + for (int64_t k = 0; k < int64_t(panel_rank_rows.size()); ++k) { + int64_t kb = panel_rank_rows[ k ]; + if (kb != last_ij) { + last_ij = kb; + irange.push_back( k ); } } - - // last column if there is a clean-up tile - std::vector a_array01; - std::vector b_array01; - std::vector c_array01; - a_array01.reserve( batch_size ); - b_array01.reserve( batch_size ); - c_array01.reserve( batch_size ); - - int64_t lda01 = 0; - int64_t ldb01 = 0; - int64_t ldc01 = 0; - int64_t mb01 = C.tileMb( i0 ); - int64_t nb01 = C.tileNb( nt-1 ); - - if (j_last == 1) { - //for (int64_t j = 0; j < nt; ++j) { - int64_t j = C.nt()-1; - //for (int64_t i : panel_rank_rows) { - for (int64_t i_ = 0; i_ < i_interior; ++i_) { + irange.push_back( panel_rank_rows.size() ); + + int64_t batch_count = 0; + std::vector group_params; + // loop over regions + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + // Two groups are needed to handle the different sizes + Params group; + group.mb = C.tileMb( panel_rank_rows[ irange[ ii ] ] ); + group.nb = C.tileNb( jrange[ jj ] ); + + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i_ = irange[ ii ]; i_ < irange[ ii+1 ]; ++i_) { int64_t i = panel_rank_rows[ i_ ]; - if (i > j) { - if (C.tileIsLocal( i, j ) - && device == C.tileDevice( i, j )) { - a_array01.push_back( A( i, 0, device ).data() ); - b_array01.push_back( B( j, 0, device ).data() ); - c_array01.push_back( C( i, j, device ).data() ); - lda01 = A( i, 0, device ).stride(); - ldb01 = B( j, 0, device ).stride(); - ldc01 = C( i, j, device ).stride(); + + if ((i > j) + && C.tileIsLocal( i, j ) && device == C.tileDevice( i, j )) { + + // Add tiles to current group + auto Aij = A( i, 0, device ); + auto Bij = B( j, 0, device ); + auto Cij = C( i, j, device ); + + a_array_host[ batch_count ] = Aij.data(); + b_array_host[ batch_count ] = Bij.data(); + c_array_host[ batch_count ] = Cij.data(); + + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + group.ld[2] = Cij.stride(); } - } - else if (i < j) { - if (C.tileIsLocal( j, i ) - && device == C.tileDevice( j, i )) { - a_array01.push_back( B( j, 0, device ).data() ); - b_array01.push_back( A( i, 0, device ).data() ); - c_array01.push_back( C( j, i, device ).data() ); - mb01 = C.tileNb( nt-1 ); - nb01 = C.tileMb( i0 ); - lda01 = B( j, 0, device ).stride(); - ldb01 = A( i, 0, device ).stride(); - ldc01 = C( j, i, device ).stride(); + else { + // assert( group.ld[0] == Aij.stride() ); + // assert( group.ld[1] == Bij.stride() ); + // assert( group.ld[2] == Cij.stride() ); } + + ++group.count; + ++batch_count; } - else { // i == j - // assert conflicts with default(none) in old gcc. - //assert( ! C.tileIsLocal( i, j ) ); + }} // for j, i + + // If mb != nb, we need to start a new group for the upper + // triangular logic + // If the problem is square, we can use a single group for + // better parallelism + if (group.mb != group.nb) { + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } + + std::swap( group.mb, group.nb ); + group.count = 0; + group.ld[0] = 0; + group.ld[1] = 0; + group.ld[2] = 0; } - } - // last row if there is a clean-up tile - std::vector a_array10; - std::vector b_array10; - std::vector c_array10; - a_array10.reserve( batch_size ); - b_array10.reserve( batch_size ); - c_array10.reserve( batch_size ); - - int64_t lda10 = 0; - int64_t ldb10 = 0; - int64_t ldc10 = 0; - int64_t mb10 = C.tileMb( i1 ); - int64_t nb10 = C.tileNb( 0 ); - - if (i_last == 1) { - int64_t i = i1; - for (int64_t j = 0; j < j_interior; ++j) { - if (i > j) { - if (C.tileIsLocal( i, j ) - && device == C.tileDevice( i, j )) { - a_array10.push_back( A( i, 0, device ).data() ); - b_array10.push_back( B( j, 0, device ).data() ); - c_array10.push_back( C( i, j, device ).data() ); - lda10 = A( i, 0, device ).stride(); - ldb10 = B( j, 0, device ).stride(); - ldc10 = C( i, j, device ).stride(); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i_ = irange[ ii ]; i_ < irange[ ii+1 ]; ++i_) { + int64_t i = panel_rank_rows[ i_ ]; + + if ((i < j) + && C.tileIsLocal( j, i ) && device == C.tileDevice( j, i )) { + + // Add tiles to current group + auto Aij = B( j, 0, device ); + auto Bij = A( i, 0, device ); + auto Cij = C( j, i, device ); + + a_array_host[ batch_count ] = Aij.data(); + b_array_host[ batch_count ] = Bij.data(); + c_array_host[ batch_count ] = Cij.data(); + + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + group.ld[2] = Cij.stride(); } - } - else if (i < j) { - if (C.tileIsLocal( j, i ) - && device == C.tileDevice( j, i )) { - a_array10.push_back( B( j, 0, device ).data() ); - b_array10.push_back( A( i, 0, device ).data() ); - c_array10.push_back( C( j, i, device ).data() ); - mb10 = C.tileNb( 0 ); - nb10 = C.tileMb( i1 ); - lda10 = A( i, 0, device ).stride(); - ldb10 = B( j, 0, device ).stride(); - ldc10 = C( j, i, device ).stride(); + else { + // assert( group.ld[0] == Aij.stride() ); + // assert( group.ld[1] == Bij.stride() ); + // assert( group.ld[2] == Cij.stride() ); } - } - else { // i == j - // assert conflicts with default(none) in old gcc. - //assert( ! C.tileIsLocal( i, j ) ); - } - } - } - // bottom-right corner - std::vector a_array11; - std::vector b_array11; - std::vector c_array11; - - int64_t lda11 = 0; - int64_t ldb11 = 0; - int64_t ldc11 = 0; - int64_t mb11 = C.tileMb( i1 ); - int64_t nb11 = C.tileNb( nt-1 ); - - if (i_last == 1 && j_last == 1) { - int64_t i = i1; - int64_t j = nt-1; - if (i > j) { - if (C.tileIsLocal( i, j ) - && device == C.tileDevice( i, j)) { - a_array11.push_back( A( i, 0, device ).data() ); - b_array11.push_back( B( j, 0, device ).data() ); - c_array11.push_back( C( i, j, device ).data() ); - lda11 = A( i, 0, device ).stride(); - ldb11 = B( j, 0, device ).stride(); - ldc11 = C( i, j, device ).stride(); - } - } - else if (i < j) { - if (C.tileIsLocal( j, i ) - && device == C.tileDevice( j, i )) { - a_array11.push_back( B( j, 0, device ).data() ); - b_array11.push_back( A( i, 0, device ).data() ); - c_array11.push_back( C( j, i, device ).data() ); - mb11 = C.tileNb( nt-1 ); - nb11 = C.tileMb( i1 ); - lda11 = A( i, 0, device ).stride(); - ldb11 = B( j, 0, device ).stride(); - ldc11 = C( j, i, device ).stride(); + ++group.count; + ++batch_count; } + }} // for j, i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - else { // i == j - // assert conflicts with default(none) in old gcc. - //assert( ! C.tileIsLocal( i, j ) ); - } - } + }} // for jj, ii { trace::Block trace_block( "blas::batch::gemm" ); @@ -407,7 +317,7 @@ void he2hb_her2k_offdiag_ranks( std::vector opB_( 1, opB ); std::vector alpha_( 1, alpha ); std::vector beta_( 1, beta ); - std::vector kb_( 1, kb ); + std::vector kb_( 1, A.tileNb(0) ); // info size 0 disables slow checks in batched BLAS++. std::vector info; @@ -415,64 +325,30 @@ void he2hb_her2k_offdiag_ranks( // assert conflicts with default(none) in old gcc. //assert( queue != nullptr ); - if (c_array00.size() > 0) { - std::vector m( 1, mb00 ); - std::vector n( 1, nb00 ); - std::vector ldda( 1, lda00 ); - std::vector lddb( 1, ldb00 ); - std::vector lddc( 1, ldc00 ); - blas::batch::gemm( - layout, opA_, opB_, - m, n, kb_, - alpha_, a_array00, ldda, - b_array00, lddb, - beta_, c_array00, lddc, - c_array00.size(), info, *queue ); - } + for (size_t g = 0; g < group_params.size(); ++g) { - if (c_array01.size() > 0) { - std::vector m( 1, mb01 ); - std::vector n( 1, nb01 ); - std::vector ldda( 1, lda01 ); - std::vector lddb( 1, ldb01 ); - std::vector lddc( 1, ldc01 ); - blas::batch::gemm( - layout, opA_, opB_, - m, n, kb_, - alpha_, a_array01, ldda, - b_array01, lddb, - beta_, c_array01, lddc, - c_array01.size(), info, *queue ); - } + int64_t group_count = group_params[ g ].count; - if (c_array10.size() > 0) { - std::vector m( 1, mb10 ); - std::vector n( 1, nb10 ); - std::vector ldda( 1, lda10 ); - std::vector lddb( 1, ldb10 ); - std::vector lddc( 1, ldc10 ); - blas::batch::gemm( - layout, opA_, opB_, - m, n, kb_, - alpha_, a_array10, ldda, - b_array10, lddb, - beta_, c_array10, lddc, - c_array10.size(), info, *queue ); - } + std::vector m( 1, group_params[g].mb ); + std::vector n( 1, group_params[g].nb ); + std::vector ldda( 1, group_params[g].ld[0] ); + std::vector lddb( 1, group_params[g].ld[1] ); + std::vector lddc( 1, group_params[g].ld[2] ); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + std::vector c_array(c_array_host, c_array_host+group_count); - if (c_array11.size() > 0) { - std::vector m( 1, mb11 ); - std::vector n( 1, nb11 ); - std::vector ldda( 1, lda11 ); - std::vector lddb( 1, ldb11 ); - std::vector lddc( 1, ldc11 ); blas::batch::gemm( layout, opA_, opB_, m, n, kb_, - alpha_, a_array11, ldda, - b_array11, lddb, - beta_, c_array11, lddc, - c_array11.size(), info, *queue ); + alpha_, a_array, ldda, + b_array, lddb, + beta_, c_array, lddc, + group_count, info, *queue ); + a_array_host += group_count; + b_array_host += group_count; + c_array_host += group_count; } queue->sync(); } diff --git a/src/internal/internal_her2k.cc b/src/internal/internal_her2k.cc index 1f6b195fd..6cc382c5e 100644 --- a/src/internal/internal_her2k.cc +++ b/src/internal/internal_her2k.cc @@ -672,7 +672,7 @@ void her2k(internal::TargetType, m, n, k, conj_alpha_s, bt_array, lddbt, at_array, lddat, - one_, c_array, lddc, + one_, c_array, lddc, group_count, info, *queue); } a_array_host += group_count; From c20af461a731953d162319614921f7267341ad7c Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Tue, 7 Nov 2023 14:41:48 -0500 Subject: [PATCH 31/35] Relax some uniform tile size assumptions --- src/he2hb.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/he2hb.cc b/src/he2hb.cc index 6f63be04a..c55f15f18 100644 --- a/src/he2hb.cc +++ b/src/he2hb.cc @@ -72,7 +72,6 @@ void he2hb( const bool set_hold = 1; // Do tileGetAndHold in the bcast int64_t n = A.n(); - int64_t nb_A = A.tileNb( 0 ); GridOrder grid_order; int nprow, npcol, myrow, mycol; // For the workspace matrix (W) change the GPU distribution to row cyclic, @@ -82,8 +81,8 @@ void he2hb( A.gridinfo( &grid_order, &nprow, &npcol, &myrow, &mycol ); assert( grid_order == GridOrder::Col ); // todo: update for Row - auto tileNb = slate::func::uniform_blocksize( n, nb_A ); - auto tileRank = slate::func::process_2d_grid( GridOrder::Col, nprow, npcol ); + auto tileNb = A.tileNbFunc(); + auto tileRank = A.tileRankFunc(); int num_devices = blas::get_device_count(); auto tileDevice = slate::func::device_1d_grid( GridOrder::Col, nprow, num_devices ); @@ -135,7 +134,7 @@ void he2hb( lapack::Queue* queue = A.compute_queue( panel_device, queue_0 ); - int64_t nb = A.tileNb(0); + int64_t nb = func::max_blocksize(A.nt(), A.tileNbFunc()); size_t size_tau = (size_t) std::min( mlocal, nb ); size_t size_A = (size_t) blas::max( 1, mlocal ) * nb; size_t hsize, dsize; From 6dc04d2e16aa45e3c6aa5a7859c541d6fe4b94e2 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 16 Nov 2023 14:50:39 -0500 Subject: [PATCH 32/35] Improve docs --- src/internal/internal_batch.hh | 42 +++++++++------------------------- src/internal/internal_util.hh | 10 ++++++++ 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index be71a04d5..d80f6279e 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -155,7 +155,7 @@ inline void cblas_gemm_batch( // Utilities for computing device batch regions //------------------------------------------------------------------------------ -/// Computes the range of tiles with either the same mb or the same nb +/// Computes the range of tiles with either the same mb or the same nb. /// /// @param[in] want_rows /// If true, compute the row-ranges. Else, compute the column-ranges. @@ -216,33 +216,13 @@ public: }; //------------------------------------------------------------------------------ -/// Computes and populates the regions for the given matrices. -/// -/// @tparam store_diag -/// Wheather the diagonal tiles may need to be special cased +/// @copydoc device_regions_build(std::array< std::reference_wrapper>, mat_count >, std::array< scalar_t**, mat_count >, int64_t, std::function) /// -/// @tparam mat_count -/// The number of matrices used by the kernel +/// @params[in] irange +/// The ranges of tiles with a uniform number of rows /// -/// @tparam scalar_t -/// The type of the matrices -/// -/// @param[in] diag_same -/// Whether to include the diagonal tiles in the off-diagonal groups -/// If false, store_diag must be true -/// -/// @param[in] mats -/// An array of the matrices to build regions for -/// -/// @param[in] mats_array_host -/// An array of the arrays to fill with pointers to device data -/// -/// @param[in] device -/// The device to build regions for -/// -/// @param[in] extra_setup -/// Callback that is called whenever a tile is added to a group. -/// The group index and the tile indices are passed as arguments +/// @params[in] jrange +/// The ranges of tiles with a uniform number of columns /// template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag > std::vector< device_regions_params > device_regions_build( @@ -370,10 +350,8 @@ std::vector< device_regions_params > device_regions_build //------------------------------------------------------------------------------ /// Computes and populates the regions for the given matrices. /// -/// irange and jrange are computed internally -/// /// @tparam store_diag -/// Wheather the diagonal tiles may need to be special cased +/// Whether the diagonal tiles may need to be special cased /// /// @tparam mat_count /// The number of matrices used by the kernel @@ -381,10 +359,10 @@ std::vector< device_regions_params > device_regions_build /// @tparam scalar_t /// The type of the matrices /// -/// @param[in] diag_same +/// @tparam[in] diag_same /// Whether to include the diagonal tiles in the off-diagonal groups /// If false, store_diag must be true -/// +//------------------------------------------------------------------------------ /// @param[in] mats /// An array of the matrices to build regions for /// @@ -398,6 +376,8 @@ std::vector< device_regions_params > device_regions_build /// Callback that is called whenever a tile is added to a group. /// The group index and the tile indices are passed as arguments /// +/// @return A list of batches with identical size. +/// template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag > std::vector< device_regions_params > device_regions_build( std::array< std::reference_wrapper>, mat_count > mats, diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 2fd909b1a..7c2a39fe5 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -109,6 +109,16 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, return V; } +//------------------------------------------------------------------------------ +/// Computes the global index for each tile +/// +/// @param[in] want_rows +/// Whether to compute the row or column indices +/// +/// @param[in] A +/// The matrix to get tile sizes from +/// +/// @return a vector mapping tile indices to global indices template std::vector tile_offsets(bool want_rows, slate::BaseMatrix& A) { From 4a2e569b6784bf1324cb52adad78642e6bc05c41 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Thu, 16 Nov 2023 15:56:40 -0500 Subject: [PATCH 33/35] Use RowCol instead of some bool flags --- include/slate/enums.hh | 1 + src/internal/internal_batch.hh | 21 +++++++++---------- src/internal/internal_genorm.cc | 10 ++++----- src/internal/internal_geset.cc | 4 ---- src/internal/internal_he2hb_gemm.cc | 2 +- .../internal_he2hb_her2k_offdiag_ranks.cc | 2 +- src/internal/internal_he2hb_trmm.cc | 2 +- src/internal/internal_henorm.cc | 4 ++-- src/internal/internal_synorm.cc | 4 ++-- src/internal/internal_trnorm.cc | 8 +++---- src/internal/internal_trsmA.cc | 4 ++-- src/internal/internal_util.hh | 6 ++++-- 12 files changed, 33 insertions(+), 35 deletions(-) diff --git a/include/slate/enums.hh b/include/slate/enums.hh index 41abc8f69..994dd30e7 100644 --- a/include/slate/enums.hh +++ b/include/slate/enums.hh @@ -21,6 +21,7 @@ typedef blas::Side Side; typedef blas::Layout Layout; using lapack::Equed; +using lapack::RowCol; typedef lapack::Norm Norm; typedef lapack::Direction Direction; diff --git a/src/internal/internal_batch.hh b/src/internal/internal_batch.hh index d80f6279e..0a663bddb 100644 --- a/src/internal/internal_batch.hh +++ b/src/internal/internal_batch.hh @@ -157,8 +157,8 @@ inline void cblas_gemm_batch( //------------------------------------------------------------------------------ /// Computes the range of tiles with either the same mb or the same nb. /// -/// @param[in] want_rows -/// If true, compute the row-ranges. Else, compute the column-ranges. +/// @param[in] dim +/// Whether to compute the row ranges or the column ranges /// /// @param[in] A /// The matrix to get tile sizes from @@ -166,8 +166,10 @@ inline void cblas_gemm_batch( /// @return The ranges of uniform tile sizes /// template -std::vector device_regions_range( bool want_rows, BaseMatrix& A ) +std::vector device_regions_range( RowCol dim, BaseMatrix& A ) { + bool want_rows = dim == RowCol::Row; + int64_t kt = want_rows ? A.mt() : A.nt(); std::vector< int64_t > range; @@ -204,11 +206,8 @@ public: std::conditional_t< store_diag, bool, Empty > is_diagonal; device_regions_params() - : count(0), mb(0), nb(0) + : count(0), mb(0), nb(0), ld{0} { - for (int i = 0; i < mat_count; ++i) { - ld[i] = 0; - } if constexpr (store_diag) { is_diagonal = false; } @@ -274,8 +273,8 @@ std::vector< device_regions_params > device_regions_build // * Lower matrices start at j+1 // * Upper matrices end at j // * General matrices run the whole range - int istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); - int iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); + int64_t istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0)); + int64_t iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt)); for (int64_t i = istart; i < iend; ++i) { if ((diag_same || i != j) && A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { @@ -386,8 +385,8 @@ std::vector< device_regions_params > device_regions_build std::function extra_setup = {}) { // Find ranges of matching mb's and ranges of matching nb's. - auto irange = device_regions_range( true, mats[0].get() ); - auto jrange = device_regions_range( false, mats[0].get() ); + auto irange = device_regions_range( RowCol::Row, mats[0].get() ); + auto jrange = device_regions_range( RowCol::Col, mats[0].get() ); return device_regions_build< store_diag, mat_count, scalar_t, diag_same >( mats, mats_array_host, device, extra_setup, diff --git a/src/internal/internal_genorm.cc b/src/internal/internal_genorm.cc index 55b5534d9..7a83c5ef7 100644 --- a/src/internal/internal_genorm.cc +++ b/src/internal/internal_genorm.cc @@ -382,8 +382,8 @@ void norm( // Find ranges of matching mb's and ranges of matching nb's to avoid // repeatedly recomputing them - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); + auto irange = device_regions_range( RowCol::Row, A ); + auto jrange = device_regions_range( RowCol::Col, A ); int64_t ldv = 0; if (scope == NormScope::Matrix) { @@ -516,7 +516,7 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { - auto joffsets = tile_offsets( false, A ); + auto joffsets = tile_offsets( RowCol::Col, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -540,7 +540,7 @@ void norm( } } else if (in_norm == Norm::Inf) { - auto ioffsets = tile_offsets( true, A ); + auto ioffsets = tile_offsets( RowCol::Row, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -576,7 +576,7 @@ void norm( else if (scope == NormScope::Columns) { if (in_norm == Norm::Max) { - auto joffsets = tile_offsets( false, A ); + auto joffsets = tile_offsets( RowCol::Col, A ); // Reduction over devices to local result. // todo: re-arrange loops to be able to issue omp tasks diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index 1671ca87b..377ec50ec 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -92,10 +92,6 @@ void set(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - // Find ranges of matching mb's and ranges of matching nb's. - std::vector< int64_t > irange = device_regions_range( true, A ); - std::vector< int64_t > jrange = device_regions_range( false, A ); - #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { #pragma omp task slate_omp_default_none priority( priority ) \ diff --git a/src/internal/internal_he2hb_gemm.cc b/src/internal/internal_he2hb_gemm.cc index 46c769832..c32874089 100644 --- a/src/internal/internal_he2hb_gemm.cc +++ b/src/internal/internal_he2hb_gemm.cc @@ -164,7 +164,7 @@ void he2hb_gemm( using Params = device_regions_params; // Find ranges of matching mb's and ranges of matching nb's. - auto irange = device_regions_range( true, C ); + auto irange = device_regions_range( RowCol::Row, C ); // loop over regions int64_t batch_count = 0; diff --git a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc index 506ee9e8c..b93d69677 100644 --- a/src/internal/internal_he2hb_her2k_offdiag_ranks.cc +++ b/src/internal/internal_he2hb_her2k_offdiag_ranks.cc @@ -199,7 +199,7 @@ void he2hb_her2k_offdiag_ranks( using Params = device_regions_params; // Find ranges of matching mb's and ranges of matching nb's. - auto jrange = device_regions_range( false, C ); + auto jrange = device_regions_range( RowCol::Col, C ); std::vector< int64_t > irange; int64_t last_ij = -1; diff --git a/src/internal/internal_he2hb_trmm.cc b/src/internal/internal_he2hb_trmm.cc index 0f7b210ca..ca89bee37 100644 --- a/src/internal/internal_he2hb_trmm.cc +++ b/src/internal/internal_he2hb_trmm.cc @@ -198,7 +198,7 @@ void he2hb_trmm( using Params = device_regions_params; // Find ranges of matching mb's and ranges of matching nb's. - auto irange = device_regions_range( true, B ); + auto irange = device_regions_range( RowCol::Row, B ); // loop over regions int64_t batch_count = 0; diff --git a/src/internal/internal_henorm.cc b/src/internal/internal_henorm.cc index 6c955096c..7b4eecd14 100644 --- a/src/internal/internal_henorm.cc +++ b/src/internal/internal_henorm.cc @@ -352,7 +352,7 @@ void norm( // Find ranges of matching mb's and ranges of matching nb's to avoid // repeatedly recomputing them - auto ijrange = device_regions_range( true, A ); + auto ijrange = device_regions_range( RowCol::Row, A ); int64_t ldv = 0; if (in_norm == Norm::Max) { @@ -498,7 +498,7 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - auto ioffsets = tile_offsets( true, A ); + auto ioffsets = tile_offsets( RowCol::Row, A ); assert(A.n() == A.m()); for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_synorm.cc b/src/internal/internal_synorm.cc index 4e16277f0..9ee06ad3d 100644 --- a/src/internal/internal_synorm.cc +++ b/src/internal/internal_synorm.cc @@ -354,7 +354,7 @@ void norm(internal::TargetType, // Find ranges of matching mb's and ranges of matching nb's to avoid // repeatedly recomputing them - auto ijrange = device_regions_range( true, A ); + auto ijrange = device_regions_range( RowCol::Row, A ); int64_t ldv = 0; if (in_norm == Norm::Max) { @@ -506,7 +506,7 @@ void norm(internal::TargetType, devices_values.data(), 1); } else if (in_norm == Norm::One || in_norm == Norm::Inf) { - auto ioffsets = tile_offsets( true, A ); + auto ioffsets = tile_offsets( RowCol::Row, A ); assert(A.n() == A.m()); for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_trnorm.cc b/src/internal/internal_trnorm.cc index e00ecca07..72e203c44 100644 --- a/src/internal/internal_trnorm.cc +++ b/src/internal/internal_trnorm.cc @@ -372,8 +372,8 @@ void norm( // Find ranges of matching mb's and ranges of matching nb's to avoid // repeatedly recomputing them - auto irange = device_regions_range( true, A ); - auto jrange = device_regions_range( false, A ); + auto irange = device_regions_range( RowCol::Row, A ); + auto jrange = device_regions_range( RowCol::Col, A ); int64_t ldv = 0; if (in_norm == Norm::Max) { @@ -503,7 +503,7 @@ void norm( devices_values.data(), 1); } else if (in_norm == Norm::One) { - auto joffsets = tile_offsets( false, A ); + auto joffsets = tile_offsets( RowCol::Col, A ); for (int device = 0; device < A.num_devices(); ++device) { @@ -543,7 +543,7 @@ void norm( } } else if (in_norm == Norm::Inf) { - auto ioffsets = tile_offsets( true, A ); + auto ioffsets = tile_offsets( RowCol::Row, A ); for (int device = 0; device < A.num_devices(); ++device) { diff --git a/src/internal/internal_trsmA.cc b/src/internal/internal_trsmA.cc index 35f71f024..4c32dd816 100644 --- a/src/internal/internal_trsmA.cc +++ b/src/internal/internal_trsmA.cc @@ -235,7 +235,7 @@ void trsmA(internal::TargetType, std::vector group_params; if (side == Side::Right) { // Find ranges of matching mb's and ranges of matching nb's. - auto irange = device_regions_range( true, B ); + auto irange = device_regions_range( RowCol::Row, B ); // loop over regions for (size_t ii = 0; ii < irange.size() - 1; ++ii) { @@ -272,7 +272,7 @@ void trsmA(internal::TargetType, } else { // Find ranges of matching mb's and ranges of matching nb's. - auto jrange = device_regions_range( false, B ); + auto jrange = device_regions_range( RowCol::Col, B ); // loop over regions for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { diff --git a/src/internal/internal_util.hh b/src/internal/internal_util.hh index 7c2a39fe5..268fd39e2 100644 --- a/src/internal/internal_util.hh +++ b/src/internal/internal_util.hh @@ -112,7 +112,7 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, //------------------------------------------------------------------------------ /// Computes the global index for each tile /// -/// @param[in] want_rows +/// @param[in] dim /// Whether to compute the row or column indices /// /// @param[in] A @@ -120,8 +120,10 @@ slate::Matrix alloc_basis(slate::BaseMatrix& A, int64_t n, /// /// @return a vector mapping tile indices to global indices template -std::vector tile_offsets(bool want_rows, slate::BaseMatrix& A) +std::vector tile_offsets( RowCol dim, slate::BaseMatrix& A ) { + bool want_rows = dim == RowCol::Row; + int64_t kt = want_rows ? A.mt() : A.nt(); std::vector< int64_t > offset_list; From fd21fc99fd41e51091c684b65d7d8c9f72dbf8a6 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 22 Nov 2023 10:21:13 -0500 Subject: [PATCH 34/35] Cleanup various mistakes pointed out in review --- src/internal/internal_gescale_row_col.cc | 15 +++------------ src/internal/internal_he2hb_gemm.cc | 2 +- src/internal/internal_he2hb_trmm.cc | 7 ++++++- src/internal/internal_syr2k.cc | 6 +++--- src/internal/internal_trsmA.cc | 4 +--- 5 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/internal/internal_gescale_row_col.cc b/src/internal/internal_gescale_row_col.cc index 6291f0222..8bf2f68a1 100644 --- a/src/internal/internal_gescale_row_col.cc +++ b/src/internal/internal_gescale_row_col.cc @@ -5,6 +5,7 @@ #include "slate/internal/device.hh" #include "internal/internal_batch.hh" +#include "internal/internal_util.hh" #include "internal/internal.hh" #include "internal/DevVector.hh" #include "slate/internal/util.hh" @@ -97,20 +98,10 @@ void scale_row_col( std::vector< int64_t > ioffsets, joffsets; if (want_row) { - ioffsets.reserve(A.mt()); - int64_t offset = 0; - for (int64_t i = 0; i < A.mt(); ++i) { - ioffsets.push_back( offset ); - offset += A.tileMb( i ); - } + ioffsets = tile_offsets( RowCol::Row, A ); } if (want_col) { - joffsets.reserve(A.nt()); - int64_t offset = 0; - for (int64_t j = 0; j < A.nt(); ++j) { - joffsets.push_back( offset ); - offset += A.tileNb( j ); - } + joffsets = tile_offsets( RowCol::Col, A ); } // temporarily, convert both into same layout diff --git a/src/internal/internal_he2hb_gemm.cc b/src/internal/internal_he2hb_gemm.cc index c32874089..784a742b0 100644 --- a/src/internal/internal_he2hb_gemm.cc +++ b/src/internal/internal_he2hb_gemm.cc @@ -160,7 +160,7 @@ void he2hb_gemm( scalar_t** b_array_host = a_array_host + batch_size; scalar_t** c_array_host = b_array_host + batch_size; - // Varient of device_regions_build to handle trsmA + // Variant of device_regions_build to handle he2hb_gemm using Params = device_regions_params; // Find ranges of matching mb's and ranges of matching nb's. diff --git a/src/internal/internal_he2hb_trmm.cc b/src/internal/internal_he2hb_trmm.cc index ca89bee37..3e2e79dad 100644 --- a/src/internal/internal_he2hb_trmm.cc +++ b/src/internal/internal_he2hb_trmm.cc @@ -12,6 +12,11 @@ namespace slate { namespace internal { +//------------------------------------------------------------------------------ +/// Determines whether this process contributes to B(i, 0). +/// Specifically, it checks whether there is a j in panel_rank_rows such that +/// AH(i, j) is local (taking into account the symmetric storage.) +/// template bool need_Bi0(HermitianMatrix AH, int mpi_rank, @@ -194,7 +199,7 @@ void he2hb_trmm( scalar_t** t_array_host = B.array_host(device, queue_index); scalar_t** b_array_host = t_array_host + batch_size; - // Varient of device_regions_build to handle trsmA + // Variant of device_regions_build to handle he2hb_trmm using Params = device_regions_params; // Find ranges of matching mb's and ranges of matching nb's. diff --git a/src/internal/internal_syr2k.cc b/src/internal/internal_syr2k.cc index 28c607da1..3878dc66f 100644 --- a/src/internal/internal_syr2k.cc +++ b/src/internal/internal_syr2k.cc @@ -640,7 +640,7 @@ void syr2k(internal::TargetType, layout, opA_, opB_, m, n, k, alpha_, a_array, ldda, - b_array, lddb, + b_array, lddb, beta_, c_array, lddc, group_count, info, *queue); @@ -648,8 +648,8 @@ void syr2k(internal::TargetType, layout, opA_, opB_, m, n, k, alpha_, bt_array, lddbt, - at_array, lddat, - one_, c_array, lddc, + at_array, lddat, + one_, c_array, lddc, group_count, info, *queue); } a_array_host += group_count; diff --git a/src/internal/internal_trsmA.cc b/src/internal/internal_trsmA.cc index 4c32dd816..d637bab6b 100644 --- a/src/internal/internal_trsmA.cc +++ b/src/internal/internal_trsmA.cc @@ -228,7 +228,7 @@ void trsmA(internal::TargetType, scalar_t** a_array_host = A.array_host(device, queue_index); scalar_t** b_array_host = a_array_host + batch_size; - // Varient of device_regions_build to handle trsmA + // Variant of device_regions_build to handle trsmA using Params = device_regions_params; int64_t batch_count = 0; @@ -320,7 +320,6 @@ void trsmA(internal::TargetType, blas::Queue* queue = A.compute_queue( device, queue_index ); assert( queue != nullptr ); - queue->sync(); for (size_t g = 0; g < group_params.size(); ++g) { @@ -344,7 +343,6 @@ void trsmA(internal::TargetType, alpha_, a_array, ldda, b_array, lddb, group_count, info, *queue); - queue->sync(); a_array_host += group_count; b_array_host += group_count; From 8da49561df1de43bcbe776bc9018608550506b4b Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 22 Nov 2023 10:28:27 -0500 Subject: [PATCH 35/35] Fix one more code alignment mistake --- src/internal/internal_syr2k.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/internal/internal_syr2k.cc b/src/internal/internal_syr2k.cc index 3878dc66f..f41c16baf 100644 --- a/src/internal/internal_syr2k.cc +++ b/src/internal/internal_syr2k.cc @@ -617,8 +617,8 @@ void syr2k(internal::TargetType, layout, uplo, opA_, n, k, alpha_, a_array, ldda, - b_array, lddb, - beta_, c_array, lddc, + b_array, lddb, + beta_, c_array, lddc, group_count, info, *queue); } else {