Skip to content

Commit

Permalink
fix conversion error and generator test
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jun 13, 2023
1 parent 9418553 commit 281c9ed
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
13 changes: 9 additions & 4 deletions core/test/utils/matrix_generator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ TYPED_TEST(MatrixGenerator, CanGenerateTridiagMatrix)
{
using T = typename TestFixture::value_type;
using Dense = typename TestFixture::mtx_type;
auto dist = std::normal_distribution<gko::remove_complex<T>>(0, 1);
auto dist = std::normal_distribution<>(0, 1);
auto engine = std::default_random_engine(42);
auto lower = gko::test::detail::get_rand_value<T>(dist, engine);
auto diag = gko::test::detail::get_rand_value<T>(dist, engine);
Expand All @@ -304,18 +304,23 @@ TYPED_TEST(MatrixGenerator, CanGenerateTridiagInverseMatrix)
{
using T = typename TestFixture::value_type;
using Dense = typename TestFixture::mtx_type;
auto dist = std::normal_distribution<gko::remove_complex<T>>(0, 1);
auto dist = std::normal_distribution<>(0, 1);
auto engine = std::default_random_engine(42);
auto lower = gko::test::detail::get_rand_value<T>(dist, engine);
auto upper = gko::test::detail::get_rand_value<T>(dist, engine);
// make diagonally dominant
auto diag = std::abs(gko::test::detail::get_rand_value<T>(dist, engine)) +
std::abs(lower) + std::abs(upper);
gko::size_type size = 50;
if (std::is_same<gko::half, gko::remove_complex<T>>::value) {
// half precision can only handle small matrix
size = 5;
}

auto mtx = gko::test::generate_tridiag_matrix<Dense>(
50, {lower, diag, upper}, this->exec);
size, {lower, diag, upper}, this->exec);
auto inv_mtx = gko::test::generate_tridiag_inverse_matrix<Dense>(
50, {lower, diag, upper}, this->exec);
size, {lower, diag, upper}, this->exec);

auto result = Dense::create(this->exec, mtx->get_size());
inv_mtx->apply(mtx, result);
Expand Down
4 changes: 2 additions & 2 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,

auto row_ptrs = a->get_const_row_ptrs();
auto col_idxs = a->get_const_col_idxs();
arithmetic_type valpha = alpha->at(0, 0);
arithmetic_type vbeta = beta->at(0, 0);
arithmetic_type valpha = static_cast<arithmetic_type>(alpha->at(0, 0));
arithmetic_type vbeta = static_cast<arithmetic_type>(beta->at(0, 0));

const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
Expand Down
4 changes: 2 additions & 2 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,

auto row_ptrs = a->get_const_row_ptrs();
auto col_idxs = a->get_const_col_idxs();
arithmetic_type valpha = alpha->at(0, 0);
arithmetic_type vbeta = beta->at(0, 0);
arithmetic_type valpha = static_cast<arithmetic_type>(alpha->at(0, 0));
arithmetic_type vbeta = static_cast<arithmetic_type>(beta->at(0, 0));

const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
Expand Down

0 comments on commit 281c9ed

Please sign in to comment.