Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device regions to non-banded routines #140

Merged
merged 35 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d15b78f
Minor touch ups to set
neil-lindquist Oct 19, 2023
003269f
Add regions to device scale
neil-lindquist Oct 19, 2023
e4354de
Add regions to device copy
neil-lindquist Oct 19, 2023
b96b390
Add regions to device scale_row_col
neil-lindquist Oct 19, 2023
8adc605
Add regions to device add
neil-lindquist Oct 19, 2023
0a85dac
Refactor out logic to compute device ranges
neil-lindquist Oct 20, 2023
7f156c2
Refactor duplicate logic for constructing device regions
neil-lindquist Oct 20, 2023
2a3bb58
Refactor out more code duplication
neil-lindquist Oct 20, 2023
6561d7a
Add regions to device gemm
neil-lindquist Oct 23, 2023
d219186
Refactor generic regions code to handle scale_row_col and copy
neil-lindquist Oct 24, 2023
99d29b2
Add regions to device trmm and trsm
neil-lindquist Oct 24, 2023
506fbab
FIXUP bcast in region builder
neil-lindquist Oct 24, 2023
d1cdc79
Add device regions to herk and her2k
neil-lindquist Oct 24, 2023
b942591
Move device region setup code to internal_batch.hh
neil-lindquist Oct 24, 2023
6356d1d
Add device regions to syrk and syr2k
neil-lindquist Oct 24, 2023
6fad79f
Use the batch arrays in genorm
neil-lindquist Oct 19, 2023
c41f947
Start relaxing assumption that tile?b(0) is the largest mb/nb in norm…
neil-lindquist Oct 20, 2023
cafffa9
Add device regions to non-band norms
neil-lindquist Oct 26, 2023
cf8d66b
Move an argument for device regions into a template
neil-lindquist Oct 26, 2023
264df79
Fix non-band norms for variable block sizes
neil-lindquist Oct 27, 2023
b7cfdcb
Improve batch regions docs and remove some unneeded includes
neil-lindquist Oct 27, 2023
1a35bfa
Add device regions to gemmA
neil-lindquist Oct 27, 2023
4e174f1
Fix column norms and reduce norm overheads
neil-lindquist Oct 31, 2023
e284fed
Fix indexing mistake in norms
neil-lindquist Nov 1, 2023
ecb47d9
Better follow line limts
neil-lindquist Nov 1, 2023
1aa7064
Add device regions to trsmA
neil-lindquist Nov 2, 2023
201e4b8
Add device regions to he2hb_gemm
neil-lindquist Nov 2, 2023
0d2305d
Clean up some he2hb helper functions
neil-lindquist Nov 2, 2023
646d46a
Add device regions to he2hb trmm
neil-lindquist Nov 2, 2023
c96a391
Add device regions to he2hb_her2k_offdiag
neil-lindquist Nov 2, 2023
c20af46
Relax some uniform tile size assumptions
neil-lindquist Nov 7, 2023
6dc04d2
Improve docs
neil-lindquist Nov 16, 2023
4a2e569
Use RowCol instead of some bool flags
neil-lindquist Nov 16, 2023
fd21fc9
Cleanup various mistakes pointed out in review
neil-lindquist Nov 22, 2023
8da4956
Fix one more code alignment mistake
neil-lindquist Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/slate/enums.hh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ typedef blas::Side Side;
typedef blas::Layout Layout;

using lapack::Equed;
using lapack::RowCol;
typedef lapack::Norm Norm;
typedef lapack::Direction Direction;

Expand Down
18 changes: 9 additions & 9 deletions include/slate/internal/device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,6 @@ void gescale(
scalar_t* A, int64_t lda,
blas::Queue& queue);

//------------------------------------------------------------------------------
template <typename scalar_t>
void tzscale(
Uplo uplo,
int64_t m, int64_t n,
blas::real_type<scalar_t> numer, blas::real_type<scalar_t> denom,
scalar_t** Aarray, int64_t lda,
int64_t batch_count, blas::Queue& queue);

//------------------------------------------------------------------------------
template <typename scalar_t, typename scalar_t2>
void gescale_row_col_batch(
Expand Down Expand Up @@ -175,6 +166,15 @@ void gescale(
scalar_t** Aarray, int64_t lda,
int64_t batch_count, blas::Queue& queue);

//------------------------------------------------------------------------------
template <typename scalar_t>
void tzscale(
Uplo uplo,
int64_t m, int64_t n,
blas::real_type<scalar_t> numer, blas::real_type<scalar_t> denom,
scalar_t** Aarray, int64_t lda,
int64_t batch_count, blas::Queue& queue);

//------------------------------------------------------------------------------
template <typename scalar_t>
void geadd(
Expand Down
4 changes: 4 additions & 0 deletions src/cuda/device_tzscale.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ __global__ void tzscale_kernel(
}
}

//==============================================================================
namespace batch {

//------------------------------------------------------------------------------
/// Batched routine for element-wise trapezoidal tile scale.
/// Sets upper or lower part of
Expand Down Expand Up @@ -169,5 +172,6 @@ void tzscale(
batch_count, queue );
}

} // namespace batch
} // namespace device
} // namespace slate
7 changes: 3 additions & 4 deletions src/he2hb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ void he2hb(
const bool set_hold = 1; // Do tileGetAndHold in the bcast

int64_t n = A.n();
int64_t nb_A = A.tileNb( 0 );
GridOrder grid_order;
int nprow, npcol, myrow, mycol;
// For the workspace matrix (W) change the GPU distribution to row cyclic,
Expand All @@ -82,8 +81,8 @@ void he2hb(
A.gridinfo( &grid_order, &nprow, &npcol, &myrow, &mycol );
assert( grid_order == GridOrder::Col ); // todo: update for Row

auto tileNb = slate::func::uniform_blocksize( n, nb_A );
auto tileRank = slate::func::process_2d_grid( GridOrder::Col, nprow, npcol );
auto tileNb = A.tileNbFunc();
auto tileRank = A.tileRankFunc();
int num_devices = blas::get_device_count();
auto tileDevice = slate::func::device_1d_grid( GridOrder::Col,
nprow, num_devices );
Expand Down Expand Up @@ -135,7 +134,7 @@ void he2hb(

lapack::Queue* queue = A.compute_queue( panel_device, queue_0 );

int64_t nb = A.tileNb(0);
int64_t nb = func::max_blocksize(A.nt(), A.tileNbFunc());
mgates3 marked this conversation as resolved.
Show resolved Hide resolved
size_t size_tau = (size_t) std::min( mlocal, nb );
size_t size_A = (size_t) blas::max( 1, mlocal ) * nb;
size_t hsize, dsize;
Expand Down
4 changes: 4 additions & 0 deletions src/hip/device_tzscale.hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ __global__ void tzscale_kernel(
}
}

//==============================================================================
namespace batch {

//------------------------------------------------------------------------------
/// Batched routine for element-wise trapezoidal tile scale.
/// Sets upper or lower part of
Expand Down Expand Up @@ -170,5 +173,6 @@ void tzscale(
batch_count, queue );
}

} // namespace batch
} // namespace device
} // namespace slate
2 changes: 1 addition & 1 deletion src/hip/device_tzscale.hip.cc.dep
Original file line number Diff line number Diff line change
@@ -1 +1 @@
95be14909da63e90e9d7d888f5b5e1bd src/cuda/device_tzscale.cu
2f13aaf1009fad8799225aa36791f3d8 src/cuda/device_tzscale.cu
250 changes: 249 additions & 1 deletion src/internal/internal_batch.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

//------------------------------------------------------------------------------
/// @file
/// Provides various helper functions for batched routines.
///
/// Provides simple precision-independent wrappers around MKL batch
/// routines. Eventually to be replaced by BLAS++ batch routines.
///
/// Provides routines to build the batch regions for device batched kernels.
#ifndef SLATE_INTERNAL_BATCH_HH
#define SLATE_INTERNAL_BATCH_HH

#include "slate/Exception.hh"
#include "slate/BaseMatrix.hh"

#include <blas.hh>

Expand Down Expand Up @@ -146,7 +151,250 @@ inline void cblas_gemm_batch(
}
#endif // BLAS_HAVE_MKL

} // namespace slate

// Utilities for computing device batch regions

//------------------------------------------------------------------------------
/// Computes the range of tiles with either the same mb or the same nb.
///
/// @param[in] dim
/// Whether to compute the row ranges or the column ranges
///
/// @param[in] A
/// The matrix to get tile sizes from
///
/// @return The ranges of uniform tile sizes
///
template<typename scalar_t>
std::vector<int64_t> device_regions_range( RowCol dim, BaseMatrix<scalar_t>& A )
{
bool want_rows = dim == RowCol::Row;

int64_t kt = want_rows ? A.mt() : A.nt();

std::vector< int64_t > range;
int64_t last = -1;
for (int64_t k = 0; k < kt; ++k) {
int64_t kb = want_rows ? A.tileMb( k ) : A.tileNb( k );
if (kb != last) {
last = kb;
range.push_back( k );
}
}
range.push_back( kt );
return range;
}

//------------------------------------------------------------------------------
/// Helper class to store the information on a device region.
///
/// @tparam store_diag
/// Wheather the diagonal tiles may need to be special cased
///
/// @tparam mat_count
/// The number of matrices used by the kernel
///
template< bool store_diag, int mat_count >
struct device_regions_params {
int64_t count, mb, nb;
int64_t ld[mat_count];

private:
// When store_diag is false, we don't want to allocate any memory for is_diagonal
struct Empty {};
public:
std::conditional_t< store_diag, bool, Empty > is_diagonal;

device_regions_params()
: count(0), mb(0), nb(0), ld{0}
{
if constexpr (store_diag) {
is_diagonal = false;
}
}
};

//------------------------------------------------------------------------------
/// @copydoc device_regions_build(std::array< std::reference_wrapper<BaseMatrix<scalar_t>>, mat_count >, std::array< scalar_t**, mat_count >, int64_t, std::function<void(int64_t, int64_t, int64_t)>)
///
/// @params[in] irange
/// The ranges of tiles with a uniform number of rows
///
neil-lindquist marked this conversation as resolved.
Show resolved Hide resolved
/// @params[in] jrange
/// The ranges of tiles with a uniform number of columns
///
neil-lindquist marked this conversation as resolved.
Show resolved Hide resolved
template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag >
std::vector< device_regions_params<store_diag, mat_count> > device_regions_build(
std::array< std::reference_wrapper<BaseMatrix<scalar_t>>, mat_count > mats,
std::array< scalar_t**, mat_count > mats_array_host,
int64_t device,
std::function<void(int64_t, int64_t, int64_t)> extra_setup,
std::vector<int64_t>& irange,
std::vector<int64_t>& jrange)
{
// The first two arguments should be valid targets for brace-initialization
// reference_wrapper works around fact that C++ doesn't allow array of references

using Params = device_regions_params<store_diag, mat_count>;

auto& A = mats[0].get();

// Trapezoidal matrices always need special treatment for diagonal tiles
assert( !diag_same || A.uplo() == Uplo::General );

static_assert( diag_same || store_diag,
"Can't special case the diagonal when is_diagonal is not allocated" );

// Size 1 dimensions get broadcast to allow setting up GEMM et al.
// i_step[m]=0 results in only accessing row 0 of matrix m (likewise for j)
// The first matrix is always indexed normally since it determines the loops
int64_t i_step[mat_count];
int64_t j_step[mat_count];
i_step[0] = 1;
j_step[0] = 1;
for (int m = 1; m < mat_count; ++m) {
i_step[m] = (mats[ m ].get().mt() > 1);
j_step[m] = (mats[ m ].get().nt() > 1);
}

mgates3 marked this conversation as resolved.
Show resolved Hide resolved
int64_t batch_count = 0;
int64_t mt = A.mt();
std::vector<Params> group_params;
// loop over regions
for (size_t jj = 0; jj < jrange.size() - 1; ++jj) {
for (size_t ii = 0; ii < irange.size() - 1; ++ii) {
// Loop over the tiles in this region,
// save any that should be computed on this process & device
Params group;
group.mb = A.tileMb( irange[ ii ] );
group.nb = A.tileNb( jrange[ jj ] );
for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) {
mgates3 marked this conversation as resolved.
Show resolved Hide resolved
// This is a column major loop. So,
// * Lower matrices start at j+1
// * Upper matrices end at j
// * General matrices run the whole range
int64_t istart = std::max(irange[ ii ], (A.uplo() == Uplo::Lower ? j+1 : 0));
int64_t iend = std::min(irange[ ii+1 ], (A.uplo() == Uplo::Upper ? j : mt));
for (int64_t i = istart; i < iend; ++i) {
if ((diag_same || i != j)
&& A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) {

// Add tiles to current group
for (int m = 0; m < mat_count; ++m) {
auto Mij = mats[ m ].get()( i*i_step[m], j*j_step[m], device );
mats_array_host[ m ][ batch_count ] = Mij.data();
if (group.count == 0) {
group.ld[m] = Mij.stride();
}
else {
assert( group.ld[m] == Mij.stride() );
}
}
if (extra_setup) {
extra_setup( group_params.size(), i, j );
}
++group.count;
++batch_count;
}
} // for i
} // for j
// If any tiles in the region should be computed here, save the group
if (group.count > 0) {
group_params.push_back( group );
}

// If the diagonal tiles need special treatment, build those groups
if constexpr (store_diag && !diag_same) {
// Loop over the diagonal tiles in this region. If any should be
// computed on this process & device, save them.
group = Params();
group.is_diagonal = true;
group.mb = A.tileMb( irange[ ii ] );
group.nb = A.tileNb( jrange[ jj ] );
// Diagonal tiles only in the intersection of irange and jrange
int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]);
int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]);
for (int64_t ij = ijstart; ij < ijend; ++ij) {
if (A.tileIsLocal( ij, ij )
&& device == A.tileDevice( ij, ij )) {

// Add tiles to current group
// This logic matches that of above
for (int m = 0; m < mat_count; ++m) {
auto Mij = mats[ m ].get()( ij*i_step[m], ij*j_step[m], device );
mats_array_host[ m ][ batch_count ] = Mij.data();
if (group.count == 0) {
group.ld[m] = Mij.stride();
}
else {
assert( group.ld[m] == Mij.stride() );
}
}
if (extra_setup) {
extra_setup( group_params.size(), ij, ij );
}
++group.count;
++batch_count;
}
} // for ij
// If any tiles in the region should be computed here, save the group
if (group.count > 0) {
group_params.push_back( group );
}
} // if store_diag && !diag_same
}} // for jj, ii
return group_params;
}

//------------------------------------------------------------------------------
/// Computes and populates the regions for the given matrices.
///
/// @tparam store_diag
/// Whether the diagonal tiles may need to be special cased
///
/// @tparam mat_count
/// The number of matrices used by the kernel
///
/// @tparam scalar_t
/// The type of the matrices
///
/// @tparam[in] diag_same
/// Whether to include the diagonal tiles in the off-diagonal groups
/// If false, store_diag must be true
//------------------------------------------------------------------------------
/// @param[in] mats
/// An array of the matrices to build regions for
///
/// @param[in] mats_array_host
/// An array of the arrays to fill with pointers to device data
///
/// @param[in] device
/// The device to build regions for
///
/// @param[in] extra_setup
/// Callback that is called whenever a tile is added to a group.
/// The group index and the tile indices are passed as arguments
///
/// @return A list of batches with identical size.
///
template< bool store_diag, int mat_count, typename scalar_t, bool diag_same=!store_diag >
std::vector< device_regions_params<store_diag, mat_count> > device_regions_build(
std::array< std::reference_wrapper<BaseMatrix<scalar_t>>, mat_count > mats,
std::array< scalar_t**, mat_count > mats_array_host,
int64_t device,
std::function<void(int64_t, int64_t, int64_t)> extra_setup = {})
{
// Find ranges of matching mb's and ranges of matching nb's.
auto irange = device_regions_range( RowCol::Row, mats[0].get() );
auto jrange = device_regions_range( RowCol::Col, mats[0].get() );

return device_regions_build< store_diag, mat_count, scalar_t, diag_same >(
mats, mats_array_host, device, extra_setup,
irange, jrange );
}


} // namespace internal
} // namespace slate

#endif // SLATE_INTERNAL_BATCH_HH
Loading