Skip to content

Commit

Permalink
heev: document and check for square MPI grid and lower triangular matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Dec 28, 2024
1 parent f7fcc31 commit f2bd75b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/he2hb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void he2hb(
using real_t = blas::real_type<scalar_t>;
using blas::real;

assert( A.uplo() == Uplo::Lower ); // for now
slate_assert( A.uplo() == Uplo::Lower ); // for now

// Constants
const scalar_t zero = 0.0;
Expand Down
18 changes: 17 additions & 1 deletion src/heev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@ namespace slate {

//------------------------------------------------------------------------------
/// Distributed parallel Hermitian matrix eigen decomposition.
/// heev Computes all eigenvalues and, optionally, eigenvectors of a
/// Computes all eigenvalues and, optionally, eigenvectors of a
/// Hermitian matrix A. The matrix A is preliminary reduced to
/// tridiagonal form using a two-stage approach:
/// First stage: reduction to band tridiagonal form (see he2hb);
/// Second stage: reduction from band to tridiagonal form (see hb2st).
///
/// Note: currently requires a **square (p-by-p) MPI process grid**.
/// This is because it applies the same QR factorization on the
/// left (p block-rows) and the right (p block-cols), with a size p
/// reduction tree. We hope to eventually remove this restriction.
///
/// Note: currently requires a lower triangular storage Hermitian matrix.
///
//------------------------------------------------------------------------------
/// @tparam scalar_t
/// One of float, double, std::complex<float>, std::complex<double>.
Expand Down Expand Up @@ -80,6 +87,15 @@ void heev(
MethodEig method = get_option( opts, Option::MethodEig, MethodEig::DC );
Target target = get_option( opts, Option::Target, Target::HostTask );

// Currently he2hb requires lower triangular matrix.
slate_assert( A.uplo() == Uplo::Lower );

// Currently requires square process grid.
GridOrder grid_order;
int nprow, npcol, myrow, mycol;
A.gridinfo( &grid_order, &nprow, &npcol, &myrow, &mycol );
slate_assert( nprow == npcol );

// Scale matrix to allowable range, if necessary.
real_t Anorm = norm( Norm::Max, A );
real_t alpha = 1.0;
Expand Down
6 changes: 4 additions & 2 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ using testsweeper::ansi_bold;
using testsweeper::ansi_red;
using testsweeper::ansi_normal;

using testsweeper::no_check;
using testsweeper::skipped;

using blas::Layout, blas::Layout_help;
using blas::Side, blas::Side_help;
using blas::Uplo, blas::Uplo_help;
Expand Down Expand Up @@ -507,9 +510,8 @@ Params::Params():
ref_gbytes( "ref gbyte/s", 12, 3, PT_Out, no_data, 0, 0, "reference Gbyte/s rate" ),
ref_iters ( "ref iters", 5, PT_Out, 0, 0, 0, "reference iterations to solution" ),

// default -1 means "no check"
// name, w, type, default, min, max, help
okay ( "status", 6, PT_Out, -1, 0, 0, "success indicator" ),
okay ( "status", 6, PT_Out, no_check, 0, 0, "success indicator" ),
msg ( "", 1, PT_Out, "", "error message" )
{
// set header different than command line prefix
Expand Down
32 changes: 22 additions & 10 deletions test/test_heev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,30 @@ void test_heev_work(Params& params, bool run)
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
gridinfo(mpi_rank, p, q, &myrow, &mycol);

// Skip invalid or unimplemented options.
slate::HermitianMatrix<scalar_t> A( uplo, n, nb, p, q, MPI_COMM_WORLD );

// Vector Lambda (global output) has eigenvalues in descending order.
std::vector<real_t> Lambda( n );

// Test that invalid or unimplemented options throw exceptions.
bool invalid = false;
if (uplo == slate::Uplo::Upper) {
params.msg() = "skipping: Uplo::Upper isn't supported.";
return;
params.msg() = "Uplo::Upper isn't supported.";
invalid = true;
}
else if (p != q) {
params.msg() = "requires square process grid (p == q).";
invalid = true;
}
if (p != q) {
params.msg() = "skipping: requires square process grid (p == q).";
if (invalid) {
params.okay() = false; // fails unless caught below
try {
slate::eig_vals( A, Lambda, opts );
}
catch (slate::Exception const& ex) {
//params.msg() += std::string(" Caught: ") + ex.what();
params.okay() = testsweeper::skipped;
}
return;
}

Expand All @@ -106,9 +123,6 @@ void test_heev_work(Params& params, bool run)

std::vector<scalar_t> A_data;

// matrix Lambda (global output) gets eigenvalues in decending order
std::vector<real_t> Lambda(n);

// matrix Z (local output), Z(n,n), gets orthonormal eigenvectors
// corresponding to Lambda of the reference scalapack
int64_t mlocZ = num_local_rows_cols(n, nb, myrow, p);
Expand All @@ -117,11 +131,9 @@ void test_heev_work(Params& params, bool run)
std::vector<scalar_t> Z_data( lldZ * nlocZ );

// Initialize SLATE data structures
slate::HermitianMatrix<scalar_t> A;
if (origin != slate::Origin::ScaLAPACK) {
// SLATE allocates CPU or GPU tiles.
slate::Target origin_target = origin2target(origin);
A = slate::HermitianMatrix<scalar_t>(uplo, n, nb, p, q, MPI_COMM_WORLD);
A.insertLocalTiles(origin_target);
}
else {
Expand Down
2 changes: 1 addition & 1 deletion testsweeper

0 comments on commit f2bd75b

Please sign in to comment.