Skip to content

Commit

Permalink
Merge pull request #133 from mgates3/chol-info
Browse files Browse the repository at this point in the history
chol: info check
  • Loading branch information
mgates3 authored Nov 30, 2023
2 parents ff9543f + 9b2b2df commit 0e9cac3
Show file tree
Hide file tree
Showing 13 changed files with 562 additions and 514 deletions.
20 changes: 10 additions & 10 deletions include/slate/simplified_api.hh
Original file line number Diff line number Diff line change
Expand Up @@ -390,56 +390,56 @@ void chol_solve(

// posv
template <typename scalar_t>
void chol_solve(
int64_t chol_solve(
HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options())
{
posv(A, B, opts);
return posv( A, B, opts );
}

// forward real-symmetric matrices to posv;
// disabled for complex
template <typename scalar_t>
void chol_solve(
int64_t chol_solve(
SymmetricMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options(),
enable_if_t< ! is_complex<scalar_t>::value >* = nullptr)
{
posv(A, B, opts);
return posv( A, B, opts );
}

//-----------------------------------------
// chol_factor()

// pbtrf
template <typename scalar_t>
void chol_factor(
int64_t chol_factor(
HermitianBandMatrix<scalar_t>& A,
Options const& opts = Options())
{
pbtrf(A, opts);
return pbtrf( A, opts );
}

// potrf
template <typename scalar_t>
void chol_factor(
int64_t chol_factor(
HermitianMatrix<scalar_t>& A,
Options const& opts = Options())
{
potrf(A, opts);
return potrf( A, opts );
}

// forward real-symmetric matrices to potrf;
// disabled for complex
template <typename scalar_t>
void chol_factor(
int64_t chol_factor(
SymmetricMatrix<scalar_t>& A,
Options const& opts = Options(),
enable_if_t< ! is_complex<scalar_t>::value >* = nullptr)
{
potrf(A, opts);
return potrf( A, opts );
}

//-----------------------------------------
Expand Down
32 changes: 16 additions & 16 deletions include/slate/slate.hh
Original file line number Diff line number Diff line change
Expand Up @@ -650,44 +650,44 @@ void getri(
//-----------------------------------------
// pbsv()
template <typename scalar_t>
void pbsv(
int64_t pbsv(
HermitianBandMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options());

//-----------------------------------------
// posv()
template <typename scalar_t>
void posv(
int64_t posv(
HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options());

// forward real-symmetric matrices to potrf;
// disabled for complex
template <typename scalar_t>
void posv(
int64_t posv(
SymmetricMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Options const& opts = Options(),
enable_if_t< ! is_complex<scalar_t>::value >* = nullptr)
{
HermitianMatrix<scalar_t> AH(A);
posv(AH, B, opts);
return posv( AH, B, opts );
}

//-----------------------------------------
// posv_mixed()
template <typename scalar_t>
void posv_mixed(
int64_t posv_mixed(
HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Matrix<scalar_t>& X,
int& iter,
Options const& opts = Options());

template <typename scalar_hi, typename scalar_lo>
void posv_mixed(
int64_t posv_mixed(
HermitianMatrix<scalar_hi>& A,
Matrix<scalar_hi>& B,
Matrix<scalar_hi>& X,
Expand All @@ -698,40 +698,40 @@ void posv_mixed(

template <typename scalar_t>
[[deprecated( "Use posv_mixed instead. Will be removed 2024-02." )]]
void posvMixed(
int64_t posvMixed(
HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Matrix<scalar_t>& X,
int& iter,
Options const& opts = Options())
{
posv_mixed( A, B, X, iter, opts );
return posv_mixed( A, B, X, iter, opts );
}

template <typename scalar_hi, typename scalar_lo>
[[deprecated( "Use posv_mixed instead. Will be removed 2024-02." )]]
void posvMixed(
int64_t posvMixed(
HermitianMatrix<scalar_hi>& A,
Matrix<scalar_hi>& B,
Matrix<scalar_hi>& X,
int& iter,
Options const& opts = Options())
{
posv_mixed( A, B, X, iter, opts );
return posv_mixed( A, B, X, iter, opts );
}

//-----------------------------------------
// posv_mixed_gmres()
template <typename scalar_t>
void posv_mixed_gmres(
int64_t posv_mixed_gmres(
HermitianMatrix<scalar_t>& A,
Matrix<scalar_t>& B,
Matrix<scalar_t>& X,
int& iter,
Options const& opts = Options());

template <typename scalar_hi, typename scalar_lo>
void posv_mixed_gmres(
int64_t posv_mixed_gmres(
HermitianMatrix<scalar_hi>& A,
Matrix<scalar_hi>& B,
Matrix<scalar_hi>& X,
Expand All @@ -743,27 +743,27 @@ void posv_mixed_gmres(
//-----------------------------------------
// pbtrf()
template <typename scalar_t>
void pbtrf(
int64_t pbtrf(
HermitianBandMatrix<scalar_t>& A,
Options const& opts = Options());

//-----------------------------------------
// potrf()
template <typename scalar_t>
void potrf(
int64_t potrf(
HermitianMatrix<scalar_t>& A,
Options const& opts = Options());

// forward real-symmetric matrices to potrf;
// disabled for complex
template <typename scalar_t>
void potrf(
int64_t potrf(
SymmetricMatrix<scalar_t>& A,
Options const& opts = Options(),
enable_if_t< ! is_complex<scalar_t>::value >* = nullptr)
{
HermitianMatrix<scalar_t> AH(A);
potrf(AH, opts);
return potrf( AH, opts );
}

//-----------------------------------------
Expand Down
13 changes: 7 additions & 6 deletions src/gesv_mixed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,23 @@ int64_t gesv_mixed(
int& iter,
Options const& opts)
{
Timer t_gesv_mixed;
using real_hi = blas::real_type<scalar_hi>;

Target target = get_option( opts, Option::Target, Target::HostTask );
Timer t_gesv_mixed;

// Constants
// Assumes column major
const Layout layout = Layout::ColMajor;

bool converged = false;
using real_hi = blas::real_type<scalar_hi>;
const real_hi eps = std::numeric_limits<real_hi>::epsilon();
const scalar_hi one_hi = 1.0;

// Options
Target target = get_option( opts, Option::Target, Target::HostTask );
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
Expand All @@ -146,7 +148,6 @@ int64_t gesv_mixed(
if (target == Target::Devices) {
#pragma omp parallel
#pragma omp master
#pragma omp taskgroup
{
#pragma omp task slate_omp_default_none \
shared( A ) firstprivate( layout )
Expand Down
33 changes: 16 additions & 17 deletions src/gesv_mixed_gmres.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,22 @@ int64_t gesv_mixed_gmres(
// Assumes column major
const Layout layout = Layout::ColMajor;

// Options
Target target = get_option( opts, Option::Target, Target::HostTask );

bool converged = false;
int64_t itermax = get_option<int64_t>( opts, Option::MaxIterations, 30 );
double tol = get_option<double>( opts, Option::Tolerance, eps*std::sqrt(A.m()) );
bool use_fallback = get_option<int64_t>( opts, Option::UseFallbackSolver, true );
const int64_t restart = std::min(
std::min( int64_t( 30 ), itermax ), A.tileMb( 0 )-1 );
int64_t restart = blas::min( 30, itermax, A.tileMb( 0 )-1 );

bool converged = false;
iter = 0;

assert( B.mt() == A.mt() );
slate_assert( A.tileMb( 0 ) >= restart );
assert( A.tileMb( 0 ) >= restart );

// TODO: implement block gmres
if (B.n() != 1) {
slate_not_implemented( "block-GMRES is not yet supported" );
slate_not_implemented( "block-GMRES for multiple RHS is not yet supported" );
}

// workspace
Expand All @@ -164,7 +164,7 @@ int64_t gesv_mixed_gmres(

// workspace vector for the orthogonalization process
auto z = X.template emptyLike<scalar_hi>();
z.insertLocalTiles(target);
z.insertLocalTiles( target );

// Hessenberg Matrix. Allocate as a single tile
slate::Matrix<scalar_hi> H(
Expand Down Expand Up @@ -236,8 +236,7 @@ int64_t gesv_mixed_gmres(
timers[ "gesv_mixed_gmres::gemm_hi" ] += t_gemm_hi.stop();
colNorms( Norm::Max, X, colnorms_X.data(), opts );
colNorms( Norm::Max, R, colnorms_R.data(), opts );
if (internal::iterRefConverged<real_hi>( colnorms_R, colnorms_X, cte ))
{
if (internal::iterRefConverged<real_hi>( colnorms_R, colnorms_X, cte )) {
iter = iiter;
converged = true;
break;
Expand Down Expand Up @@ -272,7 +271,7 @@ int64_t gesv_mixed_gmres(
// excessive restarting or delayed completion.
int j = 0;
for (; j < restart && iiter < itermax
&& !internal::iterRefConverged(
&& ! internal::iterRefConverged(
arnoldi_residual, colnorms_X, cte );
++j, ++iiter) {
auto Vj1 = V.slice( 0, V.m()-1, j+1, j+1 );
Expand Down Expand Up @@ -341,15 +340,15 @@ int64_t gesv_mixed_gmres(
auto H_00 = H( 0, 0 );
for (int64_t i = 0; i < j; ++i) {
blas::rot( 1, &H_00.at( i, j ), 1, &H_00.at( i+1, j ), 1,
givens_alpha[i], givens_beta[i] );
givens_alpha[i], givens_beta[i] );
}
scalar_hi H_jj = H_00.at( j, j ), H_j1j = H_00.at( j+1, j );
blas::rotg( &H_jj, & H_j1j, &givens_alpha[j], &givens_beta[j] );
blas::rot( 1, &H_00.at( j, j ), 1, &H_00.at( j+1, j ), 1,
givens_alpha[j], givens_beta[j] );
givens_alpha[j], givens_beta[j] );
auto S_00 = S( 0, 0 );
blas::rot( 1, &S_00.at( j, 0 ), 1, &S_00.at( j+1, 0 ), 1,
givens_alpha[j], givens_beta[j] );
givens_alpha[j], givens_beta[j] );
arnoldi_residual[0] = cabs1( S_00.at( j+1, 0 ) );
}
timers[ "gesv_mixed_gmres::rotations" ] += t_gesv_mixed_gmres_rotations.stop();
Expand Down Expand Up @@ -411,7 +410,6 @@ int64_t gesv_mixed_gmres(
return info;
}


//------------------------------------------------------------------------------
// Explicit instantiations.
template <>
Expand All @@ -422,7 +420,8 @@ int64_t gesv_mixed_gmres<double>(
int& iter,
Options const& opts)
{
return gesv_mixed_gmres<double, float>( A, pivots, B, X, iter, opts );
return gesv_mixed_gmres<double, float>(
A, pivots, B, X, iter, opts );
}

template <>
Expand All @@ -433,8 +432,8 @@ int64_t gesv_mixed_gmres< std::complex<double> >(
int& iter,
Options const& opts)
{
return gesv_mixed_gmres<std::complex<double>, std::complex<float>>(
A, pivots, B, X, iter, opts );
return gesv_mixed_gmres< std::complex<double>, std::complex<float> >(
A, pivots, B, X, iter, opts );
}

} // namespace slate
18 changes: 4 additions & 14 deletions src/internal/internal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -675,20 +675,10 @@ void unmbr_tb2bd(Side side, Op op,
//-----------------------------------------
// potrf()
template <Target target=Target::HostTask, typename scalar_t>
void potrf(HermitianMatrix<scalar_t>&& A,
int priority=0, int64_t queue_index=0,
lapack::device_info_int* device_info=nullptr);

// forward real-symmetric matrices to potrf;
// disabled for complex
template <Target target=Target::HostTask, typename scalar_t>
void potrf(SymmetricMatrix<scalar_t>&& A,
int priority=0, int64_t queue_index=0,
lapack::device_info_int* device_info=nullptr,
enable_if_t< ! is_complex<scalar_t>::value >* = nullptr)
{
potrf<target>(SymmetricMatrix<scalar_t>(A), priority);
}
int64_t potrf(
HermitianMatrix<scalar_t>&& A,
int priority=0, int64_t queue_index=0,
lapack::device_info_int* device_info=nullptr );

//-----------------------------------------
// hegst()
Expand Down
Loading

0 comments on commit 0e9cac3

Please sign in to comment.