Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Half precision support #1257

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
4ed36a8
only can compile cuda/omp
yhmtsai Nov 29, 2022
ddc7c16
next_precision to itself when complex only float, double add empty co…
yhmtsai Jan 5, 2023
501e6c7
can compile with cuda/omp/ref (without test)
yhmtsai Jan 8, 2023
957d29c
compile for cuda/sycl/test/mpi (hip needs trick)
yhmtsai Jan 11, 2023
a0c389c
hip finally
yhmtsai Jan 12, 2023
fe5e491
fix the narrow issue and atomic support
yhmtsai Jan 12, 2023
6c3c12b
fixed more error
yhmtsai Jan 12, 2023
a0ee872
fix the op order and gdb
yhmtsai Jan 12, 2023
5acbf27
add the rand template not_implemented
yhmtsai Jan 12, 2023
b171312
this version can compile/run complex<half> on cuda114
yhmtsai Jan 12, 2023
6c17701
does not work for the other executor
yhmtsai Jan 14, 2023
cdbf0a0
fix complex issue and sqrt issue
yhmtsai Feb 6, 2023
209c799
try fix the compilation issue from MSVC and MacOS
yhmtsai Feb 6, 2023
75b54fa
move the half to public and use sycl::half for dpcpp
yhmtsai Feb 7, 2023
48ea338
limit the next precision in test and benchmark
yhmtsai Feb 7, 2023
fdcc066
allow disable half operation
yhmtsai Feb 7, 2023
f0a8a07
fix macro
yhmtsai Feb 8, 2023
f041b4a
clean and refine the code
yhmtsai Feb 8, 2023
3154a04
move half.hpp out of type.hpp
yhmtsai Feb 8, 2023
58784ab
enable half for testing
yhmtsai Feb 8, 2023
9b2465b
__habs is added in cuda10.2
yhmtsai Feb 8, 2023
e2a6c9a
fix nullptr and missing instantiation.
yhmtsai Feb 9, 2023
51cf597
fix missing device_type and ptr_param
yhmtsai Mar 23, 2023
0a42796
update rounding
yhmtsai Mar 25, 2023
e3b81df
do not use distribution with half
yhmtsai Mar 27, 2023
c9fd747
WIP fix half of failed test
yhmtsai Mar 27, 2023
684cadb
fix/skip half test and fix numeric_limit on device
yhmtsai Jun 13, 2023
60767ed
mkl csr does not support half
yhmtsai Jun 21, 2023
d65255a
add half to batch_vector
yhmtsai Sep 7, 2023
5c0454f
fix hip thrust complex op, avoid const in nvhpc, reduce job in windows
yhmtsai Sep 12, 2023
da15916
fix nvc++ atomic, dpcpp half
yhmtsai Sep 13, 2023
5f9e3ff
make half test optional
yhmtsai Sep 14, 2023
fe45560
nvhpc optimization/computation error workaround
yhmtsai Sep 15, 2023
c7f0d2a
some math func is not defined if nvhpc is for host
yhmtsai Sep 29, 2023
710e037
add half spmv benchmark (with cusparse for cuda)
yhmtsai Sep 30, 2023
34845f3
fixes batched support for half
MarcelKoch Oct 24, 2023
48afbb5
generate PTX load/stores for half
MarcelKoch Nov 3, 2023
a51f136
fix mc64 for half
MarcelKoch Dec 12, 2023
60123dc
fix hip memory.hip.hpp for half
MarcelKoch Dec 19, 2023
8f1e28f
WIP: can compile but three tests are still failed
yhmtsai Apr 20, 2024
6dbd616
fix config, ambiguous namespace, and batch
yhmtsai Jul 3, 2024
cd270e1
update format
yhmtsai Jul 3, 2024
69d5b59
check the failed tests
yhmtsai Sep 17, 2024
57fc170
fix windows and icpx
yhmtsai Sep 18, 2024
18e825f
hip does not support atomic on 16 bits
yhmtsai Sep 18, 2024
825f76f
fix batch
yhmtsai Sep 18, 2024
81d63ac
add miss instantiation
yhmtsai Sep 19, 2024
2a6d382
update documentation, remove half.hpp
yhmtsai Sep 24, 2024
8731fc3
put function in gko not std
yhmtsai Sep 24, 2024
64406f3
fix after rebase
yhmtsai Oct 2, 2024
baa95f7
hip does not support 16bit shuffle
yhmtsai Oct 3, 2024
4bb8093
merge two #if block
yhmtsai Oct 7, 2024
c539398
do not use attributes in sqrt and abs
yhmtsai Oct 8, 2024
d0e2446
make half constexpr
yhmtsai Oct 8, 2024
0d777df
isolate half out of device completely
yhmtsai Oct 16, 2024
56e2af8
bits constexpr construct half and make numeric_limit in half
yhmtsai Oct 18, 2024
6cc26d7
refine the code and fix error without half
yhmtsai Oct 21, 2024
3d15350
reduce abs/sqrt location
yhmtsai Oct 21, 2024
c4697a5
move the math function to math
yhmtsai Oct 22, 2024
377432a
nohalf
yhmtsai Oct 22, 2024
e806a0a
cbgmres without half
yhmtsai Oct 23, 2024
3e49252
direct without half
yhmtsai Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ option(GINKGO_BUILD_DOC "Generate documentation" OFF)
option(GINKGO_FAST_TESTS "Reduces the input size for a few tests known to be time-intensive" OFF)
option(GINKGO_TEST_NONDEFAULT_STREAM "Uses non-default streams in CUDA and HIP tests" OFF)
option(GINKGO_MIXED_PRECISION "Instantiate true mixed-precision kernels (otherwise they will be conversion-based using implicit temporary storage)" OFF)
option(GINKGO_ENABLE_HALF "Enable the use of half precision" ON)
# We do not support MSVC. SYCL will come later
if(MSVC OR GINKGO_BUILD_SYCL)
message(STATUS "HALF is not supported in MSVC, and later support in SYCL")
set(GINKGO_ENABLE_HALF OFF CACHE BOOL "Enable the use of half precision" FORCE)
endif()
option(GINKGO_SKIP_DEPENDENCY_UPDATE
"Do not update dependencies each time the project is rebuilt" ON)
option(GINKGO_WITH_CLANG_TIDY "Make Ginkgo call `clang-tidy` to find programming issues." OFF)
Expand Down
15 changes: 14 additions & 1 deletion accessor/cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
#include "utils.hpp"


struct __half;


Comment on lines +20 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some ABI/name mangling differences between struct and class in MSVC - are you sure this will be defined as a struct always?

Also for cleaner headers/exports, maybe we should make this conditional on CUDA compilation?

namespace gko {


class half;


namespace acc {
namespace detail {

Expand All @@ -27,6 +35,11 @@ struct cuda_type {
using type = T;
};

template <>
struct cuda_type<gko::half> {
using type = __half;
};

yhmtsai marked this conversation as resolved.
Show resolved Hide resolved
// Unpack cv and reference / pointer qualifiers
template <typename T>
struct cuda_type<const T> {
Expand Down Expand Up @@ -57,7 +70,7 @@ struct cuda_type<T&&> {
// Transform std::complex to thrust::complex
template <typename T>
struct cuda_type<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename cuda_type<T>::type>;
};


Expand Down
14 changes: 13 additions & 1 deletion accessor/hip_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
#include "utils.hpp"


struct __half;


namespace gko {


class half;


namespace acc {
namespace detail {

Expand Down Expand Up @@ -53,11 +61,15 @@ struct hip_type<T&&> {
using type = typename hip_type<T>::type&&;
};

template <>
struct hip_type<gko::half> {
using type = __half;
};
yhmtsai marked this conversation as resolved.
Show resolved Hide resolved

// Transform std::complex to thrust::complex
template <typename T>
struct hip_type<std::complex<T>> {
using type = thrust::complex<T>;
using type = thrust::complex<typename hip_type<T>::type>;
};


Expand Down
6 changes: 4 additions & 2 deletions accessor/reference_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

// CUDA TOOLKIT < 11 does not support constexpr in combination with
// thrust::complex, which is why constexpr is only present in later versions
#if defined(__CUDA_ARCH__) && defined(__CUDACC_VER_MAJOR__) && \
(__CUDACC_VER_MAJOR__ < 11)
// TODO: NVC++ constexpr
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO addressed with the constexpr PR?

#if (defined(__CUDA_ARCH__) && defined(__CUDACC_VER_MAJOR__) && \
(__CUDACC_VER_MAJOR__ < 11)) || \
(defined(__NVCOMPILER) && GINKGO_ENABLE_HALF)

#define GKO_ACC_ENABLE_REFERENCE_CONSTEXPR

Expand Down
22 changes: 14 additions & 8 deletions benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,20 @@ function(ginkgo_add_single_benchmark_executable name use_lib_linops macro_def ty
target_compile_definitions("${name}" PRIVATE "${macro_def}")
ginkgo_benchmark_add_tuning_maybe("${name}")
if("${use_lib_linops}")
if (GINKGO_BUILD_CUDA)
if(GINKGO_BUILD_CUDA)
target_compile_definitions("${name}" PRIVATE HAS_CUDA=1)
target_link_libraries("${name}" cusparse_linops_${type})
endif()
if (GINKGO_BUILD_HIP)
target_compile_definitions("${name}" PRIVATE HAS_HIP=1)
target_link_libraries("${name}" hipsparse_linops_${type})
endif()
if (GINKGO_BUILD_SYCL)
target_compile_definitions("${name}" PRIVATE HAS_DPCPP=1)
target_link_libraries("${name}" onemkl_linops_${type})
# only cuda supports half currently
if(NOT ("${type}" STREQUAL "h"))
if (GINKGO_BUILD_HIP)
target_compile_definitions("${name}" PRIVATE HAS_HIP=1)
target_link_libraries("${name}" hipsparse_linops_${type})
endif()
if (GINKGO_BUILD_SYCL)
target_compile_definitions("${name}" PRIVATE HAS_DPCPP=1)
target_link_libraries("${name}" onemkl_linops_${type})
endif()
endif()
endif()
endfunction(ginkgo_add_single_benchmark_executable)
Expand Down Expand Up @@ -117,6 +120,9 @@ if (GINKGO_BUILD_CUDA)
ginkgo_benchmark_cusparse_linops(s GKO_BENCHMARK_USE_SINGLE_PRECISION)
ginkgo_benchmark_cusparse_linops(z GKO_BENCHMARK_USE_DOUBLE_COMPLEX_PRECISION)
ginkgo_benchmark_cusparse_linops(c GKO_BENCHMARK_USE_SINGLE_COMPLEX_PRECISION)
if (GINKGO_ENABLE_HALF)
ginkgo_benchmark_cusparse_linops(h GKO_BENCHMARK_USE_HALF_PRECISION)
endif()
add_library(cuda_timer utils/cuda_timer.cpp)
target_link_libraries(cuda_timer PRIVATE ginkgo CUDA::cudart)
ginkgo_compile_features(cuda_timer)
Expand Down
17 changes: 13 additions & 4 deletions benchmark/run_all_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ elif [ "${BENCHMARK_PRECISION}" == "dcomplex" ]; then
BENCH_SUFFIX="_dcomplex"
elif [ "${BENCHMARK_PRECISION}" == "scomplex" ]; then
BENCH_SUFFIX="_scomplex"
elif [ "${BENCHMARK_PRECISION}" == "half" ]; then
BENCH_SUFFIX="_half"
else
echo "BENCHMARK_PRECISION is set to the not supported \"${BENCHMARK_PRECISION}\"." 1>&2
echo "Currently supported values: \"double\", \"single\", \"dcomplex\" and \"scomplex\"" 1>&2
echo "Currently supported values: \"double\", \"single\", \"half\", \"dcomplex\" and \"scomplex\"" 1>&2
exit 1
fi

Expand Down Expand Up @@ -216,9 +218,16 @@ keep_latest() {
compute_matrix_statistics() {
[ "${DRY_RUN}" == "true" ] && return
cp "$1" "$1.imd" # make sure we're not loosing the original input
./matrix_statistics/matrix_statistics${BENCH_SUFFIX} \
--backup="$1.bkp" --double_buffer="$1.bkp2" \
<"$1.imd" 2>&1 >"$1"
if [ "${BENCH_SUFFIX}" == "_half" ]; then
# half precision benchmark still uses single for statistics
./matrix_statistics/matrix_statistics_single \
--backup="$1.bkp" --double_buffer="$1.bkp2" \
<"$1.imd" 2>&1 >"$1"
else
./matrix_statistics/matrix_statistics${BENCH_SUFFIX} \
--backup="$1.bkp" --double_buffer="$1.bkp2" \
<"$1.imd" 2>&1 >"$1"
fi
keep_latest "$1" "$1.bkp" "$1.bkp2" "$1.imd"
}

Expand Down
5 changes: 5 additions & 0 deletions benchmark/spmv/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
ginkgo_add_typed_benchmark_executables(spmv "YES" spmv.cpp)
# TODO: move to all benchmark
if (GINKGO_ENABLE_HALF)
ginkgo_add_single_benchmark_executable(
"spmv_half" "YES" "GKO_BENCHMARK_USE_HALF_PRECISION" "h" spmv.cpp)
endif()
Comment on lines +2 to +6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't you move it to benchmark/CMakeLists.txt? Is this still a TODO?

if(GINKGO_BUILD_MPI)
add_subdirectory(distributed)
endif()
4 changes: 3 additions & 1 deletion benchmark/spmv/spmv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ struct SpmvBenchmark : Benchmark<spmv_benchmark_state<Generator>> {
exec->synchronize();
auto max_relative_norm2 =
compute_max_relative_norm2(x_clone.get(), state.answer.get());
format_case["max_relative_norm2"] = max_relative_norm2;
format_case["max_relative_norm2"] =
static_cast<typename gko::detail::arth_type<rc_etype>::type>(
max_relative_norm2);
}

IterationControl ic{timer};
Expand Down
40 changes: 25 additions & 15 deletions benchmark/utils/cuda_linops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,14 +527,19 @@ class CusparseHybrid
((CUDA_VERSION >= 10020) && !(defined(_WIN32) || defined(__CYGWIN__)))


// cuSPARSE does not support 16 bit compute for full 16 bit floating point
// input. Also, the scalar must be the compute type, i.e. float.
template <typename ValueType>
void cusparse_generic_spmv(std::shared_ptr<const gko::CudaExecutor> gpu_exec,
const cusparseSpMatDescr_t mat,
const gko::array<ValueType>& scalars,
const gko::LinOp* b, gko::LinOp* x,
cusparseOperation_t trans, cusparseSpMVAlg_t alg)
void cusparse_generic_spmv(
std::shared_ptr<const gko::CudaExecutor> gpu_exec,
const cusparseSpMatDescr_t mat,
const gko::array<typename gko::detail::arth_type<ValueType>::type>& scalars,
const gko::LinOp* b, gko::LinOp* x, cusparseOperation_t trans,
cusparseSpMVAlg_t alg)
{
cudaDataType_t cu_value = gko::kernels::cuda::cuda_data_type<ValueType>();
cudaDataType_t compute_value = gko::kernels::cuda::cuda_data_type<
typename gko::detail::arth_type<ValueType>::type>();
using gko::kernels::cuda::as_culibs_type;
auto dense_b = gko::as<gko::matrix::Dense<ValueType>>(b);
auto dense_x = gko::as<gko::matrix::Dense<ValueType>>(x);
Expand All @@ -553,13 +558,14 @@ void cusparse_generic_spmv(std::shared_ptr<const gko::CudaExecutor> gpu_exec,
gko::size_type buffer_size = 0;
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpMV_bufferSize(
gpu_exec->get_sparselib_handle(), trans, &scalars.get_const_data()[0],
mat, vecb, &scalars.get_const_data()[1], vecx, cu_value, alg,
mat, vecb, &scalars.get_const_data()[1], vecx, compute_value, alg,
&buffer_size));
gko::array<char> buffer_array(gpu_exec, buffer_size);
auto dbuffer = buffer_array.get_data();
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseSpMV(
gpu_exec->get_sparselib_handle(), trans, &scalars.get_const_data()[0],
mat, vecb, &scalars.get_const_data()[1], vecx, cu_value, alg, dbuffer));
mat, vecb, &scalars.get_const_data()[1], vecx, compute_value, alg,
dbuffer));
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyDnVec(vecx));
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseDestroyDnVec(vecb));
}
Expand Down Expand Up @@ -638,8 +644,8 @@ class CusparseGenericCsr
protected:
void apply_impl(const gko::LinOp* b, gko::LinOp* x) const override
{
cusparse_generic_spmv(this->get_gpu_exec(), mat_, scalars, b, x, trans_,
Alg);
cusparse_generic_spmv<ValueType>(this->get_gpu_exec(), mat_, scalars, b,
x, trans_, Alg);
}

void apply_impl(const gko::LinOp* alpha, const gko::LinOp* b,
Expand All @@ -655,9 +661,11 @@ class CusparseGenericCsr
{}

private:
using compute_type = typename gko::detail::arth_type<ValueType>::type;
// Contains {alpha, beta}
gko::array<ValueType> scalars{
this->get_executor(), {gko::one<ValueType>(), gko::zero<ValueType>()}};
gko::array<compute_type> scalars{
this->get_executor(),
{gko::one<compute_type>(), gko::zero<compute_type>()}};
std::shared_ptr<csr> csr_;
cusparseOperation_t trans_;
cusparseSpMatDescr_t mat_;
Expand Down Expand Up @@ -730,8 +738,8 @@ class CusparseGenericCoo
protected:
void apply_impl(const gko::LinOp* b, gko::LinOp* x) const override
{
cusparse_generic_spmv(this->get_gpu_exec(), mat_, scalars, b, x, trans_,
default_csr_alg);
cusparse_generic_spmv<ValueType>(this->get_gpu_exec(), mat_, scalars, b,
x, trans_, default_csr_alg);
}

void apply_impl(const gko::LinOp* alpha, const gko::LinOp* b,
Expand All @@ -746,9 +754,11 @@ class CusparseGenericCoo
{}

private:
using compute_type = typename gko::detail::arth_type<ValueType>::type;
// Contains {alpha, beta}
gko::array<ValueType> scalars{
this->get_executor(), {gko::one<ValueType>(), gko::zero<ValueType>()}};
gko::array<compute_type> scalars{
this->get_executor(),
{gko::one<compute_type>(), gko::zero<compute_type>()}};
std::shared_ptr<coo> coo_;
cusparseOperation_t trans_;
cusparseSpMatDescr_t mat_;
Expand Down
5 changes: 1 addition & 4 deletions benchmark/utils/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ struct DefaultSystemGenerator {
{
auto res = Vec::create(exec);
res->read(gko::matrix_data<ValueType, itype>(
size,
std::uniform_real_distribution<gko::remove_complex<ValueType>>(-1.0,
1.0),
get_engine()));
size, std::uniform_real_distribution<>(-1.0, 1.0), get_engine()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will potentially give us conversion warnings in the future, but I'm generally fine with it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is handled in the get_rand_value,

return res;
}

Expand Down
36 changes: 35 additions & 1 deletion benchmark/utils/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ using itype = gko::int32;
#if defined(GKO_BENCHMARK_USE_DOUBLE_PRECISION) || \
defined(GKO_BENCHMARK_USE_SINGLE_PRECISION) || \
defined(GKO_BENCHMARK_USE_DOUBLE_COMPLEX_PRECISION) || \
defined(GKO_BENCHMARK_USE_SINGLE_COMPLEX_PRECISION)
defined(GKO_BENCHMARK_USE_SINGLE_COMPLEX_PRECISION) || \
defined(GKO_BENCHMARK_USE_HALF_PRECISION)
// separate ifdefs to catch duplicate definitions
#ifdef GKO_BENCHMARK_USE_DOUBLE_PRECISION
using etype = double;
Expand All @@ -31,11 +32,44 @@ using etype = std::complex<double>;
#ifdef GKO_BENCHMARK_USE_SINGLE_COMPLEX_PRECISION
using etype = std::complex<float>;
#endif
#ifdef GKO_BENCHMARK_USE_HALF_PRECISION
#include <ginkgo/core/base/half.hpp>
using etype = gko::half;
#endif
#else // default to double precision
using etype = double;
#endif

using rc_etype = gko::remove_complex<etype>;


namespace detail {


// singly linked list of all our supported precisions
template <typename T>
struct next_precision_impl {};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably also be a good time to implement prev_precision?


template <>
struct next_precision_impl<float> {
using type = double;
};

template <>
struct next_precision_impl<double> {
using type = float;
};


template <typename T>
struct next_precision_impl<std::complex<T>> {
using type = std::complex<typename next_precision_impl<T>::type>;
};


} // namespace detail

template <typename T>
using next_precision = typename detail::next_precision_impl<T>::type;
MarcelKoch marked this conversation as resolved.
Show resolved Hide resolved

#endif // GKO_BENCHMARK_UTILS_TYPES_HPP_
3 changes: 3 additions & 0 deletions cmake/get_info.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,14 @@ if(TARGET hwloc)
ginkgo_print_variable(${detailed_log} "HWLOC_LIBRARIES")
ginkgo_print_variable(${detailed_log} "HWLOC_INCLUDE_DIRS")
endif()
ginkgo_print_variable(${minimal_log} "GINKGO_ENABLE_HALF")
ginkgo_print_variable(${detailed_log} "GINKGO_ENABLE_HALF")
ginkgo_print_module_footer(${detailed_log} "")

ginkgo_print_generic_header(${detailed_log} " Extensions:")
ginkgo_print_variable(${detailed_log} "GINKGO_EXTENSION_KOKKOS_CHECK_TYPE_ALIGNMENT")


_minimal(
"
--\n-- Detailed information (More compiler flags, module configuration) can be found in detailed.log
Expand Down
10 changes: 10 additions & 0 deletions common/cuda_hip/base/device_matrix_data_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <thrust/sort.h>
#include <thrust/tuple.h>

#include "common/cuda_hip/base/math.hpp"
#include "common/cuda_hip/base/thrust.hpp"
#include "common/cuda_hip/base/types.hpp"

Expand All @@ -22,6 +23,15 @@ namespace GKO_DEVICE_NAMESPACE {
namespace components {


// __half != only in __device__
// Although gko::is_nonzero is constexpr, it still shows calling __device__ in
// __host__
template <typename T>
GKO_INLINE __device__ constexpr bool is_nonzero(T value)
{
return value != zero<T>();
}

template <typename ValueType, typename IndexType>
void remove_zeros(std::shared_ptr<const DefaultExecutor> exec,
array<ValueType>& values, array<IndexType>& row_idxs,
Expand Down
Loading
Loading