Skip to content

Commit

Permalink
Merge pull request #126 from Treece-Burgess/timers
Browse files Browse the repository at this point in the history
Add additional timers to gesv,  posv,  heev, and svd
  • Loading branch information
mgates3 authored Oct 17, 2023
2 parents 1686089 + 98370cf commit 3e89d68
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 12 deletions.
8 changes: 8 additions & 0 deletions src/gesv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,24 @@ void gesv(Matrix<scalar_t>& A, Pivots& pivots,
Matrix<scalar_t>& B,
Options const& opts)
{
Timer t_gesv;

slate_assert(A.mt() == A.nt()); // square
slate_assert(B.mt() == A.mt());

// factorization
Timer t_getrf;
getrf(A, pivots, opts);
timers[ "gesv::getrf" ] = t_getrf.stop();

// solve
Timer t_getrs;
getrs(A, pivots, B, opts);
timers[ "gesv::getrs" ] = t_getrs.stop();

// todo: return value for errors?

timers[ "gesv" ] = t_gesv.stop();
}

//------------------------------------------------------------------------------
Expand Down
15 changes: 15 additions & 0 deletions src/heev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ void heev(
Matrix<scalar_t>& Z,
Options const& opts)
{
Timer t_heev;

using real_t = blas::real_type<scalar_t>;
using std::real;

Expand Down Expand Up @@ -101,7 +103,9 @@ void heev(

// 1. Reduce to band form.
TriangularFactors<scalar_t> T;
Timer t_he2hb;
he2hb(A, T, opts);
timers[ "heev::he2hb" ] = t_he2hb.stop();

// Copy band.
// Currently, gathers band matrix to rank 0.
Expand All @@ -125,7 +129,9 @@ void heev(
V.insertLocalTiles();

// 2. Reduce band to real symmetric tri-diagonal.
Timer t_hb2st;
hb2st(Aband, V, opts);
timers[ "heev::hb2st" ] = t_hb2st.stop();

// Copy diagonal and super-diagonal to vectors.
internal::copyhb2st( Aband, Lambda, E );
Expand All @@ -136,6 +142,7 @@ void heev(
// Bcast the Lambda and E vectors (diagonal and sup/super-diagonal).
MPI_Bcast( &Lambda[0], n, mpi_real_type, 0, A.mpiComm() );
MPI_Bcast( &E[0], n-1, mpi_real_type, 0, A.mpiComm() );
Timer t_stev;
if (method == MethodEig::QR) {
// QR iteration to get eigenvalues and eigenvectors of tridiagonal.
steqr2( Job::Vec, Lambda, E, Z );
Expand All @@ -154,6 +161,7 @@ void heev(
copy( Zreal, Z );
}
}
timers[ "heev::stev" ] = t_stev.stop();

// Find the total number of processors.
int mpi_size;
Expand All @@ -165,18 +173,24 @@ void heev(
redistribute(Z, Z1d, opts);

// Back-transform: Z = Q1 * Q2 * Z.
Timer t_unmtr_hb2st;
unmtr_hb2st( Side::Left, Op::NoTrans, V, Z1d, opts );
timers[ "heev::unmtr_hb2st" ] = t_unmtr_hb2st.stop();

redistribute(Z1d, Z, opts);
Timer t_unmtr_he2hb;
unmtr_he2hb( Side::Left, Op::NoTrans, A, T, Z, opts );
timers[ "heev::unmtr_he2hb" ] = t_unmtr_he2hb.stop();
}
else {
Timer t_stev;
if (A.mpiRank() == 0) {
// QR iteration to get eigenvalues.
sterf<real_t>( Lambda, E, opts );
}
// Bcast eigenvalues.
MPI_Bcast( &Lambda[0], n, mpi_real_type, 0, A.mpiComm() );
timers[ "heev::stev" ] = t_stev.stop();
}

// If matrix was scaled, then rescale eigenvalues appropriately.
Expand All @@ -185,6 +199,7 @@ void heev(
// todo: deal with not all eigenvalues converging, cf. LAPACK.
blas::scal( n, Anorm/alpha, &Lambda[0], 1 );
}
timers[ "heev" ] = t_heev.stop();
}

//------------------------------------------------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions src/posv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,23 @@ void posv(HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts)
{
Timer t_posv;

slate_assert(B.mt() == A.mt());

// factorization
Timer t_potrf;
potrf(A, opts);
timers[ "posv::potrf" ] = t_potrf.stop();

// solve
Timer t_potrs;
potrs(A, B, opts);
timers[ "posv::potrs" ] = t_potrs.stop();

// todo: return value for errors?

timers[ "posv" ] = t_posv.stop();
}

//------------------------------------------------------------------------------
Expand Down
30 changes: 30 additions & 0 deletions src/svd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void svd(
Matrix<scalar_t>& VT,
Options const& opts)
{
Timer t_svd;

using real_t = blas::real_type<scalar_t>;
using std::swap;
using blas::max;
Expand Down Expand Up @@ -141,8 +143,12 @@ void svd(
bool lq_path = n > m;
Matrix<scalar_t> Ahat, Uhat, VThat;
TriangularFactors<scalar_t> TQ;
timers[ "svd::geqrf" ] = 0;
timers[ "svd::gelqf" ] = 0;
if (qr_path) {
Timer t_geqrf;
geqrf( A, TQ, opts );
timers[ "svd::geqrf" ] = t_geqrf.stop();

// Upper triangular part of A (R).
auto R_ = A.slice(0, n-1, 0, n-1);
Expand All @@ -166,7 +172,9 @@ void svd(
}
}
else if (lq_path) {
Timer t_gelqf;
gelqf( A, TQ, opts );
timers[ "svd::gelqf" ] = t_gelqf.stop();
swap(m, n);

// Lower triangular part of A (R).
Expand Down Expand Up @@ -204,7 +212,9 @@ void svd(

// 1. Reduce to band form.
TriangularFactors<scalar_t> TU, TV;
Timer t_ge2tb;
ge2tb(Ahat, TU, TV, opts);
timers[ "svd::ge2tb" ] = t_ge2tb.stop();

// Currently, tb2bd and bdsqr run on a single node, gathers band matrix to rank 0.
TriangularBandMatrix<scalar_t> Aband( Uplo::Upper, Diag::NonUnit,
Expand Down Expand Up @@ -236,7 +246,9 @@ void svd(
U2.insertLocalTiles();

// Reduce band to bi-diagonal.
Timer t_tb2bd;
tb2bd( Aband, U2, VT2, opts );
timers[ "svd::tb2bd" ] = t_tb2bd.stop();

// Copy diagonal and super-diagonal to vectors.
internal::copytb2bd(Aband, Sigma, E);
Expand Down Expand Up @@ -283,11 +295,13 @@ void svd(
// QR iteration
//bdsqr<scalar_t>(jobu, jobvt, Sigma, E, Uhat, VThat, opts);
// Call the SVD
Timer t_bdsvd;
lapack::bdsqr(Uplo::Upper, min_mn, ncvt, nru, 0,
&Sigma[0], &E[0],
&VT1D_row_cyclic_data[0], ldvt,
&U1D_row_cyclic_data[0], ldu,
dummy, 1);
timers[ "svd::bdsvd" ] = t_bdsvd.stop();

// If matrix was scaled, then rescale singular values appropriately.
if (is_scale) {
Expand Down Expand Up @@ -316,18 +330,24 @@ void svd(
redistribute(U1d_row_cyclic, U1d, opts);

// First, U = U2 * U ===> U1d = U2 * U1d
Timer t_unmbr_tb2bd_U;
unmtr_hb2st( Side::Left, Op::NoTrans, U2, U1d, opts );
timers[ "svd::unmbr_tb2bd_U" ] = t_unmbr_tb2bd_U.stop();

// Redistribute U1d into U
redistribute(U1d, Uhat, opts);

// Second, U = U1 * U ===> U = Ahat * U
Timer t_unmbr_ge2tb_U;
unmbr_ge2tb( Side::Left, Op::NoTrans, Ahat, TU, Uhat, opts );
timers[ "svd::unmbr_ge2tb_U" ] = t_unmbr_ge2tb_U.stop();
Timer t_unmqr;
if (qr_path) {
// When initial QR was used.
// U = Q*U;
unmqr( Side::Left, slate::Op::NoTrans, A, TQ, U, opts );
}
timers[ "svd::unmqr" ] = t_unmqr.stop();
}

// Back-transform: VT = VT * VT2 * VT1.
Expand All @@ -351,21 +371,28 @@ void svd(
redistribute(V, V1d, opts);

// First: V = VT2 * V ===> V1d = VT2 * V1d
Timer t_unmbr_tb2bd_V;
unmtr_hb2st( Side::Left, Op::NoTrans, VT2, V1d, opts );
timers[ "svd::unmbr_tb2bd_V" ] = t_unmbr_tb2bd_V.stop();

// Redistribute V1d into V
auto V1dT = conj_transpose(V1d);
redistribute(V1dT, VThat, opts);

// Second: VT = VT1 * VT ===> VT = Ahat * VT
Timer t_unmbr_ge2tb_V;
unmbr_ge2tb( Side::Right, Op::NoTrans, Ahat, TV, VThat, opts );
timers[ "svd::unmbr_ge2tb_V" ] = t_unmbr_ge2tb_V.stop();
Timer t_unmlq;
if (lq_path) {
// VT = VT*Q;
unmlq( Side::Right, slate::Op::NoTrans, A, TQ, VT, opts );
}
timers[ "svd::unmlq" ] = t_unmlq.stop();
}
}
else {
Timer t_bdsvd;
if (A.mpiRank() == 0) {
// QR iteration
//bdsqr<scalar_t>(jobu, jobvt, Sigma, E, U, VT, opts);
Expand All @@ -386,7 +413,10 @@ void svd(

// Bcast singular values.
MPI_Bcast( &Sigma[0], min_mn, mpi_real_type, 0, A.mpiComm() );
timers[ "svd::bdsvd" ] = t_bdsvd.stop();
}

timers[ "svd" ] = t_svd.stop();
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ Params::Params():
time8 ("time (s)", 9, 3, ParamType::Output, no_data_flag, 0, 0, "extra timer"),
time9 ("time (s)", 9, 3, ParamType::Output, no_data_flag, 0, 0, "extra timer"),
time10 ("time (s)", 9, 3, ParamType::Output, no_data_flag, 0, 0, "extra timer"),
time11 ("time (s)", 9, 3, ParamType::Output, no_data_flag, 0, 0, "extra timer"),
time12 ("time (s)", 9, 3, ParamType::Output, no_data_flag, 0, 0, "extra timer"),
iters ("iters", 5, ParamType::Output, 0, 0, 0, "iterations to solution"),

ref_time ("ref time (s)", 12, 3, ParamType::Output, no_data_flag, 0, 0, "reference time to solution"),
Expand Down
2 changes: 2 additions & 0 deletions test/test.hh
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ public:
testsweeper::ParamDouble time8;
testsweeper::ParamDouble time9;
testsweeper::ParamDouble time10;
testsweeper::ParamDouble time11;
testsweeper::ParamDouble time12;
testsweeper::ParamInt iters;

testsweeper::ParamDouble ref_time;
Expand Down
29 changes: 24 additions & 5 deletions test/test_gesv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,18 @@ void test_gesv_work(Params& params, bool run)
bool trace = params.trace() == 'y';
bool nonuniform_nb = params.nonuniform_nb() == 'y';
int verbose = params.verbose();
int timer_level = params.timer_level();
SLATE_UNUSED(verbose);
slate::Origin origin = params.origin();
slate::Target target = params.target();
slate::GridOrder grid_order = params.grid_order();
params.matrix.mark();
params.matrixB.mark();

// Currently only gesv* supports timer_level >= 2.
if (params.routine != "gesv")
timer_level = 1;

// NoPiv and CALU ignore threshold.
double pivot_threshold = params.pivot_threshold();

Expand All @@ -85,15 +90,24 @@ void test_gesv_work(Params& params, bool run)
params.gflops();
params.ref_time();
params.ref_gflops();
params.time2();
params.time2.name( "trs time (s)" );
params.time2.width( 12 );
params.gflops2();
params.gflops2.name( "trs gflop/s" );

bool do_getrs = params.routine == "getrs"
|| (check && params.routine == "getrf");

if (do_getrs) {
params.time2();
params.time2.name( "trs time (s)" );
params.time2.width( 12 );
params.gflops2();
params.gflops2.name( "trs gflop/s" );
}
if (timer_level >= 2) {
params.time2();
params.time3();
params.time2.name( "getrf (s)" );
params.time3.name( "getrs (s)" );
}

if (params.routine == "gesv_mixed" || params.routine == "gesv_mixed_gmres") {
params.iters();
}
Expand Down Expand Up @@ -292,6 +306,11 @@ void test_gesv_work(Params& params, bool run)
params.time() = time;
params.gflops() = gflop / time;

if (timer_level >= 2) {
params.time2() = slate::timers[ "gesv::getrf" ];
params.time3() = slate::timers[ "gesv::getrs" ];
}

//==================================================
// Run SLATE test: getrs
// getrs: Solve AX = B after factoring A above.
Expand Down
20 changes: 20 additions & 0 deletions test/test_heev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ void test_heev_work(Params& params, bool run)
bool check = params.check() == 'y' && ! ref_only;
bool trace = params.trace() == 'y';
int verbose = params.verbose();
int timer_level = params.timer_level();
slate::Origin origin = params.origin();
slate::Target target = params.target();
slate::MethodEig method_eig = params.method_eig();
Expand All @@ -59,6 +60,18 @@ void test_heev_work(Params& params, bool run)
params.error.name( "value err" );
params.error2.name( "back err" );
params.ortho.name( "Z orth." );
if (timer_level >= 2) {
params.time2();
params.time3();
params.time4();
params.time5();
params.time6();
params.time2.name( "he2hb (s)" );
params.time3.name( "hb2st (s)" );
params.time4.name( "stev (s)" );
params.time5.name( "unmtr_hb2st (s)" );
params.time6.name( "unmtr_he2hb (s)" );
}

if (! run)
return;
Expand Down Expand Up @@ -175,6 +188,13 @@ void test_heev_work(Params& params, bool run)

// compute and save timing/performance
params.time() = time;
if (timer_level >= 2) {
params.time2() = slate::timers[ "heev::he2hb" ];
params.time3() = slate::timers[ "heev::hb2st" ];
params.time4() = slate::timers[ "heev::stev" ];
params.time5() = slate::timers[ "heev::unmtr_hb2st" ];
params.time6() = slate::timers[ "heev::unmtr_he2hb" ];
}

if (check && jobz == slate::Job::Vec) {
//==================================================
Expand Down
Loading

0 comments on commit 3e89d68

Please sign in to comment.