Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arbitrary regions in device set #129

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 117 additions & 76 deletions src/internal/internal_geset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,112 +92,153 @@ void set(internal::TargetType<Target::Devices>,
{
using ij_tuple = typename BaseMatrix<scalar_t>::ij_tuple;

// Define index ranges for regions of matrix.
// Tiles in each region are all the same size.
int64_t irange[4][2] = {
{ 0, A.mt()-1 },
{ A.mt()-1, A.mt() },
{ 0, A.mt()-1 },
{ A.mt()-1, A.mt() }
};
int64_t jrange[4][2] = {
{ 0, A.nt()-1 },
{ 0, A.nt()-1 },
{ A.nt()-1, A.nt() },
{ A.nt()-1, A.nt() }
};
int64_t mt = A.mt();
int64_t nt = A.nt();

// Find ranges of matching mb's.
std::vector< int64_t > irange;
int64_t last_mb = -1;
for (int64_t i = 0; i < mt; ++i) {
int64_t mb = A.tileMb( i );
if (mb != last_mb) {
last_mb = mb;
irange.push_back( i );
}
}
irange.push_back( mt );

// Find ranges of matching nb's.
std::vector< int64_t > jrange;
int last_nb = -1;
for (int64_t j = 0; j < nt; ++j) {
int64_t nb = A.tileNb( j );
if (nb != last_nb) {
last_nb = nb;
jrange.push_back( j );
}
}
jrange.push_back( nt );
cayrols marked this conversation as resolved.
Show resolved Hide resolved

#pragma omp taskgroup
for (int device = 0; device < A.num_devices(); ++device) {
#pragma omp task slate_omp_default_none \
shared( A ) priority( priority ) \
firstprivate(device, irange, jrange, queue_index, offdiag_value, diag_value)
#pragma omp task priority( priority ) \
cayrols marked this conversation as resolved.
Show resolved Hide resolved
shared( A, irange, jrange ) \
firstprivate( device, queue_index, offdiag_value, diag_value )
{
// temporarily, convert both into same layout
// todo: this is in-efficient, because both matrices may have same layout already
// and possibly wrong, because an input matrix is being altered
// Get local tiles for writing.
// convert to column major layout to simplify lda's
// todo: this is in-efficient because the diagonal is independant of layout
// todo: best, handle directly through the CUDA kernels
auto layout = Layout::ColMajor;
auto layout = LayoutConvert( Layout::ColMajor );
std::set<ij_tuple> A_tiles_set;

for (int64_t i = 0; i < A.mt(); ++i) {
for (int64_t j = 0; j < A.nt(); ++j) {
if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) {
A_tiles_set.insert({i, j});
if (A.tileIsLocal( i, j ) && device == A.tileDevice( i, j )) {
A_tiles_set.insert( { i, j } );
}
}
}
A.tileGetForWriting(A_tiles_set, device, LayoutConvert(layout));
A.tileGetForWriting( A_tiles_set, device, layout );

scalar_t** a_array_host = A.array_host( device, queue_index );

scalar_t** a_array_host = A.array_host(device, queue_index);
// If offdiag == diag value, lump diag tiles with offdiag tiles
// in one batch.
bool diag_same = offdiag_value == diag_value;

// Build batch groups.
int64_t batch_count = 0;
int64_t mb[8], nb[8], lda[8], group_count[8];
for (int q = 0; q < 4; ++q) {
group_count[q] = 0;
lda[q] = 0;
mb[q] = A.tileMb(irange[q][0]);
nb[q] = A.tileNb(jrange[q][0]);
for (int64_t i = irange[q][0]; i < irange[q][1]; ++i) {
for (int64_t j = jrange[q][0]; j < jrange[q][1]; ++j) {
if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) {
if (i != j) {
a_array_host[batch_count] = A(i, j, device).data();
lda[q] = A(i, j, device).stride();
++group_count[q];
++batch_count;
}
struct Params {
int64_t count, mb, nb, lda;
scalar_t diag_value;
};
std::vector<Params> group_params;
for (size_t jj = 0; jj < jrange.size() - 1; ++jj) {
for (size_t ii = 0; ii < irange.size() - 1; ++ii) {
Params group = { 0, -1, -1, -1, offdiag_value };
for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) {
for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) {
if ((diag_same || i != j)
&& A.tileIsLocal( i, j )
&& device == A.tileDevice( i, j ))
{
auto Aij = A( i, j, device );
a_array_host[ batch_count ] = Aij.data();
if (group.count == 0) {
group.mb = Aij.mb();
group.nb = Aij.nb();
group.lda = Aij.stride();
}
else {
assert( group.mb == Aij.mb() );
assert( group.nb == Aij.nb() );
assert( group.lda == Aij.stride() );
}
Comment on lines +173 to 177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this piece of code needed? The way I understand it is more we do not trust enough the irange and jrage, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks on mb and nb are probably unnecessary, yea. I'll look at removing them. The check on lda could come into play if users are providing memory themselves and have inconsistent strides. It's maybe not a case that will actually arise, but assertions are nicer to debug if it does.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're asserts to ensure the program logic is correct. If compiled with -DNDEBUG, they disappear. I would leave them, though maybe they don't need to go into every function that gets regions.

++group.count;
++batch_count;
}
}} // for j, i
if (group.count > 0) {
group_params.push_back( group );
}
}
for (int q = 4; q < 8; ++q) {
group_count[q] = 0;
lda[q] = 0;
mb[q] = A.tileMb(irange[q-4][0]);
nb[q] = A.tileNb(jrange[q-4][0]);
for (int64_t i = irange[q-4][0]; i < irange[q-4][1]; ++i) {
for (int64_t j = jrange[q-4][0]; j < jrange[q-4][1]; ++j) {
if (A.tileIsLocal(i, j) && device == A.tileDevice(i, j)) {
if (i == j) {
a_array_host[batch_count] = A(i, j, device).data();
lda[q] = A(i, j, device).stride();
++group_count[q];
++batch_count;
}} // for jj, ii

// Build batch groups for diagonal tiles,
// when offdiag_value != diag_value.
if (! diag_same) {
for (size_t jj = 0; jj < jrange.size() - 1; ++jj) {
for (size_t ii = 0; ii < irange.size() - 1; ++ii) {
Params group = { 0, -1, -1, -1, diag_value };
// Diagonal tiles only in the intersection of irange and jrange
int64_t ijstart = std::max(irange[ ii ], jrange[ jj ]);
int64_t ijend = std::min(irange[ ii+1 ], jrange[ jj+1 ]);
for (int64_t ij = ijstart; ij < ijend; ++ij) {
if (A.tileIsLocal( ij, ij )
&& device == A.tileDevice( ij, ij ))
{
auto Aij = A( ij, ij, device );
a_array_host[ batch_count ] = Aij.data();
if (group.count == 0) {
group.mb = Aij.mb();
group.nb = Aij.nb();
group.lda = Aij.stride();
}
else {
assert( group.mb == Aij.mb() );
assert( group.nb == Aij.nb() );
assert( group.lda == Aij.stride() );
}
++group.count;
++batch_count;
}
} // for ij
if (group.count > 0) {
group_params.push_back( group );
}
}
}} // for jj, ii
}

scalar_t** a_array_dev = A.array_device(device, queue_index);

blas::Queue* queue = A.compute_queue(device, queue_index);
blas::Queue* queue = A.compute_queue( device, queue_index );

scalar_t** a_array_dev = A.array_device( device, queue_index );
blas::device_memcpy<scalar_t*>(
a_array_dev, a_array_host, batch_count,
blas::MemcpyKind::HostToDevice, *queue);

for (int q = 0; q < 4; ++q) {
if (group_count[q] > 0) {
device::batch::geset(mb[q], nb[q],
offdiag_value, offdiag_value, a_array_dev, lda[q],
group_count[q], *queue);
a_array_dev += group_count[q];
}
}
for (int q = 4; q < 8; ++q) {
if (group_count[q] > 0) {
device::batch::geset(mb[q], nb[q],
offdiag_value, diag_value, a_array_dev, lda[q],
group_count[q], *queue);
a_array_dev += group_count[q];
}
for (size_t g = 0; g < group_params.size(); ++g) {
int64_t group_count = group_params[ g ].count;
device::batch::geset(
group_params[ g ].mb,
group_params[ g ].nb,
offdiag_value, group_params[ g ].diag_value,
a_array_dev, group_params[ g ].lda,
group_count, *queue );
a_array_dev += group_count;
}

queue->sync();
}
}
} // end task
} // end for dev
}

//------------------------------------------------------------------------------
Expand Down
Loading