Skip to content

Commit

Permalink
Remove tile life from hbmm
Browse files Browse the repository at this point in the history
  • Loading branch information
neil-lindquist committed Dec 1, 2023
1 parent 00dc003 commit 6c90ba4
Showing 1 changed file with 87 additions and 10 deletions.
97 changes: 87 additions & 10 deletions src/hbmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,21 @@ void hbmm(
using blas::min;
using BcastList = typename Matrix<scalar_t>::BcastList;
const scalar_t one = 1.0;
const int64_t priority_0 = 0;
const int64_t queue_0 = 0;

// Assumes column major
const Layout layout = Layout::ColMajor;

// Options
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );

// Use only TileReleaseStrategy::Slate for hbmm.
// Internal hbmm routine called here won't release
// any tiles. This routine will clean up tiles.
Options opts2 = opts;
opts2[ Option::TileReleaseStrategy ] = TileReleaseStrategy::Slate;

// if on right, change to left by transposing A, B, C to get
// op(C) = op(A)*op(B)
if (side == Side::Right) {
Expand Down Expand Up @@ -153,7 +161,8 @@ void hbmm(
Side::Left,
alpha, A.sub(0, 0),
B.sub(0, 0, 0, B.nt()-1),
beta, C.sub(0, 0, 0, C.nt()-1));
beta, C.sub(0, 0, 0, C.nt()-1),
priority_0, opts2 );

int64_t i_end = min(0 + kdt + 1, A.mt());

Expand All @@ -162,7 +171,7 @@ void hbmm(
alpha, A.sub(1, i_end-1, 0, 0),
B.sub(0, 0, 0, B.nt()-1),
beta, C.sub(1, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );
}

if (beta != one) {
Expand All @@ -186,6 +195,22 @@ void hbmm(
}
}

#pragma omp task depend(in:gemm[0])
{
int64_t i_end = min(0 + kdt + 1, A.mt());

auto A_colblock = A.sub(0, i_end-1, 0, 0);
auto B_rowblock = B.sub(0, 0, 0, B.nt()-1);

// Erase remote tiles on all devices including host
A_colblock.releaseRemoteWorkspace();
B_rowblock.releaseRemoteWorkspace();

// Erase local workspace on devices.
A_colblock.releaseLocalWorkspace();
B_rowblock.releaseLocalWorkspace();
}

for (int64_t k = 1; k < A.nt(); ++k) {
// send next block col of A and block row of B
if (k+lookahead < A.nt()) {
Expand Down Expand Up @@ -237,22 +262,40 @@ void hbmm(
alpha, conj_transpose( Arow_k ),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(i_begin, k-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );

internal::hemm<Target::HostTask>(
Side::Left,
alpha, A.sub(k, k),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(k, k, 0, C.nt()-1));
one, C.sub(k, k, 0, C.nt()-1),
priority_0, opts2 );

if (i_end-1 > k) {
internal::gemm<target>(
alpha, A.sub(k+1, i_end-1, k, k),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(k+1, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );
}
}

#pragma omp task depend(in:gemm[k])
{
auto A_rowblock = A.sub(k, k, i_begin, k);
auto A_colblock = A.sub(k+1, i_end-1, k, k);
auto B_rowblock = B.sub(k, k, 0, B.nt()-1);

// Erase remote tiles on all devices including host
A_colblock.releaseRemoteWorkspace();
A_rowblock.releaseRemoteWorkspace();
B_rowblock.releaseRemoteWorkspace();

// Erase local workspace on devices.
A_colblock.releaseLocalWorkspace();
A_rowblock.releaseLocalWorkspace();
B_rowblock.releaseLocalWorkspace();
}
}
}
else {
Expand Down Expand Up @@ -315,7 +358,8 @@ void hbmm(
Side::Left,
alpha, A.sub(0, 0),
B.sub(0, 0, 0, B.nt()-1),
beta, C.sub(0, 0, 0, C.nt()-1));
beta, C.sub(0, 0, 0, C.nt()-1),
priority_0, opts2 );

int64_t i_end = min(0 + kdt + 1, A.mt());

Expand All @@ -325,7 +369,7 @@ void hbmm(
alpha, conj_transpose( Arow_k ),
B.sub(0, 0, 0, B.nt()-1),
beta, C.sub(1, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );
}

if (beta != one) {
Expand All @@ -346,6 +390,21 @@ void hbmm(
}
}

#pragma omp task depend(in:gemm[0])
{
int64_t i_end = min(0 + kdt + 1, A.mt());
auto A_colblock = A.sub(0, 0, 0, i_end-1);
auto B_rowblock = B.sub(0, 0, 0, B.nt()-1);

// Erase remote tiles on all devices including host
A_colblock.releaseRemoteWorkspace();
B_rowblock.releaseRemoteWorkspace();

// Erase local workspace on devices.
A_colblock.releaseLocalWorkspace();
B_rowblock.releaseLocalWorkspace();
}

for (int64_t k = 1; k < A.nt(); ++k) {
// send next block col of A and block row of B
if (k+lookahead < A.nt()) {
Expand Down Expand Up @@ -395,23 +454,41 @@ void hbmm(
alpha, A.sub(i_begin, k-1, k, k),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(i_begin, k-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );

internal::hemm<Target::HostTask>(
Side::Left,
alpha, A.sub(k, k),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(k, k, 0, C.nt()-1));
one, C.sub(k, k, 0, C.nt()-1),
priority_0, opts2 );

if (i_end-1 > k) {
auto Arow_k = A.sub(k, k, k+1, i_end-1);
internal::gemm<target>(
alpha, conj_transpose( Arow_k ),
B.sub(k, k, 0, B.nt()-1),
one, C.sub(k+1, i_end-1, 0, C.nt()-1),
layout);
layout, priority_0, queue_0, opts2 );
}
}

#pragma omp task depend(in:gemm[k])
{
auto A_colblock = A.sub(i_begin, k, k, k);
auto A_rowblock = A.sub(k, k, k+1, i_end-1);
auto B_rowblock = B.sub(k, k, 0, B.nt()-1);

// Erase remote tiles on all devices including host
A_colblock.releaseRemoteWorkspace();
A_rowblock.releaseRemoteWorkspace();
B_rowblock.releaseRemoteWorkspace();

// Erase local workspace on devices.
A_colblock.releaseLocalWorkspace();
A_rowblock.releaseLocalWorkspace();
B_rowblock.releaseLocalWorkspace();
}
}
}
}
Expand Down

0 comments on commit 6c90ba4

Please sign in to comment.