From 3239bc0ac59b2d9f8179ebd35816cf1215aa8726 Mon Sep 17 00:00:00 2001 From: Mark Gates Date: Mon, 24 Oct 2022 23:45:50 -0400 Subject: [PATCH 1/4] arbitrary batch regions in geset --- src/internal/internal_geset.cc | 192 ++++++++++++++++++++------------- 1 file changed, 118 insertions(+), 74 deletions(-) diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index d08e0a12e..fec2a0201 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -92,112 +92,156 @@ void set(internal::TargetType, { using ij_tuple = typename BaseMatrix::ij_tuple; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A ) priority( priority ) \ - firstprivate(device, irange, jrange, queue_index, offdiag_value, diag_value) + #pragma omp task priority( priority ) slate_omp_default_none \ + shared( A, irange, jrange ) \ + firstprivate( device, queue_index, offdiag_value, diag_value ) { + // Get local tiles for writing. // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already // and possibly wrong, because an input matrix is being altered // todo: best, handle directly through the CUDA kernels - auto layout = Layout::ColMajor; + auto layout = LayoutConvert( Layout::ColMajor ); std::set A_tiles_set; for (int64_t i = 0; i < A.mt(); ++i) { for (int64_t j = 0; j < A.nt(); ++j) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - A_tiles_set.insert({i, j}); + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + A_tiles_set.insert( { i, j } ); } } } - A.tileGetForWriting(A_tiles_set, device, LayoutConvert(layout)); + A.tileGetForWriting( A_tiles_set, device, layout ); + + scalar_t** a_array_host = A.array_host( device, queue_index ); - scalar_t** a_array_host = A.array_host(device, queue_index); + // If offdiag == diag value, lump diag tiles with offdiag tiles + // in one batch. + bool diag_same = offdiag_value == diag_value; + // Build batch groups. + // group_index points to the start of each group. + int group = 0; int64_t batch_count = 0; - int64_t mb[8], nb[8], lda[8], group_count[8]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); - for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - if (i != j) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } + struct Params { + int64_t start, mb, nb, lda; + }; + std::vector group_params; + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + bool first = true; + group_params.push_back( { batch_count, -1, -1, -1 } ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if ((diag_same || i != j) + && A.tileIsLocal( i, j ) + && device == A.tileDevice( i, j )) + { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (first) { + group_params[ group ].mb = Aij.mb(); + group_params[ group ].nb = Aij.nb(); + group_params[ group ].lda = Aij.stride(); + first = false; + } + else { + assert( group_params[ group ].mb == Aij.mb() ); + assert( group_params[ group ].nb == Aij.nb() ); + assert( group_params[ group ].lda == Aij.stride() ); } + ++batch_count; } - } - } - for (int q = 4; q < 8; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q-4][0]); - nb[q] = A.tileNb(jrange[q-4][0]); - for (int64_t i = irange[q-4][0]; i < irange[q-4][1]; ++i) { - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - if (i == j) { - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + }} // for j, i + ++group; + }} // for jj, ii + + // Build batch groups for diagonal tiles, + // when offdiag_value != diag_value. + if (! diag_same) { + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + bool first = true; + group_params.push_back( { batch_count, -1, -1, -1 } ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (i == j + && A.tileIsLocal( i, j ) + && device == A.tileDevice( i, j )) + { + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (first) { + group_params[ group ].mb = Aij.mb(); + group_params[ group ].nb = Aij.nb(); + group_params[ group ].lda = Aij.stride(); + first = false; } + else { + assert( group_params[ group ].mb == Aij.mb() ); + assert( group_params[ group ].nb == Aij.nb() ); + assert( group_params[ group ].lda == Aij.stride() ); + } + ++batch_count; } - } - } + }} // for j, i + ++group; + }} // for jj, ii } + group_params.push_back( { batch_count, -1, -1, -1 } ); - scalar_t** a_array_dev = A.array_device(device, queue_index); - - blas::Queue* queue = A.compute_queue(device, queue_index); + blas::Queue* queue = A.compute_queue( device, queue_index ); + scalar_t** a_array_dev = A.array_device( device, queue_index ); blas::device_memcpy( a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::geset(mb[q], nb[q], - offdiag_value, offdiag_value, a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count + = group_params[ g+1 ].start - group_params[ g ].start; + if (group_count > 0) { + device::batch::geset( + group_params[ g ].mb, + group_params[ g ].nb, + offdiag_value, offdiag_value, a_array_dev, + group_params[ g ].lda, + group_count, *queue ); + a_array_dev += group_count; } } - for (int q = 4; q < 8; ++q) { - if (group_count[q] > 0) { - device::batch::geset(mb[q], nb[q], - offdiag_value, diag_value, a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - } - } - queue->sync(); - } - } + } // end task + } // end for dev } //------------------------------------------------------------------------------ From 494eaa32447070cb9becd5a2533ac5668bf31768 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 18 Oct 2023 09:38:59 -0400 Subject: [PATCH 2/4] Get batch regions working in geset --- src/internal/internal_geset.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index fec2a0201..ec9b8bc92 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -121,7 +121,7 @@ void set(internal::TargetType, #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task priority( priority ) slate_omp_default_none \ + #pragma omp task priority( priority ) \ shared( A, irange, jrange ) \ firstprivate( device, queue_index, offdiag_value, diag_value ) { @@ -154,12 +154,13 @@ void set(internal::TargetType, int64_t batch_count = 0; struct Params { int64_t start, mb, nb, lda; + scalar_t diag_value; }; std::vector group_params; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { bool first = true; - group_params.push_back( { batch_count, -1, -1, -1 } ); + group_params.push_back( { batch_count, -1, -1, -1, offdiag_value } ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if ((diag_same || i != j) @@ -191,7 +192,7 @@ void set(internal::TargetType, for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { bool first = true; - group_params.push_back( { batch_count, -1, -1, -1 } ); + group_params.push_back( { batch_count, -1, -1, -1, diag_value } ); for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if (i == j @@ -226,15 +227,15 @@ void set(internal::TargetType, a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (size_t g = 0; g < group_params.size(); ++g) { + for (size_t g = 0; g < group_params.size() - 1; ++g) { int64_t group_count = group_params[ g+1 ].start - group_params[ g ].start; if (group_count > 0) { device::batch::geset( group_params[ g ].mb, group_params[ g ].nb, - offdiag_value, offdiag_value, a_array_dev, - group_params[ g ].lda, + offdiag_value, group_params[ g ].diag_value, + a_array_dev, group_params[ g ].lda, group_count, *queue ); a_array_dev += group_count; } From 36b81e17ae79b2989ffc5ab58187b1bac0857e3c Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 18 Oct 2023 10:15:07 -0400 Subject: [PATCH 3/4] Improve regions construction in geset --- src/internal/internal_geset.cc | 92 ++++++++++++++++------------------ 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/src/internal/internal_geset.cc b/src/internal/internal_geset.cc index ec9b8bc92..d82e957df 100644 --- a/src/internal/internal_geset.cc +++ b/src/internal/internal_geset.cc @@ -126,9 +126,8 @@ void set(internal::TargetType, firstprivate( device, queue_index, offdiag_value, diag_value ) { // Get local tiles for writing. - // temporarily, convert both into same layout - // todo: this is in-efficient, because both matrices may have same layout already - // and possibly wrong, because an input matrix is being altered + // convert to column major layout to simplify lda's + // todo: this is in-efficient because the diagonal is independant of layout // todo: best, handle directly through the CUDA kernels auto layout = LayoutConvert( Layout::ColMajor ); std::set A_tiles_set; @@ -149,18 +148,15 @@ void set(internal::TargetType, bool diag_same = offdiag_value == diag_value; // Build batch groups. - // group_index points to the start of each group. - int group = 0; int64_t batch_count = 0; struct Params { - int64_t start, mb, nb, lda; + int64_t count, mb, nb, lda; scalar_t diag_value; }; std::vector group_params; for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - bool first = true; - group_params.push_back( { batch_count, -1, -1, -1, offdiag_value } ); + Params group = { 0, -1, -1, -1, offdiag_value }; for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { if ((diag_same || i != j) @@ -169,21 +165,23 @@ void set(internal::TargetType, { auto Aij = A( i, j, device ); a_array_host[ batch_count ] = Aij.data(); - if (first) { - group_params[ group ].mb = Aij.mb(); - group_params[ group ].nb = Aij.nb(); - group_params[ group ].lda = Aij.stride(); - first = false; + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } else { - assert( group_params[ group ].mb == Aij.mb() ); - assert( group_params[ group ].nb == Aij.nb() ); - assert( group_params[ group ].lda == Aij.stride() ); + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); } + ++group.count; ++batch_count; } }} // for j, i - ++group; + if (group.count > 0) { + group_params.push_back( group ); + } }} // for jj, ii // Build batch groups for diagonal tiles, @@ -191,34 +189,35 @@ void set(internal::TargetType, if (! diag_same) { for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { for (size_t ii = 0; ii < irange.size() - 1; ++ii) { - bool first = true; - group_params.push_back( { batch_count, -1, -1, -1, diag_value } ); - for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { - for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { - if (i == j - && A.tileIsLocal( i, j ) - && device == A.tileDevice( i, j )) + Params group = { 0, -1, -1, -1, diag_value }; + // Diagonal tiles only in the intersection of irange and jrange + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) { - auto Aij = A( i, j, device ); + auto Aij = A( ij, ij, device ); a_array_host[ batch_count ] = Aij.data(); - if (first) { - group_params[ group ].mb = Aij.mb(); - group_params[ group ].nb = Aij.nb(); - group_params[ group ].lda = Aij.stride(); - first = false; + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } else { - assert( group_params[ group ].mb == Aij.mb() ); - assert( group_params[ group ].nb == Aij.nb() ); - assert( group_params[ group ].lda == Aij.stride() ); + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); } + ++group.count; ++batch_count; } - }} // for j, i - ++group; + } // for ij + if (group.count > 0) { + group_params.push_back( group ); + } }} // for jj, ii } - group_params.push_back( { batch_count, -1, -1, -1 } ); blas::Queue* queue = A.compute_queue( device, queue_index ); @@ -227,18 +226,15 @@ void set(internal::TargetType, a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (size_t g = 0; g < group_params.size() - 1; ++g) { - int64_t group_count - = group_params[ g+1 ].start - group_params[ g ].start; - if (group_count > 0) { - device::batch::geset( - group_params[ g ].mb, - group_params[ g ].nb, - offdiag_value, group_params[ g ].diag_value, - a_array_dev, group_params[ g ].lda, - group_count, *queue ); - a_array_dev += group_count; - } + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + device::batch::geset( + group_params[ g ].mb, + group_params[ g ].nb, + offdiag_value, group_params[ g ].diag_value, + a_array_dev, group_params[ g ].lda, + group_count, *queue ); + a_array_dev += group_count; } queue->sync(); } // end task From 791135c80b3d8ca4e998271d827ecbe4d6c52827 Mon Sep 17 00:00:00 2001 From: Neil Lindquist Date: Wed, 18 Oct 2023 14:57:46 -0400 Subject: [PATCH 4/4] arbitrary batch regions on tzset --- src/internal/internal_tzset.cc | 245 ++++++++++++++++++--------------- 1 file changed, 135 insertions(+), 110 deletions(-) diff --git a/src/internal/internal_tzset.cc b/src/internal/internal_tzset.cc index 94d0059c3..314c0241e 100644 --- a/src/internal/internal_tzset.cc +++ b/src/internal/internal_tzset.cc @@ -120,47 +120,51 @@ void set( { using ij_tuple = typename BaseTrapezoidMatrix::ij_tuple; - // Define index ranges for regions of matrix. - // Tiles in each region are all the same size. - // Ranges begin : end are [ begin, end ), exclusive of end. - // 0 is interior [ 0 : mt-1, 0 : nt-1 ] - // 1 is bottom row [ mt-1, 0 : nt-1 ] - // 2 is right col [ 0 : mt-1, nt-1 ] - // 3 is bottom-right tile [ mt-1, nt-1 ] - // 0-3 are for off-diagonal tiles. - // 4-7 are the same as 0-3, respectively, but for diagonal tiles. - int64_t irange[4][2] = { - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() }, - { 0, A.mt()-1 }, - { A.mt()-1, A.mt() } - }; - int64_t jrange[4][2] = { - { 0, A.nt()-1 }, - { 0, A.nt()-1 }, - { A.nt()-1, A.nt() }, - { A.nt()-1, A.nt() } - }; + int64_t mt = A.mt(); + int64_t nt = A.nt(); + + // Find ranges of matching mb's. + std::vector< int64_t > irange; + int64_t last_mb = -1; + for (int64_t i = 0; i < mt; ++i) { + int64_t mb = A.tileMb( i ); + if (mb != last_mb) { + last_mb = mb; + irange.push_back( i ); + } + } + irange.push_back( mt ); + + // Find ranges of matching nb's. + std::vector< int64_t > jrange; + int last_nb = -1; + for (int64_t j = 0; j < nt; ++j) { + int64_t nb = A.tileNb( j ); + if (nb != last_nb) { + last_nb = nb; + jrange.push_back( j ); + } + } + jrange.push_back( nt ); #pragma omp taskgroup for (int device = 0; device < A.num_devices(); ++device) { - #pragma omp task slate_omp_default_none \ - shared( A ) priority( priority ) \ - firstprivate( device, irange, jrange, queue_index ) \ - firstprivate( offdiag_value, diag_value ) + #pragma omp task priority( priority ) \ + shared( A, irange, jrange ) \ + firstprivate( device, queue_index, offdiag_value, diag_value ) { // temporarily, convert both into same layout // todo: this is in-efficient, because both matrices may have same layout already // and possibly wrong, because an input matrix is being altered // todo: best, handle directly through the CUDA kernels - auto layout = Layout::ColMajor; + auto layout = LayoutConvert::ColMajor; std::set A_tiles_set; if (A.uplo() == Uplo::Lower) { for (int64_t j = 0; j < A.nt(); ++j) { for (int64_t i = j; i < A.mt(); ++i) { // lower trapezoid - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - A_tiles_set.insert({i, j}); + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + A_tiles_set.insert( {i, j} ); } } } @@ -168,91 +172,111 @@ void set( else { // upper for (int64_t j = 0; j < A.nt(); ++j) { for (int64_t i = 0; i <= j && i < A.mt(); ++i) { // upper trapezoid - if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) { - A_tiles_set.insert({i, j}); + if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) { + A_tiles_set.insert( {i, j} ); } } } } - A.tileGetForWriting(A_tiles_set, device, LayoutConvert(layout)); + A.tileGetForWriting( A_tiles_set, device, layout ); scalar_t** a_array_host = A.array_host( device ); scalar_t** a_array_dev = A.array_device( device ); + // Build batch groups int64_t batch_count = 0; - int64_t mb[8], nb[8], lda[8], group_count[8]; - for (int q = 0; q < 4; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q][0]); - nb[q] = A.tileNb(jrange[q][0]); + struct Params { + int64_t count, mb, nb, lda; + bool is_diagonal; + }; + std::vector group_params; + // Build batch groups for off-diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, false }; if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = std::max(j, irange[q][0]); i < irange[q][1]; ++i) { - if (i != j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = std::max(irange[ ii ], j); i < irange[ ii+1 ]; ++i) { + if (i != j + && A.tileIsLocal(i, j) + && device == A.tileDevice(i, j)) { + + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } - } - } - } - else { // upper - for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) { - for (int64_t i = irange[q][0]; i < irange[q][1] && i <= j; ++i) { - if (i != j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); } + ++group.count; + ++batch_count; } - } + }} // for j,i } - } - for (int q = 4; q < 8; ++q) { - group_count[q] = 0; - lda[q] = 0; - mb[q] = A.tileMb(irange[q-4][0]); - nb[q] = A.tileNb(jrange[q-4][0]); - if (A.uplo() == Uplo::Lower) { - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = std::max(j, irange[q-4][0]); i < irange[q-4][1]; ++i) { - if (i == j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; + else { // A.uplo() == Uplo::Upper + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ] && i <= j; ++i) { + if (i != j + && A.tileIsLocal(i, j) + && device == A.tileDevice(i, j)) { + + auto Aij = A( i, j, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; + ++batch_count; } - } + }} // for j,i } - else { // upper - for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) { - for (int64_t i = irange[q-4][0]; i < irange[q-4][1] && i <= j; ++i) { - if (i == j - && A.tileIsLocal(i, j) - && device == A.tileDevice(i, j)) { - - a_array_host[batch_count] = A(i, j, device).data(); - lda[q] = A(i, j, device).stride(); - ++group_count[q]; - ++batch_count; - } + if (group.count > 0) { + group_params.push_back( group ); + } + }} // for jj,ii + + // Build batch groups for diagonal tiles, + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + Params group = { 0, -1, -1, -1, true }; + int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]); + int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]); + for (int64_t ij = ijstart; ij < ijend; ++ij) { + if (A.tileIsLocal( ij, ij ) + && device == A.tileDevice( ij, ij )) + { + auto Aij = A( ij, ij, device ); + a_array_host[ batch_count ] = Aij.data(); + if (group.count == 0) { + group.mb = Aij.mb(); + group.nb = Aij.nb(); + group.lda = Aij.stride(); } + else { + assert( group.mb == Aij.mb() ); + assert( group.nb == Aij.nb() ); + assert( group.lda == Aij.stride() ); + } + ++group.count; + ++batch_count; } + } // for ij + if (group.count > 0) { + group_params.push_back( group ); } - } + }} // for jj,ii blas::Queue* queue = A.compute_queue(device, queue_index); @@ -260,27 +284,28 @@ void set( a_array_dev, a_array_host, batch_count, blas::MemcpyKind::HostToDevice, *queue); - for (int q = 0; q < 4; ++q) { - if (group_count[q] > 0) { - device::batch::geset( - mb[q], nb[q], - offdiag_value, offdiag_value, - a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; - } - } - for (int q = 4; q < 8; ++q) { - if (group_count[q] > 0) { + for (size_t g = 0; g < group_params.size(); ++g) { + int64_t group_count = group_params[ g ].count; + + if (group_params[ g ].is_diagonal) { device::batch::tzset( - A.uplo(), mb[q], nb[q], + A.uplo(), + group_params[ g ].mb, + group_params[ g ].nb, offdiag_value, diag_value, - a_array_dev, lda[q], - group_count[q], *queue); - a_array_dev += group_count[q]; + a_array_dev, group_params[ g ].lda, + group_count, *queue); + } + else { + device::batch::geset( + group_params[ g ].mb, + group_params[ g ].nb, + offdiag_value, offdiag_value, + a_array_dev, group_params[ g ].lda, + group_count, *queue ); } + a_array_dev += group_count; } - queue->sync(); } }