diff --git a/common/cuda_hip/components/volatile.hpp.inc b/common/cuda_hip/components/volatile.hpp.inc index 8a4e27b7905..c5668cea470 100644 --- a/common/cuda_hip/components/volatile.hpp.inc +++ b/common/cuda_hip/components/volatile.hpp.inc @@ -32,13 +32,22 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. template __device__ __forceinline__ - std::enable_if_t::value, ValueType> + std::enable_if_t::value, ValueType> load(const ValueType* values, IndexType index) { const volatile ValueType* val = values + index; return *val; } +template +__device__ __forceinline__ + std::enable_if_t::value, ValueType> + load(const ValueType* values, int index) +{ + const volatile ValueType* val = values + index; + return *val; +} + template __device__ __forceinline__ std::enable_if_t< std::is_floating_point::value, thrust::complex> @@ -50,9 +59,18 @@ load(const thrust::complex* values, IndexType index) } template -__device__ __forceinline__ - std::enable_if_t::value, void> - store(ValueType* values, IndexType index, ValueType value) +__device__ __forceinline__ void store( + ValueType* values, IndexType index, + std::enable_if_t::value, ValueType> value) +{ + volatile ValueType* val = values + index; + *val = value; +} + +template +__device__ __forceinline__ void store( + ValueType* values, int index, + std::enable_if_t::value, ValueType> value) { volatile ValueType* val = values + index; *val = value; diff --git a/core/solver/lower_trs.cpp b/core/solver/lower_trs.cpp index f46dd2c7d12..1ab045b6d6d 100644 --- a/core/solver/lower_trs.cpp +++ b/core/solver/lower_trs.cpp @@ -135,8 +135,8 @@ void LowerTrs::generate() if (this->get_system_matrix()) { this->get_executor()->run(lower_trs::make_generate( this->get_system_matrix().get(), this->solve_struct_, - this->get_parameters().unit_diagonal, parameters_.algorithm, - parameters_.num_rhs)); + this->get_parameters().unit_diagonal, + gko::lend(parameters_.strategy), parameters_.num_rhs)); } } @@ -178,8 +178,8 @@ void LowerTrs::apply_impl(const LinOp* b, LinOp* x) const } exec->run(lower_trs::make_solve( lend(this->get_system_matrix()), lend(this->solve_struct_), - this->get_parameters().unit_diagonal, parameters_.algorithm, - trans_b, trans_x, dense_b, dense_x)); + this->get_parameters().unit_diagonal, trans_b, trans_x, dense_b, + dense_x)); }, b, x); } diff --git a/core/solver/lower_trs_kernels.hpp b/core/solver/lower_trs_kernels.hpp index a6e39ce8436..cd98a0ddeeb 100644 --- a/core/solver/lower_trs_kernels.hpp +++ b/core/solver/lower_trs_kernels.hpp @@ -56,11 +56,11 @@ namespace lower_trs { bool& do_transpose) -#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype>* matrix, \ - std::shared_ptr& solve_struct, \ - bool unit_diag, const solver::trisolve_algorithm algorithm, \ +#define GKO_DECLARE_LOWER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + std::shared_ptr& solve_struct, \ + bool unit_diag, const solver::trisolve_strategy* strategy, \ const size_type num_rhs) @@ -68,7 +68,6 @@ namespace lower_trs { void solve(std::shared_ptr exec, \ const matrix::Csr<_vtype, _itype>* matrix, \ const solver::SolveStruct* solve_struct, bool unit_diag, \ - const solver::trisolve_algorithm algorithm, \ matrix::Dense<_vtype>* trans_b, matrix::Dense<_vtype>* trans_x, \ const matrix::Dense<_vtype>* b, matrix::Dense<_vtype>* x) diff --git a/core/solver/upper_trs.cpp b/core/solver/upper_trs.cpp index 6d60fe04f88..2e848370c48 100644 --- a/core/solver/upper_trs.cpp +++ b/core/solver/upper_trs.cpp @@ -135,8 +135,8 @@ void UpperTrs::generate() if (this->get_system_matrix()) { this->get_executor()->run(upper_trs::make_generate( this->get_system_matrix().get(), this->solve_struct_, - this->get_parameters().unit_diagonal, parameters_.algorithm, - parameters_.num_rhs)); + this->get_parameters().unit_diagonal, + gko::lend(parameters_.strategy), parameters_.num_rhs)); } } @@ -178,8 +178,8 @@ void UpperTrs::apply_impl(const LinOp* b, LinOp* x) const } exec->run(upper_trs::make_solve( lend(this->get_system_matrix()), lend(this->solve_struct_), - this->get_parameters().unit_diagonal, parameters_.algorithm, - trans_b, trans_x, dense_b, dense_x)); + this->get_parameters().unit_diagonal, trans_b, trans_x, dense_b, + dense_x)); }, b, x); } diff --git a/core/solver/upper_trs_kernels.hpp b/core/solver/upper_trs_kernels.hpp index 84a370c5f9d..c57adb3c6fe 100644 --- a/core/solver/upper_trs_kernels.hpp +++ b/core/solver/upper_trs_kernels.hpp @@ -56,11 +56,11 @@ namespace upper_trs { bool& do_transpose) -#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ - void generate(std::shared_ptr exec, \ - const matrix::Csr<_vtype, _itype>* matrix, \ - std::shared_ptr& solve_struct, \ - bool unit_diag, const solver::trisolve_algorithm algorithm, \ +#define GKO_DECLARE_UPPER_TRS_GENERATE_KERNEL(_vtype, _itype) \ + void generate(std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + std::shared_ptr& solve_struct, \ + bool unit_diag, const solver::trisolve_strategy* strategy, \ const size_type num_rhs) @@ -68,7 +68,6 @@ namespace upper_trs { void solve(std::shared_ptr exec, \ const matrix::Csr<_vtype, _itype>* matrix, \ const solver::SolveStruct* solve_struct, bool unit_diag, \ - const solver::trisolve_algorithm algorithm, \ matrix::Dense<_vtype>* trans_b, matrix::Dense<_vtype>* trans_x, \ const matrix::Dense<_vtype>* b, matrix::Dense<_vtype>* x) diff --git a/cuda/CMakeLists.txt b/cuda/CMakeLists.txt index 46946b3b696..d243f3c163f 100644 --- a/cuda/CMakeLists.txt +++ b/cuda/CMakeLists.txt @@ -44,6 +44,7 @@ target_sources(ginkgo_cuda solver/cb_gmres_kernels.cu solver/idr_kernels.cu solver/lower_trs_kernels.cu + solver/common_trs_kernels.cu solver/multigrid_kernels.cu solver/upper_trs_kernels.cu stop/criterion_kernels.cu diff --git a/cuda/solver/common_trs_kernels.cu b/cuda/solver/common_trs_kernels.cu new file mode 100644 index 00000000000..945e9215621 --- /dev/null +++ b/cuda/solver/common_trs_kernels.cu @@ -0,0 +1,1910 @@ +/************************************************************* +Copyright (c) 2017-2022, the Ginkgo authors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS +IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED +TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*************************************************************/ + + +#include +#include +#include + + +#include +#include + + +#include +#include +#include + + +#include "core/components/prefix_sum_kernels.hpp" +#include "core/matrix/dense_kernels.hpp" +#include "core/synthesizer/implementation_selection.hpp" +#include "cuda/base/cusparse_bindings.hpp" +#include "cuda/base/math.hpp" +#include "cuda/base/pointer_mode_guard.hpp" +#include "cuda/base/types.hpp" +#include "cuda/components/atomic.cuh" +#include "cuda/components/thread_ids.cuh" +#include "cuda/components/uninitialized_array.hpp" +#include "cuda/components/volatile.cuh" + + +namespace gko { +namespace solver { + + +// struct SolveStruct { +// virtual ~SolveStruct() = default; +// }; + + +} // namespace solver + + +namespace kernels { +namespace cuda { +namespace { + + +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11031)) + + +template +struct CudaSolveStruct : gko::solver::SolveStruct { + cusparseHandle_t handle; + cusparseSpSMDescr_t spsm_descr; + cusparseSpMatDescr_t descr_a; + size_type num_rhs; + + // Implicit parameter in spsm_solve, therefore stored here. + array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type num_rhs, bool is_upper, bool unit_diag) + : handle{exec->get_cusparse_handle()}, + spsm_descr{}, + descr_a{}, + num_rhs{num_rhs}, + work{exec} + { + if (num_rhs == 0) { + return; + } + cusparse::pointer_mode_guard pm_guard(handle); + spsm_descr = cusparse::create_spsm_descr(); + descr_a = cusparse::create_csr( + matrix->get_size()[0], matrix->get_size()[1], + matrix->get_num_stored_elements(), + const_cast(matrix->get_const_row_ptrs()), + const_cast(matrix->get_const_col_idxs()), + const_cast(matrix->get_const_values())); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_FILL_MODE, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); + cusparse::set_attribute( + descr_a, CUSPARSE_SPMAT_DIAG_TYPE, + unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT); + + const auto rows = matrix->get_size()[0]; + // workaround suggested by NVIDIA engineers: for some reason + // cusparse needs non-nullptr input vectors even for analysis + auto descr_b = cusparse::create_dnmat( + dim<2>{matrix->get_size()[0], num_rhs}, matrix->get_size()[1], + reinterpret_cast(0xDEAD)); + auto descr_c = cusparse::create_dnmat( + dim<2>{matrix->get_size()[0], num_rhs}, matrix->get_size()[1], + reinterpret_cast(0xDEAF)); + + auto work_size = cusparse::spsm_buffer_size( + handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), descr_a, + descr_b, descr_c, CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + work.resize_and_reset(work_size); + + cusparse::spsm_analysis(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, + one(), descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr, + work.get_data()); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + + void solve(const matrix::Csr*, + const matrix::Dense* input, + matrix::Dense* output, matrix::Dense*, + matrix::Dense*) const + { + if (input->get_size()[1] != num_rhs) { + throw gko::ValueMismatch{ + __FILE__, + __LINE__, + __FUNCTION__, + input->get_size()[1], + num_rhs, + "the dimensions of the multivector do not match the value " + "provided at generation time. Check the value specified in " + ".with_num_rhs(...)."}; + } + cusparse::pointer_mode_guard pm_guard(handle); + auto descr_b = cusparse::create_dnmat( + input->get_size(), input->get_stride(), + const_cast(input->get_const_values())); + auto descr_c = cusparse::create_dnmat( + output->get_size(), output->get_stride(), output->get_values()); + + cusparse::spsm_solve(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_NON_TRANSPOSE, one(), + descr_a, descr_b, descr_c, + CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); + + cusparse::destroy(descr_b); + cusparse::destroy(descr_c); + } + + ~CudaSolveStruct() + { + if (descr_a) { + cusparse::destroy(descr_a); + descr_a = nullptr; + } + if (spsm_descr) { + cusparse::destroy(spsm_descr); + spsm_descr = nullptr; + } + } + + CudaSolveStruct(const SolveStruct&) = delete; + + CudaSolveStruct(SolveStruct&&) = delete; + + CudaSolveStruct& operator=(const SolveStruct&) = delete; + + CudaSolveStruct& operator=(SolveStruct&&) = delete; +}; + + +#elif (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) + +template +struct CudaSolveStruct : gko::solver::SolveStruct { + std::shared_ptr exec; + cusparseHandle_t handle; + int algorithm; + csrsm2Info_t solve_info; + cusparseSolvePolicy_t policy; + cusparseMatDescr_t factor_descr; + size_type num_rhs; + mutable array work; + + CudaSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type num_rhs, bool is_upper, bool unit_diag) + : exec{exec}, + handle{exec->get_cusparse_handle()}, + algorithm{}, + solve_info{}, + policy{}, + factor_descr{}, + num_rhs{num_rhs}, + work{exec} + { + if (num_rhs == 0) { + return; + } + cusparse::pointer_mode_guard pm_guard(handle); + factor_descr = cusparse::create_mat_descr(); + solve_info = cusparse::create_solve_info(); + cusparse::set_mat_fill_mode( + factor_descr, + is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); + cusparse::set_mat_diag_type( + factor_descr, + unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT); + algorithm = 0; + policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL; + + size_type work_size{}; + + cusparse::buffer_size_ext( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + &work_size); + + // allocate workspace + work.resize_and_reset(work_size); + + cusparse::csrsm2_analysis( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, + matrix->get_num_stored_elements(), one(), factor_descr, + matrix->get_const_values(), matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, + work.get_data()); + } + + void solve(const matrix::Csr* matrix, + const matrix::Dense* input, + matrix::Dense* output, matrix::Dense*, + matrix::Dense*) const + { + if (input->get_size()[1] != num_rhs) { + throw gko::ValueMismatch{ + __FILE__, + __LINE__, + __FUNCTION__, + input->get_size()[1], + num_rhs, + "the dimensions of the multivector do not match the value " + "provided at generation time. Check the value specified in " + ".with_num_rhs(...)."}; + } + cusparse::pointer_mode_guard pm_guard(handle); + dense::copy(exec, input, output); + cusparse::csrsm2_solve( + handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], + output->get_stride(), matrix->get_num_stored_elements(), + one(), factor_descr, matrix->get_const_values(), + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + output->get_values(), output->get_stride(), solve_info, policy, + work.get_data()); + } + + ~CudaSolveStruct() + { + if (factor_descr) { + cusparse::destroy(factor_descr); + factor_descr = nullptr; + } + if (solve_info) { + cusparse::destroy(solve_info); + solve_info = nullptr; + } + } + + CudaSolveStruct(const CudaSolveStruct&) = delete; + + CudaSolveStruct(CudaSolveStruct&&) = delete; + + CudaSolveStruct& operator=(const CudaSolveStruct&) = delete; + + CudaSolveStruct& operator=(CudaSolveStruct&&) = delete; +}; + + +#endif + + +void should_perform_transpose_kernel(std::shared_ptr exec, + bool& do_transpose) +{ + do_transpose = false; +} + + +constexpr int default_block_size = 512; +constexpr int fallback_block_size = 32; + + +template +__global__ void sptrsv_naive_caching_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, size_type b_stride, + ValueType* const x, size_type x_stride, const size_type n, + const size_type nrhs, bool unit_diag, bool* nan_produced, + IndexType* atomic_counter) +{ + __shared__ uninitialized_array x_s_array; + __shared__ IndexType block_base_idx; + + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * default_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto rhs = full_gid % nrhs; + const auto gid = full_gid / nrhs; + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n) { + return; + } + + const auto self_shmem_id = full_gid / default_block_size; + const auto self_shid = full_gid % default_block_size; + + ValueType* x_s = x_s_array; + x_s[self_shid] = nan(); + + __syncthreads(); + + // lower tri matrix: start at beginning, run forward until last entry, + // (row_end - 1) which is the diagonal entry + // upper tri matrix: start at last entry (row_end - 1), run backward + // until first entry, which is the diagonal entry + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = is_upper ? rowptrs[row] - 1 : rowptrs[row + 1]; + const int row_step = is_upper ? -1 : 1; + + auto sum = zero(); + auto i = row_begin; + for (; i != row_end; i += row_step) { + const auto dependency = colidxs[i]; + if (is_upper ? dependency <= row : dependency >= row) { + break; + } + auto x_p = &x[dependency * x_stride + rhs]; + + const auto dependency_gid = is_upper ? (n - 1 - dependency) * nrhs + rhs + : dependency * nrhs + rhs; + const bool shmem_possible = + (dependency_gid / default_block_size) == self_shmem_id; + if (shmem_possible) { + const auto dependency_shid = dependency_gid % default_block_size; + x_p = &x_s[dependency_shid]; + } + + ValueType x = *x_p; + while (is_nan(x)) { + x = load(x_p, 0); + } + + sum += x * vals[i]; + } + + // The first entry past the triangular part will be the diagonal + const auto diag = unit_diag ? one() : vals[i]; + const auto r = (b[row * b_stride + rhs] - sum) / diag; + + store(x_s, self_shid, r); + x[row * x_stride + rhs] = r; + + // This check to ensure no infinte loops happen. + if (is_nan(r)) { + store(x_s, self_shid, zero()); + x[row * x_stride + rhs] = zero(); + *nan_produced = true; + } +} + + +template +__global__ void sptrsv_naive_legacy_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, size_type b_stride, + ValueType* const x, size_type x_stride, const size_type n, + const size_type nrhs, bool unit_diag, bool* nan_produced, + IndexType* atomic_counter) +{ + __shared__ IndexType block_base_idx; + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * fallback_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto rhs = full_gid % nrhs; + const auto gid = full_gid / nrhs; + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n) { + return; + } + + // lower tri matrix: start at beginning, run forward + // upper tri matrix: start at last entry (row_end - 1), run backward + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = is_upper ? rowptrs[row] - 1 : rowptrs[row + 1]; + const int row_step = is_upper ? -1 : 1; + + auto sum = zero(); + auto j = row_begin; + auto col = colidxs[j]; + while (j != row_end) { + auto x_val = load(x, col * x_stride + rhs); + while (!is_nan(x_val)) { + sum += vals[j] * x_val; + j += row_step; + col = colidxs[j]; + x_val = load(x, col * x_stride + rhs); + } + // to avoid the kernel hanging on matrices without diagonal, + // we bail out if we are past the triangle, even if it's not + // the diagonal entry. This may lead to incorrect results, + // but prevents an infinite loop. + if (is_upper ? row >= col : row <= col) { + // assert(row == col); + auto diag = unit_diag ? one() : vals[j]; + const auto r = (b[row * b_stride + rhs] - sum) / diag; + store(x, row * x_stride + rhs, r); + // after we encountered the diagonal, we are done + // this also skips entries outside the triangle + j = row_end; + if (is_nan(r)) { + store(x, row * x_stride + rhs, zero()); + *nan_produced = true; + } + } + } +} + + +template +__global__ void sptrsv_init_kernel(bool* const nan_produced, + IndexType* const atomic_counter) +{ + *nan_produced = false; + *atomic_counter = IndexType{}; +} + + +template +struct SptrsvebcrnSolveStruct : gko::solver::SolveStruct { + bool is_upper; + bool unit_diag; + + SptrsvebcrnSolveStruct(std::shared_ptr, + const matrix::Csr*, size_type, + bool is_upper, bool unit_diag) + : is_upper{is_upper}, unit_diag{unit_diag} + {} + + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + // Pre-Volta GPUs may deadlock due to missing independent thread + // scheduling. + const auto is_fallback_required = exec->get_major_version() < 7; + + const auto n = matrix->get_size()[0]; + const auto nrhs = b->get_size()[1]; + + // Initialize x to all NaNs. + dense::fill(exec, x, nan()); + + array nan_produced(exec, 1); + array atomic_counter(exec, 1); + sptrsv_init_kernel<<<1, 1>>>(nan_produced.get_data(), + atomic_counter.get_data()); + + const dim3 block_size( + is_fallback_required ? fallback_block_size : default_block_size, 1, + 1); + const dim3 grid_size(ceildiv(n * nrhs, block_size.x), 1, 1); + + if (is_fallback_required) { + if (is_upper) { + sptrsv_naive_legacy_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } else { + sptrsv_naive_legacy_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } + } else { + if (is_upper) { + sptrsv_naive_caching_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } else { + sptrsv_naive_caching_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } + } +#if GKO_VERBOSE_LEVEL >= 1 + if (exec->copy_val_to_host(nan_produced.get_const_data())) { + std::cerr << "Error: triangular solve produced NaN, either not all " + "diagonal " + "elements are nonzero, or the system is very " + "ill-conditioned. " + "The NaN will be replaced with a zero.\n"; + } +#endif // GKO_VERBOSE_LEVEL >= 1 + } +}; + + +template +__global__ void sptrsmebrcnm_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, size_type b_stride, + ValueType* const x, size_type x_stride, const size_type n, + const IndexType nrhs, bool* nan_produced, IndexType* atomic_counter, + IndexType m, bool unit_diag) +{ + __shared__ IndexType block_base_idx; + + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * default_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto rhs = (full_gid / m) % nrhs; + const auto gid = full_gid / (m * nrhs); + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n || rhs >= nrhs || full_gid % m != 0) { + return; + } + + // lower tri matrix: start at beginning, run forward until last entry, + // (row_end - 1) which is the diagonal entry + // upper tri matrix: start at last entry (row_end - 1), run backward + // until first entry, which is the diagonal entry + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_diag = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const int row_step = is_upper ? -1 : 1; + + auto sum = zero(); + auto i = row_begin; + for (; i != row_diag; i += row_step) { + const auto dependency = colidxs[i]; + if (is_upper ? dependency <= row : dependency >= row) { + break; + } + + auto x_p = &x[dependency * x_stride + rhs]; + + + ValueType x = *x_p; + while (is_nan(x)) { + x = load(x_p, 0); + } + + sum += x * vals[i]; + } + + const auto diag = unit_diag ? one() : vals[i]; + const auto r = (b[row * b_stride + rhs] - sum) / diag; + x[row * x_stride + rhs] = r; + + // This check to ensure no infinte loops happen. + if (is_nan(r)) { + x[row * x_stride + rhs] = zero(); + *nan_produced = true; + } +} + + +template +struct SptrsvebcrnmSolveStruct : gko::solver::SolveStruct { + bool is_upper; + IndexType m; + bool unit_diag; + + SptrsvebcrnmSolveStruct(std::shared_ptr, + const matrix::Csr*, size_type, + bool is_upper, bool unit_diag, uint8 m) + : is_upper{is_upper}, m{m}, unit_diag{unit_diag} + {} + + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + // Pre-Volta GPUs may deadlock due to missing independent thread + // scheduling. + const auto is_fallback_required = exec->get_major_version() < 7; + + const auto n = matrix->get_size()[0]; + const IndexType nrhs = b->get_size()[1]; + + // Initialize x to all NaNs. + dense::fill(exec, x, nan()); + + array nan_produced(exec, 1); + array atomic_counter(exec, 1); + sptrsv_init_kernel<<<1, 1>>>(nan_produced.get_data(), + atomic_counter.get_data()); + + const dim3 block_size( + is_fallback_required ? fallback_block_size : default_block_size, 1, + 1); + const dim3 grid_size( + ceildiv(n * (is_fallback_required ? 1 : m) * nrhs, block_size.x), 1, + 1); + + if (is_fallback_required) { + if (is_upper) { + sptrsv_naive_legacy_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } else { + sptrsv_naive_legacy_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + unit_diag, nan_produced.get_data(), + atomic_counter.get_data()); + } + } else { + if (is_upper) { + sptrsmebrcnm_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + nan_produced.get_data(), atomic_counter.get_data(), m, + unit_diag); + } else { + sptrsmebrcnm_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, + nan_produced.get_data(), atomic_counter.get_data(), m, + unit_diag); + } + } +#if GKO_VERBOSE_LEVEL >= 1 + if (exec->copy_val_to_host(nan_produced.get_const_data())) { + std::cerr << "Error: triangular solve produced NaN, either not all " + "diagonal " + "elements are nonzero, or the system is very " + "ill-conditioned. " + "The NaN will be replaced with a zero.\n"; + } +#endif // GKO_VERBOSE_LEVEL >= 1 + } +}; + + +template +__global__ void sptrsvelcr_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const ValueType* const b, + const size_type b_stride, ValueType* const x, const size_type x_stride, + const IndexType* const levels, const IndexType sweep, const IndexType n, + const IndexType nrhs, bool unit_diag) +{ + const auto gid = thread::get_thread_id_flat(); + const auto row = gid / nrhs; + const auto rhs = gid % nrhs; + + if (row >= n) { + return; + } + + if (levels[row] != sweep) { + return; + } + + const auto row_start = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const auto row_step = is_upper ? -1 : 1; + + auto sum = zero(); + IndexType i = row_start; + for (; i != row_end; i += row_step) { + const auto dependency = colidxs[i]; + if (is_upper ? dependency <= row : dependency >= row) { + break; + } + + sum += x[dependency * x_stride + rhs] * vals[i]; + } + + const auto diag = unit_diag ? one() : vals[i]; + const auto r = (b[row * b_stride + rhs] - sum) / diag; + x[row * x_stride + rhs] = r; +} + + +template +__global__ void level_generation_kernel(const IndexType* const rowptrs, + const IndexType* const colidxs, + volatile IndexType* const levels, + volatile IndexType* const height, + const IndexType n, + IndexType* const atomic_counter) +{ + __shared__ uninitialized_array level_s_array; + __shared__ IndexType block_base_idx; + + if (threadIdx.x == 0) { + block_base_idx = + atomic_add(atomic_counter, IndexType{1}) * default_block_size; + } + __syncthreads(); + const auto full_gid = static_cast(threadIdx.x) + block_base_idx; + const auto gid = full_gid; + const auto row = is_upper ? n - 1 - gid : gid; + + if (gid >= n) { + return; + } + + const auto self_shmem_id = full_gid / default_block_size; + const auto self_shid = full_gid % default_block_size; + + IndexType* level_s = level_s_array; + level_s[self_shid] = -1; + + __syncthreads(); + + // lower tri matrix: start at beginning, run forward until last entry, + // (row_end - 1) which is the diagonal entry + // upper tri matrix: start at last entry (row_end - 1), run backward + // until first entry, which is the diagonal entry + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const int row_step = is_upper ? -1 : 1; + + IndexType level = -one(); + for (auto i = row_begin; i != row_end; i += row_step) { + const auto dependency = colidxs[i]; + if (is_upper ? dependency <= row : dependency >= row) { + break; + } + + auto l_p = &levels[dependency]; + + const auto dependency_gid = is_upper ? n - 1 - dependency : dependency; + const bool shmem_possible = + (dependency_gid / default_block_size) == self_shmem_id; + if (shmem_possible) { + const auto dependency_shid = dependency_gid % default_block_size; + l_p = &level_s[dependency_shid]; + } + + IndexType l = *l_p; + while (l == -one()) { + l = load(l_p, 0); + } + + level = max(l, level); + } + + store(level_s, self_shid, level + 1); + levels[row] = level + 1; + + atomic_max((IndexType*)height, level + 1); +} + + +template +__global__ void sptrsv_level_counts_kernel( + const IndexType* const levels, volatile IndexType* const level_counts, + IndexType* const lperm, const IndexType n) +{ + const auto gid = blockIdx.x * blockDim.x + threadIdx.x; + const auto row = gid; + + if (row >= n) { + return; + } + + auto level = levels[row]; + + // TODO: Make this a parallel reduction from n -> #levels + const auto i = atomic_add((IndexType*)(level_counts + level), (IndexType)1); + + lperm[row] = i; +} + + +template +__global__ void sptrsv_lperm_finalize_kernel( + const IndexType* const levels, const IndexType* const level_counts, + IndexType* const lperm, const IndexType n) +{ + const auto gid = blockIdx.x * blockDim.x + threadIdx.x; + const auto row = gid; + + if (row >= n) { + return; + } + + lperm[row] += level_counts[levels[row]]; +} + + +template +struct SptrsvlrSolveStruct : solver::SolveStruct { + bool is_upper; + array levels; + IndexType height; + bool unit_diag; + + SptrsvlrSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type, bool is_upper, bool unit_diag) + : is_upper{is_upper}, unit_diag{unit_diag} + { + const IndexType n = matrix->get_size()[0]; + cudaMemset(levels.get_data(), 0xFF, n * sizeof(IndexType)); + + array changed(exec, 1); + cudaMemset(changed.get_data(), 1, sizeof(uint8)); + + array height_d(exec, 1); + cudaMemset(height_d.get_data(), 0, sizeof(IndexType)); + + array atomic_counter(exec, 1); + cudaMemset(atomic_counter.get_data(), 0, sizeof(IndexType)); + + const auto block_size = default_block_size; + const auto block_count = (n + block_size - 1) / block_size; + + if (is_upper) { + level_generation_kernel + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + levels.get_data(), height_d.get_data(), n, + atomic_counter.get_data()); + } else { + level_generation_kernel + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + levels.get_data(), height_d.get_data(), n, + atomic_counter.get_data()); + } + + height = exec->copy_val_to_host(height_d.get_const_data()) + 1; + } + + void solve(std::shared_ptr, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + const IndexType n = matrix->get_size()[0]; + const IndexType nrhs = b->get_size()[1]; + + for (IndexType done_for = 0; done_for < height; ++done_for) { + const dim3 block_size(default_block_size, 1, 1); + const dim3 grid_size(ceildiv(n * nrhs, block_size.x), 1, 1); + + if (is_upper) { + sptrsvelcr_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + levels.get_const_data(), done_for, n, nrhs, unit_diag); + } else { + sptrsvelcr_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + levels.get_const_data(), done_for, n, nrhs, unit_diag); + } + } + } +}; + + +// Values other than 32 don't work. +constexpr int32 warp_inverse_size = 32; + + +template +__global__ void sptrsvebcrwi_generate_prep_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + IndexType* const row_skip_counts, const size_type n) +{ + const auto gid = thread::get_thread_id_flat(); + const auto row = gid; + + if (row >= n) { + return; + } + + const auto row_start = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = + is_upper ? rowptrs[row] - 1 : rowptrs[row + 1]; // Includes diagonal + const auto row_step = is_upper ? -1 : 1; + const auto local_inv_count = + is_upper ? warp_inverse_size - (row % warp_inverse_size) - 1 + : row % warp_inverse_size; + + // TODO: Evaluate end-to-start iteration with early break optimization + // Note: This optimization is only sensible when a hint + // "does this use compact storage" is set to false. + // FIXME: Document a requirement of sorted indices, then + // break on first hit in the diagonal box, calculating + // the number of not-visited entries. That is more + // efficient for compact storage schemes. + IndexType row_skip_count = 0; + for (IndexType i = row_start; i != row_end; i += row_step) { + const auto dep = colidxs[i]; + + if (is_upper) { + // Includes diagonal, entries from the other factor evaluate to + // negative + if (dep - row <= local_inv_count) { + ++row_skip_count; + } + } else { + if (row - dep <= local_inv_count) { + ++row_skip_count; + } + } + } + + row_skip_counts[row] = row_skip_count; +} + + +template +__global__ void sptrsvebcrwi_generate_inv_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, IndexType* const row_skip_counts, + ValueType* const band_inv, // zero initialized + uint32* const masks, const size_type n, const bool unit_diag) +{ + const auto gid = thread::get_thread_id_flat(); + const auto inv_block = gid / warp_inverse_size; + const auto rhs = gid % warp_inverse_size; + + const auto local_start_row = is_upper ? warp_inverse_size - 1 : 0; + const auto local_end_row = is_upper ? -1 : warp_inverse_size; + const auto local_step_row = is_upper ? -1 : 1; + +#pragma unroll + for (IndexType _i = local_start_row; _i != local_end_row; + _i += local_step_row) { + const auto row = (gid / warp_inverse_size) * warp_inverse_size + _i; + + // Skips entries beyond matrix size, in the last/first block + if (row >= n) { + continue; + } + + // Go though all block-internal dependencies of the row + + const auto row_start = is_upper + ? rowptrs[row] + row_skip_counts[row] - 1 + : rowptrs[row + 1] - row_skip_counts[row]; + const auto row_end = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const auto row_step = is_upper ? -1 : 1; + + auto sum = zero(); + IndexType i = row_start; + for (; i != row_end; i += row_step) { + const auto dep = colidxs[i]; + + // To skip out-of-triangle entries for compressed storage + if (dep == row) { + break; + } + + sum += + band_inv[inv_block * (warp_inverse_size * warp_inverse_size) + + dep % warp_inverse_size + rhs * warp_inverse_size] * + vals[i]; + } + + const auto diag = unit_diag ? one() : vals[i]; + const auto r = + ((rhs == _i ? one() : zero()) - sum) / diag; + + band_inv[inv_block * (warp_inverse_size * warp_inverse_size) + + row % warp_inverse_size + rhs * warp_inverse_size] = r; + } + + + if (gid >= n) { + return; + } + + const auto local_row = rhs; + const auto row = gid; + + const auto activemask = __activemask(); + + // Discover connected components. + + // Abuse masks as intermediate storage for component descriptors + store(masks, row, local_row); + __syncwarp(activemask); + + for (IndexType _i = 0; _i < warp_inverse_size; ++_i) { + uint32 current_min = local_row; + + const auto h_start = is_upper ? local_row + 1 : 0; + const auto h_end = is_upper ? warp_inverse_size : local_row; + const auto v_start = is_upper ? 0 : local_row + 1; + const auto v_end = is_upper ? local_row : warp_inverse_size; + + for (IndexType i = h_start; i < h_end; ++i) { + if (band_inv[inv_block * (warp_inverse_size * warp_inverse_size) + + local_row + i * warp_inverse_size] != 0.0) { + const auto load1 = load(masks, row - local_row + i); + if (current_min > load1) { + current_min = load1; + } + } + } + for (IndexType i = v_start; i < v_end; ++i) { + if (band_inv[inv_block * (warp_inverse_size * warp_inverse_size) + + i + local_row * warp_inverse_size] != 0.0) { + const auto load2 = load(masks, row - local_row + i); + if (current_min > load2) { + current_min = load2; + } + } + } + + // That was one round of fixed-point min iteration. + store(masks, row, current_min); + __syncwarp(activemask); + } + + // Now translate that into masks. + uint32 mask = 0b0; + const auto component = load(masks, row); + for (IndexType i = 0; i < warp_inverse_size; ++i) { + if (load(masks, row - local_row + i) == component) { + mask |= (0b1 << (is_upper ? warp_inverse_size - i - 1 : i)); + } + } + + __syncwarp(activemask); + + masks[row] = mask; +} + + +template +__global__ void sptrsvebcrwi_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const IndexType* const row_skip_counts, const ValueType* const vals, + const ValueType* const b, const size_type b_stride, ValueType* const x, + const size_type x_stride, const ValueType* const band_inv, + const uint32* const masks, const size_type n, const size_type nrhs, + bool* nan_produced) +{ + const auto gid = thread::get_thread_id_flat(); + const auto row = + is_upper ? ((IndexType)n + blockDim.x - 1) / blockDim.x * blockDim.x - + gid - 1 + : gid; + const auto rhs = blockDim.y * blockIdx.y + threadIdx.y; + + if (row >= n) { + return; + } + if (rhs >= nrhs) { + return; + } + + const int self_shid = row % default_block_size; + const auto skip_count = row_skip_counts[row]; + + const auto row_start = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = + is_upper ? rowptrs[row] + skip_count - 1 + : rowptrs[row + 1] - skip_count; // no -1, as skip_count >= 1 + const auto row_step = is_upper ? -1 : 1; + + ValueType sum = 0.0; + for (IndexType i = row_start; i != row_end; i += row_step) { + const auto dependency = colidxs[i]; + auto x_p = &x[dependency * x_stride + rhs]; + + ValueType x = *x_p; + while (is_nan(x)) { + x = load(x_p, 0); + } + + sum += x * vals[i]; + } + + __shared__ uninitialized_array b_s_array; + ValueType* b_s = b_s_array; + store(b_s, self_shid, b[row * b_stride + rhs] - sum); + + // Now sync all necessary threads before going into the mult. + // Inactive threads can not have a sync bit set. + const auto syncmask = masks[row]; + __syncwarp(syncmask); + + const auto band_inv_block = + band_inv + + (warp_inverse_size * warp_inverse_size) * (row / warp_inverse_size) + + row % warp_inverse_size; + const auto local_offset = row % warp_inverse_size; + + ValueType inv_sum = zero(); + for (int i = 0; i < warp_inverse_size; ++i) { + inv_sum += band_inv_block[i * warp_inverse_size] * + load(b_s, self_shid - local_offset + i); + } + + const auto r = inv_sum; + x[row * x_stride + rhs] = r; + + // This check to ensure no infinte loops happen. + if (is_nan(r)) { + x[row * x_stride + rhs] = zero(); + *nan_produced = true; + } +} + + +template +struct SptrsvebrwiSolveStruct : gko::solver::SolveStruct { + bool is_upper; + bool unit_diag; + array band_inv; + array row_skip_counts; + array masks; + + SptrsvebrwiSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type, bool is_upper, bool unit_diag) + : is_upper{is_upper}, + unit_diag{unit_diag}, + band_inv{exec, static_cast(warp_inverse_size) * + static_cast(warp_inverse_size) * + ceildiv(matrix->get_size()[0], + static_cast(warp_inverse_size))}, + row_skip_counts{exec, matrix->get_size()[0]}, + masks{exec, matrix->get_size()[0]} + { + const auto n = matrix->get_size()[0]; + const auto inv_blocks_count = ceildiv(n, warp_inverse_size); + + cudaMemset(band_inv.get_data(), 0, + warp_inverse_size * warp_inverse_size * inv_blocks_count * + sizeof(ValueType)); + cudaMemset(masks.get_data(), 0, n * sizeof(uint32)); + + const dim3 block_size(default_block_size, 1, 1); + const dim3 grid_size(ceildiv(n, block_size.x), 1, 1); + + if (is_upper) { + sptrsvebcrwi_generate_prep_kernel + <<>>(matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), + row_skip_counts.get_data(), n); + sptrsvebcrwi_generate_inv_kernel< + decltype(as_cuda_type(ValueType{})), IndexType, true> + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + row_skip_counts.get_data(), + as_cuda_type(band_inv.get_data()), masks.get_data(), n, + unit_diag); + } else { + sptrsvebcrwi_generate_prep_kernel + <<>>(matrix->get_const_row_ptrs(), + matrix->get_const_col_idxs(), + row_skip_counts.get_data(), n); + sptrsvebcrwi_generate_inv_kernel< + decltype(as_cuda_type(ValueType{})), IndexType, false> + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + row_skip_counts.get_data(), + as_cuda_type(band_inv.get_data()), masks.get_data(), n, + unit_diag); + } + } + + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + const auto n = matrix->get_size()[0]; + const auto nrhs = b->get_size()[1]; + + // TODO: Optimize for multiple rhs, by calling to a device gemm. + + dense::fill(exec, x, nan()); + + array nan_produced(exec, {false}); + + const dim3 block_size(default_block_size, 1, 1); + const dim3 grid_size(ceildiv(n, block_size.x), nrhs, 1); + if (is_upper) { + sptrsvebcrwi_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + row_skip_counts.get_const_data(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + as_cuda_type(band_inv.get_const_data()), masks.get_const_data(), + n, nrhs, nan_produced.get_data()); + } else { + sptrsvebcrwi_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + row_skip_counts.get_const_data(), + as_cuda_type(matrix->get_const_values()), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + as_cuda_type(band_inv.get_const_data()), masks.get_const_data(), + n, nrhs, nan_produced.get_data()); + } + } +}; + + +template +__global__ void sptrsvebcrwvs_write_vwarp_ids( + const IndexType* const vwarp_offsets, IndexType* const vwarp_ids, + const IndexType num_vwarps, const IndexType n) +{ + const auto gid = thread::get_thread_id_flat(); + + if (gid >= num_vwarps) { + return; + } + + const auto vwarp_start = vwarp_offsets[gid]; + const auto vwarp_end = vwarp_offsets[gid + 1]; + + for (IndexType i = vwarp_start; i < vwarp_end; ++i) { + vwarp_ids[i] = is_upper ? n - gid - 1 : gid; + } +} + +// This is "heavily inspired" by cppreference. +template +__device__ const T* lower_bound(const T* first, const T* const last, + const T value) +{ + const T* p; + auto count = last - first; + auto step = count; + while (count > 0) { + p = first; + step = count / 2; + p += step; + if (*p < value) { + first = ++p; + count -= step + 1; + } else { + count = step; + } + } + return first; +} + +template +__global__ void sptrsvebcrwvs_generate_assigned_sizes( + const IndexType* const rowptrs, const IndexType* const colidxs, + const double avg_threads_per_row, IndexType* const assigned_sizes, + IndexType* const entry_counts, const IndexType n, const IndexType nnz) +{ + const IndexType gid = thread::get_thread_id_flat(); + const auto row = is_upper ? n - gid - 1 : gid; + const auto row_write_location = gid; + const int32 thread = threadIdx.x; + + if (gid >= n) { + return; + } + + const auto diag_pos = + lower_bound(colidxs + rowptrs[row], colidxs + rowptrs[row + 1], row) - + (colidxs + rowptrs[row]); + const auto valid_entry_count = + is_upper ? rowptrs[row + 1] - rowptrs[row] - diag_pos : diag_pos + 1; + entry_counts[row] = valid_entry_count; + + const double avg_nnz = (double)nnz / n; + const double perfect_size = + (valid_entry_count)*avg_threads_per_row / avg_nnz; + const IndexType assigned_size = std::max( + std::min((IndexType)__double2int_rn(perfect_size), (IndexType)32), + (IndexType)1); + + volatile __shared__ int32 block_size_assigner[1]; + volatile __shared__ int32 block_size_assigner_lock[1]; + + *block_size_assigner = 0; + *block_size_assigner_lock = -1; + + __syncthreads(); + + while (*block_size_assigner_lock != thread - 1) { + } + + const auto prev_offset = *block_size_assigner; + *block_size_assigner += assigned_size; + + int32 shrinked_size = 0; + if ((prev_offset + assigned_size) / 32 > prev_offset / 32) { + shrinked_size = ((prev_offset + assigned_size) / 32) * 32 - prev_offset; + *block_size_assigner += shrinked_size - assigned_size; + } + + __threadfence(); + *block_size_assigner_lock = thread; + + assigned_sizes[row_write_location] = + shrinked_size == 0 ? assigned_size : shrinked_size; + + // This part to ensure each assigner block starts on 32*k, meaning the cuts + // are well-placed. + if (thread == default_block_size - 1) { + if ((prev_offset + + (shrinked_size == 0 ? assigned_size : shrinked_size)) % + 32 != + 0) { + assigned_sizes[row_write_location] = + (prev_offset / 32 + 1) * 32 - prev_offset; + } + } +} + + +template +__global__ void sptrsvebcrwvs_kernel( + const IndexType* const rowptrs, const IndexType* const colidxs, + const ValueType* const vals, const IndexType* const vwarp_ids, + const IndexType* const vwarp_offsets, const IndexType* const entry_counts, + const ValueType* const b, const size_type b_stride, ValueType* const x, + const size_type x_stride, bool* const nan_produced, + const IndexType num_vthreads, const IndexType n, const IndexType nrhs, + const bool unit_diag) +{ + const auto gid = blockIdx.x * blockDim.x + threadIdx.x; + const auto thread = threadIdx.x; + const auto rhs = blockIdx.y * blockDim.y + threadIdx.y; + + if (gid >= num_vthreads) { + return; + } + if (rhs >= nrhs) { + return; + } + + const auto vwarp = vwarp_ids[gid]; + const auto vwarp_start = vwarp_offsets[is_upper ? n - vwarp - 1 : vwarp]; + const auto vwarp_end = vwarp_offsets[is_upper ? n - vwarp : vwarp + 1]; + const auto vwarp_size = vwarp_end - vwarp_start; + const auto row = vwarp; + const IndexType vthread = gid - vwarp_start; + + if (row >= n) { + return; + } + + const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; + const auto row_end = is_upper ? rowptrs[row] : rowptrs[row + 1] - 1; + const IndexType row_step = is_upper ? -1 : 1; + + const auto valid_entry_count = entry_counts[row]; + const auto start_offset = (valid_entry_count - 1) % vwarp_size; + + auto sum = zero(); + // i is adjusted for vthread 0 to hit the diagonal + IndexType i = + unit_diag + ? row_begin + row_step * vthread + : row_begin + + ((row_step * vthread + row_step * start_offset) % vwarp_size); + for (; (is_upper && i > row_end) || (!is_upper && i < row_end); + i += row_step * vwarp_size) { + const auto dependency = colidxs[i]; + + if (is_upper ? dependency <= row : dependency >= row) { + break; + } + + volatile auto x_p = &x[x_stride * dependency + rhs]; + + auto l = *x_p; + while (is_nan(l)) { + l = load(x_p, 0); + } + + + sum += l * vals[i]; + } + + uint32 syncmask = ((1 << vwarp_size) - 1) << (vwarp_start & 31); + + ValueType total = sum; + for (int offset = 1; offset < vwarp_size; ++offset) { + auto a = real(sum); + const auto received_a = __shfl_down_sync(syncmask, a, offset); + const auto should_add = (syncmask >> ((thread & 31) + offset)) & 1 == 1; + total += should_add * received_a; + if (gko::is_complex()) { + auto b = imag(sum); + const auto received_b = __shfl_down_sync(syncmask, b, offset); + auto ptotal = + (thrust::complex>*)&total; + *ptotal += should_add * received_b * + (thrust::complex>) + unit_root(4); + } + } + + if (vthread == 0) { + const auto diag = unit_diag ? one() : vals[i]; + const auto r = (b[row * b_stride + rhs] - total) / diag; + x[row * x_stride + rhs] = r; + + // This check to ensure no infinte loops happen. + if (is_nan(r)) { + x[row * x_stride + rhs] = zero(); + *nan_produced = true; + } + } +} + + +template +struct SptrsvebrwvSolveStruct : gko::solver::SolveStruct { + bool is_upper; + bool unit_diag; + IndexType vthread_count; + array vwarp_ids; + array vwarp_offsets; + array entry_counts; + + SptrsvebrwvSolveStruct(std::shared_ptr exec, + const matrix::Csr* matrix, + size_type, bool is_upper, bool unit_diag) + : is_upper{is_upper}, + unit_diag{unit_diag}, + vwarp_offsets{exec, matrix->get_size()[0] + 1}, + entry_counts{exec, matrix->get_size()[0]}, + vwarp_ids{exec} + { + const auto desired_avg_threads_per_row = 1.0; + + const IndexType n = matrix->get_size()[0]; + const IndexType nnz = matrix->get_num_stored_elements(); + + array assigned_sizes(exec, n); + + const auto block_size = default_block_size; + const auto block_count = (n + block_size - 1) / block_size; + + if (is_upper) { + sptrsvebcrwvs_generate_assigned_sizes + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + desired_avg_threads_per_row, assigned_sizes.get_data(), + entry_counts.get_data(), n, nnz); + } else { + sptrsvebcrwvs_generate_assigned_sizes + <<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + desired_avg_threads_per_row, assigned_sizes.get_data(), + entry_counts.get_data(), n, nnz); + } + + cudaMemcpy(vwarp_offsets.get_data(), assigned_sizes.get_const_data(), + n * sizeof(IndexType), cudaMemcpyDeviceToDevice); + components::prefix_sum(exec, vwarp_offsets.get_data(), n + 1); + + cudaMemcpy(&vthread_count, vwarp_offsets.get_const_data() + n, + sizeof(IndexType), cudaMemcpyDeviceToHost); + + vwarp_ids.resize_and_reset(vthread_count); + const auto block_size_vwarped = default_block_size; + const auto block_count_vwarped = + (n + block_size_vwarped - 1) / block_size_vwarped; + if (is_upper) { + sptrsvebcrwvs_write_vwarp_ids + <<>>( + vwarp_offsets.get_const_data(), vwarp_ids.get_data(), n, n); + } else { + sptrsvebcrwvs_write_vwarp_ids + <<>>( + vwarp_offsets.get_const_data(), vwarp_ids.get_data(), n, n); + } + } + + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + const IndexType n = matrix->get_size()[0]; + const IndexType nrhs = b->get_size()[1]; + + // TODO: Optimize for multiple rhs. + + dense::fill(exec, x, nan()); + + array nan_produced(exec, {false}); + + const dim3 block_size(default_block_size, 1024 / default_block_size, 1); + const dim3 grid_size(ceildiv(vthread_count, block_size.x), + ceildiv(nrhs, block_size.y), 1); + + if (is_upper) { + sptrsvebcrwvs_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + vwarp_ids.get_const_data(), vwarp_offsets.get_const_data(), + entry_counts.get_const_data(), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + nan_produced.get_data(), vthread_count, n, nrhs, unit_diag); + } else { + sptrsvebcrwvs_kernel<<>>( + matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), + as_cuda_type(matrix->get_const_values()), + vwarp_ids.get_const_data(), vwarp_offsets.get_const_data(), + entry_counts.get_const_data(), + as_cuda_type(b->get_const_values()), b->get_stride(), + as_cuda_type(x->get_values()), x->get_stride(), + nan_produced.get_data(), vthread_count, n, nrhs, unit_diag); + } + +#if GKO_VERBOSE_LEVEL >= 1 + if (exec->copy_val_to_host(nan_produced.get_const_data())) { + std::cerr << "Error: triangular solve produced NaN, either not all " + "diagonal " + "elements are nonzero, or the system is very " + "ill-conditioned. " + "The NaN will be replaced with a zero.\n"; + } +#endif // GKO_VERBOSE_LEVEL >= 1 + } +}; + + +template +struct BlockedSolveStruct : solver::SolveStruct { + struct pos_size_depth { + std::pair pos; + std::pair size; + IndexType depth; + + pos_size_depth left_child(IndexType max_depth) const + { + if (depth == max_depth - 1) { // Check if triangle + return pos_size_depth{ + std::make_pair(pos.first - size.second, pos.second), + std::make_pair(size.second, size.second), max_depth}; + } else { + return pos_size_depth{ + std::make_pair(pos.first - size.first / 2, pos.second), + std::make_pair(size.first / 2, + size.second - size.first / 2), + depth + 1}; + } + } + + pos_size_depth right_child(IndexType max_depth) const + { + if (depth == max_depth - 1) { + return pos_size_depth{ + std::make_pair(pos.first, pos.second + size.second), + std::make_pair(size.first, size.first), max_depth}; + } else { + return pos_size_depth{ + std::make_pair(pos.first + size.first / 2, + pos.second + size.second), + std::make_pair( + ceildiv(size.first, 2), + (pos.first + size.first - (pos.second + size.second)) / + 2), + depth + 1}; + } + } + }; + + std::vector> solvers; + std::vector>> blocks; + std::vector block_coords; + + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + const matrix::Dense* b, + matrix::Dense* x) const + { + auto mb = matrix::Dense::create(exec); + mb->copy_from(b); + + const auto block_count = blocks.size(); + for (IndexType i = 0; i < block_count; ++i) { + if (i % 2 == 0) { + const auto bv = + mb->create_submatrix(span{block_coords[i].pos.second, + block_coords[i].pos.second + + block_coords[i].size.second}, + span{0, 1}); + auto xv = + x->create_submatrix(span{block_coords[i].pos.first, + block_coords[i].pos.first + + block_coords[i].size.first}, + span{0, 1}); + + solvers[i / 2]->solve(exec, blocks[i].get(), bv.get(), xv.get(), + bv.get(), xv.get()); + } else { + const auto xv = + x->create_submatrix(span{block_coords[i].pos.second, + block_coords[i].pos.second + + block_coords[i].size.second}, + span{0, 1}); + auto bv = + mb->create_submatrix(span{block_coords[i].pos.first, + block_coords[i].pos.first + + block_coords[i].size.first}, + span{0, 1}); + auto neg_one = + gko::initialize>({-1}, exec); + auto one = + gko::initialize>({1}, exec); + blocks[i]->apply(neg_one.get(), xv.get(), one.get(), bv.get()); + } + } + } + + + BlockedSolveStruct( + std::shared_ptr exec, + const matrix::Csr* matrix, + const gko::size_type num_rhs, bool is_upper, bool unit_diag, + std::shared_ptr< + std::vector>> + solver_ids) + { + const auto host_exec = exec->get_master(); + const auto n = matrix->get_size()[0]; + const auto sptrsv_count = solver_ids->size(); + const auto block_count = 2 * sptrsv_count - 1; + const auto depth = get_significant_bit(sptrsv_count); + + // Generate the block sizes and positions. + array blocks(host_exec, block_count); + pos_size_depth* blocksp = blocks.get_data(); + blocksp[0] = pos_size_depth{std::make_pair(n / 2, 0), + std::make_pair(ceildiv(n, 2), n / 2), 0}; + IndexType write = 1; + for (IndexType read = 0; write < block_count; ++read) { + const auto cur = blocksp[read]; + blocksp[write++] = cur.left_child(depth); + blocksp[write++] = cur.right_child(depth); + } + + // Generate a permutation to execution order + array perm(host_exec, block_count); + IndexType* permp = perm.get_data(); + for (IndexType i = 0; i <= depth; ++i) { + const auto step = 2 << i; + const auto start = (1 << i) - 1; + const auto add = sptrsv_count / (1 << i) - 1; + + for (IndexType j = start; j < block_count; j += step) { + permp[j] = (j - start) / step + add; + } + } + + // Apply the perm + // For upper_trs, we also need to reflect the cuts + for (IndexType i = 0; i < block_count; ++i) { + auto block = blocksp[permp[i]]; + + if (is_upper) { + std::swap(block.pos.first, block.pos.second); + std::swap(block.size.first, block.size.second); + } + + block_coords.push_back(block); + this->blocks.push_back(std::move(matrix->create_submatrix( + span{block.pos.first, block.pos.first + block.size.first}, + span{block.pos.second, block.pos.second + block.size.second}))); + this->blocks[i]->set_strategy( + std::make_shared< + typename matrix::Csr::automatical>( + exec)); + } + + if (is_upper) { + for (auto i = 0; i < block_count / 2; ++i) { + std::swap(block_coords[i], block_coords[block_count - i - 1]); + std::swap(this->blocks[i], this->blocks[block_count - i - 1]); + } + } + + // Finally create the appropriate solvers + for (IndexType i = 0; i < sptrsv_count; ++i) { + this->solvers.push_back(std::make_shared()); + solver::SolveStruct::generate( + exec, this->blocks[2 * i].get(), this->solvers[i], num_rhs, + solver_ids.get()->at(i).get(), is_upper, unit_diag); + } + } +}; + + +} // namespace +} // namespace cuda +} // namespace kernels + + +template +void gko::solver::SolveStruct::generate( + std::shared_ptr exec, + const matrix::Csr* matrix, + std::shared_ptr& solve_struct, + const gko::size_type num_rhs, + const gko::solver::trisolve_strategy* strategy, bool is_upper, + bool unit_diag) +{ + if (matrix->get_size()[0] == 0) { + return; + } + if (strategy->type == gko::solver::trisolve_type::sparselib) { + if (gko::kernels::cuda::cusparse::is_supported::value) { + solve_struct = std::make_shared< + gko::kernels::cuda::CudaSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag); + } else { + GKO_NOT_SUPPORTED(solve_struct); + } + } else if (strategy->type == gko::solver::trisolve_type::level) { + solve_struct = std::make_shared< + gko::kernels::cuda::SptrsvlrSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag); + } else if (strategy->type == gko::solver::trisolve_type::winv) { + solve_struct = std::make_shared< + gko::kernels::cuda::SptrsvebrwiSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag); + } else if (strategy->type == gko::solver::trisolve_type::wvar) { + solve_struct = std::make_shared< + gko::kernels::cuda::SptrsvebrwvSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag); + } else if (strategy->type == gko::solver::trisolve_type::thinned) { + solve_struct = std::make_shared< + gko::kernels::cuda::SptrsvebcrnmSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag, strategy->thinned_m); + } else if (strategy->type == gko::solver::trisolve_type::block) { + solve_struct = std::make_shared< + gko::kernels::cuda::BlockedSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag, strategy->block_inner); + } else if (strategy->type == gko::solver::trisolve_type::syncfree) { + solve_struct = std::make_shared< + gko::kernels::cuda::SptrsvebcrnSolveStruct>( + exec, matrix, num_rhs, is_upper, unit_diag); + } +} + +#define GKO_DECLARE_SOLVER_SOLVESTRUCT_GENERATE(_vtype, _itype) \ + void gko::solver::SolveStruct::generate( \ + std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + std::shared_ptr& solve_struct, \ + const gko::size_type num_rhs, \ + const gko::solver::trisolve_strategy* strategy, bool is_upper, \ + bool unit_diag) + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_SOLVER_SOLVESTRUCT_GENERATE); + + +template +void gko::solver::SolveStruct::solve( + std::shared_ptr exec, + const matrix::Csr* matrix, + matrix::Dense* trans_b, matrix::Dense* trans_x, + const matrix::Dense* b, matrix::Dense* x) const +{ + if (matrix->get_size()[0] == 0 || b->get_size()[1] == 0) { + return; + } + if (auto sptrsvebcrn_struct = + dynamic_cast*>(this)) { + sptrsvebcrn_struct->solve(exec, matrix, b, x); + } else if (auto sptrsvlr_struct = + dynamic_cast*>(this)) { + sptrsvlr_struct->solve(exec, matrix, b, x); + } else if (auto sptrsvebcrnm_struct = dynamic_cast< + const gko::kernels::cuda::SptrsvebcrnmSolveStruct< + ValueType, IndexType>*>(this)) { + sptrsvebcrnm_struct->solve(exec, matrix, b, x); + } else if (auto sptrsvebrwi_struct = dynamic_cast< + const gko::kernels::cuda::SptrsvebrwiSolveStruct< + ValueType, IndexType>*>(this)) { + sptrsvebrwi_struct->solve(exec, matrix, b, x); + } else if (auto sptrsvb_struct = + dynamic_cast*>(this)) { + sptrsvb_struct->solve(exec, matrix, b, x); + } else if (auto sptrsvwv_struct = dynamic_cast< + const gko::kernels::cuda::SptrsvebrwvSolveStruct< + ValueType, IndexType>*>(this)) { + sptrsvwv_struct->solve(exec, matrix, b, x); + } else if (gko::kernels::cuda::cusparse::is_supported< + ValueType, + IndexType>::value) { // Must always be last check + if (auto cuda_solve_struct = + dynamic_cast*>(this)) { + cuda_solve_struct->solve(matrix, b, x, trans_b, trans_x); + } else { + GKO_NOT_SUPPORTED(this); + } + } else { + GKO_NOT_IMPLEMENTED; + } +} + +#define GKO_DECLARE_SOLVER_SOLVESTRUCT_SOLVE(_vtype, _itype) \ + void gko::solver::SolveStruct::solve( \ + std::shared_ptr exec, \ + const matrix::Csr<_vtype, _itype>* matrix, \ + matrix::Dense<_vtype>* trans_b, matrix::Dense<_vtype>* trans_x, \ + const matrix::Dense<_vtype>* b, matrix::Dense<_vtype>* x) const + +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( + GKO_DECLARE_SOLVER_SOLVESTRUCT_SOLVE); + + +} // namespace gko diff --git a/cuda/solver/common_trs_kernels.cuh b/cuda/solver/common_trs_kernels.cuh index f61d70f1a76..7f5e72ce75f 100644 --- a/cuda/solver/common_trs_kernels.cuh +++ b/cuda/solver/common_trs_kernels.cuh @@ -34,42 +34,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GKO_CUDA_SOLVER_COMMON_TRS_KERNELS_CUH_ -#include -#include -#include - - #include -#include - -#include -#include -#include - -#include "core/matrix/dense_kernels.hpp" -#include "core/synthesizer/implementation_selection.hpp" -#include "cuda/base/cusparse_bindings.hpp" -#include "cuda/base/math.hpp" -#include "cuda/base/pointer_mode_guard.hpp" -#include "cuda/base/types.hpp" -#include "cuda/components/atomic.cuh" -#include "cuda/components/thread_ids.cuh" -#include "cuda/components/uninitialized_array.hpp" -#include "cuda/components/volatile.cuh" +#include namespace gko { -namespace solver { - - -struct SolveStruct { - virtual ~SolveStruct() = default; -}; - - -} // namespace solver namespace kernels { @@ -77,508 +48,19 @@ namespace cuda { namespace { -#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11031)) - - -template -struct CudaSolveStruct : gko::solver::SolveStruct { - cusparseHandle_t handle; - cusparseSpSMDescr_t spsm_descr; - cusparseSpMatDescr_t descr_a; - size_type num_rhs; - - // Implicit parameter in spsm_solve, therefore stored here. - array work; - - CudaSolveStruct(std::shared_ptr exec, - const matrix::Csr* matrix, - size_type num_rhs, bool is_upper, bool unit_diag) - : handle{exec->get_cusparse_handle()}, - spsm_descr{}, - descr_a{}, - num_rhs{num_rhs}, - work{exec} - { - if (num_rhs == 0) { - return; - } - cusparse::pointer_mode_guard pm_guard(handle); - spsm_descr = cusparse::create_spsm_descr(); - descr_a = cusparse::create_csr( - matrix->get_size()[0], matrix->get_size()[1], - matrix->get_num_stored_elements(), - const_cast(matrix->get_const_row_ptrs()), - const_cast(matrix->get_const_col_idxs()), - const_cast(matrix->get_const_values())); - cusparse::set_attribute( - descr_a, CUSPARSE_SPMAT_FILL_MODE, - is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); - cusparse::set_attribute( - descr_a, CUSPARSE_SPMAT_DIAG_TYPE, - unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT); - - const auto rows = matrix->get_size()[0]; - // workaround suggested by NVIDIA engineers: for some reason - // cusparse needs non-nullptr input vectors even for analysis - auto descr_b = cusparse::create_dnmat( - dim<2>{matrix->get_size()[0], num_rhs}, matrix->get_size()[1], - reinterpret_cast(0xDEAD)); - auto descr_c = cusparse::create_dnmat( - dim<2>{matrix->get_size()[0], num_rhs}, matrix->get_size()[1], - reinterpret_cast(0xDEAF)); - - auto work_size = cusparse::spsm_buffer_size( - handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, one(), descr_a, - descr_b, descr_c, CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); - - work.resize_and_reset(work_size); - - cusparse::spsm_analysis(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, - one(), descr_a, descr_b, descr_c, - CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr, - work.get_data()); - - cusparse::destroy(descr_b); - cusparse::destroy(descr_c); - } - - void solve(const matrix::Csr*, - const matrix::Dense* input, - matrix::Dense* output, matrix::Dense*, - matrix::Dense*) const - { - if (input->get_size()[1] != num_rhs) { - throw gko::ValueMismatch{ - __FILE__, - __LINE__, - __FUNCTION__, - input->get_size()[1], - num_rhs, - "the dimensions of the multivector do not match the value " - "provided at generation time. Check the value specified in " - ".with_num_rhs(...)."}; - } - cusparse::pointer_mode_guard pm_guard(handle); - auto descr_b = cusparse::create_dnmat( - input->get_size(), input->get_stride(), - const_cast(input->get_const_values())); - auto descr_c = cusparse::create_dnmat( - output->get_size(), output->get_stride(), output->get_values()); - - cusparse::spsm_solve(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_NON_TRANSPOSE, one(), - descr_a, descr_b, descr_c, - CUSPARSE_SPSM_ALG_DEFAULT, spsm_descr); - - cusparse::destroy(descr_b); - cusparse::destroy(descr_c); - } - - ~CudaSolveStruct() - { - if (descr_a) { - cusparse::destroy(descr_a); - descr_a = nullptr; - } - if (spsm_descr) { - cusparse::destroy(spsm_descr); - spsm_descr = nullptr; - } - } - - CudaSolveStruct(const SolveStruct&) = delete; - - CudaSolveStruct(SolveStruct&&) = delete; - - CudaSolveStruct& operator=(const SolveStruct&) = delete; - - CudaSolveStruct& operator=(SolveStruct&&) = delete; -}; - - -#elif (defined(CUDA_VERSION) && (CUDA_VERSION >= 9020)) - -template -struct CudaSolveStruct : gko::solver::SolveStruct { - std::shared_ptr exec; - cusparseHandle_t handle; - int algorithm; - csrsm2Info_t solve_info; - cusparseSolvePolicy_t policy; - cusparseMatDescr_t factor_descr; - size_type num_rhs; - mutable array work; - - CudaSolveStruct(std::shared_ptr exec, - const matrix::Csr* matrix, - size_type num_rhs, bool is_upper, bool unit_diag) - : exec{exec}, - handle{exec->get_cusparse_handle()}, - algorithm{}, - solve_info{}, - policy{}, - factor_descr{}, - num_rhs{num_rhs}, - work{exec} - { - if (num_rhs == 0) { - return; - } - cusparse::pointer_mode_guard pm_guard(handle); - factor_descr = cusparse::create_mat_descr(); - solve_info = cusparse::create_solve_info(); - cusparse::set_mat_fill_mode( - factor_descr, - is_upper ? CUSPARSE_FILL_MODE_UPPER : CUSPARSE_FILL_MODE_LOWER); - cusparse::set_mat_diag_type( - factor_descr, - unit_diag ? CUSPARSE_DIAG_TYPE_UNIT : CUSPARSE_DIAG_TYPE_NON_UNIT); - algorithm = 0; - policy = CUSPARSE_SOLVE_POLICY_USE_LEVEL; - - size_type work_size{}; - - cusparse::buffer_size_ext( - handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, - matrix->get_num_stored_elements(), one(), factor_descr, - matrix->get_const_values(), matrix->get_const_row_ptrs(), - matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, - &work_size); - - // allocate workspace - work.resize_and_reset(work_size); - - cusparse::csrsm2_analysis( - handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], num_rhs, - matrix->get_num_stored_elements(), one(), factor_descr, - matrix->get_const_values(), matrix->get_const_row_ptrs(), - matrix->get_const_col_idxs(), nullptr, num_rhs, solve_info, policy, - work.get_data()); - } - - void solve(const matrix::Csr* matrix, - const matrix::Dense* input, - matrix::Dense* output, matrix::Dense*, - matrix::Dense*) const - { - if (input->get_size()[1] != num_rhs) { - throw gko::ValueMismatch{ - __FILE__, - __LINE__, - __FUNCTION__, - input->get_size()[1], - num_rhs, - "the dimensions of the multivector do not match the value " - "provided at generation time. Check the value specified in " - ".with_num_rhs(...)."}; - } - cusparse::pointer_mode_guard pm_guard(handle); - dense::copy(exec, input, output); - cusparse::csrsm2_solve( - handle, algorithm, CUSPARSE_OPERATION_NON_TRANSPOSE, - CUSPARSE_OPERATION_TRANSPOSE, matrix->get_size()[0], - output->get_stride(), matrix->get_num_stored_elements(), - one(), factor_descr, matrix->get_const_values(), - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - output->get_values(), output->get_stride(), solve_info, policy, - work.get_data()); - } - - ~CudaSolveStruct() - { - if (factor_descr) { - cusparse::destroy(factor_descr); - factor_descr = nullptr; - } - if (solve_info) { - cusparse::destroy(solve_info); - solve_info = nullptr; - } - } - - CudaSolveStruct(const CudaSolveStruct&) = delete; - - CudaSolveStruct(CudaSolveStruct&&) = delete; - - CudaSolveStruct& operator=(const CudaSolveStruct&) = delete; - - CudaSolveStruct& operator=(CudaSolveStruct&&) = delete; -}; - - -#endif - - -void should_perform_transpose_kernel(std::shared_ptr exec, +void should_perform_transpose_kernel(std::shared_ptr, bool& do_transpose) { do_transpose = false; } -template -void generate_kernel(std::shared_ptr exec, - const matrix::Csr* matrix, - std::shared_ptr& solve_struct, - const gko::size_type num_rhs, bool is_upper, - bool unit_diag) -{ - if (matrix->get_size()[0] == 0) { - return; - } - if (cusparse::is_supported::value) { - solve_struct = std::make_shared>( - exec, matrix, num_rhs, is_upper, unit_diag); - } else { - GKO_NOT_IMPLEMENTED; - } -} - - -template -void solve_kernel(std::shared_ptr exec, - const matrix::Csr* matrix, - const solver::SolveStruct* solve_struct, - matrix::Dense* trans_b, - matrix::Dense* trans_x, - const matrix::Dense* b, - matrix::Dense* x) -{ - if (matrix->get_size()[0] == 0 || b->get_size()[1] == 0) { - return; - } - using vec = matrix::Dense; - - if (cusparse::is_supported::value) { - if (auto cuda_solve_struct = - dynamic_cast*>( - solve_struct)) { - cuda_solve_struct->solve(matrix, b, x, trans_b, trans_x); - } else { - GKO_NOT_SUPPORTED(solve_struct); - } - } else { - GKO_NOT_IMPLEMENTED; - } -} - - -constexpr int default_block_size = 512; -constexpr int fallback_block_size = 32; - - -template -__global__ void sptrsv_naive_caching_kernel( - const IndexType* const rowptrs, const IndexType* const colidxs, - const ValueType* const vals, const ValueType* const b, size_type b_stride, - ValueType* const x, size_type x_stride, const size_type n, - const size_type nrhs, bool unit_diag, bool* nan_produced, - IndexType* atomic_counter) -{ - __shared__ uninitialized_array x_s_array; - __shared__ IndexType block_base_idx; - - if (threadIdx.x == 0) { - block_base_idx = - atomic_add(atomic_counter, IndexType{1}) * default_block_size; - } - __syncthreads(); - const auto full_gid = static_cast(threadIdx.x) + block_base_idx; - const auto rhs = full_gid % nrhs; - const auto gid = full_gid / nrhs; - const auto row = is_upper ? n - 1 - gid : gid; - - if (gid >= n) { - return; - } - - const auto self_shmem_id = full_gid / default_block_size; - const auto self_shid = full_gid % default_block_size; - - ValueType* x_s = x_s_array; - x_s[self_shid] = nan(); - - __syncthreads(); - - // lower tri matrix: start at beginning, run forward until last entry, - // (row_end - 1) which is the diagonal entry - // upper tri matrix: start at last entry (row_end - 1), run backward - // until first entry, which is the diagonal entry - const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; - const auto row_end = is_upper ? rowptrs[row] - 1 : rowptrs[row + 1]; - const int row_step = is_upper ? -1 : 1; - - auto sum = zero(); - auto i = row_begin; - for (; i != row_end; i += row_step) { - const auto dependency = colidxs[i]; - if (is_upper ? dependency <= row : dependency >= row) { - break; - } - auto x_p = &x[dependency * x_stride + rhs]; - - const auto dependency_gid = is_upper ? (n - 1 - dependency) * nrhs + rhs - : dependency * nrhs + rhs; - const bool shmem_possible = - (dependency_gid / default_block_size) == self_shmem_id; - if (shmem_possible) { - const auto dependency_shid = dependency_gid % default_block_size; - x_p = &x_s[dependency_shid]; - } - - ValueType x = *x_p; - while (is_nan(x)) { - x = load(x_p, 0); - } - - sum += x * vals[i]; - } - - // The first entry past the triangular part will be the diagonal - const auto diag = unit_diag ? one() : vals[i]; - const auto r = (b[row * b_stride + rhs] - sum) / diag; - - store(x_s, self_shid, r); - x[row * x_stride + rhs] = r; - - // This check to ensure no infinte loops happen. - if (is_nan(r)) { - store(x_s, self_shid, zero()); - x[row * x_stride + rhs] = zero(); - *nan_produced = true; - } -} - - -template -__global__ void sptrsv_naive_legacy_kernel( - const IndexType* const rowptrs, const IndexType* const colidxs, - const ValueType* const vals, const ValueType* const b, size_type b_stride, - ValueType* const x, size_type x_stride, const size_type n, - const size_type nrhs, bool unit_diag, bool* nan_produced, - IndexType* atomic_counter) -{ - __shared__ IndexType block_base_idx; - if (threadIdx.x == 0) { - block_base_idx = - atomic_add(atomic_counter, IndexType{1}) * fallback_block_size; - } - __syncthreads(); - const auto full_gid = static_cast(threadIdx.x) + block_base_idx; - const auto rhs = full_gid % nrhs; - const auto gid = full_gid / nrhs; - const auto row = is_upper ? n - 1 - gid : gid; - - if (gid >= n) { - return; - } - - // lower tri matrix: start at beginning, run forward - // upper tri matrix: start at last entry (row_end - 1), run backward - const auto row_begin = is_upper ? rowptrs[row + 1] - 1 : rowptrs[row]; - const auto row_end = is_upper ? rowptrs[row] - 1 : rowptrs[row + 1]; - const int row_step = is_upper ? -1 : 1; - - ValueType sum = 0.0; - auto j = row_begin; - auto col = colidxs[j]; - while (j != row_end) { - auto x_val = load(x, col * x_stride + rhs); - while (!is_nan(x_val)) { - sum += vals[j] * x_val; - j += row_step; - col = colidxs[j]; - x_val = load(x, col * x_stride + rhs); - } - // to avoid the kernel hanging on matrices without diagonal, - // we bail out if we are past the triangle, even if it's not - // the diagonal entry. This may lead to incorrect results, - // but prevents an infinite loop. - if (is_upper ? row >= col : row <= col) { - // assert(row == col); - auto diag = unit_diag ? one() : vals[j]; - const auto r = (b[row * b_stride + rhs] - sum) / diag; - store(x, row * x_stride + rhs, r); - // after we encountered the diagonal, we are done - // this also skips entries outside the triangle - j = row_end; - if (is_nan(r)) { - store(x, row * x_stride + rhs, zero()); - *nan_produced = true; - } - } - } -} - - -template -__global__ void sptrsv_init_kernel(bool* const nan_produced, - IndexType* const atomic_counter) -{ - *nan_produced = false; - *atomic_counter = IndexType{}; -} - - -template -void sptrsv_naive_caching(std::shared_ptr exec, - const matrix::Csr* matrix, - bool unit_diag, const matrix::Dense* b, - matrix::Dense* x) -{ - // Pre-Volta GPUs may deadlock due to missing independent thread scheduling. - const auto is_fallback_required = exec->get_major_version() < 7; - - const auto n = matrix->get_size()[0]; - const auto nrhs = b->get_size()[1]; - - // Initialize x to all NaNs. - dense::fill(exec, x, nan()); - - array nan_produced(exec, 1); - array atomic_counter(exec, 1); - sptrsv_init_kernel<<<1, 1>>>(nan_produced.get_data(), - atomic_counter.get_data()); - - const dim3 block_size( - is_fallback_required ? fallback_block_size : default_block_size, 1, 1); - const dim3 grid_size(ceildiv(n * nrhs, block_size.x), 1, 1); - - if (is_fallback_required) { - sptrsv_naive_legacy_kernel<<>>( - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - as_cuda_type(matrix->get_const_values()), - as_cuda_type(b->get_const_values()), b->get_stride(), - as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, unit_diag, - nan_produced.get_data(), atomic_counter.get_data()); - } else { - sptrsv_naive_caching_kernel<<>>( - matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), - as_cuda_type(matrix->get_const_values()), - as_cuda_type(b->get_const_values()), b->get_stride(), - as_cuda_type(x->get_values()), x->get_stride(), n, nrhs, unit_diag, - nan_produced.get_data(), atomic_counter.get_data()); - } - -#if GKO_VERBOSE_LEVEL >= 1 - if (exec->copy_val_to_host(nan_produced.get_const_data())) { - std::cerr - << "Error: triangular solve produced NaN, either not all diagonal " - "elements are nonzero, or the system is very ill-conditioned. " - "The NaN will be replaced with a zero.\n"; - } -#endif // GKO_VERBOSE_LEVEL >= 1 -} - - } // namespace } // namespace cuda } // namespace kernels + + } // namespace gko -#endif // GKO_CUDA_SOLVER_COMMON_TRS_KERNELS_CUH_ +#endif // GKO_CUDA_SOLVER_COMMON_TRS_KERNELS_CUH_ \ No newline at end of file diff --git a/cuda/solver/lower_trs_kernels.cu b/cuda/solver/lower_trs_kernels.cu index 6e1911221d3..4f833d2184e 100644 --- a/cuda/solver/lower_trs_kernels.cu +++ b/cuda/solver/lower_trs_kernels.cu @@ -73,13 +73,11 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* strategy, const size_type num_rhs) { - if (algorithm == solver::trisolve_algorithm::sparselib) { - generate_kernel(exec, matrix, solve_struct, - num_rhs, false, unit_diag); - } + gko::solver::SolveStruct::generate(exec, matrix, solve_struct, num_rhs, + strategy, false, unit_diag); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -90,16 +88,10 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { - if (algorithm == solver::trisolve_algorithm::sparselib) { - solve_kernel(exec, matrix, solve_struct, trans_b, - trans_x, b, x); - } else { - sptrsv_naive_caching(exec, matrix, unit_diag, b, x); - } + solve_struct->solve(exec, matrix, trans_b, trans_x, b, x); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/cuda/solver/upper_trs_kernels.cu b/cuda/solver/upper_trs_kernels.cu index 74d15ba4b19..25ea7f55039 100644 --- a/cuda/solver/upper_trs_kernels.cu +++ b/cuda/solver/upper_trs_kernels.cu @@ -73,13 +73,11 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* strategy, const size_type num_rhs) { - if (algorithm == solver::trisolve_algorithm::sparselib) { - generate_kernel(exec, matrix, solve_struct, - num_rhs, true, unit_diag); - } + gko::solver::SolveStruct::generate(exec, matrix, solve_struct, num_rhs, + strategy, true, unit_diag); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -90,16 +88,10 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { - if (algorithm == solver::trisolve_algorithm::sparselib) { - solve_kernel(exec, matrix, solve_struct, trans_b, - trans_x, b, x); - } else { - sptrsv_naive_caching(exec, matrix, unit_diag, b, x); - } + solve_struct->solve(exec, matrix, trans_b, trans_x, b, x); } GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( diff --git a/cuda/test/solver/lower_trs_kernels.cpp b/cuda/test/solver/lower_trs_kernels.cpp index 9d3d453d653..debb8385872 100644 --- a/cuda/test/solver/lower_trs_kernels.cpp +++ b/cuda/test/solver/lower_trs_kernels.cpp @@ -137,7 +137,7 @@ TEST_F(LowerTrs, CudaSingleRhsApplySyncfreeIsEquivalentToRef) auto lower_trs_factory = gko::solver::LowerTrs<>::build().on(ref); auto d_lower_trs_factory = gko::solver::LowerTrs<>::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .on(cuda); auto solver = lower_trs_factory->generate(csr_mtx); auto d_solver = d_lower_trs_factory->generate(d_csr_mtx); @@ -171,7 +171,7 @@ TEST_F(LowerTrs, CudaMultipleRhsApplySyncfreeIsEquivalentToRef) gko::solver::LowerTrs<>::build().with_num_rhs(3u).on(ref); auto d_lower_trs_factory = gko::solver::LowerTrs<>::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .with_num_rhs(3u) .on(cuda); auto solver = lower_trs_factory->generate(csr_mtx); diff --git a/cuda/test/solver/upper_trs_kernels.cpp b/cuda/test/solver/upper_trs_kernels.cpp index bf33c298e91..dacf6386df7 100644 --- a/cuda/test/solver/upper_trs_kernels.cpp +++ b/cuda/test/solver/upper_trs_kernels.cpp @@ -137,7 +137,7 @@ TEST_F(UpperTrs, CudaSingleRhsApplySyncfreelibIsEquivalentToRef) auto upper_trs_factory = gko::solver::UpperTrs<>::build().on(ref); auto d_upper_trs_factory = gko::solver::UpperTrs<>::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .on(cuda); auto solver = upper_trs_factory->generate(csr_mtx); auto d_solver = d_upper_trs_factory->generate(d_csr_mtx); @@ -171,7 +171,7 @@ TEST_F(UpperTrs, CudaMultipleRhsApplySyncfreeIsEquivalentToRef) gko::solver::UpperTrs<>::build().with_num_rhs(3u).on(ref); auto d_upper_trs_factory = gko::solver::UpperTrs<>::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .with_num_rhs(3u) .on(cuda); auto solver = upper_trs_factory->generate(csr_mtx); diff --git a/dpcpp/solver/lower_trs_kernels.dp.cpp b/dpcpp/solver/lower_trs_kernels.dp.cpp index ba808d4543e..8652c0e654d 100644 --- a/dpcpp/solver/lower_trs_kernels.dp.cpp +++ b/dpcpp/solver/lower_trs_kernels.dp.cpp @@ -70,7 +70,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy strategy, const size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -85,7 +85,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) GKO_NOT_IMPLEMENTED; diff --git a/dpcpp/solver/upper_trs_kernels.dp.cpp b/dpcpp/solver/upper_trs_kernels.dp.cpp index 101b0510b67..0d377073e06 100644 --- a/dpcpp/solver/upper_trs_kernels.dp.cpp +++ b/dpcpp/solver/upper_trs_kernels.dp.cpp @@ -70,7 +70,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy strategy, const size_type num_rhs) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( @@ -85,7 +85,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) GKO_NOT_IMPLEMENTED; diff --git a/hip/solver/lower_trs_kernels.hip.cpp b/hip/solver/lower_trs_kernels.hip.cpp index e0f97b16feb..1e0a5d68679 100644 --- a/hip/solver/lower_trs_kernels.hip.cpp +++ b/hip/solver/lower_trs_kernels.hip.cpp @@ -73,7 +73,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy strategy, const size_type num_rhs) { generate_kernel(exec, matrix, solve_struct, num_rhs, @@ -88,7 +88,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/hip/solver/upper_trs_kernels.hip.cpp b/hip/solver/upper_trs_kernels.hip.cpp index b6c0beb7aaf..578e82d3f9d 100644 --- a/hip/solver/upper_trs_kernels.hip.cpp +++ b/hip/solver/upper_trs_kernels.hip.cpp @@ -73,7 +73,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy strategy, const size_type num_rhs) { generate_kernel(exec, matrix, solve_struct, num_rhs, @@ -88,7 +88,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/include/ginkgo/core/solver/triangular.hpp b/include/ginkgo/core/solver/triangular.hpp index 6a4ba8d3d89..6ca0f16c11a 100644 --- a/include/ginkgo/core/solver/triangular.hpp +++ b/include/ginkgo/core/solver/triangular.hpp @@ -55,16 +55,68 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace gko { namespace solver { - -struct SolveStruct; - +enum class trisolve_type { + sparselib, + syncfree, + thinned, + block, + winv, + wvar, + level +}; /** * A helper for algorithm selection in the triangular solvers. - * It currently only matters for the Cuda executor as there, - * we have a choice between the Ginkgo syncfree and cuSPARSE implementations. + * It currently only matters for the Cuda executor, + * as we only there have different options. */ -enum class trisolve_algorithm { sparselib, syncfree }; +struct trisolve_strategy { + trisolve_type type; + uint8 thinned_m; + std::shared_ptr>> + block_inner; + + trisolve_strategy(trisolve_type type) : type{type} {} + trisolve_strategy(trisolve_type type, uint8 thinned_m) + : type{type}, thinned_m{thinned_m} + {} + trisolve_strategy( + trisolve_type type, + std::shared_ptr>> + block_inner) + : type{type}, block_inner{block_inner} + {} + + static std::shared_ptr sparselib() + { + return std::make_shared(trisolve_type::sparselib); + } + static std::shared_ptr syncfree() + { + return std::make_shared(trisolve_type::syncfree); + } +}; + + +struct SolveStruct { + virtual ~SolveStruct() = default; + + template + void solve(std::shared_ptr exec, + const matrix::Csr* matrix, + matrix::Dense* trans_b, + matrix::Dense* trans_x, + const matrix::Dense* b, + matrix::Dense* x) const; + + template + static void generate(std::shared_ptr exec, + const matrix::Csr* matrix, + std::shared_ptr& solve_struct, + const gko::size_type num_rhs, + const trisolve_strategy* strategy, bool is_upper, + bool unit_diag); +}; template @@ -112,7 +164,7 @@ class LowerTrs : public EnableLinOp>, * Number of right hand sides. * * @note This value is currently only required for the CUDA - * trisolve_algorithm::sparselib algorithm. + * trisolve_strategy::sparselib algorithm. */ gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u); @@ -128,8 +180,8 @@ class LowerTrs : public EnableLinOp>, * executor where the choice is between the Ginkgo (syncfree) and the * cuSPARSE (sparselib) implementation. Default is sparselib. */ - trisolve_algorithm GKO_FACTORY_PARAMETER_SCALAR( - algorithm, trisolve_algorithm::sparselib); + std::shared_ptr GKO_FACTORY_PARAMETER_VECTOR( + strategy, trisolve_strategy::sparselib()); }; GKO_ENABLE_LIN_OP_FACTORY(LowerTrs, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); @@ -262,7 +314,7 @@ class UpperTrs : public EnableLinOp>, * Number of right hand sides. * * @note This value is currently only required for the CUDA - * trisolve_algorithm::sparselib algorithm. + * trisolve_strategy::sparselib algorithm. */ gko::size_type GKO_FACTORY_PARAMETER_SCALAR(num_rhs, 1u); @@ -278,8 +330,8 @@ class UpperTrs : public EnableLinOp>, * executor where the choice is between the Ginkgo (syncfree) and the * cuSPARSE (sparselib) implementation. Default is sparselib. */ - trisolve_algorithm GKO_FACTORY_PARAMETER_SCALAR( - algorithm, trisolve_algorithm::sparselib); + std::shared_ptr GKO_FACTORY_PARAMETER_VECTOR( + strategy, trisolve_strategy::sparselib()); }; GKO_ENABLE_LIN_OP_FACTORY(UpperTrs, parameters, Factory); GKO_ENABLE_BUILD_METHOD(Factory); diff --git a/omp/solver/lower_trs_kernels.cpp b/omp/solver/lower_trs_kernels.cpp index 8629ab665dd..702910da1f0 100644 --- a/omp/solver/lower_trs_kernels.cpp +++ b/omp/solver/lower_trs_kernels.cpp @@ -70,7 +70,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* strategy, const size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated @@ -90,7 +90,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/omp/solver/upper_trs_kernels.cpp b/omp/solver/upper_trs_kernels.cpp index 09dabd8c19c..c00f5eb80a1 100644 --- a/omp/solver/upper_trs_kernels.cpp +++ b/omp/solver/upper_trs_kernels.cpp @@ -70,7 +70,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* strategy, const size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated @@ -90,7 +90,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/reference/solver/lower_trs_kernels.cpp b/reference/solver/lower_trs_kernels.cpp index 813c93f314a..f195c2cf09f 100644 --- a/reference/solver/lower_trs_kernels.cpp +++ b/reference/solver/lower_trs_kernels.cpp @@ -66,7 +66,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* algorithm, const size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated @@ -87,7 +87,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense*, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/reference/solver/upper_trs_kernels.cpp b/reference/solver/upper_trs_kernels.cpp index 8cf5263b51f..7123d18d0be 100644 --- a/reference/solver/upper_trs_kernels.cpp +++ b/reference/solver/upper_trs_kernels.cpp @@ -66,7 +66,7 @@ template void generate(std::shared_ptr exec, const matrix::Csr* matrix, std::shared_ptr& solve_struct, - bool unit_diag, const solver::trisolve_algorithm algorithm, + bool unit_diag, const solver::trisolve_strategy* strategy, const size_type num_rhs) { // This generate kernel is here to allow for a more sophisticated @@ -87,7 +87,6 @@ template void solve(std::shared_ptr exec, const matrix::Csr* matrix, const solver::SolveStruct* solve_struct, bool unit_diag, - const solver::trisolve_algorithm algorithm, matrix::Dense* trans_b, matrix::Dense* trans_x, const matrix::Dense* b, matrix::Dense* x) { diff --git a/test/solver/solver.cpp b/test/solver/solver.cpp index 19f280df5d8..1f69a8cb453 100644 --- a/test/solver/solver.cpp +++ b/test/solver/solver.cpp @@ -298,7 +298,7 @@ struct LowerTrs : SimpleSolverTest> { std::shared_ptr exec, gko::size_type num_rhs) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::sparselib) + .with_strategy(gko::solver::trisolve_strategy::sparselib()) .with_num_rhs(num_rhs); } @@ -348,7 +348,7 @@ struct UpperTrs : SimpleSolverTest> { std::shared_ptr exec, gko::size_type num_rhs) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::sparselib) + .with_strategy(gko::solver::trisolve_strategy::sparselib()) .with_num_rhs(num_rhs); } @@ -379,7 +379,7 @@ struct LowerTrsUnitdiag : LowerTrs { std::shared_ptr exec, gko::size_type num_rhs) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::sparselib) + .with_strategy(gko::solver::trisolve_strategy::sparselib()) .with_num_rhs(num_rhs) .with_unit_diagonal(true); } @@ -391,7 +391,7 @@ struct UpperTrsUnitdiag : UpperTrs { std::shared_ptr exec, gko::size_type num_rhs) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::sparselib) + .with_strategy(gko::solver::trisolve_strategy::sparselib()) .with_num_rhs(num_rhs) .with_unit_diagonal(true); } @@ -405,8 +405,8 @@ struct LowerTrsSyncfree : LowerTrs { std::shared_ptr exec, gko::size_type iteration_count) { - return solver_type::build().with_algorithm( - gko::solver::trisolve_algorithm::syncfree); + return solver_type::build().with_strategy( + gko::solver::trisolve_strategy::syncfree()); } }; @@ -418,8 +418,8 @@ struct UpperTrsSyncfree : UpperTrs { std::shared_ptr exec, gko::size_type iteration_count) { - return solver_type::build().with_algorithm( - gko::solver::trisolve_algorithm::syncfree); + return solver_type::build().with_strategy( + gko::solver::trisolve_strategy::syncfree()); } }; @@ -432,7 +432,7 @@ struct LowerTrsSyncfreeUnitdiag : LowerTrs { gko::size_type iteration_count) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .with_unit_diagonal(true); } }; @@ -446,7 +446,7 @@ struct UpperTrsSyncfreeUnitdiag : UpperTrs { gko::size_type iteration_count) { return solver_type::build() - .with_algorithm(gko::solver::trisolve_algorithm::syncfree) + .with_strategy(gko::solver::trisolve_strategy::syncfree()) .with_unit_diagonal(true); } };