Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove tile life from remaining routines #150

Merged
merged 13 commits into from
Dec 7, 2023
23 changes: 12 additions & 11 deletions include/slate/BaseMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ public:

/// Decrements the number of times the tile {i, j} is received
/// through MPI.
void tileDecrementReceiveCount(int64_t i, int64_t j)
void tileDecrementReceiveCount(int64_t i, int64_t j, int64_t release_count = 1 )
{
storage_->tileDecrementReceiveCount( globalIndex( i, j ) );
storage_->tileDecrementReceiveCount( globalIndex( i, j ), release_count );
}

void tileErase( int64_t i, int64_t j, int device=HostNum );
Expand Down Expand Up @@ -626,9 +626,9 @@ public:
void releaseLocalWorkspace();
void releaseLocalWorkspace( std::set<ij_tuple>& tile_set );

void releaseRemoteWorkspaceTile( int64_t i, int64_t j );
void releaseRemoteWorkspace();
void releaseRemoteWorkspace( std::set<ij_tuple>& tile_set );
void releaseRemoteWorkspaceTile( int64_t i, int64_t j, int64_t release_count = 1 );
void releaseRemoteWorkspace( int64_t recieve_count = 1 );
void releaseRemoteWorkspace( std::set<ij_tuple>& tile_set, int64_t release_count = 1 );

/// Removes all temporary host and device workspace tiles from matrix.
/// WARNING: currently, this clears the entire parent matrix,
Expand Down Expand Up @@ -3952,15 +3952,16 @@ void BaseMatrix<scalar_t>::releaseLocalWorkspace(
/// reaches zero, the tile is erased. Otherwise, tile is not erased.
///
template <typename scalar_t>
void BaseMatrix<scalar_t>::releaseRemoteWorkspaceTile(int64_t i, int64_t j)
void BaseMatrix<scalar_t>::releaseRemoteWorkspaceTile(
int64_t i, int64_t j, int64_t release_count )
{
if (! tileIsLocal( i, j )) { // erase remote tiles
// This lock ensures that no other thread is trying to
// remove this tile from the map of tiles.
LockGuard guard( storage_->getTilesMapLock() );

if (tileExists( i, j, AnyDevice )) {
tileDecrementReceiveCount( i, j );
tileDecrementReceiveCount( i, j, release_count );
if (tileReceiveCount( i, j ) <= 0) {
tileRelease( i, j, AllDevices );
}
Expand All @@ -3973,11 +3974,11 @@ void BaseMatrix<scalar_t>::releaseRemoteWorkspaceTile(int64_t i, int64_t j)
/// including host, if not on hold or modified.
///
template <typename scalar_t>
void BaseMatrix<scalar_t>::releaseRemoteWorkspace()
void BaseMatrix<scalar_t>::releaseRemoteWorkspace( int64_t release_count )
{
for (int64_t j = 0; j < nt(); ++j) {
for (int64_t i = 0; i < mt(); ++i) {
releaseRemoteWorkspaceTile( i, j );
releaseRemoteWorkspaceTile( i, j, release_count );
}
}
}
Expand All @@ -3991,12 +3992,12 @@ void BaseMatrix<scalar_t>::releaseRemoteWorkspace()
///
template <typename scalar_t>
void BaseMatrix<scalar_t>::releaseRemoteWorkspace(
std::set<ij_tuple>& tile_set)
std::set<ij_tuple>& tile_set, int64_t release_count )
{
for (auto ij : tile_set) {
int64_t i = std::get<0>( ij );
int64_t j = std::get<1>( ij );
releaseRemoteWorkspaceTile( i, j );
releaseRemoteWorkspaceTile( i, j, release_count );
}
}

Expand Down
4 changes: 2 additions & 2 deletions include/slate/internal/MatrixStorage.hh
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,10 @@ public:

//--------------------------------------------------------------------------
/// Decrement tile's receive counter.
void tileDecrementReceiveCount(ij_tuple ij)
void tileDecrementReceiveCount( ij_tuple ij, int64_t release_count = 1 )
{
LockGuard guard( getTilesMapLock() );
tiles_.at( ij )->receiveCount()--;
tiles_.at( ij )->receiveCount() -= release_count;
}

/// Ensures the tile node exists and increments both the tile life and
Expand Down
26 changes: 24 additions & 2 deletions src/gbmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ void gbmm(
const Layout layout = Layout::ColMajor;

const scalar_t one = 1.0;
const int64_t priority_0 = 0;
const int64_t queue_0 = 0;

// Options
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );

// Use only TileReleaseStrategy::Slate for gbmm.
// Internal gbmm routine called here won't release
// any tiles. This routine will clean up tiles.
Options opts2 = opts;
opts2[ Option::TileReleaseStrategy ] = TileReleaseStrategy::Slate;

// OpenMP needs pointer types, but vectors are exception safe
std::vector<uint8_t> bcast_vector(A.nt());
std::vector<uint8_t> gemm_vector(A.nt());
Expand Down Expand Up @@ -124,7 +132,7 @@ void gbmm(
alpha, A.sub(i_begin, i_end-1, 0, 0),
B.sub(0, 0, 0, B.nt()-1),
beta, C.sub(i_begin, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );

if (beta != one) {
// Scale block rows of C below the bandwidth of A:
Expand Down Expand Up @@ -189,7 +197,21 @@ void gbmm(
alpha, A.sub(i_begin, i_end-1, k, k),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(i_begin, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );
}

#pragma omp task depend(in:gemm[k])
{
auto A_colblock = A.sub(i_begin, i_end-1, k, k);
auto B_rowblock = B.sub(k, k, 0, B.nt()-1);

// Erase remote tiles on all devices including host
A_colblock.releaseRemoteWorkspace();
B_rowblock.releaseRemoteWorkspace();

// Erase local workspace on devices.
A_colblock.releaseLocalWorkspace();
B_rowblock.releaseLocalWorkspace();
}
}
}
Expand Down
36 changes: 32 additions & 4 deletions src/gbtrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ int64_t gbtrf(
const int priority_0 = 0;
const int priority_1 = 1;
const int tag_0 = 0;
const int queue_0 = 0;
// Assumes column major
const Layout layout = Layout::ColMajor;

Expand All @@ -47,6 +48,12 @@ int64_t gbtrf(
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads,
max_panel_threads );

// Use only TileReleaseStrategy::Slate for gbtrf
// Internal routines called here won't release any
// tiles. This routine will clean up tiles.
Options opts2 = opts;
opts2[ Option::TileReleaseStrategy ] = TileReleaseStrategy::Slate;

int64_t info = 0;
int64_t A_nt = A.nt();
int64_t A_mt = A.mt();
Expand Down Expand Up @@ -146,7 +153,8 @@ int64_t gbtrf(
// solve A(k, k) A(k, j) = A(k, j)
internal::trsm<Target::HostTask>(
Side::Left,
one, std::move( Tkk ), A.sub(k, k, j, j), priority_1 );
one, std::move( Tkk ), A.sub(k, k, j, j),
priority_1, layout, queue_0, opts2 );

// send A(k, j) across column A(k+1:mt-1, j)
A.tileBcast(k, j, A.sub(k+1, i_end-1, j, j), layout, tag_j);
Expand All @@ -156,7 +164,7 @@ int64_t gbtrf(
-one, A.sub(k+1, i_end-1, k, k),
A.sub(k, k, j, j),
one, A.sub(k+1, i_end-1, j, j),
layout, priority_1 );
layout, priority_1, queue_0, opts2 );
}
}
// Update trailing submatrix, normal priority.
Expand All @@ -181,7 +189,8 @@ int64_t gbtrf(
internal::trsm<Target::HostTask>(
Side::Left,
one, std::move( Tkk ),
A.sub(k, k, k+1+lookahead, j_end-1));
A.sub(k, k, k+1+lookahead, j_end-1),
priority_0, layout, queue_0, opts2 );

// send A(k, kl+1:j_end-1) across A(k+1:mt-1, kl+1:nt-1)
BcastList bcast_list_A;
Expand All @@ -196,9 +205,28 @@ int64_t gbtrf(
-one, A.sub(k+1, i_end-1, k, k),
A.sub(k, k, k+1+lookahead, j_end-1),
one, A.sub(k+1, i_end-1, k+1+lookahead, j_end-1),
layout);
layout, priority_0, queue_0, opts2 );
}
}

#pragma omp task depend(inout:column[k])
{
auto left_panel = A.sub( k, i_end-1, k, k );
auto top_panel = A.sub( k, k, k+1, j_end-1 );

// Erase remote tiles on all devices, including host
left_panel.releaseRemoteWorkspace();
top_panel.releaseRemoteWorkspace();

// Update the origin tiles before their
// workspace copies on devices are erased.
left_panel.tileUpdateAllOrigin();
top_panel.tileUpdateAllOrigin();

// Erase local workspace on devices
left_panel.releaseLocalWorkspace();
top_panel.releaseLocalWorkspace();
}
}

#pragma omp taskwait
Expand Down
3 changes: 3 additions & 0 deletions src/gemmC.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ void gemmC(
B.sub(k, k, 0, B.nt()-1),
one, std::move( C ),
layout, priority_0, queue_0, opts2 );
}

#pragma omp task depend(in:gemm[k])
{
auto A_colblock = A.sub(0, A.mt()-1, k, k);
auto B_rowblock = B.sub(k, k, 0, B.nt()-1);

Expand Down
52 changes: 27 additions & 25 deletions src/getrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ int64_t getrf(

// Constants
const scalar_t one = 1.0;
const int life_1 = 1;
const int priority_0 = 0;
const int priority_1 = 1;
const int queue_0 = 0;
Expand All @@ -45,6 +44,11 @@ int64_t getrf(
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads,
max_panel_threads );

// Use only TileReleaseStrategy::Slate for getrf.
// Internal routines won't release any tiles. getrf will clean up tiles.
Options opts2 = Options( opts );
opts2[ Option::TileReleaseStrategy ] = TileReleaseStrategy::Slate;

// Host can use Col/RowMajor for row swapping,
// RowMajor is slightly more efficient.
// Layout host_layout = Layout::RowMajor;
Expand All @@ -63,8 +67,6 @@ int64_t getrf(
int64_t min_mt_nt = std::min(A.mt(), A.nt());
pivots.resize(min_mt_nt);

bool is_shared = target == Target::Devices && lookahead > 0;

// OpenMP needs pointer types, but vectors are exception safe
std::vector< uint8_t > column_vector(A_nt);
uint8_t* column = column_vector.data();
Expand Down Expand Up @@ -110,7 +112,7 @@ int64_t getrf(
bcast_list_A.push_back({i, k, {A.sub(i, i, k+1, A_nt-1)}});
}
A.template listBcast<target>(
bcast_list_A, Layout::ColMajor, tag_k, life_1, is_shared );
bcast_list_A, Layout::ColMajor, tag_k );

// Root broadcasts the pivot to all ranks.
// todo: Panel ranks send the pivots to the right.
Expand Down Expand Up @@ -142,7 +144,7 @@ int64_t getrf(
internal::trsm<target>(
Side::Left,
one, std::move( Tkk ), A.sub(k, k, j, j),
priority_1, Layout::ColMajor, queue_jk1 );
priority_1, Layout::ColMajor, queue_jk1, opts2 );

// send A(k, j) across column A(k+1:mt-1, j)
// todo: trsm still operates in ColMajor
Expand All @@ -153,7 +155,7 @@ int64_t getrf(
-one, A.sub(k+1, A_mt-1, k, k),
A.sub(k, k, j, j),
one, A.sub(k+1, A_mt-1, j, j),
target_layout, priority_1, queue_jk1 );
target_layout, priority_1, queue_jk1, opts2 );
}
}
// pivot to the left
Expand Down Expand Up @@ -199,7 +201,7 @@ int64_t getrf(
Side::Left,
one, std::move( Tkk ),
A.sub(k, k, k+1+lookahead, A_nt-1),
priority_0, Layout::ColMajor, queue_1 );
priority_0, Layout::ColMajor, queue_1, opts2 );

// send A(k, kl+1:A_nt-1) across A(k+1:mt-1, kl+1:nt-1)
BcastList bcast_list_A;
Expand All @@ -216,26 +218,26 @@ int64_t getrf(
-one, A.sub(k+1, A_mt-1, k, k),
A.sub(k, k, k+1+lookahead, A_nt-1),
one, A.sub(k+1, A_mt-1, k+1+lookahead, A_nt-1),
target_layout, priority_0, queue_1 );
target_layout, priority_0, queue_1, opts2 );
}
}
if (is_shared) {
#pragma omp task depend(inout:column[k])
{
for (int64_t i = k+1; i < A_mt; ++i) {
if (A.tileIsLocal(i, k)) {
A.tileUpdateOrigin(i, k);

std::set<int> dev_set;
A.sub(i, i, k+1, A_nt-1).getLocalDevices(&dev_set);

for (auto device : dev_set) {
A.tileUnsetHold(i, k, device);
A.tileRelease(i, k, device);
}
}
}
}
#pragma omp task depend(inout:column[k])
{
auto left_panel = A.sub( k, A_mt-1, k, k );
auto top_panel = A.sub( k, k, k+1, A_nt-1 );

// Erase remote tiles on all devices including host
left_panel.releaseRemoteWorkspace();
top_panel.releaseRemoteWorkspace();

// Update the origin tiles before their
// workspace copies on devices are erased.
left_panel.tileUpdateAllOrigin();
top_panel.tileUpdateAllOrigin();

// Erase local workspace on devices.
left_panel.releaseLocalWorkspace();
top_panel.releaseLocalWorkspace();
}
kk += A.tileNb( k );
}
Expand Down
Loading