Skip to content

Commit

Permalink
refine the code and fix error without half
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 21, 2024
1 parent 5b687cf commit af5b228
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 87 deletions.
17 changes: 6 additions & 11 deletions benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,12 @@ function(ginkgo_add_single_benchmark_executable name use_lib_linops macro_def ty
target_compile_definitions("${name}" PRIVATE "${macro_def}")
ginkgo_benchmark_add_tuning_maybe("${name}")
if("${use_lib_linops}")
if ("${type}" STREQUAL "h")
# only cuda supports half currently
if (GINKGO_BUILD_CUDA)
target_compile_definitions("${name}" PRIVATE HAS_CUDA=1)
target_link_libraries("${name}" cusparse_linops_${type})
endif()
else()
if (GINKGO_BUILD_CUDA)
target_compile_definitions("${name}" PRIVATE HAS_CUDA=1)
target_link_libraries("${name}" cusparse_linops_${type})
endif()
if(GINKGO_BUILD_CUDA)
target_compile_definitions("${name}" PRIVATE HAS_CUDA=1)
target_link_libraries("${name}" cusparse_linops_${type})
endif()
# only cuda supports half currently
if(NOT ("${type}" STREQUAL "h"))
if (GINKGO_BUILD_HIP)
target_compile_definitions("${name}" PRIVATE HAS_HIP=1)
target_link_libraries("${name}" hipsparse_linops_${type})
Expand Down
22 changes: 13 additions & 9 deletions common/cuda_hip/base/device_matrix_data_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ namespace GKO_DEVICE_NAMESPACE {
namespace components {


// __half != only in __device__
// Although gko::is_nonzero is constexpr, it still shows calling __device__ in
// __host__
template <typename T>
GKO_INLINE __device__ constexpr bool is_nonzero(T value)
{
return value != zero<T>();
}

template <typename ValueType, typename IndexType>
void remove_zeros(std::shared_ptr<const DefaultExecutor> exec,
array<ValueType>& values, array<IndexType>& row_idxs,
Expand All @@ -31,13 +40,9 @@ void remove_zeros(std::shared_ptr<const DefaultExecutor> exec,
auto value_ptr = as_device_type(values.get_const_data());
auto size = values.get_size();
// count nonzeros
// __half != is only device, can not call __device__ from a __host__
// __device__ (is_nonzero)
auto nnz =
thrust::count_if(thrust_policy(exec), value_ptr, value_ptr + size,
[] __device__(device_value_type value) {
return value != zero(value);
});
auto nnz = thrust::count_if(
thrust_policy(exec), value_ptr, value_ptr + size,
[] __device__(device_value_type value) { return is_nonzero(value); });
if (nnz < size) {
using tuple_type =
thrust::tuple<IndexType, IndexType, device_value_type>;
Expand All @@ -53,8 +58,7 @@ void remove_zeros(std::shared_ptr<const DefaultExecutor> exec,
as_device_type(new_values.get_data())));
thrust::copy_if(thrust_policy(exec), it, it + size, out_it,
[] __device__(tuple_type entry) {
return thrust::get<2>(entry) !=
zero(thrust::get<2>(entry));
return is_nonzero(thrust::get<2>(entry));
});
// swap out storage
values = std::move(new_values);
Expand Down
2 changes: 1 addition & 1 deletion core/matrix/permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ void dispatch_dense(const LinOp* op, Functor fn)
using matrix::Dense;
using std::complex;
run<Dense,
#ifdef GINKGO_ENABLE_HALF
#if GINKGO_ENABLE_HALF
gko::half, std::complex<gko::half>,
#endif
double, float, std::complex<double>, std::complex<float>>(op, fn);
Expand Down
3 changes: 2 additions & 1 deletion core/multigrid/pgm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ Pgm<ValueType, IndexType>::generate_local(
auto abs_mtx = local_matrix->compute_absolute();
// abs_mtx is already real valuetype, so transpose is enough
auto weight_mtx = gko::as<weight_csr_type>(abs_mtx->transpose());
auto half_scalar = initialize<matrix::Dense<real_type>>({half(0.5)}, exec);
auto half_scalar =
initialize<matrix::Dense<real_type>>({real_type{0.5}}, exec);
auto identity = matrix::Identity<real_type>::create(exec, num_rows);
// W = (abs_mtx + transpose(abs_mtx))/2
abs_mtx->apply(half_scalar, identity, half_scalar, weight_mtx);
Expand Down
14 changes: 5 additions & 9 deletions core/reorder/mc64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ void initialize_weights(const matrix::Csr<ValueType, IndexType>* host_mtx,
array<remove_complex<ValueType>>& row_maxima_array,
gko::experimental::reorder::mc64_strategy strategy)
{
auto inf = static_cast<remove_complex<ValueType>>(
std::numeric_limits<remove_complex<ValueType>>::infinity());
const auto inf = std::numeric_limits<remove_complex<ValueType>>::infinity();
const auto num_rows = host_mtx->get_size()[0];
const auto row_ptrs = host_mtx->get_const_row_ptrs();
const auto col_idxs = host_mtx->get_const_col_idxs();
Expand All @@ -50,7 +49,7 @@ void initialize_weights(const matrix::Csr<ValueType, IndexType>* host_mtx,
for (IndexType row = 0; row < num_rows; row++) {
const auto row_begin = row_ptrs[row];
const auto row_end = row_ptrs[row + 1];
auto row_max = static_cast<remove_complex<ValueType>>(-inf);
auto row_max = -inf;
for (IndexType idx = row_begin; idx < row_end; idx++) {
const auto weight = calculate_weight(values[idx]);
weights[idx] = weight;
Expand Down Expand Up @@ -181,8 +180,7 @@ void shortest_augmenting_path(
addressable_priority_queue<ValueType, IndexType>& queue,
std::vector<IndexType>& q_j, ValueType tolerance)
{
auto inf =
static_cast<ValueType>(std::numeric_limits<ValueType>::infinity());
const auto inf = std::numeric_limits<ValueType>::infinity();
auto weights = weights_array.get_data();
auto dual_u = dual_u_array.get_data();
auto distance = distance_array.get_data();
Expand Down Expand Up @@ -436,8 +434,7 @@ void compute_scaling(const matrix::Csr<ValueType, IndexType>* host_mtx,
mc64_strategy strategy, ValueType* row_scaling,
ValueType* col_scaling)
{
auto inf = static_cast<remove_complex<ValueType>>(
std::numeric_limits<remove_complex<ValueType>>::infinity());
const auto inf = std::numeric_limits<remove_complex<ValueType>>::infinity();
const auto num_rows = host_mtx->get_size()[0];
const auto weights = weights_array.get_const_data();
const auto dual_u = dual_u_array.get_const_data();
Expand Down Expand Up @@ -541,8 +538,7 @@ std::unique_ptr<LinOp> Mc64<ValueType, IndexType>::generate_impl(
marked_cols.fill(0);
matched_idxs.fill(0);
unmatched_rows.fill(0);
auto inf = static_cast<remove_complex<ValueType>>(
std::numeric_limits<remove_complex<ValueType>>::infinity());
const auto inf = std::numeric_limits<remove_complex<ValueType>>::infinity();
dual_u.fill(inf);
distance.fill(inf);

Expand Down
12 changes: 4 additions & 8 deletions core/test/base/extended_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ TEST_F(FloatToHalf, ConvertsNan)
{
half x = create_from_bits("0" "11111111" "00000000000000000000001");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
#if defined(SYCL_LANGUAGE_VERSION)
// Sycl put the 1000000000, but ours put mask
ASSERT_EQ(get_bits(x), get_bits("0" "11111" "1000000000"));
#else
Expand All @@ -126,8 +125,7 @@ TEST_F(FloatToHalf, ConvertsNegNan)
{
half x = create_from_bits("1" "11111111" "00010000000000000000000");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
#if defined(SYCL_LANGUAGE_VERSION)
// Sycl put the 1000000000, but ours put mask
ASSERT_EQ(get_bits(x), get_bits("1" "11111" "1000000000"));
#else
Expand Down Expand Up @@ -254,8 +252,7 @@ TEST_F(HalfToFloat, ConvertsNan)
{
float x = create_from_bits("0" "11111" "0001001000");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
#if defined(SYCL_LANGUAGE_VERSION)
// sycl keeps significand
ASSERT_EQ(get_bits(x), get_bits("0" "11111111" "00010010000000000000000"));
#else
Expand All @@ -268,8 +265,7 @@ TEST_F(HalfToFloat, ConvertsNegNan)
{
float x = create_from_bits("1" "11111" "0000000001");

#if defined(SYCL_LANGUAGE_VERSION) && \
(__LIBSYCL_MAJOR_VERSION > 5 || (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
#if defined(SYCL_LANGUAGE_VERSION)
// sycl keeps significand
ASSERT_EQ(get_bits(x), get_bits("1" "11111111" "00000000010000000000000"));
#else
Expand Down
59 changes: 27 additions & 32 deletions include/ginkgo/core/base/exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@


namespace gko {


/**
* The Error class is used to report exceptional behaviour in library
* functions. Ginkgo uses C++ exception mechanism to this end, and the
* Error class represents a base class for all types of errors. The exact
* list of errors which could occur during the execution of a certain
* library routine is provided in the documentation of that routine, along
* with a short description of the situation when that error can occur.
* During runtime, these errors can be detected by using standard C++
* try-catch blocks, and a human-readable error description can be obtained
* by calling the Error::what() method.
* Error class represents a base class for all types of errors. The exact list
* of errors which could occur during the execution of a certain library
* routine is provided in the documentation of that routine, along with a short
* description of the situation when that error can occur.
* During runtime, these errors can be detected by using standard C++ try-catch
* blocks, and a human-readable error description can be obtained by calling
* the Error::what() method.
*
* As an example, trying to compute a matrix-vector product with arguments
* of incompatible size will result in a DimensionMismatch error, which is
* As an example, trying to compute a matrix-vector product with arguments of
* incompatible size will result in a DimensionMismatch error, which is
* demonstrated in the following program.
*
* ```cpp
Expand Down Expand Up @@ -66,8 +68,8 @@ class Error : public std::exception {
{}

/**
* Returns a human-readable string with a more detailed description of
* the error.
* Returns a human-readable string with a more detailed description of the
* error.
*/
virtual const char* what() const noexcept override { return what_.c_str(); }

Expand Down Expand Up @@ -96,8 +98,8 @@ class NotImplemented : public Error {


/**
* NotCompiled is thrown when attempting to call an operation which is a
* part of a module that was not compiled on the system.
* NotCompiled is thrown when attempting to call an operation which is a part of
* a module that was not compiled on the system.
*/
class NotCompiled : public Error {
public:
Expand Down Expand Up @@ -234,8 +236,7 @@ class CurandError : public Error {


/**
* CusparseError is thrown when a cuSPARSE routine throws a non-zero error
* code.
* CusparseError is thrown when a cuSPARSE routine throws a non-zero error code.
*/
class CusparseError : public Error {
public:
Expand Down Expand Up @@ -304,8 +305,7 @@ class HipError : public Error {


/**
* HipblasError is thrown when a hipBLAS routine throws a non-zero error
* code.
* HipblasError is thrown when a hipBLAS routine throws a non-zero error code.
*/
class HipblasError : public Error {
public:
Expand All @@ -328,8 +328,7 @@ class HipblasError : public Error {


/**
* HiprandError is thrown when a hipRAND routine throws a non-zero error
* code.
* HiprandError is thrown when a hipRAND routine throws a non-zero error code.
*/
class HiprandError : public Error {
public:
Expand Down Expand Up @@ -436,8 +435,7 @@ class DimensionMismatch : public Error {
* @param second_name The name of the second operator
* @param second_rows The output dimension of the second operator
* @param second_cols The input dimension of the second operator
* @param clarification An additional message describing the error
* further
* @param clarification An additional message describing the error further
*/
DimensionMismatch(const std::string& file, int line,
const std::string& func, const std::string& first_name,
Expand Down Expand Up @@ -469,8 +467,7 @@ class BadDimension : public Error {
* @param op_name The name of the operator
* @param op_num_rows The row dimension of the operator
* @param op_num_cols The column dimension of the operator
* @param clarification An additional message further describing the
* error
* @param clarification An additional message further describing the error
*/
BadDimension(const std::string& file, int line, const std::string& func,
const std::string& op_name, size_type op_num_rows,
Expand All @@ -486,8 +483,8 @@ class BadDimension : public Error {
/**
* Error that denotes issues between block sizes and matrix dimensions
*
* \tparam IndexType Type of index used by the linear algebra object that
* is incompatible with the required block size.
* \tparam IndexType Type of index used by the linear algebra object that is
* incompatible with the required block size.
*/
template <typename IndexType>
class BlockSizeError : public Error {
Expand Down Expand Up @@ -520,8 +517,7 @@ class ValueMismatch : public Error {
* @param func The function name where the error occurred
* @param val1 The first value to be compared.
* @param val2 The second value to be compared.
* @param clarification An additional message further describing the
* error
* @param clarification An additional message further describing the error
*/
ValueMismatch(const std::string& file, int line, const std::string& func,
size_type val1, size_type val2,
Expand Down Expand Up @@ -580,9 +576,8 @@ class OutOfBoundsError : public Error {


/**
* OverflowError is thrown when an index calculation for storage
* requirements overflows. This most likely means that the index type is too
* small.
* OverflowError is thrown when an index calculation for storage requirements
* overflows. This most likely means that the index type is too small.
*/
class OverflowError : public Error {
public:
Expand Down Expand Up @@ -619,8 +614,8 @@ class StreamError : public Error {


/**
* KernelNotFound is thrown if Ginkgo cannot find a kernel which satisfies
* the criteria imposed by the input arguments.
* KernelNotFound is thrown if Ginkgo cannot find a kernel which satisfies the
* criteria imposed by the input arguments.
*/
class KernelNotFound : public Error {
public:
Expand Down
3 changes: 3 additions & 0 deletions include/ginkgo/core/base/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,15 @@ class complex<gko::half> {
imag_ += val.imag();
return *this;
}

template <typename T>
complex& operator-=(const complex<T>& val)
{
real_ -= val.real();
imag_ -= val.imag();
return *this;
}

template <typename T>
complex& operator*=(const complex<T>& val)
{
Expand All @@ -551,6 +553,7 @@ class complex<gko::half> {
imag_ = result_f.imag();
return *this;
}

template <typename T>
complex& operator/=(const complex<T>& val)
{
Expand Down
7 changes: 7 additions & 0 deletions include/ginkgo/core/base/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,13 @@ GKO_INLINE constexpr T one()
return T(1);
}

template <>
GKO_INLINE constexpr half one<half>()
{
constexpr auto bits = static_cast<uint16>(0b0'01111'0000000000u);
return half::create_from_bits(bits);
}


/**
* Returns the multiplicative identity for T.
Expand Down
6 changes: 4 additions & 2 deletions include/ginkgo/core/base/matrix_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ template <typename ValueType, typename Distribution, typename Generator>
typename std::enable_if<!is_complex_s<ValueType>::value, ValueType>::type
get_rand_value(Distribution&& dist, Generator&& gen)
{
return ValueType(dist(gen));
return static_cast<ValueType>(dist(gen));
}


template <typename ValueType, typename Distribution, typename Generator>
typename std::enable_if<is_complex_s<ValueType>::value, ValueType>::type
get_rand_value(Distribution&& dist, Generator&& gen)
{
return ValueType(dist(gen), dist(gen));
using real_value_type = remove_complex<ValueType>;
return ValueType{static_cast<real_value_type>(dist(gen)),
static_cast<real_value_type>(dist(gen))};
}


Expand Down
7 changes: 1 addition & 6 deletions include/ginkgo/core/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,8 @@ using uint64 = std::uint64_t;
*/
using uintptr = std::uintptr_t;

// #if defined(SYCL_LANGUAGE_VERSION) && \
// (__LIBSYCL_MAJOR_VERSION > 5 || \
// (__LIBSYCL_MAJOR_VERSION == 5 && __LIBSYCL_MINOR_VERSION >= 7))
// using half = sycl::half;
// #else

class half;
// #endif


/**
Expand Down
Loading

0 comments on commit af5b228

Please sign in to comment.