Skip to content

Commit

Permalink
Add device regions to trsmA
Browse files Browse the repository at this point in the history
  • Loading branch information
neil-lindquist committed Nov 2, 2023
1 parent 3bedce1 commit 63ccb95
Showing 1 changed file with 96 additions and 87 deletions.
183 changes: 96 additions & 87 deletions src/internal/internal_trsmA.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,78 +225,87 @@ void trsmA(internal::TargetType<Target::Devices>,
B.tileGetForWriting(
B_tiles_set, device, LayoutConvert( layout ) );

// interior col or row
std::vector<scalar_t*> a_array0;
std::vector<scalar_t*> b_array0;
a_array0.reserve( batch_size );
b_array0.reserve( batch_size );

// bottom-right tile
// todo: replace batch trsm with plain trsm
std::vector<scalar_t*> a_array1;
std::vector<scalar_t*> b_array1;

int64_t lda0 = 0;
int64_t ldb0 = 0;
int64_t lda1 = 0;
int64_t ldb1 = 0;

int64_t mb0 = B.tileMb(0);
int64_t nb0 = B.tileNb(0);
int64_t mb1 = B.tileMb(B.mt()-1);
int64_t nb1 = B.tileNb(B.nt()-1);

auto A00d = A( 0, 0, device );
auto dAdata = A00d.data();
lda1 = lda0 = A00d.stride();
scalar_t** a_array_host = A.array_host(device, queue_index);
scalar_t** b_array_host = a_array_host + batch_size;

// Varient of device_regions_build to handle trsmA
using Params = device_regions_params<false, 2>;

int64_t batch_count = 0;
std::vector<Params> group_params;
if (side == Side::Right) {
// TODO loop over B_tiles_set instead of looking for again.
for (int64_t i = 0; i < B.mt()-1; ++i) {
if (B.tileExists( i, 0, device ))
{
auto Bi0d = B( i, 0, device );
a_array0.push_back( dAdata );
b_array0.push_back( Bi0d.data() );
ldb0 = Bi0d.stride();
}
}
{
int64_t i = B.mt()-1;
if (B.tileExists( i, 0, device ))
{
auto Bi0d = B( i, 0, device );
a_array1.push_back( dAdata );
b_array1.push_back( Bi0d.data() );
ldb1 = Bi0d.stride();
// Find ranges of matching mb's and ranges of matching nb's.
auto irange = device_regions_range( true, B );

// loop over regions
for (size_t ii = 0; ii < irange.size() - 1; ++ii) {
// Loop over the tiles in this region,
// save any that should be computed on this process & device
Params group;
group.mb = B.tileMb( irange[ ii ] );
group.nb = B.tileNb( 0 );
for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) {
if (B.tileExists( i, 0, device )) {

// Add tiles to current group
auto Aij = A( 0, 0, device );
a_array_host[ batch_count ] = Aij.data();
auto Bij = B( i, 0, device );
b_array_host[ batch_count ] = Bij.data();
if (group.count == 0) {
group.ld[0] = Aij.stride();
group.ld[1] = Bij.stride();
}
else {
assert( group.ld[0] == Aij.stride() );
assert( group.ld[1] == Bij.stride() );
}
++group.count;
++batch_count;
}
} // for i
// If any tiles in the region should be computed here, save the group
if (group.count > 0) {
group_params.push_back( group );
}
}
} // for ii
}
else {
for (int64_t j = 0; j < B.nt()-1; ++j) {
if (B.tileExists( 0, j, device ))
{
auto B0jd = B( 0, j, device );
a_array0.push_back( dAdata );
b_array0.push_back( B0jd.data() );
ldb0 = B0jd.stride();
// Find ranges of matching mb's and ranges of matching nb's.
auto jrange = device_regions_range( false, B );

// loop over regions
for (size_t jj = 0; jj < jrange.size() - 1; ++jj) {
// Loop over the tiles in this region,
// save any that should be computed on this process & device
Params group;
group.mb = B.tileMb( 0 );
group.nb = B.tileNb( jrange[ jj ] );
for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) {
if (B.tileExists( 0, j, device )) {

// Add tiles to current group
auto Aij = A( 0, 0, device );
a_array_host[ batch_count ] = Aij.data();
auto Bij = B( 0, j, device );
b_array_host[ batch_count ] = Bij.data();
if (group.count == 0) {
group.ld[0] = Aij.stride();
group.ld[1] = Bij.stride();
}
else {
assert( group.ld[0] == Aij.stride() );
assert( group.ld[1] == Bij.stride() );
}
++group.count;
++batch_count;
}
} // for i
// If any tiles in the region should be computed here, save the group
if (group.count > 0) {
group_params.push_back( group );
}
}
{
int64_t j = B.nt()-1;
if (B.tileExists( 0, j, device ))
{
auto B0jd = B( 0, j, device );
a_array1.push_back( dAdata );
b_array1.push_back( B0jd.data() );
ldb1 = B0jd.stride();
}
}
}

if (B.op() != Op::NoTrans) {
swap( mb0, nb0 );
swap( mb1, nb1 );
} // for ii
}

{
Expand All @@ -311,35 +320,35 @@ void trsmA(internal::TargetType<Target::Devices>,

blas::Queue* queue = A.compute_queue( device, queue_index );
assert( queue != nullptr );
queue->sync();

if (a_array0.size() > 0) {
std::vector<int64_t> m( 1, mb0 );
std::vector<int64_t> n( 1, nb0 );
std::vector<int64_t> lda( 1, lda0 );
std::vector<int64_t> ldb( 1, ldb0 );
for (size_t g = 0; g < group_params.size(); ++g) {

blas::batch::trsm(
layout, side_, uplo_, opA_, diag_,
m, n,
alpha_, a_array0, lda,
b_array0, ldb,
a_array0.size(), info, *queue);
}
int64_t group_count = group_params[ g ].count;

if (a_array1.size() > 0) {
std::vector<int64_t> m(1, mb1);
std::vector<int64_t> n(1, nb1);
std::vector<int64_t> lda(1, lda1);
std::vector<int64_t> ldb(1, ldb1);
std::vector<int64_t> m(1, group_params[ g ].mb);
std::vector<int64_t> n(1, group_params[ g ].nb);
std::vector<int64_t> ldda(1, group_params[ g ].ld[0]);
std::vector<int64_t> lddb(1, group_params[ g ].ld[1]);

std::vector<scalar_t*> a_array(a_array_host, a_array_host+group_count);
std::vector<scalar_t*> b_array(b_array_host, b_array_host+group_count);

if (B.op() != Op::NoTrans) {
swap(m, n);
}

blas::batch::trsm(
layout, side_, uplo_, opA_, diag_,
m, n,
alpha_, a_array1, lda,
b_array1, ldb,
a_array1.size(), info, *queue);
}
alpha_, a_array, ldda,
b_array, lddb,
group_count, info, *queue);
queue->sync();

a_array_host += group_count;
b_array_host += group_count;
}
queue->sync();
}

Expand Down

0 comments on commit 63ccb95

Please sign in to comment.