Skip to content

Commit

Permalink
add device kernels for validation
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jan 12, 2025
1 parent 64ead5e commit 560c14b
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 22 deletions.
110 changes: 106 additions & 4 deletions common/cuda_hip/factorization/factorization_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,22 @@

#include "core/factorization/factorization_kernels.hpp"

#include <thrust/logical.h>

#include <ginkgo/core/base/array.hpp>

#include "common/cuda_hip/base/config.hpp"
#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/runtime.hpp"
#include "common/cuda_hip/base/thrust.hpp"
#include "common/cuda_hip/base/types.hpp"
#include "common/cuda_hip/components/cooperative_groups.hpp"
#include "common/cuda_hip/components/intrinsics.hpp"
#include "common/cuda_hip/components/searching.hpp"
#include "common/cuda_hip/components/thread_ids.hpp"
#include "common/cuda_hip/factorization/factorization_helpers.hpp"
#include "core/base/array_access.hpp"
#include "core/components/fill_array_kernels.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "core/matrix/csr_builder.hpp"

Expand Down Expand Up @@ -277,6 +281,73 @@ __global__ __launch_bounds__(default_block_size) void count_nnz_per_l_row(
}


template <typename IndexType>
__global__ __launch_bounds__(default_block_size) void symbolic_validate(
const IndexType* __restrict__ mtx_row_ptrs,
const IndexType* __restrict__ mtx_cols,
const IndexType* __restrict__ factor_row_ptrs,
const IndexType* __restrict__ factor_cols, size_type size,
const IndexType* __restrict__ storage_offsets,
const int64* __restrict__ row_descs, const int32* __restrict__ storage,
bool* __restrict__ found, bool* __restrict__ missing)
{
const auto row = thread::get_subwarp_id_flat<config::warp_size>();
if (row >= size) {
return;
}
const auto warp =
group::tiled_partition<config::warp_size>(group::this_thread_block());
const auto lane = warp.thread_rank();
gko::matrix::csr::device_sparsity_lookup<IndexType> lookup{
factor_row_ptrs, factor_cols, storage_offsets,
storage, row_descs, static_cast<size_type>(row)};
const auto mtx_begin = mtx_row_ptrs[row];
const auto mtx_end = mtx_row_ptrs[row + 1];
const auto factor_begin = factor_row_ptrs[row];
const auto factor_end = factor_row_ptrs[row + 1];
bool local_missing = false;
const auto mark_found = [&](IndexType col) {
const auto local_idx = lookup[col];
const auto idx = local_idx + factor_begin;
if (local_idx == invalid_index<IndexType>()) {
local_missing = true;
}
found[idx] = true;
};
// check the original matrix is part of the factors
for (auto nz = mtx_begin + lane; nz < mtx_end; nz += config::warp_size) {
mark_found(mtx_cols[nz]);
}
// check the diagonal is part of the factors
if (lane == 0) {
mark_found(row);
}
// check it is a valid factorization
for (auto nz = factor_begin; nz < factor_end; nz++) {
const auto dep = factor_cols[nz];
if (dep >= row) {
continue;
}
// for every lower triangular entry
const auto dep_begin = factor_row_ptrs[dep];
const auto dep_end = factor_row_ptrs[dep + 1];
for (auto dep_nz = dep_begin + lane; dep_nz < dep_end;
dep_nz += config::warp_size) {
const auto col = factor_cols[dep_nz];
// check every upper triangular entry thereof is part of the
// factorization
if (col > dep) {
mark_found(col);
}
}
}
local_missing = warp.any(local_missing);
if (lane == 0) {
missing[row] = local_missing;
}
}


} // namespace kernel


Expand Down Expand Up @@ -488,10 +559,41 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(


template <typename ValueType, typename IndexType>
void symbolic_validate(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
bool& valid) GKO_NOT_IMPLEMENTED;
void symbolic_validate(
std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
const matrix::csr::lookup_data<IndexType>& factors_lookup, bool& valid)
{
const auto size = system_matrix->get_size()[0];
const auto row_ptrs = system_matrix->get_const_row_ptrs();
const auto col_idxs = system_matrix->get_const_col_idxs();
const auto factor_row_ptrs = factors->get_const_row_ptrs();
const auto factor_col_idxs = factors->get_const_col_idxs();
// this stores for each factor nonzero whether it occurred as part of the
// factorization.
array<bool> found(exec, factors->get_num_stored_elements());
components::fill_array(exec, found.get_data(), found.get_size(), false);
// this stores for each row whether there were any elements missing
array<bool> missing(exec, size);
components::fill_array(exec, missing.get_data(), missing.get_size(), false);
if (size > 0) {
const auto num_blocks =
ceildiv(size, default_block_size / config::warp_size);
kernel::symbolic_validate<<<num_blocks, default_block_size>>>(
row_ptrs, col_idxs, factor_row_ptrs, factor_col_idxs, size,
factors_lookup.storage_offsets.get_const_data(),
factors_lookup.row_descs.get_const_data(),
factors_lookup.storage.get_const_data(), found.get_data(),
missing.get_data());
}
valid = thrust::all_of(thrust_policy(exec), found.get_const_data(),
found.get_const_data() + found.get_size(),
thrust::identity<bool>{}) &&
!thrust::any_of(thrust_policy(exec), missing.get_const_data(),
missing.get_const_data() + missing.get_size(),
thrust::identity<bool>{});
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_FACTORIZATION_SYMBOLIC_VALIDATE_KERNEL);
Expand Down
5 changes: 4 additions & 1 deletion core/factorization/factorization_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ginkgo/core/matrix/csr.hpp>

#include "core/base/kernel_declaration.hpp"
#include "core/matrix/csr_lookup.hpp"


namespace gko {
Expand Down Expand Up @@ -57,7 +58,9 @@ namespace kernels {
void symbolic_validate( \
std::shared_ptr<const DefaultExecutor> exec, \
const matrix::Csr<ValueType, IndexType>* system_matrix, \
const matrix::Csr<ValueType, IndexType>* factors, bool& valid)
const matrix::Csr<ValueType, IndexType>* factors, \
const matrix::csr::lookup_data<IndexType>& factors_lookup, \
bool& valid)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
Expand Down
10 changes: 6 additions & 4 deletions dpcpp/factorization/factorization_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,10 +591,12 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(


template <typename ValueType, typename IndexType>
void symbolic_validate(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
bool& valid) GKO_NOT_IMPLEMENTED;
void symbolic_validate(
std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
const matrix::csr::lookup_data<IndexType>& factors_lookup,
bool& valid) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
GKO_DECLARE_FACTORIZATION_SYMBOLIC_VALIDATE_KERNEL);
Expand Down
9 changes: 5 additions & 4 deletions omp/factorization/factorization_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,11 @@ bool symbolic_validate_impl(std::shared_ptr<const DefaultExecutor> exec,
}

template <typename ValueType, typename IndexType>
void symbolic_validate(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
bool& valid)
void symbolic_validate(
std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
const matrix::csr::lookup_data<IndexType>& factors_lookup, bool& valid)
{
valid = symbolic_validate_impl(
exec, system_matrix->get_const_row_ptrs(),
Expand Down
9 changes: 5 additions & 4 deletions reference/factorization/factorization_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,11 @@ bool symbolic_validate_impl(std::shared_ptr<const DefaultExecutor> exec,
}

template <typename ValueType, typename IndexType>
void symbolic_validate(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
bool& valid)
void symbolic_validate(
std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<ValueType, IndexType>* system_matrix,
const matrix::Csr<ValueType, IndexType>* factors,
const matrix::csr::lookup_data<IndexType>& factors_lookup, bool& valid)
{
valid = symbolic_validate_impl(
exec, system_matrix->get_const_row_ptrs(),
Expand Down
37 changes: 32 additions & 5 deletions reference/test/factorization/lu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ginkgo/core/matrix/csr.hpp>
#include <ginkgo/core/matrix/sparsity_csr.hpp>

#include "core/base/index_range.hpp"
#include "core/components/prefix_sum_kernels.hpp"
#include "core/factorization/cholesky_kernels.hpp"
#include "core/factorization/elimination_forest.hpp"
Expand Down Expand Up @@ -340,19 +341,43 @@ TYPED_TEST(Lu, ValidateValidFactors)
bool valid = false;

gko::kernels::reference::factorization::symbolic_validate(
this->ref, this->mtx.get(), this->mtx_lu.get(), valid);
this->ref, this->mtx.get(), this->mtx_lu.get(),
gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid);

ASSERT_TRUE(valid);
});
}


TYPED_TEST(Lu, ValidateInvalidFactorsIdentity)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
this->forall_matrices([this] {
bool valid = true;
gko::matrix_data<value_type, index_type> data(this->mtx_lu->get_size());
// an identity matrix is a valid factorization, but doesn't contain the
// system matrix
for (auto row : gko::irange{static_cast<index_type>(data.size[0])}) {
data.nonzeros.emplace_back(row, row, gko::one<value_type>());
}
this->mtx_lu->read(data);

gko::kernels::reference::factorization::symbolic_validate(
this->ref, this->mtx.get(), this->mtx_lu.get(),
gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid);

ASSERT_FALSE(valid);
});
}


TYPED_TEST(Lu, ValidateInvalidFactorsMissing)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
this->forall_matrices([this] {
bool valid = false;
bool valid = true;
gko::matrix_data<value_type, index_type> data;
this->mtx_lu->write(data);
// delete a random entry somewhere in the middle of the matrix
Expand All @@ -361,7 +386,8 @@ TYPED_TEST(Lu, ValidateInvalidFactorsMissing)
this->mtx_lu->read(data);

gko::kernels::reference::factorization::symbolic_validate(
this->ref, this->mtx.get(), this->mtx_lu.get(), valid);
this->ref, this->mtx.get(), this->mtx_lu.get(),
gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid);

ASSERT_FALSE(valid);
});
Expand All @@ -373,7 +399,7 @@ TYPED_TEST(Lu, ValidateInvalidFactorsExtra)
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
this->forall_matrices([this] {
bool valid = false;
bool valid = true;
gko::matrix_data<value_type, index_type> data;
this->mtx_lu->write(data);
const auto it = std::adjacent_find(
Expand All @@ -385,7 +411,8 @@ TYPED_TEST(Lu, ValidateInvalidFactorsExtra)
this->mtx_lu->read(data);

gko::kernels::reference::factorization::symbolic_validate(
this->ref, this->mtx.get(), this->mtx_lu.get(), valid);
this->ref, this->mtx.get(), this->mtx_lu.get(),
gko::matrix::csr::build_lookup(this->mtx_lu.get()), valid);

ASSERT_FALSE(valid);
});
Expand Down
Loading

0 comments on commit 560c14b

Please sign in to comment.