From 560c14bcd6bfcd860b5202237a6d2eb122783977 Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Sun, 12 Jan 2025 12:27:38 +0100 Subject: [PATCH] add device kernels for validation --- .../factorization/factorization_kernels.cpp | 110 +++++++++++++++++- core/factorization/factorization_kernels.hpp | 5 +- .../factorization_kernels.dp.cpp | 10 +- omp/factorization/factorization_kernels.cpp | 9 +- .../factorization/factorization_kernels.cpp | 9 +- reference/test/factorization/lu_kernels.cpp | 37 +++++- test/factorization/lu_kernels.cpp | 91 +++++++++++++++ 7 files changed, 249 insertions(+), 22 deletions(-) diff --git a/common/cuda_hip/factorization/factorization_kernels.cpp b/common/cuda_hip/factorization/factorization_kernels.cpp index 758a91bd71d..80a07c2a32d 100644 --- a/common/cuda_hip/factorization/factorization_kernels.cpp +++ b/common/cuda_hip/factorization/factorization_kernels.cpp @@ -4,11 +4,14 @@ #include "core/factorization/factorization_kernels.hpp" +#include + #include #include "common/cuda_hip/base/config.hpp" #include "common/cuda_hip/base/math.hpp" #include "common/cuda_hip/base/runtime.hpp" +#include "common/cuda_hip/base/thrust.hpp" #include "common/cuda_hip/base/types.hpp" #include "common/cuda_hip/components/cooperative_groups.hpp" #include "common/cuda_hip/components/intrinsics.hpp" @@ -16,6 +19,7 @@ #include "common/cuda_hip/components/thread_ids.hpp" #include "common/cuda_hip/factorization/factorization_helpers.hpp" #include "core/base/array_access.hpp" +#include "core/components/fill_array_kernels.hpp" #include "core/components/prefix_sum_kernels.hpp" #include "core/matrix/csr_builder.hpp" @@ -277,6 +281,73 @@ __global__ __launch_bounds__(default_block_size) void count_nnz_per_l_row( } +template +__global__ __launch_bounds__(default_block_size) void symbolic_validate( + const IndexType* __restrict__ mtx_row_ptrs, + const IndexType* __restrict__ mtx_cols, + const IndexType* __restrict__ factor_row_ptrs, + const IndexType* __restrict__ factor_cols, size_type size, + const IndexType* __restrict__ storage_offsets, + const int64* __restrict__ row_descs, const int32* __restrict__ storage, + bool* __restrict__ found, bool* __restrict__ missing) +{ + const auto row = thread::get_subwarp_id_flat(); + if (row >= size) { + return; + } + const auto warp = + group::tiled_partition(group::this_thread_block()); + const auto lane = warp.thread_rank(); + gko::matrix::csr::device_sparsity_lookup lookup{ + factor_row_ptrs, factor_cols, storage_offsets, + storage, row_descs, static_cast(row)}; + const auto mtx_begin = mtx_row_ptrs[row]; + const auto mtx_end = mtx_row_ptrs[row + 1]; + const auto factor_begin = factor_row_ptrs[row]; + const auto factor_end = factor_row_ptrs[row + 1]; + bool local_missing = false; + const auto mark_found = [&](IndexType col) { + const auto local_idx = lookup[col]; + const auto idx = local_idx + factor_begin; + if (local_idx == invalid_index()) { + local_missing = true; + } + found[idx] = true; + }; + // check the original matrix is part of the factors + for (auto nz = mtx_begin + lane; nz < mtx_end; nz += config::warp_size) { + mark_found(mtx_cols[nz]); + } + // check the diagonal is part of the factors + if (lane == 0) { + mark_found(row); + } + // check it is a valid factorization + for (auto nz = factor_begin; nz < factor_end; nz++) { + const auto dep = factor_cols[nz]; + if (dep >= row) { + continue; + } + // for every lower triangular entry + const auto dep_begin = factor_row_ptrs[dep]; + const auto dep_end = factor_row_ptrs[dep + 1]; + for (auto dep_nz = dep_begin + lane; dep_nz < dep_end; + dep_nz += config::warp_size) { + const auto col = factor_cols[dep_nz]; + // check every upper triangular entry thereof is part of the + // factorization + if (col > dep) { + mark_found(col); + } + } + } + local_missing = warp.any(local_missing); + if (lane == 0) { + missing[row] = local_missing; + } +} + + } // namespace kernel @@ -488,10 +559,41 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( template -void symbolic_validate(std::shared_ptr exec, - const matrix::Csr* system_matrix, - const matrix::Csr* factors, - bool& valid) GKO_NOT_IMPLEMENTED; +void symbolic_validate( + std::shared_ptr exec, + const matrix::Csr* system_matrix, + const matrix::Csr* factors, + const matrix::csr::lookup_data& factors_lookup, bool& valid) +{ + const auto size = system_matrix->get_size()[0]; + const auto row_ptrs = system_matrix->get_const_row_ptrs(); + const auto col_idxs = system_matrix->get_const_col_idxs(); + const auto factor_row_ptrs = factors->get_const_row_ptrs(); + const auto factor_col_idxs = factors->get_const_col_idxs(); + // this stores for each factor nonzero whether it occurred as part of the + // factorization. + array found(exec, factors->get_num_stored_elements()); + components::fill_array(exec, found.get_data(), found.get_size(), false); + // this stores for each row whether there were any elements missing + array missing(exec, size); + components::fill_array(exec, missing.get_data(), missing.get_size(), false); + if (size > 0) { + const auto num_blocks = + ceildiv(size, default_block_size / config::warp_size); + kernel::symbolic_validate<<>>( + row_ptrs, col_idxs, factor_row_ptrs, factor_col_idxs, size, + factors_lookup.storage_offsets.get_const_data(), + factors_lookup.row_descs.get_const_data(), + factors_lookup.storage.get_const_data(), found.get_data(), + missing.get_data()); + } + valid = thrust::all_of(thrust_policy(exec), found.get_const_data(), + found.get_const_data() + found.get_size(), + thrust::identity{}) && + !thrust::any_of(thrust_policy(exec), missing.get_const_data(), + missing.get_const_data() + missing.get_size(), + thrust::identity{}); +} GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_FACTORIZATION_SYMBOLIC_VALIDATE_KERNEL); diff --git a/core/factorization/factorization_kernels.hpp b/core/factorization/factorization_kernels.hpp index fdfca1d2427..4619647a23e 100644 --- a/core/factorization/factorization_kernels.hpp +++ b/core/factorization/factorization_kernels.hpp @@ -13,6 +13,7 @@ #include #include "core/base/kernel_declaration.hpp" +#include "core/matrix/csr_lookup.hpp" namespace gko { @@ -57,7 +58,9 @@ namespace kernels { void symbolic_validate( \ std::shared_ptr exec, \ const matrix::Csr* system_matrix, \ - const matrix::Csr* factors, bool& valid) + const matrix::Csr* factors, \ + const matrix::csr::lookup_data& factors_lookup, \ + bool& valid) #define GKO_DECLARE_ALL_AS_TEMPLATES \ diff --git a/dpcpp/factorization/factorization_kernels.dp.cpp b/dpcpp/factorization/factorization_kernels.dp.cpp index 18749bbf886..8fd11b3a68a 100644 --- a/dpcpp/factorization/factorization_kernels.dp.cpp +++ b/dpcpp/factorization/factorization_kernels.dp.cpp @@ -591,10 +591,12 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( template -void symbolic_validate(std::shared_ptr exec, - const matrix::Csr* system_matrix, - const matrix::Csr* factors, - bool& valid) GKO_NOT_IMPLEMENTED; +void symbolic_validate( + std::shared_ptr exec, + const matrix::Csr* system_matrix, + const matrix::Csr* factors, + const matrix::csr::lookup_data& factors_lookup, + bool& valid) GKO_NOT_IMPLEMENTED; GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( GKO_DECLARE_FACTORIZATION_SYMBOLIC_VALIDATE_KERNEL); diff --git a/omp/factorization/factorization_kernels.cpp b/omp/factorization/factorization_kernels.cpp index a6a59f7dbc3..28f772532e7 100644 --- a/omp/factorization/factorization_kernels.cpp +++ b/omp/factorization/factorization_kernels.cpp @@ -339,10 +339,11 @@ bool symbolic_validate_impl(std::shared_ptr exec, } template -void symbolic_validate(std::shared_ptr exec, - const matrix::Csr* system_matrix, - const matrix::Csr* factors, - bool& valid) +void symbolic_validate( + std::shared_ptr exec, + const matrix::Csr* system_matrix, + const matrix::Csr* factors, + const matrix::csr::lookup_data& factors_lookup, bool& valid) { valid = symbolic_validate_impl( exec, system_matrix->get_const_row_ptrs(), diff --git a/reference/factorization/factorization_kernels.cpp b/reference/factorization/factorization_kernels.cpp index 6388f62adcc..b07cb9d783a 100644 --- a/reference/factorization/factorization_kernels.cpp +++ b/reference/factorization/factorization_kernels.cpp @@ -274,10 +274,11 @@ bool symbolic_validate_impl(std::shared_ptr exec, } template -void symbolic_validate(std::shared_ptr exec, - const matrix::Csr* system_matrix, - const matrix::Csr* factors, - bool& valid) +void symbolic_validate( + std::shared_ptr exec, + const matrix::Csr* system_matrix, + const matrix::Csr* factors, + const matrix::csr::lookup_data& factors_lookup, bool& valid) { valid = symbolic_validate_impl( exec, system_matrix->get_const_row_ptrs(), diff --git a/reference/test/factorization/lu_kernels.cpp b/reference/test/factorization/lu_kernels.cpp index a0358a6f23d..00d4aefab70 100644 --- a/reference/test/factorization/lu_kernels.cpp +++ b/reference/test/factorization/lu_kernels.cpp @@ -17,6 +17,7 @@ #include #include +#include "core/base/index_range.hpp" #include "core/components/prefix_sum_kernels.hpp" #include "core/factorization/cholesky_kernels.hpp" #include "core/factorization/elimination_forest.hpp" @@ -340,19 +341,43 @@ TYPED_TEST(Lu, ValidateValidFactors) bool valid = false; gko::kernels::reference::factorization::symbolic_validate( - this->ref, this->mtx.get(), this->mtx_lu.get(), valid); + this->ref, this->mtx.get(), this->mtx_lu.get(), + gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid); ASSERT_TRUE(valid); }); } +TYPED_TEST(Lu, ValidateInvalidFactorsIdentity) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + this->forall_matrices([this] { + bool valid = true; + gko::matrix_data data(this->mtx_lu->get_size()); + // an identity matrix is a valid factorization, but doesn't contain the + // system matrix + for (auto row : gko::irange{static_cast(data.size[0])}) { + data.nonzeros.emplace_back(row, row, gko::one()); + } + this->mtx_lu->read(data); + + gko::kernels::reference::factorization::symbolic_validate( + this->ref, this->mtx.get(), this->mtx_lu.get(), + gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid); + + ASSERT_FALSE(valid); + }); +} + + TYPED_TEST(Lu, ValidateInvalidFactorsMissing) { using value_type = typename TestFixture::value_type; using index_type = typename TestFixture::index_type; this->forall_matrices([this] { - bool valid = false; + bool valid = true; gko::matrix_data data; this->mtx_lu->write(data); // delete a random entry somewhere in the middle of the matrix @@ -361,7 +386,8 @@ TYPED_TEST(Lu, ValidateInvalidFactorsMissing) this->mtx_lu->read(data); gko::kernels::reference::factorization::symbolic_validate( - this->ref, this->mtx.get(), this->mtx_lu.get(), valid); + this->ref, this->mtx.get(), this->mtx_lu.get(), + gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid); ASSERT_FALSE(valid); }); @@ -373,7 +399,7 @@ TYPED_TEST(Lu, ValidateInvalidFactorsExtra) using value_type = typename TestFixture::value_type; using index_type = typename TestFixture::index_type; this->forall_matrices([this] { - bool valid = false; + bool valid = true; gko::matrix_data data; this->mtx_lu->write(data); const auto it = std::adjacent_find( @@ -385,7 +411,8 @@ TYPED_TEST(Lu, ValidateInvalidFactorsExtra) this->mtx_lu->read(data); gko::kernels::reference::factorization::symbolic_validate( - this->ref, this->mtx.get(), this->mtx_lu.get(), valid); + this->ref, this->mtx.get(), this->mtx_lu.get(), + gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid); ASSERT_FALSE(valid); }); diff --git a/test/factorization/lu_kernels.cpp b/test/factorization/lu_kernels.cpp index 739e8907cab..d1d6feaea26 100644 --- a/test/factorization/lu_kernels.cpp +++ b/test/factorization/lu_kernels.cpp @@ -19,10 +19,12 @@ #include #include +#include "core/base/index_range.hpp" #include "core/components/fill_array_kernels.hpp" #include "core/components/prefix_sum_kernels.hpp" #include "core/factorization/cholesky_kernels.hpp" #include "core/factorization/elimination_forest.hpp" +#include "core/factorization/factorization_kernels.hpp" #include "core/factorization/symbolic.hpp" #include "core/matrix/csr_kernels.hpp" #include "core/matrix/csr_lookup.hpp" @@ -189,6 +191,95 @@ TYPED_TEST(Lu, KernelFactorizeIsEquivalentToRef) } +TYPED_TEST(Lu, KernelValidateValidFactors) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + this->forall_matrices([this] { + bool valid = false; + + gko::kernels::GKO_DEVICE_NAMESPACE::factorization::symbolic_validate( + this->exec, this->dmtx.get(), this->dmtx_lu.get(), + gko::matrix::csr::build_lookup(this->dmtx_lu.get()), valid); + + ASSERT_TRUE(valid); + }); +} + + +TYPED_TEST(Lu, KernelValidateInvalidFactorsIdentity) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + this->forall_matrices([this] { + bool valid = true; + gko::matrix_data data( + this->dmtx_lu->get_size()); + // an identity matrix is a valid factorization, but doesn't contain the + // system matrix + for (auto row : gko::irange{static_cast(data.size[0])}) { + data.nonzeros.emplace_back(row, row, gko::one()); + } + this->dmtx_lu->read(data); + + gko::kernels::GKO_DEVICE_NAMESPACE::factorization::symbolic_validate( + this->exec, this->dmtx.get(), this->dmtx_lu.get(), + gko::matrix::csr::build_lookup(this->dmtx_lu.get()), valid); + + ASSERT_FALSE(valid); + }); +} + + +TYPED_TEST(Lu, KernelValidateInvalidFactorsMissing) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + this->forall_matrices([this] { + bool valid = true; + gko::matrix_data data; + this->dmtx_lu->write(data); + // delete a random entry somewhere in the middle of the matrix + data.nonzeros.erase(data.nonzeros.begin() + + data.nonzeros.size() * 3 / 4); + this->dmtx_lu->read(data); + + gko::kernels::GKO_DEVICE_NAMESPACE::factorization::symbolic_validate( + this->exec, this->dmtx.get(), this->dmtx_lu.get(), + gko::matrix::csr::build_lookup(this->dmtx_lu.get()), valid); + + ASSERT_FALSE(valid); + }); +} + + +TYPED_TEST(Lu, KernelValidateInvalidFactorsExtra) +{ + using value_type = typename TestFixture::value_type; + using index_type = typename TestFixture::index_type; + this->forall_matrices([this] { + bool valid = true; + gko::matrix_data data; + this->dmtx_lu->write(data); + // insert an entry between two non-adjacent values in a row somewhere + // not at the beginning + const auto it = std::adjacent_find( + data.nonzeros.begin() + data.nonzeros.size() / 5, + data.nonzeros.end(), [](auto a, auto b) { + return a.row == b.row && a.column < b.column - 1; + }); + data.nonzeros.insert(it, {it->row, it->column + 1, it->value}); + this->dmtx_lu->read(data); + + gko::kernels::GKO_DEVICE_NAMESPACE::factorization::symbolic_validate( + this->exec, this->dmtx.get(), this->dmtx_lu.get(), + gko::matrix::csr::build_lookup(this->dmtx_lu.get()), valid); + + ASSERT_FALSE(valid); + }); +} + + TYPED_TEST(Lu, SymbolicCholeskyWorks) { using value_type = typename TestFixture::value_type;