diff --git a/src/internal/internal_trsmA.cc b/src/internal/internal_trsmA.cc index 8f4918bec..a00490ff4 100644 --- a/src/internal/internal_trsmA.cc +++ b/src/internal/internal_trsmA.cc @@ -225,78 +225,87 @@ void trsmA(internal::TargetType, B.tileGetForWriting( B_tiles_set, device, LayoutConvert( layout ) ); - // interior col or row - std::vector a_array0; - std::vector b_array0; - a_array0.reserve( batch_size ); - b_array0.reserve( batch_size ); - - // bottom-right tile - // todo: replace batch trsm with plain trsm - std::vector a_array1; - std::vector b_array1; - - int64_t lda0 = 0; - int64_t ldb0 = 0; - int64_t lda1 = 0; - int64_t ldb1 = 0; - - int64_t mb0 = B.tileMb(0); - int64_t nb0 = B.tileNb(0); - int64_t mb1 = B.tileMb(B.mt()-1); - int64_t nb1 = B.tileNb(B.nt()-1); - - auto A00d = A( 0, 0, device ); - auto dAdata = A00d.data(); - lda1 = lda0 = A00d.stride(); + scalar_t** a_array_host = A.array_host(device, queue_index); + scalar_t** b_array_host = a_array_host + batch_size; + // Varient of device_regions_build to handle trsmA + using Params = device_regions_params; + + int64_t batch_count = 0; + std::vector group_params; if (side == Side::Right) { - // TODO loop over B_tiles_set instead of looking for again. - for (int64_t i = 0; i < B.mt()-1; ++i) { - if (B.tileExists( i, 0, device )) - { - auto Bi0d = B( i, 0, device ); - a_array0.push_back( dAdata ); - b_array0.push_back( Bi0d.data() ); - ldb0 = Bi0d.stride(); - } - } - { - int64_t i = B.mt()-1; - if (B.tileExists( i, 0, device )) - { - auto Bi0d = B( i, 0, device ); - a_array1.push_back( dAdata ); - b_array1.push_back( Bi0d.data() ); - ldb1 = Bi0d.stride(); + // Find ranges of matching mb's and ranges of matching nb's. + auto irange = device_regions_range( true, B ); + + // loop over regions + for (size_t ii = 0; ii < irange.size() - 1; ++ii) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = B.tileMb( irange[ ii ] ); + group.nb = B.tileNb( 0 ); + for (int64_t i = irange[ ii ]; i < irange[ ii+1 ]; ++i) { + if (B.tileExists( i, 0, device )) { + + // Add tiles to current group + auto Aij = A( 0, 0, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( i, 0, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + } + else { + assert( group.ld[0] == Aij.stride() ); + assert( group.ld[1] == Bij.stride() ); + } + ++group.count; + ++batch_count; + } + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } + } // for ii } else { - for (int64_t j = 0; j < B.nt()-1; ++j) { - if (B.tileExists( 0, j, device )) - { - auto B0jd = B( 0, j, device ); - a_array0.push_back( dAdata ); - b_array0.push_back( B0jd.data() ); - ldb0 = B0jd.stride(); + // Find ranges of matching mb's and ranges of matching nb's. + auto jrange = device_regions_range( false, B ); + + // loop over regions + for (size_t jj = 0; jj < jrange.size() - 1; ++jj) { + // Loop over the tiles in this region, + // save any that should be computed on this process & device + Params group; + group.mb = B.tileMb( 0 ); + group.nb = B.tileNb( jrange[ jj ] ); + for (int64_t j = jrange[ jj ]; j < jrange[ jj+1 ]; ++j) { + if (B.tileExists( 0, j, device )) { + + // Add tiles to current group + auto Aij = A( 0, 0, device ); + a_array_host[ batch_count ] = Aij.data(); + auto Bij = B( 0, j, device ); + b_array_host[ batch_count ] = Bij.data(); + if (group.count == 0) { + group.ld[0] = Aij.stride(); + group.ld[1] = Bij.stride(); + } + else { + assert( group.ld[0] == Aij.stride() ); + assert( group.ld[1] == Bij.stride() ); + } + ++group.count; + ++batch_count; + } + } // for i + // If any tiles in the region should be computed here, save the group + if (group.count > 0) { + group_params.push_back( group ); } - } - { - int64_t j = B.nt()-1; - if (B.tileExists( 0, j, device )) - { - auto B0jd = B( 0, j, device ); - a_array1.push_back( dAdata ); - b_array1.push_back( B0jd.data() ); - ldb1 = B0jd.stride(); - } - } - } - - if (B.op() != Op::NoTrans) { - swap( mb0, nb0 ); - swap( mb1, nb1 ); + } // for ii } { @@ -311,35 +320,35 @@ void trsmA(internal::TargetType, blas::Queue* queue = A.compute_queue( device, queue_index ); assert( queue != nullptr ); + queue->sync(); - if (a_array0.size() > 0) { - std::vector m( 1, mb0 ); - std::vector n( 1, nb0 ); - std::vector lda( 1, lda0 ); - std::vector ldb( 1, ldb0 ); + for (size_t g = 0; g < group_params.size(); ++g) { - blas::batch::trsm( - layout, side_, uplo_, opA_, diag_, - m, n, - alpha_, a_array0, lda, - b_array0, ldb, - a_array0.size(), info, *queue); - } + int64_t group_count = group_params[ g ].count; - if (a_array1.size() > 0) { - std::vector m(1, mb1); - std::vector n(1, nb1); - std::vector lda(1, lda1); - std::vector ldb(1, ldb1); + std::vector m(1, group_params[ g ].mb); + std::vector n(1, group_params[ g ].nb); + std::vector ldda(1, group_params[ g ].ld[0]); + std::vector lddb(1, group_params[ g ].ld[1]); + + std::vector a_array(a_array_host, a_array_host+group_count); + std::vector b_array(b_array_host, b_array_host+group_count); + + if (B.op() != Op::NoTrans) { + swap(m, n); + } blas::batch::trsm( layout, side_, uplo_, opA_, diag_, m, n, - alpha_, a_array1, lda, - b_array1, ldb, - a_array1.size(), info, *queue); - } + alpha_, a_array, ldda, + b_array, lddb, + group_count, info, *queue); + queue->sync(); + a_array_host += group_count; + b_array_host += group_count; + } queue->sync(); }