Skip to content

Commit

Permalink
TMP fix compile
Browse files Browse the repository at this point in the history
  • Loading branch information
mgates3 committed Oct 26, 2023
1 parent 1141fd9 commit e4542eb
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 57 deletions.
10 changes: 6 additions & 4 deletions src/gesv_mixed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,21 @@ int64_t gesv_mixed(
int& iter,
Options const& opts)
{
Target target = get_option( opts, Option::Target, Target::HostTask );
using real_hi = blas::real_type<scalar_hi>;

// Constants
// Assumes column major
const Layout layout = Layout::ColMajor;

bool converged = false;
using real_hi = blas::real_type<scalar_hi>;
const real_hi eps = std::numeric_limits<real_hi>::epsilon();
const scalar_hi one_hi = 1.0;

// Options
Target target = get_option( opts, Option::Target, Target::HostTask );
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
Expand Down
11 changes: 4 additions & 7 deletions src/gesv_mixed_gmres.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ int64_t gesv_mixed_gmres(

// Constants
const real_hi eps = std::numeric_limits<real_hi>::epsilon();
const int64_t itermax = 30;
const int64_t restart = blas::min( 30, itermax, A.tileMb( 0 )-1 );
const int64_t mpi_rank = A.mpiRank();
const scalar_hi zero = 0.0;
const scalar_hi one = 1.0;
Expand All @@ -130,17 +128,16 @@ int64_t gesv_mixed_gmres(

// Options
Target target = get_option( opts, Option::Target, Target::HostTask );

bool converged = false;
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );
const int64_t restart = std::min(
std::min( int64_t( 30 ), itermax ), A.tileMb( 0 )-1 );
int64_t restart = blas::min( 30, itermax, A.tileMb( 0 )-1 );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
slate_assert( A.tileMb( 0 ) >= restart );
assert( A.tileMb( 0 ) >= restart );

// TODO: implement block gmres
if (B.n() != 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/internal/internal_potrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int64_t potrf(
lapack::device_info_int* device_info)
{
return potrf( internal::TargetType<target>(), A, priority,
queue_index, device_info);
queue_index, device_info );
}

//------------------------------------------------------------------------------
Expand Down
36 changes: 16 additions & 20 deletions src/posv_mixed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,23 @@ int64_t posv_mixed(
int& iter,
Options const& opts)
{
// XXX This is only used for the memory management and may be inconsistent
// with the routines called in this routine.
Target target = get_option( opts, Option::Target, Target::HostTask );
using real_hi = blas::real_type<scalar_hi>;

// Constants
// Assumes column major
const Layout layout = Layout::ColMajor;

bool converged = false;
using real_hi = blas::real_type<scalar_hi>;
const real_hi eps = std::numeric_limits<real_hi>::epsilon();
const scalar_hi one_hi = 1.0;

// Options
// XXX This is only used for the memory management and may be inconsistent
// with the routines called in this routine.
Target target = get_option( opts, Option::Target, Target::HostTask );
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
Expand Down Expand Up @@ -178,7 +180,7 @@ int64_t posv_mixed(
copy( A, A_lo, opts );

// Compute the Cholesky factorization of A_lo.
int64_t info = potrf( A_lo, opts );
int64_t info = potrf( A_lo, opts );
if (info != 0) {
iter = -3;
}
Expand All @@ -202,7 +204,7 @@ int64_t posv_mixed(
colNorms( Norm::Max, X, colnorms_X.data(), opts );
colNorms( Norm::Max, R, colnorms_R.data(), opts );

if (internal::iterRefConverged<real_hi>( colnorms_R, colnorms_X, cte) ) {
if (internal::iterRefConverged<real_hi>( colnorms_R, colnorms_X, cte )) {
iter = 0;
converged = true;
}
Expand Down Expand Up @@ -248,22 +250,16 @@ int64_t posv_mixed(
iter = -itermax - 1;
}

<<<<<<< HEAD
if (use_fallback) {
// Fall back to double precision factor and solve.
// Compute the Cholesky factorization of A.
potrf( A, opts );
info = potrf( A, opts );

// Solve the system A * X = B.
=======
// Fall back to double precision factor and solve.
// Compute the Cholesky factorization of A.
info = potrf( A, opts );

// Solve the system A * X = B.
if (info == 0) {
>>>>>>> e47915db (chol: info check)
slate::copy( B, X, opts );
potrs( A, X, opts );
if (info == 0) {
slate::copy( B, X, opts );
potrs( A, X, opts );
}
}
}

Expand Down
31 changes: 9 additions & 22 deletions src/posv_mixed_gmres.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,6 @@ int64_t posv_mixed_gmres(

// Constants
const real_hi eps = std::numeric_limits<real_hi>::epsilon();
<<<<<<< HEAD
const int64_t mpi_rank = A.mpiRank();
// Assumes column major
const Layout layout = Layout::ColMajor;

Target target = get_option( opts, Option::Target, Target::HostTask );

bool converged = false;
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );
const int64_t restart = std::min(
std::min( int64_t( 30 ), itermax ), A.tileMb( 0 )-1 );
iter = 0;
=======
const int64_t itermax = 30;
const int64_t restart = blas::min( 30, itermax, A.tileMb( 0 )-1 );
const int64_t mpi_rank = A.mpiRank();
const scalar_hi zero = 0.0;
const scalar_hi one = 1.0;
Expand All @@ -144,12 +127,16 @@ int64_t posv_mixed_gmres(

// Options
Target target = get_option( opts, Option::Target, Target::HostTask );
>>>>>>> e47915db (chol: info check)
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );
int64_t restart = blas::min( 30, itermax, A.tileMb( 0 )-1 );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
assert( A.tileMb( 0 ) >= restart );

// TODO: implement block gmres
if (B.n() != 1) {
Expand Down Expand Up @@ -194,15 +181,15 @@ int64_t posv_mixed_gmres(
{
#pragma omp task default(shared)
{
A.tileGetAndHoldAllOnDevices(LayoutConvert(layout));
A.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) );
}
#pragma omp task default(shared)
{
B.tileGetAndHoldAllOnDevices(LayoutConvert(layout));
B.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) );
}
#pragma omp task default(shared)
{
X.tileGetAndHoldAllOnDevices(LayoutConvert(layout));
X.tileGetAndHoldAllOnDevices( LayoutConvert( layout ) );
}
}
}
Expand All @@ -211,7 +198,7 @@ int64_t posv_mixed_gmres(
real_hi Anorm = norm( Norm::Inf, A, opts );

// stopping criteria
real_hi cte = Anorm * eps * std::sqrt( A.n() );
real_hi cte = Anorm * tol;

// Compute the Cholesky factorization of A in single-precision.
slate::copy( A, A_lo, opts );
Expand Down
4 changes: 1 addition & 3 deletions test/test_posv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ void test_posv_work(Params& params, bool run)
else {
slate_error("Unknown routine!");
}

// compute and save timing/performance
time2 = barrier_get_wtime(MPI_COMM_WORLD) - time2;
// compute and save timing/performance
params.time2() = time2;
params.gflops2() = lapack::Gflop<scalar_t>::potrs(n, nrhs) / time2;
}
Expand Down Expand Up @@ -439,7 +438,6 @@ void test_posv_work(Params& params, bool run)
}
}
}

Cblacs_gridexit(ictxt);
//Cblacs_exit(1) does not handle re-entering
#else // not SLATE_HAVE_SCALAPACK
Expand Down

0 comments on commit e4542eb

Please sign in to comment.