From 502df54115da99d205e807d7dcaf2003d3813f45 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 23 Oct 2024 18:02:36 +0200 Subject: [PATCH 01/16] instantiation/testing/next/prev/stub type definition --- core/base/mixed_precision_types.hpp | 151 +++++++++++++++++++++++ core/device_hooks/common_kernels.inc.cpp | 63 +++++++++- core/test/utils.hpp | 48 ++++++- include/ginkgo/core/base/math.hpp | 45 +++++++ include/ginkgo/core/base/types.hpp | 116 +++++++++++++++++ 5 files changed, 418 insertions(+), 5 deletions(-) diff --git a/core/base/mixed_precision_types.hpp b/core/base/mixed_precision_types.hpp index d9747e5cad8..5ef5de94e34 100644 --- a/core/base/mixed_precision_types.hpp +++ b/core/base/mixed_precision_types.hpp @@ -7,23 +7,44 @@ #include +#include #include #ifdef GINKGO_MIXED_PRECISION + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(_macro, ...) \ template _macro(float, float, float, __VA_ARGS__); \ template _macro(float, float, double, __VA_ARGS__); \ template _macro(float, double, float, __VA_ARGS__); \ template _macro(float, double, double, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(float, half, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(float, half, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(float, half, double, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(float, float, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(float, double, half, __VA_ARGS__)) + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(_macro, ...) \ template _macro(double, float, float, __VA_ARGS__); \ template _macro(double, float, double, __VA_ARGS__); \ template _macro(double, double, float, __VA_ARGS__); \ template _macro(double, double, double, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(double, half, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(double, half, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(double, half, double, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(double, float, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(double, double, half, __VA_ARGS__)) + + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(_macro, ...) \ template _macro(std::complex, std::complex, \ std::complex, __VA_ARGS__); \ @@ -33,6 +54,19 @@ std::complex, __VA_ARGS__); \ template _macro(std::complex, std::complex, \ std::complex, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)) #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(_macro, ...) \ template _macro(std::complex, std::complex, \ @@ -44,22 +78,95 @@ template _macro(std::complex, std::complex, \ std::complex, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(_macro, \ + ...) \ + GKO_ADAPT_HF(template _macro(half, half, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, half, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, half, double, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, float, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, float, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, float, double, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, double, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, double, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, double, double, __VA_ARGS__)) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(_macro, \ + ...) \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)) + #else #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(_macro, ...) \ template _macro(float, float, float, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(_macro, __VA_ARGS__) + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(_macro, ...) \ template _macro(double, double, double, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(_macro, __VA_ARGS__) + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(_macro, ...) \ template _macro(std::complex, std::complex, \ std::complex, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(_macro, __VA_ARGS__) + #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(_macro, ...) \ template _macro(std::complex, std::complex, \ std::complex, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(_macro, \ + ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(_macro, __VA_ARGS__) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(_macro, \ + ...) \ + GKO_ADAPT_HF(template _macro(half, half, half, __VA_ARGS__)) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(_macro, \ + ...) \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + std::complex, __VA_ARGS__)) + + #endif @@ -69,11 +176,27 @@ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(_macro, __VA_ARGS__); \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(_macro, __VA_ARGS__) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_WITH_HALF(_macro, ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF(_macro, \ + __VA_ARGS__); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF(_macro, \ + __VA_ARGS__); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF(_macro, \ + __VA_ARGS__); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF(_macro, \ + __VA_ARGS__); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF(_macro, \ + __VA_ARGS__); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF(_macro, \ + __VA_ARGS__) #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(_macro) \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE(_macro, int32); \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE(_macro, int64) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_WITH_HALF(_macro, int32); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_WITH_HALF(_macro, int64) #ifdef GINKGO_MIXED_PRECISION #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, ...) \ @@ -85,12 +208,36 @@ template _macro(std::complex, std::complex, __VA_ARGS__); \ template _macro(std::complex, std::complex, __VA_ARGS__); \ template _macro(std::complex, std::complex, __VA_ARGS__) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2_WITH_HALF(_macro, ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(half, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, float, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(half, double, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(float, half, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(double, half, __VA_ARGS__)); \ + GKO_ADAPT_HF( \ + template _macro(std::complex, std::complex, __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + __VA_ARGS__)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex, \ + __VA_ARGS__)) #else #define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, ...) \ template _macro(float, float, __VA_ARGS__); \ template _macro(double, double, __VA_ARGS__); \ template _macro(std::complex, std::complex, __VA_ARGS__); \ template _macro(std::complex, std::complex, __VA_ARGS__) + +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2_WITH_HALF(_macro, ...) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, __VA_ARGS__); \ + GKO_ADAPT_HF(template _macro(half, half, __VA_ARGS__)); \ + GKO_ADAPT_HF( \ + template _macro(std::complex, std::complex, __VA_ARGS__)) #endif @@ -98,5 +245,9 @@ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, int32); \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2(_macro, int64) +#define GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( \ + _macro) \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2_WITH_HALF(_macro, int32); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_2_WITH_HALF(_macro, int64) #endif // GKO_CORE_BASE_MIXED_PRECISION_TYPES_HPP_ diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 98d85b2b6d2..6ffeb1c5f71 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -79,26 +79,37 @@ #define GKO_STUB(_macro) _macro GKO_NOT_COMPILED(GKO_HOOK_MODULE) -#define GKO_STUB_VALUE_CONVERSION(_macro) \ - template \ - _macro(SourceType, TargetType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ - GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(_macro) #define GKO_STUB_NON_COMPLEX_VALUE_TYPE(_macro) \ template \ _macro(ValueType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(_macro) +#define GKO_STUB_NON_COMPLEX_VALUE_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(_macro) + #define GKO_STUB_VALUE_TYPE(_macro) \ template \ _macro(ValueType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(_macro) +#define GKO_STUB_VALUE_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(_macro) + #define GKO_STUB_VALUE_AND_SCALAR_TYPE(_macro) \ template \ _macro(ValueType, ScalarType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(_macro) +#define GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType, ScalarType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF(_macro) + #define GKO_STUB_INDEX_TYPE(_macro) \ template \ _macro(IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ @@ -114,16 +125,31 @@ _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro) +#define GKO_STUB_NON_COMPLEX_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) + #define GKO_STUB_VALUE_AND_INDEX_TYPE(_macro) \ template \ _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro) +#define GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) + #define GKO_STUB_VALUE_AND_INT32_TYPE(_macro) \ template \ _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(_macro) +#define GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(_macro) \ + template \ + _macro(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(_macro) + #define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(_macro) \ template \ @@ -131,6 +157,13 @@ GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(_macro) +#define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) \ + template \ + _macro(InputValueType, MatrixValueType, OutputValueType, IndexType) \ + GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) + #define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(_macro) \ template \ @@ -138,6 +171,13 @@ GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(_macro) +#define GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF(_macro) \ + template \ + _macro(InputValueType, OutputValueType, IndexType) \ + GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF(_macro) + #define GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(_macro) \ template \ @@ -150,16 +190,31 @@ _macro(IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(_macro) +#define GKO_STUB_TEMPLATE_TYPE_WITH_HALF(_macro) \ + template \ + _macro(IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(_macro) + #define GKO_STUB_VALUE_CONVERSION(_macro) \ template \ _macro(SourceType, TargetType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(_macro) +#define GKO_STUB_VALUE_CONVERSION_WITH_HALF(_macro) \ + template \ + _macro(SourceType, TargetType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF(_macro) + #define GKO_STUB_VALUE_CONVERSION_OR_COPY(_macro) \ template \ _macro(SourceType, TargetType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY(_macro) +#define GKO_STUB_VALUE_CONVERSION_OR_COPY_WITH_HALF(_macro) \ + template \ + _macro(SourceType, TargetType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY_WITH_HALF(_macro) + #define GKO_STUB_CB_GMRES(_macro) \ template \ _macro(ValueType, ValueTypeKrylovBases) GKO_NOT_COMPILED(GKO_HOOK_MODULE); \ diff --git a/core/test/utils.hpp b/core/test/utils.hpp index eee2900d731..ab9326400e0 100644 --- a/core/test/utils.hpp +++ b/core/test/utils.hpp @@ -15,6 +15,7 @@ #include +#include #include #include #include @@ -327,10 +328,25 @@ using RealValueTypes = ::testing::Types; #endif +using RealValueTypesWithHalf = ::testing::Types< +#if GINKGO_ENABLE_HALF + gko::half, +#endif +#if !GINKGO_DPCPP_SINGLE_MODE + double, +#endif + float>; + using ComplexValueTypes = add_inner_wrapper_t; +using ComplexValueTypesWithHalf = + add_inner_wrapper_t; + using ValueTypes = merge_type_list_t; +using ValueTypesWithHalf = + merge_type_list_t; + using IndexTypes = ::testing::Types; using IntegerTypes = merge_type_list_t>; @@ -341,22 +357,44 @@ using LocalGlobalIndexTypes = using PODTypes = merge_type_list_t; +using PODTypesWithHalf = + merge_type_list_t; + using ComplexAndPODTypes = merge_type_list_t; +using ComplexAndPODTypesWithHalf = + merge_type_list_t; + using ValueIndexTypes = cartesian_type_product_t; +using ValueIndexTypesWithHalf = + cartesian_type_product_t; + using RealValueIndexTypes = cartesian_type_product_t; +using RealValueIndexTypesWithHalf = + cartesian_type_product_t; + using ComplexValueIndexTypes = cartesian_type_product_t; +using ComplexValueIndexTypesWithHalf = + cartesian_type_product_t; + using TwoValueIndexType = add_to_cartesian_type_product_t< merge_type_list_t< cartesian_type_product_t, cartesian_type_product_t>, IndexTypes>; +using TwoValueIndexTypeWithHalf = add_to_cartesian_type_product_t< + merge_type_list_t, + cartesian_type_product_t>, + IndexTypes>; + using ValueLocalGlobalIndexTypes = add_to_cartesian_type_product_left_t; @@ -365,7 +403,6 @@ template struct reduction_factor { using nc_output = remove_complex; using nc_precision = remove_complex; - static const nc_output value; }; @@ -456,4 +493,13 @@ struct TupleTypenameNameGenerator { }; +#define SKIP_IF_HALF(type) \ + if (std::is_same, gko::half>::value) { \ + GTEST_SKIP() << "Skip due to half mode"; \ + } \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") + + #endif // GKO_CORE_TEST_UTILS_HPP_ diff --git a/include/ginkgo/core/base/math.hpp b/include/ginkgo/core/base/math.hpp index 5e15bb05d6a..73da407194e 100644 --- a/include/ginkgo/core/base/math.hpp +++ b/include/ginkgo/core/base/math.hpp @@ -383,6 +383,31 @@ struct next_precision_impl> { }; +template +struct next_precision_with_half_impl {}; + + +template <> +struct next_precision_with_half_impl { + using type = float; +}; + +template <> +struct next_precision_with_half_impl { + using type = double; +}; + +template <> +struct next_precision_with_half_impl { + using type = gko::half; +}; + +template +struct next_precision_with_half_impl> { + using type = std::complex::type>; +}; + + template struct reduce_precision_impl { using type = T; @@ -477,6 +502,26 @@ using next_precision = typename detail::next_precision_impl::type; template using previous_precision = next_precision; +/** + * Obtains the next type in the singly-linked precision list with half. + */ +#if GINKGO_ENABLE_HALF +template +using next_precision_with_half = + typename detail::next_precision_with_half_impl::type; + +template +using previous_precision_with_half = + next_precision_with_half>; +#else +// fallback to float/double list +template +using next_precision_with_half = next_precision; + +template +using previous_precision_with_half = previous_precision; +#endif + /** * Obtains the next type in the hierarchy with lower precision than T. diff --git a/include/ginkgo/core/base/types.hpp b/include/ginkgo/core/base/types.hpp index 1d5963c0fe8..5e1fb2a14e3 100644 --- a/include/ginkgo/core/base/types.hpp +++ b/include/ginkgo/core/base/types.hpp @@ -17,6 +17,7 @@ #include #include +#include #include @@ -399,6 +400,17 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, _enable_macro(CudaExecutor, cuda) +// cuda half operation is supported from arch 5.3 +#if GINKGO_ENABLE_HALF && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530) +#define GKO_ADAPT_HF(_macro) _macro +#else +#define GKO_ADAPT_HF(_macro) \ + static_assert(true, \ + "This assert is used to counter the false positive extra " \ + "semi-colon warnings") +#endif + + /** * Instantiates a template for each non-complex value type compiled by Ginkgo. * @@ -418,6 +430,10 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(double) #endif +#define GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half)); \ + GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(_macro) + /** * Instantiates a template for each value type compiled by Ginkgo. @@ -440,6 +456,11 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex) #endif +#define GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half)); \ + GKO_ADAPT_HF(template _macro(std::complex)); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(_macro) + // Helper macro to make Windows builds work // In MSVC, __VA_ARGS__ behave like one argument by default. @@ -528,6 +549,12 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, double) #endif +#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF(_macro) \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(_macro); \ + GKO_ADAPT_HF(template _macro(half, half)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template _macro(std::complex, half)) + /** * Instantiates a template for each index type compiled by Ginkgo. @@ -566,6 +593,11 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(float, int64); \ template _macro(double, int64) #endif +#define GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE_WITH_HALF( \ + _macro) \ + GKO_ADAPT_HF(template _macro(half, int32)); \ + GKO_ADAPT_HF(template _macro(half, int64)); \ + GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_INDEX_TYPE(_macro) #if GINKGO_DPCPP_SINGLE_MODE #define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(_macro) \ @@ -583,6 +615,11 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, int32) #endif +#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half, int32)); \ + GKO_ADAPT_HF(template _macro(std::complex, int32)); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(_macro) + /** * Instantiates a template for each value and index type compiled by Ginkgo. @@ -610,6 +647,13 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, int64) #endif +#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half, int32)); \ + GKO_ADAPT_HF(template _macro(half, int64)); \ + GKO_ADAPT_HF(template _macro(std::complex, int32)); \ + GKO_ADAPT_HF(template _macro(std::complex, int64)); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(_macro) + /** * Instantiates a template for each non-complex value, local and global index @@ -643,6 +687,14 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(double, int64, int64) #endif +#define GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_WITH_HALF( \ + _macro) \ + GKO_ADAPT_HF(template _macro(half, int32, int32)); \ + GKO_ADAPT_HF(template _macro(half, int32, int64)); \ + GKO_ADAPT_HF(template _macro(half, int64, int64)); \ + GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE( \ + _macro) + /** * Instantiates a template for each value and index type compiled by Ginkgo. @@ -677,6 +729,16 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, int64, int64) #endif +#define GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_WITH_HALF( \ + _macro) \ + GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(_macro); \ + GKO_ADAPT_HF(template _macro(half, int32, int32)); \ + GKO_ADAPT_HF(template _macro(half, int32, int64)); \ + GKO_ADAPT_HF(template _macro(half, int64, int64)); \ + GKO_ADAPT_HF(template _macro(std::complex, int32, int32)); \ + GKO_ADAPT_HF(template _macro(std::complex, int32, int64)); \ + GKO_ADAPT_HF(template _macro(std::complex, int64, int64)) + #if GINKGO_DPCPP_SINGLE_MODE #define GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(_macro) \ @@ -732,6 +794,40 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, std::complex) #endif +#if GINKGO_DPCPP_SINGLE_MODE +#define GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template <> _macro(half, double) GKO_NOT_IMPLEMENTED); \ + GKO_ADAPT_HF(template <> _macro(double, half) GKO_NOT_IMPLEMENTED); \ + GKO_ADAPT_HF(template _macro(float, half)); \ + GKO_ADAPT_HF(template _macro(half, float)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template <> _macro(std::complex, std::complex) \ + GKO_NOT_IMPLEMENTED); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template <> _macro(std::complex, std::complex) \ + GKO_NOT_IMPLEMENTED); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(_macro) +#else +#define GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half, double)); \ + GKO_ADAPT_HF(template _macro(double, half)); \ + GKO_ADAPT_HF(template _macro(float, half)); \ + GKO_ADAPT_HF(template _macro(half, float)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(_macro) +#endif + +#define GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY_WITH_HALF(_macro) \ + GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF(_macro); \ + GKO_ADAPT_HF(template _macro(half, half)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + template _macro(float, float); \ + template _macro(double, double); \ + template _macro(std::complex, std::complex); \ + template _macro(std::complex, std::complex) /** * Instantiates a template for each value type pair compiled by Ginkgo. @@ -749,6 +845,11 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, std::complex); \ template _macro(std::complex, std::complex) +#define GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_PAIR_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half, half)); \ + GKO_ADAPT_HF(template _macro(std::complex, half)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_PAIR(_macro) /** * Instantiates a template for each combined value and index type compiled by @@ -771,6 +872,12 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(std::complex, std::complex); \ template _macro(std::complex, std::complex) +#define GKO_INSTANTIATE_FOR_EACH_COMBINED_VALUE_AND_INDEX_TYPE_WITH_HALF( \ + _macro) \ + GKO_ADAPT_HF(template _macro(half, half)); \ + GKO_ADAPT_HF(template _macro(std::complex, std::complex)); \ + GKO_INSTANTIATE_FOR_EACH_COMBINED_VALUE_AND_INDEX_TYPE(_macro) + /** * Instantiates a template for each value and index type compiled by Ginkgo. * @@ -789,6 +896,10 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, template _macro(int32); \ template _macro(int64) +#define GKO_INSTANTIATE_FOR_EACH_POD_TYPE_WITH_HALF(_macro) \ + GKO_ADAPT_HF(template _macro(half)); \ + GKO_ADAPT_HF(template _macro(std::complex)); \ + GKO_INSTANTIATE_FOR_EACH_POD_TYPE(_macro) /** * Instantiates a template for each normal type @@ -803,6 +914,11 @@ GKO_ATTRIBUTES constexpr bool operator!=(precision_reduction x, GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(_macro); \ template _macro(gko::size_type) +#define GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(_macro) \ + GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(_macro); \ + GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(_macro); \ + template _macro(gko::size_type) + /** * Instantiates a template for int32 type. From c53ab57f1dcca8a9eb3501e07c0a08bd0b5ecc2a Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 23 Oct 2024 18:25:12 +0200 Subject: [PATCH 02/16] half option --- CMakeLists.txt | 6 ++++++ cmake/get_info.cmake | 2 +- include/ginkgo/config.hpp.in | 5 +++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cad0e1bca4..c48d12989aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,6 +32,12 @@ option(GINKGO_BUILD_DOC "Generate documentation" OFF) option(GINKGO_FAST_TESTS "Reduces the input size for a few tests known to be time-intensive" OFF) option(GINKGO_TEST_NONDEFAULT_STREAM "Uses non-default streams in CUDA and HIP tests" OFF) option(GINKGO_MIXED_PRECISION "Instantiate true mixed-precision kernels (otherwise they will be conversion-based using implicit temporary storage)" OFF) +option(GINKGO_ENABLE_HALF "Enable the use of half precision" ON) +# We do not support MSVC. SYCL will come later +if(MSVC OR GINKGO_BUILD_SYCL) + message(STATUS "HALF is not supported in MSVC, and later support in SYCL") + set(GINKGO_ENABLE_HALF OFF CACHE BOOL "Enable the use of half precision" FORCE) +endif() option(GINKGO_SKIP_DEPENDENCY_UPDATE "Do not update dependencies each time the project is rebuilt" ON) option(GINKGO_WITH_CLANG_TIDY "Make Ginkgo call `clang-tidy` to find programming issues." OFF) diff --git a/cmake/get_info.cmake b/cmake/get_info.cmake index 63f43c645f0..57816ab8008 100644 --- a/cmake/get_info.cmake +++ b/cmake/get_info.cmake @@ -130,7 +130,7 @@ foreach(log_type ${log_types}) "GINKGO_BUILD_OMP;GINKGO_BUILD_MPI;GINKGO_BUILD_REFERENCE;GINKGO_BUILD_CUDA;GINKGO_BUILD_HIP;GINKGO_BUILD_SYCL") ginkgo_print_module_footer(${${log_type}} " Enabled features:") ginkgo_print_foreach_variable(${${log_type}} - "GINKGO_MIXED_PRECISION;GINKGO_HAVE_GPU_AWARE_MPI") + "GINKGO_MIXED_PRECISION;GINKGO_HAVE_GPU_AWARE_MPI;GINKGO_ENABLE_HALF") ginkgo_print_module_footer(${${log_type}} " Tests, benchmarks and examples:") ginkgo_print_foreach_variable(${${log_type}} "GINKGO_BUILD_TESTS;GINKGO_FAST_TESTS;GINKGO_BUILD_EXAMPLES;GINKGO_EXTLIB_EXAMPLE;GINKGO_BUILD_BENCHMARKS;GINKGO_BENCHMARK_ENABLE_TUNING") diff --git a/include/ginkgo/config.hpp.in b/include/ginkgo/config.hpp.in index 1dfa6bc61bc..cf25dcd3c77 100644 --- a/include/ginkgo/config.hpp.in +++ b/include/ginkgo/config.hpp.in @@ -105,6 +105,11 @@ #define GKO_HAVE_HWLOC @GINKGO_HAVE_HWLOC@ // clang-format on +/* Is half operation available ? */ +// clang-format off +#cmakedefine01 GINKGO_ENABLE_HALF +// clang-format on + /* Do we need to use blocking communication in our SpMV? */ // clang-format off From 6f14b13cb625a6677698fbbadd202115264e2ddc Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 23 Oct 2024 22:39:38 +0200 Subject: [PATCH 03/16] device type mapping --- common/cuda_hip/base/math.hpp | 124 +++++++++++++++++++++++++++++---- common/cuda_hip/base/types.hpp | 14 ++++ cuda/base/types.hpp | 1 - hip/base/types.hip.hpp | 1 - 4 files changed, 126 insertions(+), 14 deletions(-) diff --git a/common/cuda_hip/base/math.hpp b/common/cuda_hip/base/math.hpp index 8c655174524..7f0391d904c 100644 --- a/common/cuda_hip/base/math.hpp +++ b/common/cuda_hip/base/math.hpp @@ -11,6 +11,21 @@ #include +#ifdef GKO_COMPILING_CUDA + + +#include + + +#elif defined(GKO_COMPILING_HIP) + + +#include + + +#endif + + namespace gko { @@ -18,9 +33,35 @@ namespace gko { // __device__ function (even though it is constexpr) template struct device_numeric_limits { - static constexpr auto inf = std::numeric_limits::infinity(); - static constexpr auto max = std::numeric_limits::max(); - static constexpr auto min = std::numeric_limits::min(); + static constexpr auto inf() { return std::numeric_limits::infinity(); } + static constexpr auto max() { return std::numeric_limits::max(); } + static constexpr auto min() { return std::numeric_limits::min(); } +}; + +template <> +struct device_numeric_limits<__half> { + // from __half documentation, it accepts unsigned short + // __half and __half_raw does not have constexpr constructor + static GKO_ATTRIBUTES GKO_INLINE auto inf() + { + __half_raw bits; + bits.x = static_cast(0b0111110000000000u); + return __half{bits}; + } + + static GKO_ATTRIBUTES GKO_INLINE auto max() + { + __half_raw bits; + bits.x = static_cast(0b0111101111111111u); + return __half{bits}; + } + + static GKO_ATTRIBUTES GKO_INLINE auto min() + { + __half_raw bits; + bits.x = static_cast(0b0000010000000000u); + return __half{bits}; + } }; @@ -33,15 +74,6 @@ struct remove_complex_impl> { }; -template -struct is_complex_impl> - : public std::integral_constant {}; - - -template -struct is_complex_or_scalar_impl> : std::is_scalar {}; - - template struct truncate_type_impl> { using type = thrust::complex::type>; @@ -52,4 +84,72 @@ struct truncate_type_impl> { } // namespace gko +namespace thrust { + + +template <> +GKO_ATTRIBUTES GKO_INLINE complex<__half> sqrt<__half>(const complex<__half>& a) +{ + return sqrt(static_cast>(a)); +} + + +template <> +GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z) +{ + return abs(static_cast>(z)); +} + + +} // namespace thrust + + +namespace gko { + + +// It is required by NVHPC 23.3, `isnan` is undefined when NVHPC is used as a +// host compiler. +#if defined(__CUDACC__) || defined(GKO_COMPILING_HIP) + +__device__ __forceinline__ bool is_nan(const __half& val) +{ + // from the cuda_fp16.hpp +#if GINKGO_HIP_PLATFORM_HCC || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) + return __hisnan(val); +#else + return isnan(static_cast(val)); +#endif +} + +__device__ __forceinline__ bool is_nan(const thrust::complex<__half>& val) +{ + return is_nan(val.real()) || is_nan(val.imag()); +} + + +__device__ __forceinline__ __half abs(const __half& val) +{ +#if GINKGO_HIP_PLATFORM_HCC || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) + return __habs(val); +#else + return abs(static_cast(val)); +#endif +} + +__device__ __forceinline__ __half sqrt(const __half& val) +{ +#if GINKGO_HIP_PLATFORM_HCC || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) + return hsqrt(val); +#else + return sqrt(static_cast(val)); +#endif +} + + +#endif + + +} // namespace gko + + #endif // GKO_COMMON_CUDA_HIP_BASE_MATH_HPP_ diff --git a/common/cuda_hip/base/types.hpp b/common/cuda_hip/base/types.hpp index 08f0516d691..42ca57eb0bf 100644 --- a/common/cuda_hip/base/types.hpp +++ b/common/cuda_hip/base/types.hpp @@ -14,5 +14,19 @@ #error "Executor definition missing" #endif +#define THRUST_HALF_FRIEND_OPERATOR(_op, _opeq) \ + GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> operator _op( \ + const thrust::complex<__half> lhs, const thrust::complex<__half> rhs) \ + { \ + return thrust::complex{lhs} _op thrust::complex(rhs); \ + } + +THRUST_HALF_FRIEND_OPERATOR(+, +=) +THRUST_HALF_FRIEND_OPERATOR(-, -=) +THRUST_HALF_FRIEND_OPERATOR(*, *=) +THRUST_HALF_FRIEND_OPERATOR(/, /=) + +#undef THRUST_HALF_FRIEND_OPERATOR + #endif // GKO_COMMON_CUDA_HIP_BASE_TYPES_HPP_ diff --git a/cuda/base/types.hpp b/cuda/base/types.hpp index 05f07ceb8dd..05b604923da 100644 --- a/cuda/base/types.hpp +++ b/cuda/base/types.hpp @@ -20,7 +20,6 @@ namespace gko { - namespace kernels { namespace cuda { namespace detail { diff --git a/hip/base/types.hip.hpp b/hip/base/types.hip.hpp index c3982b7562e..6b78cceea99 100644 --- a/hip/base/types.hip.hpp +++ b/hip/base/types.hip.hpp @@ -26,7 +26,6 @@ #include "common/cuda_hip/base/runtime.hpp" - namespace gko { namespace kernels { namespace hip { From 476cf288792cfe80450c7d2454e997cf8e325673 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 28 Nov 2024 16:46:04 +0100 Subject: [PATCH 04/16] consider custom namespace for thrust::complex<__half> and benchmark --- benchmark/CMakeLists.txt | 6 ++++++ common/cuda_hip/base/math.hpp | 5 +++++ common/cuda_hip/base/thrust_macro.hpp | 22 ++++++++++++++++++++++ common/cuda_hip/base/types.hpp | 15 +++++++++------ 4 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 common/cuda_hip/base/thrust_macro.hpp diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 55ed76d1613..c780a497c32 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -18,6 +18,9 @@ function(ginkgo_benchmark_cusparse_linops type def) PRIVATE $<$:--expt-relaxed-constexpr>) endif() + if(GINKGO_CUDA_CUSTOM_THRUST_NAMESPACE) + target_compile_definitions(cusparse_linops_${type} PRIVATE THRUST_CUB_WRAPPED_NAMESPACE=gko) + endif() # make the dependency public to catch issues target_compile_definitions(cusparse_linops_${type} PUBLIC ${def}) target_compile_definitions(cusparse_linops_${type} PRIVATE GKO_COMPILING_CUDA) @@ -28,6 +31,9 @@ endfunction() function(ginkgo_benchmark_hipsparse_linops type def) add_library(hipsparse_linops_${type} utils/hip_linops.hip.cpp) set_source_files_properties(utils/hip_linops.hip.cpp PROPERTIES LANGUAGE HIP) + if(GINKGO_CUDA_CUSTOM_THRUST_NAMESPACE) + target_compile_definitions(hipsparse_linops_${type} PRIVATE THRUST_CUB_WRAPPED_NAMESPACE=gko) + endif() target_compile_definitions(hipsparse_linops_${type} PUBLIC ${def}) target_compile_definitions(hipsparse_linops_${type} PRIVATE GKO_COMPILING_HIP) target_include_directories(hipsparse_linops_${type} SYSTEM PRIVATE ${HIPBLAS_INCLUDE_DIRS} ${HIPSPARSE_INCLUDE_DIRS}) diff --git a/common/cuda_hip/base/math.hpp b/common/cuda_hip/base/math.hpp index 7f0391d904c..3d2975c1eee 100644 --- a/common/cuda_hip/base/math.hpp +++ b/common/cuda_hip/base/math.hpp @@ -26,6 +26,9 @@ #endif +#include "common/cuda_hip/base/thrust_macro.hpp" + + namespace gko { @@ -84,6 +87,7 @@ struct truncate_type_impl> { } // namespace gko +GKO_THRUST_NAEMSPACE_PREFIX namespace thrust { @@ -102,6 +106,7 @@ GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z) } // namespace thrust +GKO_THRUST_NAEMSPACE_POSTFIX namespace gko { diff --git a/common/cuda_hip/base/thrust_macro.hpp b/common/cuda_hip/base/thrust_macro.hpp new file mode 100644 index 00000000000..c5e3fc40010 --- /dev/null +++ b/common/cuda_hip/base/thrust_macro.hpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef GKO_COMMON_CUDA_HIP_BASE_THRUST_MACRO_HPP_ +#define GKO_COMMON_CUDA_HIP_BASE_THRUST_MACRO_HPP_ + +// although thrust provides the similar thing, these macro are only defined when +// they supported. Thus, we need to provide our own macro to make it work with +// the old version +#ifdef THRUST_CUB_WRAPPED_NAMESPACE +#define GKO_THRUST_NAEMSPACE_PREFIX namespace THRUST_CUB_WRAPPED_NAMESPACE { +#define GKO_THRUST_NAEMSPACE_POSTFIX } +#define GKO_THRUST_QUALIFIER ::THRUST_CUB_WRAPPED_NAMESPACE::thrust +#else +#define GKO_THRUST_NAEMSPACE_PREFIX +#define GKO_THRUST_NAEMSPACE_POSTFIX +#define GKO_THRUST_QUALIFIER ::thrust +#endif // THRUST_CUB_WRAPPED_NAMESPACE + + +#endif // GKO_COMMON_CUDA_HIP_BASE_THRUST_MACRO_HPP_ diff --git a/common/cuda_hip/base/types.hpp b/common/cuda_hip/base/types.hpp index 42ca57eb0bf..e65b179ed68 100644 --- a/common/cuda_hip/base/types.hpp +++ b/common/cuda_hip/base/types.hpp @@ -5,7 +5,7 @@ #ifndef GKO_COMMON_CUDA_HIP_BASE_TYPES_HPP_ #define GKO_COMMON_CUDA_HIP_BASE_TYPES_HPP_ - +#include "common/cuda_hip/base/math.hpp" #if defined(GKO_COMPILING_CUDA) #include "cuda/base/types.hpp" #elif defined(GKO_COMPILING_HIP) @@ -14,11 +14,14 @@ #error "Executor definition missing" #endif -#define THRUST_HALF_FRIEND_OPERATOR(_op, _opeq) \ - GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> operator _op( \ - const thrust::complex<__half> lhs, const thrust::complex<__half> rhs) \ - { \ - return thrust::complex{lhs} _op thrust::complex(rhs); \ + +#define THRUST_HALF_FRIEND_OPERATOR(_op, _opeq) \ + GKO_ATTRIBUTES GKO_INLINE GKO_THRUST_QUALIFIER::complex<__half> \ + operator _op(const GKO_THRUST_QUALIFIER::complex<__half> lhs, \ + const GKO_THRUST_QUALIFIER::complex<__half> rhs) \ + { \ + return GKO_THRUST_QUALIFIER::complex{ \ + lhs} _op GKO_THRUST_QUALIFIER::complex(rhs); \ } THRUST_HALF_FRIEND_OPERATOR(+, +=) From 7cc8c6f43c4661b567c7560742136bf7f3fae465 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 01:11:29 +0200 Subject: [PATCH 05/16] atomic and cooperative_groups --- common/cuda_hip/components/atomic.hpp | 54 ++++++++++++++++++++++- hip/components/cooperative_groups.hip.hpp | 12 ++--- omp/components/atomic.hpp | 54 +++++++++++++++++++++-- 3 files changed, 108 insertions(+), 12 deletions(-) diff --git a/common/cuda_hip/components/atomic.hpp b/common/cuda_hip/components/atomic.hpp index aeb77d48c75..954bc7476ed 100644 --- a/common/cuda_hip/components/atomic.hpp +++ b/common/cuda_hip/components/atomic.hpp @@ -39,6 +39,7 @@ struct atomic_helper { }; +// TODO: consider it implemented by memcpy. template __forceinline__ __device__ ResultType reinterpret(ValueType val) { @@ -95,15 +96,64 @@ __forceinline__ __device__ ResultType reinterpret(ValueType val) } \ }; + +#define GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(CONVERTER_TYPE) \ + template \ + struct atomic_helper< \ + ValueType, \ + std::enable_if_t<(sizeof(ValueType) == sizeof(CONVERTER_TYPE))>> { \ + __forceinline__ __device__ static ValueType atomic_add( \ + ValueType* __restrict__ addr, ValueType val) \ + { \ + assert(false); \ + using c_type = CONVERTER_TYPE; \ + return atomic_wrapper( \ + addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \ + old = *c_addr; \ + *c_addr = reinterpret( \ + val + reinterpret(assumed)); \ + }); \ + } \ + __forceinline__ __device__ static ValueType atomic_max( \ + ValueType* __restrict__ addr, ValueType val) \ + { \ + assert(false); \ + using c_type = CONVERTER_TYPE; \ + return atomic_wrapper( \ + addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \ + if (reinterpret(assumed) < val) { \ + old = *c_addr; \ + *c_addr = reinterpret(assumed); \ + } \ + }); \ + } \ + \ + private: \ + template \ + __forceinline__ __device__ static ValueType atomic_wrapper( \ + ValueType* __restrict__ addr, Callable set_old) \ + { \ + CONVERTER_TYPE* address_as_converter = \ + reinterpret_cast(addr); \ + CONVERTER_TYPE old = *address_as_converter; \ + CONVERTER_TYPE assumed = old; \ + set_old(old, assumed, address_as_converter); \ + return reinterpret(old); \ + } \ + }; + // Support 64-bit ATOMIC_ADD and ATOMIC_MAX GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned long long int); // Support 32-bit ATOMIC_ADD and ATOMIC_MAX GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned int); -#if defined(CUDA_VERSION) -// Support 16-bit ATOMIC_ADD and ATOMIC_MAX only on CUDA +#if defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700) +// Support 16-bit atomicCAS, atomicADD, and atomicMAX only on CUDA with CC +// >= 7.0 GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned short int); +#else +GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(unsigned short int) #endif diff --git a/hip/components/cooperative_groups.hip.hpp b/hip/components/cooperative_groups.hip.hpp index 36618bb7f3e..46c2fb195bc 100644 --- a/hip/components/cooperative_groups.hip.hpp +++ b/hip/components/cooperative_groups.hip.hpp @@ -306,7 +306,7 @@ class enable_extended_shuffle : public Group { SelectorType selector) const \ { \ return shuffle_impl( \ - [this](uint32 v, SelectorType s) { \ + [this](uint16 v, SelectorType s) { \ return static_cast(this)->_name(v, s); \ }, \ var, selector); \ @@ -326,12 +326,12 @@ class enable_extended_shuffle : public Group { shuffle_impl(ShuffleOperator intrinsic_shuffle, const ValueType var, SelectorType selector) { - static_assert(sizeof(ValueType) % sizeof(uint32) == 0, - "Unable to shuffle sizes which are not 4-byte multiples"); - constexpr auto value_size = sizeof(ValueType) / sizeof(uint32); + static_assert(sizeof(ValueType) % sizeof(uint16) == 0, + "Unable to shuffle sizes which are not 2-byte multiples"); + constexpr auto value_size = sizeof(ValueType) / sizeof(uint16); ValueType result; - auto var_array = reinterpret_cast(&var); - auto result_array = reinterpret_cast(&result); + auto var_array = reinterpret_cast(&var); + auto result_array = reinterpret_cast(&result); #pragma unroll for (std::size_t i = 0; i < value_size; ++i) { result_array[i] = intrinsic_shuffle(var_array[i], selector); diff --git a/omp/components/atomic.hpp b/omp/components/atomic.hpp index c3580cd36bb..35b94a65fe5 100644 --- a/omp/components/atomic.hpp +++ b/omp/components/atomic.hpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -32,10 +33,55 @@ void atomic_add(ValueType& out, ValueType val) // The C++ standard explicitly allows casting complex* to double* // [complex.numbers.general] auto values = reinterpret_cast*>(&out); -#pragma omp atomic - values[0] += real(val); -#pragma omp atomic - values[1] += imag(val); + atomic_add(values[0], real(val)); + atomic_add(values[1], imag(val)); +} + + +template +inline ResultType copy_cast(const ValueType& val) +{ + static_assert( + sizeof(ValueType) == sizeof(ResultType) && + std::alignment_of_v == std::alignment_of_v, + "only copy the same alignment and size type"); + ResultType res; + std::memcpy(&res, &val, sizeof(ValueType)); + return res; +} + + +template <> +void atomic_add(half& out, half val) +{ +#ifdef __NVCOMPILER +// NVC++ uses atomic capture on uint16 leads the following error. +// use of undefined value '%L.B*' br label %L.B* !llvm.loop !*, !dbg !* +#pragma omp critical + { + out += val; + } +#else + static_assert( + sizeof(half) == sizeof(uint16_t) && + std::alignment_of_v == std::alignment_of_v, + "half does not fulfill the requirement of reinterpret_cast to half or " + "vice versa."); + // It is undefined behavior with reinterpret_cast, but we do not have any + // workaround when the #omp atomic does not support custom precision + uint16_t* address_as_converter = reinterpret_cast(&out); + uint16_t old = *address_as_converter; + uint16_t assumed; + do { + assumed = old; + auto answer = copy_cast(copy_cast(assumed) + val); +#pragma omp atomic capture + { + old = *address_as_converter; + *address_as_converter = (old == assumed) ? answer : old; + } + } while (assumed != old); +#endif } From a85f46234012aca25cbceb640027e2e8692487a4 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 01:16:50 +0200 Subject: [PATCH 06/16] fix math and device_numeric_limit --- common/cuda_hip/base/math.hpp | 11 ++ common/cuda_hip/components/merging.hpp | 4 +- .../factorization/par_ict_kernels.cpp | 4 +- .../factorization/par_ilut_select_kernels.hpp | 4 +- .../factorization/par_ilut_spgeam_kernels.cpp | 4 +- common/cuda_hip/reorder/rcm_kernels.cpp | 2 +- cuda/test/base/math.cu | 4 +- hip/test/base/math.hip.cpp | 4 +- include/ginkgo/core/base/math.hpp | 103 +++++------------- 9 files changed, 54 insertions(+), 86 deletions(-) diff --git a/common/cuda_hip/base/math.hpp b/common/cuda_hip/base/math.hpp index 3d2975c1eee..f83533d8f0d 100644 --- a/common/cuda_hip/base/math.hpp +++ b/common/cuda_hip/base/math.hpp @@ -83,6 +83,17 @@ struct truncate_type_impl> { }; +template +struct is_complex_impl> : public std::true_type {}; + +template <> +struct is_complex_or_scalar_impl<__half> : public std::true_type {}; + +template +struct is_complex_or_scalar_impl> + : public is_complex_or_scalar_impl {}; + + } // namespace detail } // namespace gko diff --git a/common/cuda_hip/components/merging.hpp b/common/cuda_hip/components/merging.hpp index ab070741fbd..b832a97176e 100644 --- a/common/cuda_hip/components/merging.hpp +++ b/common/cuda_hip/components/merging.hpp @@ -131,7 +131,7 @@ __forceinline__ __device__ void group_merge(const ValueType* __restrict__ a, IndexType a_begin{}; IndexType b_begin{}; auto lane = static_cast(group.thread_rank()); - auto sentinel = device_numeric_limits::max; + auto sentinel = device_numeric_limits::max(); auto a_cur = checked_load(a, a_begin + lane, a_size, sentinel); auto b_cur = checked_load(b, b_begin + lane, b_size, sentinel); for (IndexType c_begin{}; c_begin < c_size; c_begin += group_size) { @@ -240,7 +240,7 @@ __forceinline__ __device__ void sequential_merge( auto c_size = a_size + b_size; IndexType a_begin{}; IndexType b_begin{}; - auto sentinel = device_numeric_limits::max; + auto sentinel = device_numeric_limits::max(); auto a_cur = checked_load(a, a_begin, a_size, sentinel); auto b_cur = checked_load(b, b_begin, b_size, sentinel); for (IndexType c_begin{}; c_begin < c_size; c_begin++) { diff --git a/common/cuda_hip/factorization/par_ict_kernels.cpp b/common/cuda_hip/factorization/par_ict_kernels.cpp index 94aa5e5124e..3446f124123 100644 --- a/common/cuda_hip/factorization/par_ict_kernels.cpp +++ b/common/cuda_hip/factorization/par_ict_kernels.cpp @@ -128,7 +128,7 @@ __global__ __launch_bounds__(default_block_size) void ict_tri_spgeam_init( IndexType l_new_begin = l_new_row_ptrs[row]; - constexpr auto sentinel = device_numeric_limits::max; + constexpr auto sentinel = device_numeric_limits::max(); // load column indices and values for the first merge step auto a_col = checked_load(a_col_idxs, a_begin + lane, a_end, sentinel); auto a_val = checked_load(a_vals, a_begin + lane, a_end, zero()); @@ -456,4 +456,4 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( } // namespace par_ict_factorization } // namespace GKO_DEVICE_NAMESPACE } // namespace kernels -} // namespace gko \ No newline at end of file +} // namespace gko diff --git a/common/cuda_hip/factorization/par_ilut_select_kernels.hpp b/common/cuda_hip/factorization/par_ilut_select_kernels.hpp index 6f5940c2b14..79a562ff834 100644 --- a/common/cuda_hip/factorization/par_ilut_select_kernels.hpp +++ b/common/cuda_hip/factorization/par_ilut_select_kernels.hpp @@ -254,7 +254,7 @@ __global__ __launch_bounds__(basecase_block_size) void basecase_select( const ValueType* __restrict__ input, IndexType size, IndexType rank, ValueType* __restrict__ out) { - constexpr auto sentinel = device_numeric_limits::inf; + constexpr auto sentinel = device_numeric_limits::inf(); ValueType local[basecase_local_size]; __shared__ ValueType sh_local[basecase_size]; for (int i = 0; i < basecase_local_size; ++i) { @@ -301,4 +301,4 @@ __global__ __launch_bounds__(config::warp_size) void find_bucket( } // namespace kernels } // namespace gko -#endif // GKO_COMMON_CUDA_HIP_FACTORIZATION_PAR_ILUT_SELECT_KERNELS_HIP_HPP_ \ No newline at end of file +#endif // GKO_COMMON_CUDA_HIP_FACTORIZATION_PAR_ILUT_SELECT_KERNELS_HIP_HPP_ diff --git a/common/cuda_hip/factorization/par_ilut_spgeam_kernels.cpp b/common/cuda_hip/factorization/par_ilut_spgeam_kernels.cpp index 6cc77660394..a29cf6f2cb3 100644 --- a/common/cuda_hip/factorization/par_ilut_spgeam_kernels.cpp +++ b/common/cuda_hip/factorization/par_ilut_spgeam_kernels.cpp @@ -150,7 +150,7 @@ __global__ __launch_bounds__(default_block_size) void tri_spgeam_init( IndexType l_new_begin = l_new_row_ptrs[row]; IndexType u_new_begin = u_new_row_ptrs[row]; - constexpr auto sentinel = device_numeric_limits::max; + constexpr auto sentinel = device_numeric_limits::max(); // load column indices and values for the first merge step auto a_col = checked_load(a_col_idxs, a_begin + lane, a_end, sentinel); auto a_val = checked_load(a_vals, a_begin + lane, a_end, zero()); @@ -396,4 +396,4 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( } // namespace par_ilut_factorization } // namespace GKO_DEVICE_NAMESPACE } // namespace kernels -} // namespace gko \ No newline at end of file +} // namespace gko diff --git a/common/cuda_hip/reorder/rcm_kernels.cpp b/common/cuda_hip/reorder/rcm_kernels.cpp index 75050d3e977..2bb18cbdd22 100644 --- a/common/cuda_hip/reorder/rcm_kernels.cpp +++ b/common/cuda_hip/reorder/rcm_kernels.cpp @@ -525,7 +525,7 @@ __global__ __launch_bounds__(default_block_size) void ubfs_min_neighbor_kernel( const auto begin = row_ptrs[row]; const auto end = row_ptrs[row + 1]; const auto cur_level = node_levels[row]; - auto min_neighbor = device_numeric_limits::max; + auto min_neighbor = device_numeric_limits::max(); for (auto nz = begin; nz < end; nz++) { const auto col = col_idxs[nz]; const auto neighbor_level = node_levels[col]; diff --git a/cuda/test/base/math.cu b/cuda/test/base/math.cu index d1d9373b0ef..1025c3cc489 100644 --- a/cuda/test/base/math.cu +++ b/cuda/test/base/math.cu @@ -26,7 +26,7 @@ namespace kernel { template __device__ bool test_real_is_finite_function(FuncType isfin) { - constexpr T inf = gko::device_numeric_limits::inf; + constexpr T inf = gko::device_numeric_limits::inf(); constexpr T quiet_nan = NAN; bool test_true{}; bool test_false{}; @@ -46,7 +46,7 @@ __device__ bool test_complex_is_finite_function(FuncType isfin) "Template type must be a complex type."); using T = gko::remove_complex; using c_type = gko::kernels::cuda::cuda_type; - constexpr T inf = gko::device_numeric_limits::inf; + constexpr T inf = gko::device_numeric_limits::inf(); constexpr T quiet_nan = NAN; bool test_true{}; bool test_false{}; diff --git a/hip/test/base/math.hip.cpp b/hip/test/base/math.hip.cpp index f01b56739d9..f69ca804aa9 100644 --- a/hip/test/base/math.hip.cpp +++ b/hip/test/base/math.hip.cpp @@ -32,7 +32,7 @@ namespace kernel { template __device__ bool test_real_is_finite_function(FuncType isfin) { - constexpr T inf = gko::device_numeric_limits::inf; + constexpr T inf = gko::device_numeric_limits::inf(); constexpr T quiet_nan = NAN; bool test_true{}; bool test_false{}; @@ -52,7 +52,7 @@ __device__ bool test_complex_is_finite_function(FuncType isfin) "Template type must be a complex type."); using T = gko::remove_complex; using c_type = gko::kernels::hip::hip_type; - constexpr T inf = gko::device_numeric_limits::inf; + constexpr T inf = gko::device_numeric_limits::inf(); constexpr T quiet_nan = NAN; bool test_true{}; bool test_false{}; diff --git a/include/ginkgo/core/base/math.hpp b/include/ginkgo/core/base/math.hpp index 73da407194e..e308b092ea6 100644 --- a/include/ginkgo/core/base/math.hpp +++ b/include/ginkgo/core/base/math.hpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -21,79 +22,6 @@ namespace gko { -class half; - - -// HIP should not see std::abs or std::sqrt, we want the custom implementation. -// Hence, provide the using declaration only for some cases -namespace kernels { -namespace reference { - - -using std::abs; - - -using std::sqrt; - - -} // namespace reference -} // namespace kernels - - -namespace kernels { -namespace omp { - - -using std::abs; - - -using std::sqrt; - - -} // namespace omp -} // namespace kernels - - -namespace kernels { -namespace cuda { - - -using std::abs; - - -using std::sqrt; - - -} // namespace cuda -} // namespace kernels - - -namespace kernels { -namespace dpcpp { - - -using std::abs; - - -using std::sqrt; - - -} // namespace dpcpp -} // namespace kernels - - -namespace test { - - -using std::abs; - - -using std::sqrt; - - -} // namespace test - - // type manipulations @@ -706,6 +634,13 @@ GKO_INLINE constexpr T one() return T(1); } +template <> +GKO_INLINE constexpr half one() +{ + constexpr auto bits = static_cast(0b0'01111'0000000000u); + return half::create_from_bits(bits); +} + /** * Returns the multiplicative identity for T. @@ -983,6 +918,7 @@ GKO_INLINE constexpr auto squared_norm(const T& x) return real(conj(x) * x); } +using std::abs; /** * Returns the absolute value of the object. @@ -1008,6 +944,27 @@ abs(const T& x) return sqrt(squared_norm(x)); } +// increase the priority in function lookup +GKO_INLINE gko::half abs(const std::complex& x) +{ + // Using float abs not sqrt on norm to avoid overflow + return static_cast(abs(std::complex(x))); +} + + +using std::sqrt; + +GKO_INLINE gko::half sqrt(gko::half a) +{ + return gko::half(std::sqrt(float(a))); +} + +GKO_INLINE std::complex sqrt(std::complex a) +{ + return std::complex(sqrt(std::complex( + static_cast(a.real()), static_cast(a.imag())))); +} + /** * Returns the value of pi. From 7b4829cbd3541144ec512ad7aa78306543a5a288 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 02:00:19 +0200 Subject: [PATCH 07/16] array operation in half --- .../components/absolute_array_kernels.cpp | 6 ++++-- .../unified/components/fill_array_kernels.cpp | 17 +++++++++++++---- .../components/precision_conversion_kernels.cpp | 3 ++- .../unified/components/reduce_array_kernels.cpp | 3 ++- core/base/array.cpp | 9 +++++---- core/base/segmented_array.cpp | 2 +- core/device_hooks/common_kernels.inc.cpp | 12 ++++++------ include/ginkgo/core/base/segmented_array.hpp | 7 ++++++- reference/components/absolute_array_kernels.cpp | 6 ++++-- reference/components/fill_array_kernels.cpp | 5 +++-- .../components/precision_conversion_kernels.cpp | 3 ++- reference/components/reduce_array_kernels.cpp | 3 ++- 12 files changed, 50 insertions(+), 26 deletions(-) diff --git a/common/unified/components/absolute_array_kernels.cpp b/common/unified/components/absolute_array_kernels.cpp index c9ab364353c..423fa234c39 100644 --- a/common/unified/components/absolute_array_kernels.cpp +++ b/common/unified/components/absolute_array_kernels.cpp @@ -23,7 +23,8 @@ void inplace_absolute_array(std::shared_ptr exec, data); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); template @@ -37,7 +38,8 @@ void outplace_absolute_array(std::shared_ptr exec, n, in, out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); } // namespace components diff --git a/common/unified/components/fill_array_kernels.cpp b/common/unified/components/fill_array_kernels.cpp index d78a6e9f346..3e87d782974 100644 --- a/common/unified/components/fill_array_kernels.cpp +++ b/common/unified/components/fill_array_kernels.cpp @@ -23,7 +23,7 @@ void fill_array(std::shared_ptr exec, ValueType* array, array, val); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_FILL_ARRAY_KERNEL); template GKO_DECLARE_FILL_ARRAY_KERNEL(bool); @@ -32,11 +32,20 @@ void fill_seq_array(std::shared_ptr exec, ValueType* array, size_type n) { run_kernel( - exec, [] GKO_KERNEL(auto idx, auto array) { array[idx] = idx; }, n, - array); + exec, + [] GKO_KERNEL(auto idx, auto array) { + if constexpr (std::is_same_v, half>) { + // __half can not be from int64_t + array[idx] = static_cast(idx); + } else { + array[idx] = idx; + } + }, + n, array); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF( + GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); } // namespace components diff --git a/common/unified/components/precision_conversion_kernels.cpp b/common/unified/components/precision_conversion_kernels.cpp index 0402d9bef68..94a8d4e4d0f 100644 --- a/common/unified/components/precision_conversion_kernels.cpp +++ b/common/unified/components/precision_conversion_kernels.cpp @@ -23,7 +23,8 @@ void convert_precision(std::shared_ptr exec, size, in, out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(GKO_DECLARE_CONVERT_PRECISION_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF( + GKO_DECLARE_CONVERT_PRECISION_KERNEL); } // namespace components diff --git a/common/unified/components/reduce_array_kernels.cpp b/common/unified/components/reduce_array_kernels.cpp index bc8da6fa311..1e7d19264cd 100644 --- a/common/unified/components/reduce_array_kernels.cpp +++ b/common/unified/components/reduce_array_kernels.cpp @@ -34,7 +34,8 @@ void reduce_add_array(std::shared_ptr exec, arr, result); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF( + GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); } // namespace components diff --git a/core/base/array.cpp b/core/base/array.cpp index a41f7c07e55..7a98223a7b2 100644 --- a/core/base/array.cpp +++ b/core/base/array.cpp @@ -51,7 +51,8 @@ void convert_data(std::shared_ptr exec, size_type size, void convert_data(std::shared_ptr, size_type, \ const From*, To*) -GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(GKO_DECLARE_ARRAY_CONVERSION); +GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF( + GKO_DECLARE_ARRAY_CONVERSION); } // namespace detail @@ -88,19 +89,19 @@ ValueType reduce_add(const array& input_arr, #define GKO_DECLARE_ARRAY_FILL(_type) void array<_type>::fill(const _type value) -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_FILL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_ARRAY_FILL); #define GKO_DECLARE_ARRAY_REDUCE_ADD(_type) \ void reduce_add(const array<_type>& arr, array<_type>& value) -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_REDUCE_ADD); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_ARRAY_REDUCE_ADD); #define GKO_DECLARE_ARRAY_REDUCE_ADD2(_type) \ _type reduce_add(const array<_type>& arr, const _type val) -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_ARRAY_REDUCE_ADD2); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_ARRAY_REDUCE_ADD2); } // namespace gko diff --git a/core/base/segmented_array.cpp b/core/base/segmented_array.cpp index d113139f8e2..4c6356799f9 100644 --- a/core/base/segmented_array.cpp +++ b/core/base/segmented_array.cpp @@ -180,7 +180,7 @@ segmented_array& segmented_array::operator=(segmented_array&& other) #define GKO_DECLARE_SEGMENTED_ARRAY(_type) struct segmented_array<_type> -GKO_INSTANTIATE_FOR_EACH_POD_TYPE(GKO_DECLARE_SEGMENTED_ARRAY); +GKO_INSTANTIATE_FOR_EACH_POD_TYPE_WITH_HALF(GKO_DECLARE_SEGMENTED_ARRAY); } // namespace gko diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 6ffeb1c5f71..224aacc7369 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -238,19 +238,19 @@ namespace GKO_HOOK_MODULE { namespace components { -GKO_STUB_VALUE_CONVERSION(GKO_DECLARE_CONVERT_PRECISION_KERNEL); +GKO_STUB_VALUE_CONVERSION_WITH_HALF(GKO_DECLARE_CONVERT_PRECISION_KERNEL); GKO_STUB_INDEX_TYPE(GKO_DECLARE_PREFIX_SUM_NONNEGATIVE_KERNEL); // explicitly instantiate for size_type, as this is // used in the SellP format template GKO_DECLARE_PREFIX_SUM_NONNEGATIVE_KERNEL(size_type); -GKO_STUB_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL); +GKO_STUB_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_FILL_ARRAY_KERNEL); template GKO_DECLARE_FILL_ARRAY_KERNEL(bool); -GKO_STUB_TEMPLATE_TYPE(GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); -GKO_STUB_TEMPLATE_TYPE(GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_STUB_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); +GKO_STUB_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); GKO_STUB_VALUE_AND_INDEX_TYPE( diff --git a/include/ginkgo/core/base/segmented_array.hpp b/include/ginkgo/core/base/segmented_array.hpp index 49a7e6f9d38..b34605cc902 100644 --- a/include/ginkgo/core/base/segmented_array.hpp +++ b/include/ginkgo/core/base/segmented_array.hpp @@ -2,7 +2,10 @@ // // SPDX-License-Identifier: BSD-3-Clause -#pragma once +#ifndef GKO_PUBLIC_CORE_BASE_SEGMENTED_ARRAY_HPP_ +#define GKO_PUBLIC_CORE_BASE_SEGMENTED_ARRAY_HPP_ + + #include #include @@ -183,3 +186,5 @@ class copy_back_deleter> } // namespace detail } // namespace gko + +#endif // GKO_PUBLIC_CORE_BASE_SEGMENTED_ARRAY_HPP_ diff --git a/reference/components/absolute_array_kernels.cpp b/reference/components/absolute_array_kernels.cpp index 964e1f80d6a..759caae894c 100644 --- a/reference/components/absolute_array_kernels.cpp +++ b/reference/components/absolute_array_kernels.cpp @@ -20,7 +20,8 @@ void inplace_absolute_array(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); template @@ -33,7 +34,8 @@ void outplace_absolute_array(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); } // namespace components diff --git a/reference/components/fill_array_kernels.cpp b/reference/components/fill_array_kernels.cpp index 1649aa87982..663ad8f5b6b 100644 --- a/reference/components/fill_array_kernels.cpp +++ b/reference/components/fill_array_kernels.cpp @@ -20,7 +20,7 @@ void fill_array(std::shared_ptr exec, ValueType* array, std::fill_n(array, n, val); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_FILL_ARRAY_KERNEL); template GKO_DECLARE_FILL_ARRAY_KERNEL(bool); @@ -31,7 +31,8 @@ void fill_seq_array(std::shared_ptr exec, std::iota(array, array + n, 0); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF( + GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); } // namespace components diff --git a/reference/components/precision_conversion_kernels.cpp b/reference/components/precision_conversion_kernels.cpp index db12d9316ee..5ec37a1cd72 100644 --- a/reference/components/precision_conversion_kernels.cpp +++ b/reference/components/precision_conversion_kernels.cpp @@ -20,7 +20,8 @@ void convert_precision(std::shared_ptr exec, std::copy_n(in, size, out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION(GKO_DECLARE_CONVERT_PRECISION_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_WITH_HALF( + GKO_DECLARE_CONVERT_PRECISION_KERNEL); } // namespace components diff --git a/reference/components/reduce_array_kernels.cpp b/reference/components/reduce_array_kernels.cpp index a70ef95a878..3c3c6f620ec 100644 --- a/reference/components/reduce_array_kernels.cpp +++ b/reference/components/reduce_array_kernels.cpp @@ -22,7 +22,8 @@ void reduce_add_array(std::shared_ptr exec, val.get_const_data()[0]); } -GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE(GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_TEMPLATE_TYPE_WITH_HALF( + GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); } // namespace components From 561f173af5742258588985c02b151aa3213bf53f Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Wed, 23 Oct 2024 18:30:29 +0200 Subject: [PATCH 08/16] matrix with half --- common/cuda_hip/matrix/coo_kernels.cpp | 10 +- .../matrix/csr_kernels.instantiate.cpp | 124 ++++--- .../cuda_hip/matrix/csr_kernels.template.cpp | 6 +- common/cuda_hip/matrix/dense_kernels.cpp | 33 +- common/cuda_hip/matrix/diagonal_kernels.cpp | 2 +- common/cuda_hip/matrix/ell_kernels.cpp | 4 +- .../matrix/fbcsr_kernels.instantiate.cpp | 21 +- .../matrix/fbcsr_kernels.template.cpp | 2 +- common/cuda_hip/matrix/sellp_kernels.cpp | 5 +- .../cuda_hip/matrix/sparsity_csr_kernels.cpp | 14 +- common/unified/matrix/coo_kernels.cpp | 4 +- common/unified/matrix/csr_kernels.cpp | 16 +- .../matrix/dense_kernels.instantiate.cpp | 100 ++--- common/unified/matrix/diagonal_kernels.cpp | 14 +- common/unified/matrix/ell_kernels.cpp | 13 +- common/unified/matrix/hybrid_kernels.cpp | 4 +- .../matrix/scaled_permutation_kernels.cpp | 4 +- common/unified/matrix/sellp_kernels.cpp | 10 +- .../unified/matrix/sparsity_csr_kernels.cpp | 6 +- core/device_hooks/common_kernels.inc.cpp | 351 +++++++++++------- core/matrix/coo.cpp | 29 +- core/matrix/csr.cpp | 29 +- core/matrix/dense.cpp | 38 +- core/matrix/diagonal.cpp | 30 +- core/matrix/ell.cpp | 30 +- core/matrix/fbcsr.cpp | 32 +- core/matrix/hybrid.cpp | 32 +- core/matrix/identity.cpp | 4 +- core/matrix/permutation.cpp | 7 +- core/matrix/row_gatherer.cpp | 13 +- core/matrix/scaled_permutation.cpp | 2 +- core/matrix/sellp.cpp | 33 +- core/matrix/sparsity_csr.cpp | 3 +- dpcpp/matrix/coo_kernels.dp.cpp | 10 +- dpcpp/matrix/csr_kernels.dp.cpp | 74 ++-- dpcpp/matrix/dense_kernels.dp.cpp | 86 +++-- dpcpp/matrix/diagonal_kernels.dp.cpp | 2 +- dpcpp/matrix/ell_kernels.dp.cpp | 4 +- dpcpp/matrix/fbcsr_kernels.dp.cpp | 21 +- dpcpp/matrix/sellp_kernels.dp.cpp | 5 +- dpcpp/matrix/sparsity_csr_kernels.dp.cpp | 14 +- .../ginkgo/core/base/precision_dispatch.hpp | 29 +- include/ginkgo/core/matrix/coo.hpp | 57 ++- include/ginkgo/core/matrix/csr.hpp | 73 ++-- include/ginkgo/core/matrix/dense.hpp | 34 +- include/ginkgo/core/matrix/diagonal.hpp | 36 +- include/ginkgo/core/matrix/ell.hpp | 57 ++- include/ginkgo/core/matrix/fbcsr.hpp | 61 ++- include/ginkgo/core/matrix/hybrid.hpp | 40 +- include/ginkgo/core/matrix/sellp.hpp | 57 ++- omp/matrix/coo_kernels.cpp | 10 +- omp/matrix/csr_kernels.cpp | 60 +-- omp/matrix/dense_kernels.cpp | 33 +- omp/matrix/diagonal_kernels.cpp | 2 +- omp/matrix/ell_kernels.cpp | 4 +- omp/matrix/fbcsr_kernels.cpp | 21 +- omp/matrix/sellp_kernels.cpp | 5 +- omp/matrix/sparsity_csr_kernels.cpp | 10 +- reference/matrix/coo_kernels.cpp | 14 +- reference/matrix/csr_kernels.cpp | 76 ++-- reference/matrix/dense_kernels.cpp | 132 ++++--- reference/matrix/diagonal_kernels.cpp | 16 +- reference/matrix/ell_kernels.cpp | 19 +- reference/matrix/fbcsr_kernels.cpp | 21 +- reference/matrix/hybrid_kernels.cpp | 4 +- .../matrix/scaled_permutation_kernels.cpp | 4 +- reference/matrix/sellp_kernels.cpp | 15 +- reference/matrix/sparsity_csr_kernels.cpp | 16 +- reference/test/base/combination.cpp | 13 +- reference/test/matrix/coo_kernels.cpp | 17 +- reference/test/matrix/csr_kernels.cpp | 8 +- reference/test/matrix/dense_kernels.cpp | 21 +- reference/test/matrix/diagonal_kernels.cpp | 15 +- reference/test/matrix/ell_kernels.cpp | 17 +- reference/test/matrix/fbcsr_kernels.cpp | 17 +- reference/test/matrix/hybrid_kernels.cpp | 17 +- reference/test/matrix/scaled_permutation.cpp | 3 +- reference/test/matrix/sellp_kernels.cpp | 26 +- test/matrix/fbcsr_kernels.cpp | 14 +- test/matrix/matrix.cpp | 10 +- 80 files changed, 1480 insertions(+), 845 deletions(-) diff --git a/common/cuda_hip/matrix/coo_kernels.cpp b/common/cuda_hip/matrix/coo_kernels.cpp index cffe18d981b..4609f9f7f95 100644 --- a/common/cuda_hip/matrix/coo_kernels.cpp +++ b/common/cuda_hip/matrix/coo_kernels.cpp @@ -238,7 +238,8 @@ void spmv(std::shared_ptr exec, spmv2(exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV_KERNEL); template @@ -253,7 +254,7 @@ void advanced_spmv(std::shared_ptr exec, advanced_spmv2(exec, alpha, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); @@ -295,7 +296,8 @@ void spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV2_KERNEL); template @@ -338,7 +340,7 @@ void advanced_spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); diff --git a/common/cuda_hip/matrix/csr_kernels.instantiate.cpp b/common/cuda_hip/matrix/csr_kernels.instantiate.cpp index f62ca1c1815..2e28de95f5d 100644 --- a/common/cuda_hip/matrix/csr_kernels.instantiate.cpp +++ b/common/cuda_hip/matrix/csr_kernels.instantiate.cpp @@ -17,108 +17,136 @@ namespace csr { // begin -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(GKO_DECLARE_CSR_SPMV_KERNEL, - int32); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(GKO_DECLARE_CSR_SPMV_KERNEL, - int32); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(GKO_DECLARE_CSR_SPMV_KERNEL, - int32); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(GKO_DECLARE_CSR_SPMV_KERNEL, - int32); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1(GKO_DECLARE_CSR_SPMV_KERNEL, - int64); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2(GKO_DECLARE_CSR_SPMV_KERNEL, - int64); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3(GKO_DECLARE_CSR_SPMV_KERNEL, - int64); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4(GKO_DECLARE_CSR_SPMV_KERNEL, - int64); +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF( + GKO_DECLARE_CSR_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF( + GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF( + GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int32); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT1_WITH_HALF( + GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2_WITH_HALF( + GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); +// split +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT3( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT5_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT4( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_TYPE_SPLIT6_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL, int64); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_TRANSPOSE_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEMM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEAM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_EXTRACT_DIAGONAL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); // end diff --git a/common/cuda_hip/matrix/csr_kernels.template.cpp b/common/cuda_hip/matrix/csr_kernels.template.cpp index 909349ed7ab..f808e234670 100644 --- a/common/cuda_hip/matrix/csr_kernels.template.cpp +++ b/common/cuda_hip/matrix/csr_kernels.template.cpp @@ -278,7 +278,7 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_spmv( { using arithmetic_type = typename output_accessor::arithmetic_type; using output_type = typename output_accessor::storage_type; - const arithmetic_type scale_factor = alpha[0]; + const auto scale_factor = static_cast(alpha[0]); spmv_kernel(nwarps, num_rows, val, col_idxs, row_ptrs, srow, b, c, [&scale_factor](const arithmetic_type& x) { return static_cast(scale_factor * x); @@ -486,7 +486,7 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_reduce( const IndexType* __restrict__ last_row, const MatrixValueType* __restrict__ alpha, acc::range c) { - const arithmetic_type alpha_val = alpha[0]; + const auto alpha_val = static_cast(alpha[0]); merge_path_reduce( nwarps, last_val, last_row, c, [&alpha_val](const arithmetic_type& x) { return alpha_val * x; }); @@ -1193,7 +1193,7 @@ __global__ __launch_bounds__(default_block_size) void build_csr_lookup( const auto i = base_i + lane; const auto col = i < row_len ? local_cols[i] - : device_numeric_limits::max; + : device_numeric_limits::max(); const auto rel_col = static_cast(col - min_col); const auto block = rel_col / bitmap_block_size; const auto col_in_block = rel_col % bitmap_block_size; diff --git a/common/cuda_hip/matrix/dense_kernels.cpp b/common/cuda_hip/matrix/dense_kernels.cpp index d8391ace023..d0d4985dd82 100644 --- a/common/cuda_hip/matrix/dense_kernels.cpp +++ b/common/cuda_hip/matrix/dense_kernels.cpp @@ -461,7 +461,7 @@ void convert_to_coo(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); @@ -491,7 +491,7 @@ void convert_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); @@ -521,7 +521,7 @@ void convert_to_ell(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); @@ -544,7 +544,7 @@ void convert_to_fbcsr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); @@ -565,7 +565,7 @@ void count_nonzero_blocks_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL); @@ -598,7 +598,7 @@ void convert_to_hybrid(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); @@ -629,7 +629,7 @@ void convert_to_sellp(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); @@ -657,7 +657,7 @@ void convert_to_sparsity_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); @@ -681,7 +681,7 @@ void compute_dot_dispatch(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); @@ -706,7 +706,7 @@ void compute_conj_dot_dispatch(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); @@ -729,7 +729,7 @@ void compute_norm2_dispatch(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); @@ -760,7 +760,8 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); template @@ -787,7 +788,7 @@ void apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL); template @@ -812,7 +813,8 @@ void transpose(std::shared_ptr exec, } }; -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); template @@ -837,7 +839,8 @@ void conj_transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); } // namespace dense diff --git a/common/cuda_hip/matrix/diagonal_kernels.cpp b/common/cuda_hip/matrix/diagonal_kernels.cpp index e12d3ed4f9f..baee454c36d 100644 --- a/common/cuda_hip/matrix/diagonal_kernels.cpp +++ b/common/cuda_hip/matrix/diagonal_kernels.cpp @@ -81,7 +81,7 @@ void apply_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); diff --git a/common/cuda_hip/matrix/ell_kernels.cpp b/common/cuda_hip/matrix/ell_kernels.cpp index bfdd3f21e51..16371166662 100644 --- a/common/cuda_hip/matrix/ell_kernels.cpp +++ b/common/cuda_hip/matrix/ell_kernels.cpp @@ -354,7 +354,7 @@ void spmv(std::shared_ptr exec, b, c); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_SPMV_KERNEL); @@ -388,7 +388,7 @@ void advanced_spmv(std::shared_ptr exec, b, c, alpha, beta); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); diff --git a/common/cuda_hip/matrix/fbcsr_kernels.instantiate.cpp b/common/cuda_hip/matrix/fbcsr_kernels.instantiate.cpp index a3beaac4a85..a7a0263cd35 100644 --- a/common/cuda_hip/matrix/fbcsr_kernels.instantiate.cpp +++ b/common/cuda_hip/matrix/fbcsr_kernels.instantiate.cpp @@ -17,26 +17,27 @@ namespace fbcsr { // begin -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SPMV_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); // end diff --git a/common/cuda_hip/matrix/fbcsr_kernels.template.cpp b/common/cuda_hip/matrix/fbcsr_kernels.template.cpp index 23f865b6ace..e10cf10b540 100644 --- a/common/cuda_hip/matrix/fbcsr_kernels.template.cpp +++ b/common/cuda_hip/matrix/fbcsr_kernels.template.cpp @@ -564,7 +564,7 @@ void transpose_blocks_impl(syn::value_list, if (grid_dim > 0) { kernel::transpose_blocks <<get_stream()>>>( - nbnz, mat->get_values()); + nbnz, as_device_type(mat->get_values())); } } diff --git a/common/cuda_hip/matrix/sellp_kernels.cpp b/common/cuda_hip/matrix/sellp_kernels.cpp index 3e8fba395b3..4d37a0452a6 100644 --- a/common/cuda_hip/matrix/sellp_kernels.cpp +++ b/common/cuda_hip/matrix/sellp_kernels.cpp @@ -105,7 +105,8 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_SPMV_KERNEL); template @@ -131,7 +132,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); diff --git a/common/cuda_hip/matrix/sparsity_csr_kernels.cpp b/common/cuda_hip/matrix/sparsity_csr_kernels.cpp index 269708e19ae..ddda357fa31 100644 --- a/common/cuda_hip/matrix/sparsity_csr_kernels.cpp +++ b/common/cuda_hip/matrix/sparsity_csr_kernels.cpp @@ -72,11 +72,11 @@ __device__ void device_classical_spmv(const size_type num_rows, const auto subrow = thread::get_subwarp_num_flat(); const auto subid = subwarp_tile.thread_rank(); const IndexType column_id = blockIdx.y; - const arithmetic_type value = val[0]; + const auto value = static_cast(val[0]); auto row = thread::get_subwarp_id_flat(); for (; row < num_rows; row += subrow) { const auto ind_end = row_ptrs[row + 1]; - arithmetic_type temp_val = zero(); + auto temp_val = zero(); for (auto ind = row_ptrs[row] + subid; ind < ind_end; ind += subwarp_size) { temp_val += value * b(col_idxs[ind], column_id); @@ -138,7 +138,7 @@ void transpose(std::shared_ptr exec, matrix::SparsityCsr* trans) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); @@ -246,7 +246,7 @@ void spmv(std::shared_ptr exec, syn::value_list(), syn::type_list<>(), exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); @@ -264,7 +264,7 @@ void advanced_spmv(std::shared_ptr exec, syn::value_list(), syn::type_list<>(), exec, a, b, c, alpha, beta); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_ADVANCED_SPMV_KERNEL); @@ -297,7 +297,7 @@ void sort_by_column_index(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); @@ -320,7 +320,7 @@ void is_sorted_by_column_index( cpu_array = gpu_array; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_IS_SORTED_BY_COLUMN_INDEX); diff --git a/common/unified/matrix/coo_kernels.cpp b/common/unified/matrix/coo_kernels.cpp index ce13d7500ab..233dffc6f37 100644 --- a/common/unified/matrix/coo_kernels.cpp +++ b/common/unified/matrix/coo_kernels.cpp @@ -38,7 +38,7 @@ void extract_diagonal(std::shared_ptr exec, diag->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL); @@ -58,7 +58,7 @@ void fill_in_dense(std::shared_ptr exec, orig->get_const_row_idxs(), orig->get_const_col_idxs(), result); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_FILL_IN_DENSE_KERNEL); diff --git a/common/unified/matrix/csr_kernels.cpp b/common/unified/matrix/csr_kernels.cpp index 5236c1c9da9..d5741bb3e1c 100644 --- a/common/unified/matrix/csr_kernels.cpp +++ b/common/unified/matrix/csr_kernels.cpp @@ -52,7 +52,7 @@ void inv_col_permute(std::shared_ptr exec, col_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_COL_PERMUTE_KERNEL); @@ -86,7 +86,7 @@ void inv_col_scale_permute(std::shared_ptr exec, col_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_COL_SCALE_PERMUTE_KERNEL); @@ -102,7 +102,8 @@ void scale(std::shared_ptr exec, x->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SCALE_KERNEL); template @@ -117,7 +118,8 @@ void inv_scale(std::shared_ptr exec, x->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_SCALE_KERNEL); template @@ -152,7 +154,7 @@ void convert_to_sellp(std::shared_ptr exec, output->get_col_idxs(), output->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL); @@ -183,7 +185,7 @@ void convert_to_ell(std::shared_ptr exec, output->get_col_idxs(), output->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_ELL_KERNEL); @@ -227,7 +229,7 @@ void convert_to_hybrid(std::shared_ptr exec, result->get_coo_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_HYBRID_KERNEL); diff --git a/common/unified/matrix/dense_kernels.instantiate.cpp b/common/unified/matrix/dense_kernels.instantiate.cpp index aca8ad5bec4..dcf48573fc6 100644 --- a/common/unified/matrix/dense_kernels.instantiate.cpp +++ b/common/unified/matrix/dense_kernels.instantiate.cpp @@ -12,87 +12,99 @@ namespace dense { // begin -GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY( +GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY_WITH_HALF( GKO_DECLARE_DENSE_COPY_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_FILL_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_FILL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_FILL_IN_MATRIX_DATA_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SCALE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SCALE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ADD_SCALED_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SUB_SCALED_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_SQRT_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_SQRT_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_NONSYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_NONSYMM_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( GKO_DECLARE_DENSE_ROW_GATHER_KERNEL); -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( GKO_DECLARE_DENSE_ADVANCED_ROW_GATHER_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COL_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_ROW_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_COL_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ROW_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_ROW_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COL_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_COL_SCALE_PERMUTE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MAKE_COMPLEX_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GET_REAL_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GET_IMAG_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_MAKE_COMPLEX_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_REAL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_IMAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_MAX_NNZ_PER_ROW_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_SLICE_SETS_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL_SIZE_T); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_SQUARED_NORM2_KERNEL); // split -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); // end diff --git a/common/unified/matrix/diagonal_kernels.cpp b/common/unified/matrix/diagonal_kernels.cpp index dae037a5134..75960e800d7 100644 --- a/common/unified/matrix/diagonal_kernels.cpp +++ b/common/unified/matrix/diagonal_kernels.cpp @@ -36,7 +36,8 @@ void apply_to_dense(std::shared_ptr exec, b->get_size(), a->get_const_values(), b, c, inverse); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); template @@ -53,7 +54,7 @@ void right_apply_to_dense(std::shared_ptr exec, b->get_size(), a->get_const_values(), b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_DENSE_KERNEL); @@ -74,7 +75,7 @@ void right_apply_to_csr(std::shared_ptr exec, c->get_const_col_idxs()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_CSR_KERNEL); @@ -95,7 +96,7 @@ void fill_in_matrix_data(std::shared_ptr exec, output->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_FILL_IN_MATRIX_DATA_KERNEL); @@ -120,7 +121,7 @@ void convert_to_csr(std::shared_ptr exec, result->get_col_idxs(), result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_CONVERT_TO_CSR_KERNEL); @@ -137,7 +138,8 @@ void conj_transpose(std::shared_ptr exec, orig->get_size()[0], orig->get_const_values(), trans->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); } // namespace diagonal diff --git a/common/unified/matrix/ell_kernels.cpp b/common/unified/matrix/ell_kernels.cpp index 6d23e08b68b..24fc90a888e 100644 --- a/common/unified/matrix/ell_kernels.cpp +++ b/common/unified/matrix/ell_kernels.cpp @@ -67,7 +67,7 @@ void fill_in_matrix_data(std::shared_ptr exec, output->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_FILL_IN_MATRIX_DATA_KERNEL); @@ -94,7 +94,7 @@ void fill_in_dense(std::shared_ptr exec, source->get_const_values(), result); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_FILL_IN_DENSE_KERNEL); @@ -121,7 +121,8 @@ void copy(std::shared_ptr exec, result->get_col_idxs(), result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_COPY_KERNEL); template @@ -150,7 +151,7 @@ void convert_to_csr(std::shared_ptr exec, result->get_col_idxs(), result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_CONVERT_TO_CSR_KERNEL); @@ -172,7 +173,7 @@ void count_nonzeros_per_row(std::shared_ptr exec, static_cast(source->get_stride()), source->get_const_col_idxs()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_COUNT_NONZEROS_PER_ROW_KERNEL); @@ -198,7 +199,7 @@ void extract_diagonal(std::shared_ptr exec, orig->get_const_values(), diag->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_EXTRACT_DIAGONAL_KERNEL); diff --git a/common/unified/matrix/hybrid_kernels.cpp b/common/unified/matrix/hybrid_kernels.cpp index 8a21a2415f7..79a596febea 100644 --- a/common/unified/matrix/hybrid_kernels.cpp +++ b/common/unified/matrix/hybrid_kernels.cpp @@ -89,7 +89,7 @@ void fill_in_matrix_data(std::shared_ptr exec, result->get_coo_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_HYBRID_FILL_IN_MATRIX_DATA_KERNEL); @@ -150,7 +150,7 @@ void convert_to_csr(std::shared_ptr exec, coo_row_ptrs, result->get_col_idxs(), result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL); diff --git a/common/unified/matrix/scaled_permutation_kernels.cpp b/common/unified/matrix/scaled_permutation_kernels.cpp index 3eaab65e8e6..4cdc7974e50 100644 --- a/common/unified/matrix/scaled_permutation_kernels.cpp +++ b/common/unified/matrix/scaled_permutation_kernels.cpp @@ -32,7 +32,7 @@ void invert(std::shared_ptr exec, size, input_scale, input_permutation, output_scale, output_permutation); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SCALED_PERMUTATION_INVERT_KERNEL); @@ -58,7 +58,7 @@ void compose(std::shared_ptr exec, output_permutation, output_scale); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SCALED_PERMUTATION_COMPOSE_KERNEL); diff --git a/common/unified/matrix/sellp_kernels.cpp b/common/unified/matrix/sellp_kernels.cpp index 93b71ff43f2..23bfe160a69 100644 --- a/common/unified/matrix/sellp_kernels.cpp +++ b/common/unified/matrix/sellp_kernels.cpp @@ -87,7 +87,7 @@ void fill_in_matrix_data(std::shared_ptr exec, output->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_FILL_IN_MATRIX_DATA_KERNEL); @@ -119,7 +119,7 @@ void fill_in_dense(std::shared_ptr exec, source->get_const_values(), result); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_FILL_IN_DENSE_KERNEL); @@ -149,7 +149,7 @@ void count_nonzeros_per_row(std::shared_ptr exec, source->get_const_slice_sets(), source->get_const_col_idxs(), result); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_COUNT_NONZEROS_PER_ROW_KERNEL); @@ -183,7 +183,7 @@ void convert_to_csr(std::shared_ptr exec, result->get_col_idxs(), result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_CONVERT_TO_CSR_KERNEL); @@ -215,7 +215,7 @@ void extract_diagonal(std::shared_ptr exec, orig->get_const_values(), diag->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_EXTRACT_DIAGONAL_KERNEL); diff --git a/common/unified/matrix/sparsity_csr_kernels.cpp b/common/unified/matrix/sparsity_csr_kernels.cpp index c5a9c79a89b..b3f26358ad3 100644 --- a/common/unified/matrix/sparsity_csr_kernels.cpp +++ b/common/unified/matrix/sparsity_csr_kernels.cpp @@ -41,7 +41,7 @@ void fill_in_dense(std::shared_ptr exec, input->get_const_col_idxs(), input->get_const_value(), output); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_FILL_IN_DENSE_KERNEL); @@ -70,7 +70,7 @@ void diagonal_element_prefix_sum( components::prefix_sum_nonnegative(exec, prefix_sum, num_rows + 1); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_DIAGONAL_ELEMENT_PREFIX_SUM_KERNEL); @@ -106,7 +106,7 @@ void remove_diagonal_elements(std::shared_ptr exec, matrix->get_col_idxs()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_REMOVE_DIAGONAL_ELEMENTS_KERNEL); diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 224aacc7369..78b80ec2859 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -411,69 +411,93 @@ GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL); namespace dense { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL); -GKO_STUB_VALUE_CONVERSION_OR_COPY(GKO_DECLARE_DENSE_COPY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_FILL_KERNEL); -GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SCALE_KERNEL); -GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_INV_SCALE_KERNEL); -GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_KERNEL); -GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_KERNEL); -GKO_STUB_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_SQUARED_NORM2_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_SQRT_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_FILL_IN_MATRIX_DATA_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_MAX_NNZ_PER_ROW_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_SLICE_SETS_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL); +GKO_STUB_VALUE_CONVERSION_OR_COPY_WITH_HALF(GKO_DECLARE_DENSE_COPY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_FILL_KERNEL); +GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF(GKO_DECLARE_DENSE_SCALE_KERNEL); +GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF(GKO_DECLARE_DENSE_INV_SCALE_KERNEL); +GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF(GKO_DECLARE_DENSE_ADD_SCALED_KERNEL); +GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF(GKO_DECLARE_DENSE_SUB_SCALED_KERNEL); +GKO_STUB_VALUE_AND_SCALAR_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_SQUARED_NORM2_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_SQRT_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_MAX_NNZ_PER_ROW_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COMPUTE_SLICE_SETS_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL_SIZE_T); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL); -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_DENSE_ROW_GATHER_KERNEL); -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2( +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( + GKO_DECLARE_DENSE_ROW_GATHER_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( GKO_DECLARE_DENSE_ADVANCED_ROW_GATHER_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_COL_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_ROW_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_COL_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_NONSYMM_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_NONSYMM_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_SYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_SYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_ROW_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_COL_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_ROW_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_INV_COL_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DENSE_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_DENSE_COL_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_ROW_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_COL_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_NONSYMM_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_NONSYMM_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SYMM_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_SYMM_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_ROW_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COL_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_ROW_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_INV_COL_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_NONSYMM_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_MAKE_COMPLEX_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_GET_REAL_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_GET_IMAG_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_MAKE_COMPLEX_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_REAL_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_IMAG_KERNEL); } // namespace dense @@ -482,13 +506,17 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_GET_IMAG_KERNEL); namespace diagonal { -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DIAGONAL_CONVERT_TO_CSR_KERNEL); -GKO_STUB_VALUE_TYPE(GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DIAGONAL_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_FILL_IN_MATRIX_DATA_KERNEL); } // namespace diagonal @@ -675,17 +703,21 @@ GKO_STUB_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_MULTIGRID_KCYCLE_CHECK_STOP_KERNEL); namespace sparsity_csr { -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE( +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SPARSITY_CSR_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SPARSITY_CSR_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_DIAGONAL_ELEMENT_PREFIX_SUM_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_REMOVE_DIAGONAL_ELEMENTS_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_IS_SORTED_BY_COLUMN_INDEX); @@ -695,38 +727,54 @@ GKO_STUB_VALUE_AND_INDEX_TYPE( namespace csr { -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPMV_KERNEL); -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CONVERT_TO_ELL_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CONVERT_TO_HYBRID_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_COL_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_COL_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_SPMV_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_SPGEMM_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_SPGEAM_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_CONVERT_TO_ELL_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_CONVERT_TO_HYBRID_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_INV_COL_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_COL_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL); GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_OFFSETS_KERNEL); GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BUILD_LOOKUP_KERNEL); @@ -735,12 +783,14 @@ GKO_STUB_INDEX_TYPE(GKO_DECLARE_CSR_BENCHMARK_LOOKUP_KERNEL); template GKO_DECLARE_CSR_SCALE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SCALE_KERNEL); template GKO_DECLARE_CSR_INV_SCALE_KERNEL(ValueType, IndexType) GKO_NOT_COMPILED(GKO_HOOK_MODULE); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_SCALE_KERNEL); } // namespace csr @@ -749,16 +799,20 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SCALE_KERNEL); namespace fbcsr { -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); } // namespace fbcsr @@ -767,12 +821,13 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); namespace coo { -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_SPMV2_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL); } // namespace coo @@ -781,15 +836,19 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL); namespace ell { -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_SPMV_KERNEL); -GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_ELL_SPMV_KERNEL); +GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_FILL_IN_MATRIX_DATA_KERNEL); GKO_STUB_INDEX_TYPE(GKO_DECLARE_ELL_COMPUTE_MAX_ROW_NNZ_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_COPY_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_CONVERT_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_COUNT_NONZEROS_PER_ROW_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_EXTRACT_DIAGONAL_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_ELL_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_ELL_COPY_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_ELL_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_COUNT_NONZEROS_PER_ROW_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_EXTRACT_DIAGONAL_KERNEL); } // namespace ell @@ -822,8 +881,10 @@ namespace hybrid { GKO_STUB(GKO_DECLARE_HYBRID_COMPUTE_COO_ROW_PTRS_KERNEL); GKO_STUB(GKO_DECLARE_HYBRID_COMPUTE_ROW_NNZ); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_HYBRID_FILL_IN_MATRIX_DATA_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_HYBRID_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL); } // namespace hybrid @@ -842,8 +903,10 @@ GKO_STUB_INDEX_TYPE(GKO_DECLARE_PERMUTATION_COMPOSE_KERNEL); namespace scaled_permutation { -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SCALED_PERMUTATION_INVERT_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SCALED_PERMUTATION_COMPOSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SCALED_PERMUTATION_INVERT_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SCALED_PERMUTATION_COMPOSE_KERNEL); } // namespace scaled_permutation @@ -852,14 +915,18 @@ GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SCALED_PERMUTATION_COMPOSE_KERNEL); namespace sellp { -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_FILL_IN_MATRIX_DATA_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_SELLP_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_FILL_IN_MATRIX_DATA_KERNEL); GKO_STUB_INDEX_TYPE(GKO_DECLARE_SELLP_COMPUTE_SLICE_SETS_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_FILL_IN_DENSE_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_CONVERT_TO_CSR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_COUNT_NONZEROS_PER_ROW_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_EXTRACT_DIAGONAL_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_SELLP_FILL_IN_DENSE_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_CONVERT_TO_CSR_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_COUNT_NONZEROS_PER_ROW_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_EXTRACT_DIAGONAL_KERNEL); } // namespace sellp diff --git a/core/matrix/coo.cpp b/core/matrix/coo.cpp index 1368dc261c3..7b3b3876295 100644 --- a/core/matrix/coo.cpp +++ b/core/matrix/coo.cpp @@ -214,7 +214,7 @@ void Coo::apply2_impl(const LinOp* alpha, const LinOp* b, template void Coo::convert_to( - Coo, IndexType>* result) const + Coo, IndexType>* result) const { result->values_ = this->values_; result->row_idxs_ = this->row_idxs_; @@ -225,12 +225,35 @@ void Coo::convert_to( template void Coo::move_to( - Coo, IndexType>* result) + Coo, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Coo::convert_to( + Coo>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->row_idxs_ = this->row_idxs_; + result->col_idxs_ = this->col_idxs_; + result->set_size(this->get_size()); +} + + +template +void Coo::move_to( + Coo>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + template void Coo::convert_to( Csr* result) const @@ -404,7 +427,7 @@ Coo::compute_absolute() const #define GKO_DECLARE_COO_MATRIX(ValueType, IndexType) \ class Coo -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_COO_MATRIX); } // namespace matrix diff --git a/core/matrix/csr.cpp b/core/matrix/csr.cpp index 897eb1a48db..1bb3e778478 100644 --- a/core/matrix/csr.cpp +++ b/core/matrix/csr.cpp @@ -304,7 +304,7 @@ void Csr::apply_impl(const LinOp* alpha, const LinOp* b, template void Csr::convert_to( - Csr, IndexType>* result) const + Csr, IndexType>* result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -316,11 +316,34 @@ void Csr::convert_to( template void Csr::move_to( - Csr, IndexType>* result) + Csr, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Csr::convert_to( + Csr>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->row_ptrs_ = this->row_ptrs_; + result->set_size(this->get_size()); + convert_strategy_helper(result); +} + + +template +void Csr::move_to( + Csr>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + template void Csr::convert_to( @@ -1047,7 +1070,7 @@ void Csr::add_scaled_identity_impl(const LinOp* a, #define GKO_DECLARE_CSR_MATRIX(ValueType, IndexType) \ class Csr -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_CSR_MATRIX); } // namespace matrix diff --git a/core/matrix/dense.cpp b/core/matrix/dense.cpp index 367b0232969..071e689232e 100644 --- a/core/matrix/dense.cpp +++ b/core/matrix/dense.cpp @@ -582,7 +582,7 @@ Dense::Dense(Dense&& other) : Dense(other.get_executor()) template void Dense::convert_to( - Dense>* result) const + Dense>* result) const { if (result->get_size() != this->get_size()) { result->set_size(this->get_size()); @@ -597,12 +597,41 @@ void Dense::convert_to( template -void Dense::move_to(Dense>* result) +void Dense::move_to( + Dense>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Dense::convert_to( + Dense>>* + result) const +{ + if (result->get_size() != this->get_size()) { + result->set_size(this->get_size()); + result->stride_ = stride_; + result->values_.resize_and_reset(result->get_size()[0] * + result->stride_); + } + auto exec = this->get_executor(); + exec->run(dense::make_copy( + this, make_temporary_output_clone(exec, result).get())); +} + + +template +void Dense::move_to( + Dense>>* + result) +{ + this->convert_to(result); +} +#endif + + template template void Dense::convert_impl(Coo* result) const @@ -1519,7 +1548,8 @@ template void gather_mixed_real_complex(Function fn, LinOp* out) { #ifdef GINKGO_MIXED_PRECISION - run>(out, fn); + run, + next_precision_with_half>>(out, fn); #else precision_dispatch(fn, out); #endif @@ -2029,7 +2059,7 @@ Dense::Dense(std::shared_ptr exec, #define GKO_DECLARE_DENSE_MATRIX(_type) class Dense<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_MATRIX); } // namespace matrix diff --git a/core/matrix/diagonal.cpp b/core/matrix/diagonal.cpp index 1a442ffc789..85c5739b529 100644 --- a/core/matrix/diagonal.cpp +++ b/core/matrix/diagonal.cpp @@ -149,7 +149,7 @@ std::unique_ptr Diagonal::conj_transpose() const template void Diagonal::convert_to( - Diagonal>* result) const + Diagonal>* result) const { result->values_ = this->values_; result->set_size(this->get_size()); @@ -157,12 +157,34 @@ void Diagonal::convert_to( template -void Diagonal::move_to(Diagonal>* result) +void Diagonal::move_to( + Diagonal>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Diagonal::convert_to( + Diagonal>>* + result) const +{ + result->values_ = this->values_; + result->set_size(this->get_size()); +} + + +template +void Diagonal::move_to( + Diagonal>>* + result) +{ + this->convert_to(result); +} +#endif + + template void Diagonal::convert_to(Csr* result) const { @@ -373,7 +395,7 @@ std::unique_ptr> Diagonal::create_const( #define GKO_DECLARE_DIAGONAL_MATRIX(value_type) class Diagonal -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DIAGONAL_MATRIX); } // namespace matrix @@ -391,7 +413,7 @@ std::unique_ptr DiagonalExtractable::extract_diagonal_linop() #define GKO_DECLARE_DIAGONAL_EXTRACTABLE(value_type) \ std::unique_ptr \ DiagonalExtractable::extract_diagonal_linop() const -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_EXTRACTABLE); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DIAGONAL_EXTRACTABLE); } // namespace gko diff --git a/core/matrix/ell.cpp b/core/matrix/ell.cpp index 600c2ceb9d2..eafd9fa9cad 100644 --- a/core/matrix/ell.cpp +++ b/core/matrix/ell.cpp @@ -154,7 +154,7 @@ void Ell::apply_impl(const LinOp* alpha, const LinOp* b, template void Ell::convert_to( - Ell, IndexType>* result) const + Ell, IndexType>* result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -166,12 +166,36 @@ void Ell::convert_to( template void Ell::move_to( - Ell, IndexType>* result) + Ell, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Ell::convert_to( + Ell>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->num_stored_elements_per_row_ = this->num_stored_elements_per_row_; + result->stride_ = this->stride_; + result->set_size(this->get_size()); +} + + +template +void Ell::move_to( + Ell>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + template void Ell::convert_to(Dense* result) const { @@ -401,7 +425,7 @@ Ell::Ell(std::shared_ptr exec, #define GKO_DECLARE_ELL_MATRIX(ValueType, IndexType) \ class Ell -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_ELL_MATRIX); } // namespace matrix diff --git a/core/matrix/fbcsr.cpp b/core/matrix/fbcsr.cpp index 8ed9b117280..f1612be10e0 100644 --- a/core/matrix/fbcsr.cpp +++ b/core/matrix/fbcsr.cpp @@ -145,7 +145,7 @@ void Fbcsr::apply_impl(const LinOp* alpha, const LinOp* b, template void Fbcsr::convert_to( - Fbcsr, IndexType>* result) const + Fbcsr, IndexType>* const result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -158,12 +158,37 @@ void Fbcsr::convert_to( template void Fbcsr::move_to( - Fbcsr, IndexType>* result) + Fbcsr, IndexType>* const result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Fbcsr::convert_to( + Fbcsr>, + IndexType>* const result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->row_ptrs_ = this->row_ptrs_; + result->set_size(this->get_size()); + // block sizes are immutable except for assignment/conversion + result->bs_ = this->bs_; +} + + +template +void Fbcsr::move_to( + Fbcsr>, + IndexType>* const result) +{ + this->convert_to(result); +} +#endif + + template void Fbcsr::convert_to(Dense* result) const { @@ -474,7 +499,8 @@ Fbcsr::Fbcsr(std::shared_ptr exec, #define GKO_DECLARE_FBCSR_MATRIX(ValueType, IndexType) \ class Fbcsr -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_MATRIX); } // namespace matrix diff --git a/core/matrix/hybrid.cpp b/core/matrix/hybrid.cpp index d450a0dfc35..72137558a10 100644 --- a/core/matrix/hybrid.cpp +++ b/core/matrix/hybrid.cpp @@ -203,7 +203,7 @@ void Hybrid::apply_impl(const LinOp* alpha, template void Hybrid::convert_to( - Hybrid, IndexType>* result) const + Hybrid, IndexType>* result) const { this->ell_->convert_to(result->ell_); this->coo_->convert_to(result->coo_); @@ -216,12 +216,37 @@ void Hybrid::convert_to( template void Hybrid::move_to( - Hybrid, IndexType>* result) + Hybrid, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Hybrid::convert_to( + Hybrid>, + IndexType>* result) const +{ + this->ell_->convert_to(result->ell_.get()); + this->coo_->convert_to(result->coo_.get()); + // TODO set strategy correctly + // There is no way to correctly clone the strategy like in + // Csr::convert_to + result->set_size(this->get_size()); +} + + +template +void Hybrid::move_to( + Hybrid>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + template void Hybrid::convert_to(Dense* result) const { @@ -418,7 +443,8 @@ Hybrid::compute_absolute() const #define GKO_DECLARE_HYBRID_MATRIX(ValueType, IndexType) \ class Hybrid -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_HYBRID_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_HYBRID_MATRIX); } // namespace matrix diff --git a/core/matrix/identity.cpp b/core/matrix/identity.cpp index 7e035be82a3..ecd93b6f959 100644 --- a/core/matrix/identity.cpp +++ b/core/matrix/identity.cpp @@ -83,9 +83,9 @@ std::unique_ptr> Identity::create( #define GKO_DECLARE_IDENTITY_MATRIX(_type) class Identity<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDENTITY_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_IDENTITY_MATRIX); #define GKO_DECLARE_IDENTITY_FACTORY(_type) class IdentityFactory<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDENTITY_FACTORY); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_IDENTITY_FACTORY); } // namespace matrix diff --git a/core/matrix/permutation.cpp b/core/matrix/permutation.cpp index 0fe7ba2b2ce..b6b9ff2d7e4 100644 --- a/core/matrix/permutation.cpp +++ b/core/matrix/permutation.cpp @@ -267,8 +267,11 @@ void dispatch_dense(const LinOp* op, Functor fn) { using matrix::Dense; using std::complex; - run, std::complex>(op, - fn); + run, +#endif + double, float, std::complex, std::complex>(op, fn); } diff --git a/core/matrix/row_gatherer.cpp b/core/matrix/row_gatherer.cpp index fecc60a0ca9..56fcbf93d88 100644 --- a/core/matrix/row_gatherer.cpp +++ b/core/matrix/row_gatherer.cpp @@ -4,6 +4,7 @@ #include "ginkgo/core/matrix/row_gatherer.hpp" +#include #include #include "core/base/dispatch_helper.hpp" @@ -64,7 +65,11 @@ RowGatherer::create_const( template void RowGatherer::apply_impl(const LinOp* in, LinOp* out) const { - run, std::complex>( + run, +#endif + float, double, std::complex, std::complex>( in, [&](auto gather) { gather->row_gather(&row_idxs_, out); }); } @@ -72,7 +77,11 @@ template void RowGatherer::apply_impl(const LinOp* alpha, const LinOp* in, const LinOp* beta, LinOp* out) const { - run, std::complex>( + run, +#endif + float, double, std::complex, std::complex>( in, [&](auto gather) { gather->row_gather(alpha, &row_idxs_, beta, out); }); } diff --git a/core/matrix/scaled_permutation.cpp b/core/matrix/scaled_permutation.cpp index 0f295d6b5be..bbe353e543e 100644 --- a/core/matrix/scaled_permutation.cpp +++ b/core/matrix/scaled_permutation.cpp @@ -174,7 +174,7 @@ void ScaledPermutation::write( #define GKO_DECLARE_SCALED_PERMUTATION_MATRIX(ValueType, IndexType) \ class ScaledPermutation -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SCALED_PERMUTATION_MATRIX); diff --git a/core/matrix/sellp.cpp b/core/matrix/sellp.cpp index a4787e758bf..bd81b08bada 100644 --- a/core/matrix/sellp.cpp +++ b/core/matrix/sellp.cpp @@ -176,7 +176,7 @@ void Sellp::apply_impl(const LinOp* alpha, const LinOp* b, template void Sellp::convert_to( - Sellp, IndexType>* result) const + Sellp, IndexType>* result) const { result->values_ = this->values_; result->col_idxs_ = this->col_idxs_; @@ -190,12 +190,38 @@ void Sellp::convert_to( template void Sellp::move_to( - Sellp, IndexType>* result) + Sellp, IndexType>* result) { this->convert_to(result); } +#if GINKGO_ENABLE_HALF +template +void Sellp::convert_to( + Sellp>, + IndexType>* result) const +{ + result->values_ = this->values_; + result->col_idxs_ = this->col_idxs_; + result->slice_lengths_ = this->slice_lengths_; + result->slice_sets_ = this->slice_sets_; + result->slice_size_ = this->slice_size_; + result->stride_factor_ = this->stride_factor_; + result->set_size(this->get_size()); +} + + +template +void Sellp::move_to( + Sellp>, + IndexType>* result) +{ + this->convert_to(result); +} +#endif + + template void Sellp::convert_to(Dense* result) const { @@ -363,7 +389,8 @@ Sellp::compute_absolute() const #define GKO_DECLARE_SELLP_MATRIX(ValueType, IndexType) \ class Sellp -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_MATRIX); } // namespace matrix diff --git a/core/matrix/sparsity_csr.cpp b/core/matrix/sparsity_csr.cpp index 9b8ea04da52..a4d8b2fa281 100644 --- a/core/matrix/sparsity_csr.cpp +++ b/core/matrix/sparsity_csr.cpp @@ -346,7 +346,8 @@ bool SparsityCsr::is_sorted_by_column_index() const #define GKO_DECLARE_SPARSITY_MATRIX(ValueType, IndexType) \ class SparsityCsr -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SPARSITY_MATRIX); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SPARSITY_MATRIX); } // namespace matrix diff --git a/dpcpp/matrix/coo_kernels.dp.cpp b/dpcpp/matrix/coo_kernels.dp.cpp index 595af92b33b..7e8a9acfac3 100644 --- a/dpcpp/matrix/coo_kernels.dp.cpp +++ b/dpcpp/matrix/coo_kernels.dp.cpp @@ -259,7 +259,8 @@ void spmv(std::shared_ptr exec, spmv2(exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV_KERNEL); template @@ -274,7 +275,7 @@ void advanced_spmv(std::shared_ptr exec, advanced_spmv2(exec, alpha, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); @@ -311,7 +312,8 @@ void spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV2_KERNEL); template @@ -350,7 +352,7 @@ void advanced_spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); diff --git a/dpcpp/matrix/csr_kernels.dp.cpp b/dpcpp/matrix/csr_kernels.dp.cpp index 4dce0aa6ac2..efcb9b7f470 100644 --- a/dpcpp/matrix/csr_kernels.dp.cpp +++ b/dpcpp/matrix/csr_kernels.dp.cpp @@ -31,6 +31,7 @@ #include "dpcpp/base/dim3.dp.hpp" #include "dpcpp/base/dpct.hpp" #include "dpcpp/base/helper.hpp" +#include "dpcpp/base/onemkl_bindings.hpp" #include "dpcpp/components/atomic.dp.hpp" #include "dpcpp/components/cooperative_groups.dp.hpp" #include "dpcpp/components/reduction.dp.hpp" @@ -266,7 +267,7 @@ void abstract_spmv( { using arithmetic_type = typename output_accessor::arithmetic_type; using output_type = typename output_accessor::storage_type; - const arithmetic_type scale_factor = alpha[0]; + const auto scale_factor = static_cast(alpha[0]); spmv_kernel( nwarps, num_rows, val, col_idxs, row_ptrs, srow, b, c, [&scale_factor](const arithmetic_type& x) { @@ -479,8 +480,8 @@ void abstract_merge_path_spmv( sycl::nd_item<3> item_ct1, IndexType* shared_row_ptrs) { using type = typename output_accessor::arithmetic_type; - const type alpha_val = alpha[0]; - const type beta_val = beta[0]; + const type alpha_val = static_cast(alpha[0]); + const type beta_val = static_cast(beta[0]); merge_path_spmv( num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out, [&alpha_val](const type& x) { return alpha_val * x; }, @@ -566,7 +567,7 @@ void abstract_reduce( uninitialized_array& tmp_ind, uninitialized_array& tmp_val) { - const arithmetic_type alpha_val = alpha[0]; + const auto alpha_val = static_cast(alpha[0]); merge_path_reduce( nwarps, last_val, last_row, c, [&alpha_val](const arithmetic_type& x) { return alpha_val * x; }, @@ -694,8 +695,8 @@ void abstract_classical_spmv( acc::range c, sycl::nd_item<3> item_ct1) { using type = typename output_accessor::arithmetic_type; - const type alpha_val = alpha[0]; - const type beta_val = beta[0]; + const type alpha_val = static_cast(alpha[0]); + const type beta_val = static_cast(beta[0]); device_classical_spmv( num_rows, val, col_idxs, row_ptrs, b, c, [&alpha_val, &beta_val](const type& x, const type& y) { @@ -1393,8 +1394,9 @@ bool try_general_sparselib_spmv(std::shared_ptr exec, const ValueType host_beta, matrix::Dense* c) { - bool try_sparselib = !is_complex(); - if (try_sparselib) { + constexpr bool try_sparselib = + !is_complex() && !std::is_same::value; + if constexpr (try_sparselib) { oneapi::mkl::sparse::matrix_handle_t mat_handle; oneapi::mkl::sparse::init_matrix_handle(&mat_handle); oneapi::mkl::sparse::set_csr_data( @@ -1532,7 +1534,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SPMV_KERNEL); @@ -1604,7 +1606,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL); @@ -1684,7 +1686,7 @@ void calculate_nonzeros_per_row_in_span( row_nnz->get_data()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); @@ -1696,7 +1698,7 @@ void calculate_nonzeros_per_row_in_index_set( const gko::index_set& col_index_set, IndexType* row_nnz) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL); @@ -1723,7 +1725,7 @@ void compute_submatrix(std::shared_ptr exec, result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); @@ -1735,7 +1737,7 @@ void compute_submatrix_from_index_set( const gko::index_set& col_index_set, matrix::Csr* result) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL); @@ -1997,7 +1999,8 @@ void spgemm(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEMM_KERNEL); template @@ -2130,7 +2133,7 @@ void advanced_spgemm(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); @@ -2216,7 +2219,8 @@ void spgeam(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEAM_KERNEL); template @@ -2237,7 +2241,7 @@ void fill_in_dense(std::shared_ptr exec, result->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); @@ -2247,7 +2251,7 @@ void convert_to_fbcsr(std::shared_ptr exec, array& row_ptrs, array& col_idxs, array& values) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); @@ -2310,7 +2314,8 @@ void transpose(std::shared_ptr exec, generic_transpose(exec, orig, trans); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_TRANSPOSE_KERNEL); template @@ -2321,7 +2326,7 @@ void conj_transpose(std::shared_ptr exec, generic_transpose(exec, orig, trans); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); @@ -2347,7 +2352,7 @@ void inv_symm_permute(std::shared_ptr exec, permuted->get_col_idxs(), permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); @@ -2374,7 +2379,7 @@ void inv_nonsymm_permute(std::shared_ptr exec, permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); @@ -2400,7 +2405,7 @@ void row_permute(std::shared_ptr exec, row_permuted->get_col_idxs(), row_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); @@ -2426,7 +2431,7 @@ void inv_row_permute(std::shared_ptr exec, row_permuted->get_col_idxs(), row_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); @@ -2452,7 +2457,7 @@ void inv_symm_scale_permute(std::shared_ptr exec, permuted->get_col_idxs(), permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); @@ -2482,7 +2487,7 @@ void inv_nonsymm_scale_permute(std::shared_ptr exec, permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); @@ -2508,7 +2513,7 @@ void row_scale_permute(std::shared_ptr exec, row_permuted->get_col_idxs(), row_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); @@ -2534,7 +2539,7 @@ void inv_row_scale_permute(std::shared_ptr exec, row_permuted->get_col_idxs(), row_permuted->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); @@ -2592,7 +2597,7 @@ void sort_by_column_index(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); @@ -2624,7 +2629,7 @@ void is_sorted_by_column_index( *is_sorted = get_element(is_sorted_device_array, 0); }; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); @@ -2648,7 +2653,8 @@ void extract_diagonal(std::shared_ptr exec, orig_row_ptrs, orig_col_idxs, diag_values); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_EXTRACT_DIAGONAL); template @@ -2672,7 +2678,7 @@ void check_diagonal_entries_exist(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); @@ -2695,7 +2701,7 @@ void add_scaled_identity(std::shared_ptr exec, mtx->get_const_col_idxs(), mtx->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/dpcpp/matrix/dense_kernels.dp.cpp b/dpcpp/matrix/dense_kernels.dp.cpp index 04f3229eaed..c6eb163bc7d 100644 --- a/dpcpp/matrix/dense_kernels.dp.cpp +++ b/dpcpp/matrix/dense_kernels.dp.cpp @@ -177,7 +177,7 @@ void compute_dot_dispatch(std::shared_ptr exec, compute_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); @@ -192,7 +192,7 @@ void compute_conj_dot_dispatch(std::shared_ptr exec, compute_conj_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); @@ -206,7 +206,7 @@ void compute_norm2_dispatch(std::shared_ptr exec, compute_norm2(exec, x, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); @@ -217,21 +217,26 @@ void simple_apply(std::shared_ptr exec, matrix::Dense* c) { using namespace oneapi::mkl; - if (b->get_stride() != 0 && c->get_stride() != 0) { - if (a->get_size()[1] > 0) { - oneapi::mkl::blas::row_major::gemm( - *exec->get_queue(), transpose::nontrans, transpose::nontrans, - c->get_size()[0], c->get_size()[1], a->get_size()[1], - one(), a->get_const_values(), a->get_stride(), - b->get_const_values(), b->get_stride(), zero(), - c->get_values(), c->get_stride()); - } else { - dense::fill(exec, c, zero()); + if constexpr (onemkl::is_supported::value) { + if (b->get_stride() != 0 && c->get_stride() != 0) { + if (a->get_size()[1] > 0) { + oneapi::mkl::blas::row_major::gemm( + *exec->get_queue(), transpose::nontrans, + transpose::nontrans, c->get_size()[0], c->get_size()[1], + a->get_size()[1], one(), a->get_const_values(), + a->get_stride(), b->get_const_values(), b->get_stride(), + zero(), c->get_values(), c->get_stride()); + } else { + dense::fill(exec, c, zero()); + } } + } else { + GKO_NOT_IMPLEMENTED; } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); template @@ -241,23 +246,28 @@ void apply(std::shared_ptr exec, const matrix::Dense* beta, matrix::Dense* c) { using namespace oneapi::mkl; - if (b->get_stride() != 0 && c->get_stride() != 0) { - if (a->get_size()[1] > 0) { - oneapi::mkl::blas::row_major::gemm( - *exec->get_queue(), transpose::nontrans, transpose::nontrans, - c->get_size()[0], c->get_size()[1], a->get_size()[1], - exec->copy_val_to_host(alpha->get_const_values()), - a->get_const_values(), a->get_stride(), b->get_const_values(), - b->get_stride(), - exec->copy_val_to_host(beta->get_const_values()), - c->get_values(), c->get_stride()); - } else { - dense::scale(exec, beta, c); + if constexpr (onemkl::is_supported::value) { + if (b->get_stride() != 0 && c->get_stride() != 0) { + if (a->get_size()[1] > 0) { + oneapi::mkl::blas::row_major::gemm( + *exec->get_queue(), transpose::nontrans, + transpose::nontrans, c->get_size()[0], c->get_size()[1], + a->get_size()[1], + exec->copy_val_to_host(alpha->get_const_values()), + a->get_const_values(), a->get_stride(), + b->get_const_values(), b->get_stride(), + exec->copy_val_to_host(beta->get_const_values()), + c->get_values(), c->get_stride()); + } else { + dense::scale(exec, beta, c); + } } + } else { + GKO_NOT_IMPLEMENTED; } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL); template @@ -292,7 +302,7 @@ void convert_to_coo(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); @@ -326,7 +336,7 @@ void convert_to_csr(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); @@ -365,7 +375,7 @@ void convert_to_ell(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); @@ -375,7 +385,7 @@ void convert_to_fbcsr(std::shared_ptr exec, matrix::Fbcsr* result) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); @@ -385,7 +395,7 @@ void count_nonzero_blocks_per_row(std::shared_ptr exec, int bs, IndexType* result) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL); @@ -441,7 +451,7 @@ void convert_to_hybrid(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); @@ -484,7 +494,7 @@ void convert_to_sellp(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); @@ -516,7 +526,7 @@ void convert_to_sparsity_csr(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); @@ -538,7 +548,8 @@ void transpose(std::shared_ptr exec, queue, orig, trans); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); template @@ -565,7 +576,8 @@ void conj_transpose(std::shared_ptr exec, trans->get_values(), trans->get_stride()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); } // namespace dense diff --git a/dpcpp/matrix/diagonal_kernels.dp.cpp b/dpcpp/matrix/diagonal_kernels.dp.cpp index 2b63138abbe..272a6dbd581 100644 --- a/dpcpp/matrix/diagonal_kernels.dp.cpp +++ b/dpcpp/matrix/diagonal_kernels.dp.cpp @@ -82,7 +82,7 @@ void apply_to_csr(std::shared_ptr exec, inverse); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); diff --git a/dpcpp/matrix/ell_kernels.dp.cpp b/dpcpp/matrix/ell_kernels.dp.cpp index a97cb602d52..b33ed28b12d 100644 --- a/dpcpp/matrix/ell_kernels.dp.cpp +++ b/dpcpp/matrix/ell_kernels.dp.cpp @@ -415,7 +415,7 @@ void spmv(std::shared_ptr exec, exec, num_worker_per_row, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_SPMV_KERNEL); @@ -451,7 +451,7 @@ void advanced_spmv(std::shared_ptr exec, exec, num_worker_per_row, a, b, c, alpha, beta); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); diff --git a/dpcpp/matrix/fbcsr_kernels.dp.cpp b/dpcpp/matrix/fbcsr_kernels.dp.cpp index e9eb02f5fb2..7d53b862d67 100644 --- a/dpcpp/matrix/fbcsr_kernels.dp.cpp +++ b/dpcpp/matrix/fbcsr_kernels.dp.cpp @@ -32,7 +32,8 @@ void spmv(std::shared_ptr exec, const matrix::Dense* b, matrix::Dense* c) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_SPMV_KERNEL); template @@ -43,7 +44,7 @@ void advanced_spmv(std::shared_ptr exec, const matrix::Dense* beta, matrix::Dense* c) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); @@ -54,7 +55,7 @@ void fill_in_matrix_data(std::shared_ptr exec, array& col_idxs, array& values) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); @@ -63,7 +64,7 @@ void fill_in_dense(std::shared_ptr exec, const matrix::Fbcsr* source, matrix::Dense* result) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); @@ -73,7 +74,7 @@ void convert_to_csr(const std::shared_ptr exec, matrix::Csr* result) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); @@ -82,7 +83,7 @@ void transpose(std::shared_ptr exec, const matrix::Fbcsr* orig, matrix::Fbcsr* trans) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); @@ -92,7 +93,7 @@ void conj_transpose(std::shared_ptr exec, matrix::Fbcsr* trans) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); @@ -102,7 +103,7 @@ void is_sorted_by_column_index( const matrix::Fbcsr* to_check, bool* is_sorted) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); @@ -111,7 +112,7 @@ void sort_by_column_index(const std::shared_ptr exec, matrix::Fbcsr* to_sort) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); @@ -120,7 +121,7 @@ void extract_diagonal(std::shared_ptr exec, const matrix::Fbcsr* orig, matrix::Diagonal* diag) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); diff --git a/dpcpp/matrix/sellp_kernels.dp.cpp b/dpcpp/matrix/sellp_kernels.dp.cpp index 9c0fe717e8a..e83e8f2ce1a 100644 --- a/dpcpp/matrix/sellp_kernels.dp.cpp +++ b/dpcpp/matrix/sellp_kernels.dp.cpp @@ -119,7 +119,8 @@ void spmv(std::shared_ptr exec, b->get_const_values(), c->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_SPMV_KERNEL); template @@ -142,7 +143,7 @@ void advanced_spmv(std::shared_ptr exec, beta->get_const_values(), c->get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); diff --git a/dpcpp/matrix/sparsity_csr_kernels.dp.cpp b/dpcpp/matrix/sparsity_csr_kernels.dp.cpp index 66c57ac5b35..0e076794ac8 100644 --- a/dpcpp/matrix/sparsity_csr_kernels.dp.cpp +++ b/dpcpp/matrix/sparsity_csr_kernels.dp.cpp @@ -57,11 +57,11 @@ void device_classical_spmv(const size_type num_rows, const auto subrow = thread::get_subwarp_num_flat(item_ct1); const auto subid = subgroup_tile.thread_rank(); const IndexType column_id = item_ct1.get_group(1); - const arithmetic_type value = static_cast(val[0]); + const auto value = static_cast(val[0]); auto row = thread::get_subwarp_id_flat(item_ct1); for (; row < num_rows; row += subrow) { const auto ind_end = row_ptrs[row + 1]; - arithmetic_type temp_val = zero(); + auto temp_val = zero(); for (auto ind = row_ptrs[row] + subid; ind < ind_end; ind += subgroup_size) { temp_val += value * b(col_idxs[ind], column_id); @@ -237,7 +237,7 @@ void spmv(std::shared_ptr exec, syn::value_list(), syn::type_list<>(), exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); @@ -255,7 +255,7 @@ void advanced_spmv(std::shared_ptr exec, syn::value_list(), syn::type_list<>(), exec, a, b, c, alpha, beta); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_ADVANCED_SPMV_KERNEL); @@ -265,7 +265,7 @@ void transpose(std::shared_ptr exec, matrix::SparsityCsr* trans) GKO_NOT_IMPLEMENTED; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); @@ -290,7 +290,7 @@ void sort_by_column_index(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); @@ -324,7 +324,7 @@ void is_sorted_by_column_index( cpu_array = gpu_array; }; -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_IS_SORTED_BY_COLUMN_INDEX); diff --git a/include/ginkgo/core/base/precision_dispatch.hpp b/include/ginkgo/core/base/precision_dispatch.hpp index 8875b7d46f3..ad31a6b19e8 100644 --- a/include/ginkgo/core/base/precision_dispatch.hpp +++ b/include/ginkgo/core/base/precision_dispatch.hpp @@ -48,13 +48,15 @@ make_temporary_conversion(Ptr&& matrix) { using Pointee = detail::pointee; using Dense = matrix::Dense; - using NextDense = matrix::Dense>; + using NextDense = matrix::Dense>; + using NextNextDense = matrix::Dense< + next_precision_with_half>>; using MaybeConstDense = std::conditional_t::value, const Dense, Dense>; auto result = detail::temporary_conversion< - MaybeConstDense>::template create(matrix); + MaybeConstDense>::template create(matrix); if (!result) { - GKO_NOT_SUPPORTED(*matrix); + GKO_NOT_SUPPORTED(matrix); } return result; } @@ -226,23 +228,26 @@ void mixed_precision_dispatch(Function fn, const LinOp* in, LinOp* out) { #ifdef GINKGO_MIXED_PRECISION using fst_type = matrix::Dense; - using snd_type = matrix::Dense>; - if (auto dense_in = dynamic_cast(in)) { + using snd_type = matrix::Dense>; + using trd_type = matrix::Dense< + next_precision_with_half>>; + auto dispatch_out_vector = [&](auto dense_in) { if (auto dense_out = dynamic_cast(out)) { fn(dense_in, dense_out); } else if (auto dense_out = dynamic_cast(out)) { fn(dense_in, dense_out); - } else { - GKO_NOT_SUPPORTED(out); - } - } else if (auto dense_in = dynamic_cast(in)) { - if (auto dense_out = dynamic_cast(out)) { - fn(dense_in, dense_out); - } else if (auto dense_out = dynamic_cast(out)) { + } else if (auto dense_out = dynamic_cast(out)) { fn(dense_in, dense_out); } else { GKO_NOT_SUPPORTED(out); } + }; + if (auto dense_in = dynamic_cast(in)) { + dispatch_out_vector(dense_in); + } else if (auto dense_in = dynamic_cast(in)) { + dispatch_out_vector(dense_in); + } else if (auto dense_in = dynamic_cast(in)) { + dispatch_out_vector(dense_in); } else { GKO_NOT_SUPPORTED(in); } diff --git a/include/ginkgo/core/matrix/coo.hpp b/include/ginkgo/core/matrix/coo.hpp index 9373107df69..a0edf5aa862 100644 --- a/include/ginkgo/core/matrix/coo.hpp +++ b/include/ginkgo/core/matrix/coo.hpp @@ -47,15 +47,21 @@ class Hybrid; * @ingroup LinOp */ template -class Coo : public EnableLinOp>, - public ConvertibleTo, IndexType>>, - public ConvertibleTo>, - public ConvertibleTo>, - public DiagonalExtractable, - public ReadableFromMatrixData, - public WritableToMatrixData, - public EnableAbsoluteComputation< - remove_complex>> { +class Coo + : public EnableLinOp>, + public ConvertibleTo, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Coo>, + IndexType>>, +#endif + public ConvertibleTo>, + public ConvertibleTo>, + public DiagonalExtractable, + public ReadableFromMatrixData, + public WritableToMatrixData, + public EnableAbsoluteComputation< + remove_complex>> { friend class EnablePolymorphicObject; friend class Csr; friend class Dense; @@ -66,8 +72,10 @@ class Coo : public EnableLinOp>, public: using EnableLinOp::convert_to; using EnableLinOp::move_to; - using ConvertibleTo, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + using ConvertibleTo< + Coo, IndexType>>::convert_to; + using ConvertibleTo< + Coo, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -80,12 +88,33 @@ class Coo : public EnableLinOp>, using device_mat_data = device_matrix_data; using absolute_type = remove_complex; - friend class Coo, IndexType>; + friend class Coo, IndexType>; + + void convert_to(Coo, IndexType>* result) + const override; + + void move_to( + Coo, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Coo< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Coo>, + IndexType>>::convert_to; + using ConvertibleTo< + Coo>, + IndexType>>::move_to; void convert_to( - Coo, IndexType>* result) const override; + Coo>, + IndexType>* result) const override; - void move_to(Coo, IndexType>* result) override; + void move_to( + Coo>, + IndexType>* result) override; +#endif void convert_to(Csr* other) const override; diff --git a/include/ginkgo/core/matrix/csr.hpp b/include/ginkgo/core/matrix/csr.hpp index f27fe12a934..2f66683085f 100644 --- a/include/ginkgo/core/matrix/csr.hpp +++ b/include/ginkgo/core/matrix/csr.hpp @@ -98,23 +98,29 @@ void strategy_rebuild_helper(Csr* result); * @ingroup LinOp */ template -class Csr : public EnableLinOp>, - public ConvertibleTo, IndexType>>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public DiagonalExtractable, - public ReadableFromMatrixData, - public WritableToMatrixData, - public Transposable, - public Permutable, - public EnableAbsoluteComputation< - remove_complex>>, - public ScaledIdentityAddable { +class Csr + : public EnableLinOp>, + public ConvertibleTo, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Csr>, + IndexType>>, +#endif + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public DiagonalExtractable, + public ReadableFromMatrixData, + public WritableToMatrixData, + public Transposable, + public Permutable, + public EnableAbsoluteComputation< + remove_complex>>, + public ScaledIdentityAddable { friend class EnablePolymorphicObject; friend class Coo; friend class Dense; @@ -130,8 +136,10 @@ class Csr : public EnableLinOp>, public: using EnableLinOp::convert_to; using EnableLinOp::move_to; - using ConvertibleTo, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + using ConvertibleTo< + Csr, IndexType>>::convert_to; + using ConvertibleTo< + Csr, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -688,12 +696,33 @@ class Csr : public EnableLinOp>, index_type max_length_per_row_; }; - friend class Csr, IndexType>; + friend class Csr, IndexType>; + + void convert_to(Csr, IndexType>* result) + const override; + + void move_to( + Csr, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Csr< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Csr>, + IndexType>>::convert_to; + using ConvertibleTo< + Csr>, + IndexType>>::move_to; void convert_to( - Csr, IndexType>* result) const override; + Csr>, + IndexType>* result) const override; - void move_to(Csr, IndexType>* result) override; + void move_to( + Csr>, + IndexType>* result) override; +#endif void convert_to(Dense* other) const override; diff --git a/include/ginkgo/core/matrix/dense.hpp b/include/ginkgo/core/matrix/dense.hpp index bccd3adcd54..9ae96ca46d6 100644 --- a/include/ginkgo/core/matrix/dense.hpp +++ b/include/ginkgo/core/matrix/dense.hpp @@ -87,7 +87,11 @@ class SparsityCsr; template class Dense : public EnableLinOp>, - public ConvertibleTo>>, + public ConvertibleTo>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Dense>>>, +#endif public ConvertibleTo>, public ConvertibleTo>, public ConvertibleTo>, @@ -135,8 +139,8 @@ class Dense public: using EnableLinOp::convert_to; using EnableLinOp::move_to; - using ConvertibleTo>>::convert_to; - using ConvertibleTo>>::move_to; + using ConvertibleTo>>::convert_to; + using ConvertibleTo>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -276,11 +280,29 @@ class Dense return other->create_const_view_of_impl(); } - friend class Dense>; + friend class Dense>; - void convert_to(Dense>* result) const override; + void convert_to( + Dense>* result) const override; - void move_to(Dense>* result) override; + void move_to(Dense>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Dense< + previous_precision_with_half>>; + using ConvertibleTo>>>::convert_to; + using ConvertibleTo>>>::move_to; + + void convert_to( + Dense>>* + result) const override; + + void move_to( + Dense>>* + result) override; +#endif void convert_to(Coo* result) const override; diff --git a/include/ginkgo/core/matrix/diagonal.hpp b/include/ginkgo/core/matrix/diagonal.hpp index 56906a4d96f..3b11399138b 100644 --- a/include/ginkgo/core/matrix/diagonal.hpp +++ b/include/ginkgo/core/matrix/diagonal.hpp @@ -41,7 +41,11 @@ class Diagonal : public EnableLinOp>, public ConvertibleTo>, public ConvertibleTo>, - public ConvertibleTo>>, + public ConvertibleTo>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo>>>, +#endif public Transposable, public WritableToMatrixData, public WritableToMatrixData, @@ -60,8 +64,9 @@ class Diagonal using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; - using ConvertibleTo>>::convert_to; - using ConvertibleTo>>::move_to; + using ConvertibleTo< + Diagonal>>::convert_to; + using ConvertibleTo>>::move_to; using value_type = ValueType; using index_type = int64; @@ -71,15 +76,34 @@ class Diagonal using device_mat_data32 = device_matrix_data; using absolute_type = remove_complex; - friend class Diagonal>; + friend class Diagonal>; std::unique_ptr transpose() const override; std::unique_ptr conj_transpose() const override; - void convert_to(Diagonal>* result) const override; + void convert_to( + Diagonal>* result) const override; - void move_to(Diagonal>* result) override; + void move_to( + Diagonal>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Diagonal< + previous_precision_with_half>>; + using ConvertibleTo>>>::convert_to; + using ConvertibleTo>>>::move_to; + + void convert_to( + Diagonal>>* + result) const override; + + void move_to( + Diagonal>>* + result) override; +#endif void convert_to(Csr* result) const override; diff --git a/include/ginkgo/core/matrix/ell.hpp b/include/ginkgo/core/matrix/ell.hpp index 37f4c0e7f55..adbd3505855 100644 --- a/include/ginkgo/core/matrix/ell.hpp +++ b/include/ginkgo/core/matrix/ell.hpp @@ -49,28 +49,36 @@ class Hybrid; * @ingroup LinOp */ template -class Ell : public EnableLinOp>, - public ConvertibleTo, IndexType>>, - public ConvertibleTo>, - public ConvertibleTo>, - public DiagonalExtractable, - public ReadableFromMatrixData, - public WritableToMatrixData, - public EnableAbsoluteComputation< - remove_complex>> { +class Ell + : public EnableLinOp>, + public ConvertibleTo, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Ell>, + IndexType>>, +#endif + public ConvertibleTo>, + public ConvertibleTo>, + public DiagonalExtractable, + public ReadableFromMatrixData, + public WritableToMatrixData, + public EnableAbsoluteComputation< + remove_complex>> { friend class EnablePolymorphicObject; friend class Dense; friend class Coo; friend class Csr; friend class Ell, IndexType>; - friend class Ell, IndexType>; + friend class Ell, IndexType>; friend class Hybrid; public: using EnableLinOp::convert_to; using EnableLinOp::move_to; - using ConvertibleTo, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + using ConvertibleTo< + Ell, IndexType>>::convert_to; + using ConvertibleTo< + Ell, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -83,10 +91,31 @@ class Ell : public EnableLinOp>, using device_mat_data = device_matrix_data; using absolute_type = remove_complex; + void convert_to(Ell, IndexType>* result) + const override; + + void move_to( + Ell, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Ell< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Ell>, + IndexType>>::convert_to; + using ConvertibleTo< + Ell>, + IndexType>>::move_to; + void convert_to( - Ell, IndexType>* result) const override; + Ell>, + IndexType>* result) const override; - void move_to(Ell, IndexType>* result) override; + void move_to( + Ell>, + IndexType>* result) override; +#endif void convert_to(Dense* other) const override; diff --git a/include/ginkgo/core/matrix/fbcsr.hpp b/include/ginkgo/core/matrix/fbcsr.hpp index ce327e7e8a0..283807b242c 100644 --- a/include/ginkgo/core/matrix/fbcsr.hpp +++ b/include/ginkgo/core/matrix/fbcsr.hpp @@ -96,17 +96,24 @@ inline IndexType get_num_blocks(const int block_size, const IndexType size) * @ingroup LinOp */ template -class Fbcsr : public EnableLinOp>, - public ConvertibleTo, IndexType>>, - public ConvertibleTo>, - public ConvertibleTo>, - public ConvertibleTo>, - public DiagonalExtractable, - public ReadableFromMatrixData, - public WritableToMatrixData, - public Transposable, - public EnableAbsoluteComputation< - remove_complex>> { +class Fbcsr + : public EnableLinOp>, + public ConvertibleTo< + Fbcsr, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Fbcsr>, + IndexType>>, +#endif + public ConvertibleTo>, + public ConvertibleTo>, + public ConvertibleTo>, + public DiagonalExtractable, + public ReadableFromMatrixData, + public WritableToMatrixData, + public Transposable, + public EnableAbsoluteComputation< + remove_complex>> { friend class EnablePolymorphicObject; friend class Csr; friend class Dense; @@ -136,8 +143,9 @@ class Fbcsr : public EnableLinOp>, using EnableLinOp>::convert_to; using ConvertibleTo< - Fbcsr, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + Fbcsr, IndexType>>::convert_to; + using ConvertibleTo< + Fbcsr, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -145,12 +153,33 @@ class Fbcsr : public EnableLinOp>, using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; - friend class Fbcsr, IndexType>; + friend class Fbcsr, IndexType>; + + void convert_to(Fbcsr, IndexType>* + result) const override; + + void move_to( + Fbcsr, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Fbcsr< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Fbcsr>, + IndexType>>::convert_to; + using ConvertibleTo< + Fbcsr>, + IndexType>>::move_to; void convert_to( - Fbcsr, IndexType>* result) const override; + Fbcsr>, + IndexType>* result) const override; - void move_to(Fbcsr, IndexType>* result) override; + void move_to( + Fbcsr>, + IndexType>* result) override; +#endif void convert_to(Dense* other) const override; diff --git a/include/ginkgo/core/matrix/hybrid.hpp b/include/ginkgo/core/matrix/hybrid.hpp index 5e995cb0ba0..24cb3ed26c7 100644 --- a/include/ginkgo/core/matrix/hybrid.hpp +++ b/include/ginkgo/core/matrix/hybrid.hpp @@ -41,7 +41,13 @@ class Csr; template class Hybrid : public EnableLinOp>, - public ConvertibleTo, IndexType>>, + public ConvertibleTo< + Hybrid, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Hybrid>, + IndexType>>, +#endif public ConvertibleTo>, public ConvertibleTo>, public DiagonalExtractable, @@ -59,8 +65,9 @@ class Hybrid using EnableLinOp::convert_to; using EnableLinOp::move_to; using ConvertibleTo< - Hybrid, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + Hybrid, IndexType>>::convert_to; + using ConvertibleTo< + Hybrid, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -355,12 +362,33 @@ class Hybrid imbalance_bounded_limit strategy_; }; - friend class Hybrid, IndexType>; + friend class Hybrid, IndexType>; + + void convert_to(Hybrid, IndexType>* + result) const override; + + void move_to(Hybrid, IndexType>* result) + override; + +#if GINKGO_ENABLE_HALF + friend class Hybrid< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Hybrid>, + IndexType>>::convert_to; + using ConvertibleTo< + Hybrid>, + IndexType>>::move_to; void convert_to( - Hybrid, IndexType>* result) const override; + Hybrid>, + IndexType>* result) const override; - void move_to(Hybrid, IndexType>* result) override; + void move_to( + Hybrid>, + IndexType>* result) override; +#endif void convert_to(Dense* other) const override; diff --git a/include/ginkgo/core/matrix/sellp.hpp b/include/ginkgo/core/matrix/sellp.hpp index e6520324030..6140a832c85 100644 --- a/include/ginkgo/core/matrix/sellp.hpp +++ b/include/ginkgo/core/matrix/sellp.hpp @@ -40,15 +40,22 @@ class Csr; * @ingroup LinOp */ template -class Sellp : public EnableLinOp>, - public ConvertibleTo, IndexType>>, - public ConvertibleTo>, - public ConvertibleTo>, - public DiagonalExtractable, - public ReadableFromMatrixData, - public WritableToMatrixData, - public EnableAbsoluteComputation< - remove_complex>> { +class Sellp + : public EnableLinOp>, + public ConvertibleTo< + Sellp, IndexType>>, +#if GINKGO_ENABLE_HALF + public ConvertibleTo< + Sellp>, + IndexType>>, +#endif + public ConvertibleTo>, + public ConvertibleTo>, + public DiagonalExtractable, + public ReadableFromMatrixData, + public WritableToMatrixData, + public EnableAbsoluteComputation< + remove_complex>> { friend class EnablePolymorphicObject; friend class Dense; friend class Csr; @@ -58,8 +65,9 @@ class Sellp : public EnableLinOp>, using EnableLinOp::convert_to; using EnableLinOp::move_to; using ConvertibleTo< - Sellp, IndexType>>::convert_to; - using ConvertibleTo, IndexType>>::move_to; + Sellp, IndexType>>::convert_to; + using ConvertibleTo< + Sellp, IndexType>>::move_to; using ConvertibleTo>::convert_to; using ConvertibleTo>::move_to; using ConvertibleTo>::convert_to; @@ -72,12 +80,33 @@ class Sellp : public EnableLinOp>, using device_mat_data = device_matrix_data; using absolute_type = remove_complex; - friend class Sellp, IndexType>; + friend class Sellp, IndexType>; + + void convert_to(Sellp, IndexType>* + result) const override; + + void move_to( + Sellp, IndexType>* result) override; + +#if GINKGO_ENABLE_HALF + friend class Sellp< + previous_precision_with_half>, + IndexType>; + using ConvertibleTo< + Sellp>, + IndexType>>::convert_to; + using ConvertibleTo< + Sellp>, + IndexType>>::move_to; void convert_to( - Sellp, IndexType>* result) const override; + Sellp>, + IndexType>* result) const override; - void move_to(Sellp, IndexType>* result) override; + void move_to( + Sellp>, + IndexType>* result) override; +#endif void convert_to(Dense* other) const override; diff --git a/omp/matrix/coo_kernels.cpp b/omp/matrix/coo_kernels.cpp index 021795d8e9c..6d4a46b7ed3 100644 --- a/omp/matrix/coo_kernels.cpp +++ b/omp/matrix/coo_kernels.cpp @@ -42,7 +42,8 @@ void spmv(std::shared_ptr exec, spmv2(exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV_KERNEL); template @@ -57,7 +58,7 @@ void advanced_spmv(std::shared_ptr exec, advanced_spmv2(exec, alpha, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); @@ -306,7 +307,8 @@ void spmv2(std::shared_ptr exec, generic_spmv2(exec, a, b, c, one()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV2_KERNEL); template @@ -319,7 +321,7 @@ void advanced_spmv2(std::shared_ptr exec, generic_spmv2(exec, a, b, c, alpha->at(0, 0)); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); diff --git a/omp/matrix/csr_kernels.cpp b/omp/matrix/csr_kernels.cpp index 87b328b1093..d9c7b9840c1 100644 --- a/omp/matrix/csr_kernels.cpp +++ b/omp/matrix/csr_kernels.cpp @@ -77,7 +77,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SPMV_KERNEL); @@ -95,8 +95,8 @@ void advanced_spmv(std::shared_ptr exec, auto row_ptrs = a->get_const_row_ptrs(); auto col_idxs = a->get_const_col_idxs(); - arithmetic_type valpha = alpha->at(0, 0); - arithmetic_type vbeta = beta->at(0, 0); + auto valpha = static_cast(alpha->at(0, 0)); + auto vbeta = static_cast(beta->at(0, 0)); const auto a_vals = acc::helper::build_const_rrm_accessor(a); @@ -118,7 +118,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL); @@ -374,7 +374,8 @@ void spgemm(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEMM_KERNEL); template @@ -490,7 +491,7 @@ void advanced_spgemm(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); @@ -540,7 +541,8 @@ void spgeam(std::shared_ptr exec, [](IndexType, IndexType) {}); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEAM_KERNEL); template @@ -563,7 +565,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); @@ -633,7 +635,7 @@ void convert_to_fbcsr(std::shared_ptr exec, std::copy(col_idx_vec.begin(), col_idx_vec.end(), col_idxs.get_data()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); @@ -692,7 +694,8 @@ void transpose(std::shared_ptr exec, [](const ValueType x) { return x; }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_TRANSPOSE_KERNEL); template @@ -704,7 +707,7 @@ void conj_transpose(std::shared_ptr exec, [](const ValueType x) { return conj(x); }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); @@ -728,7 +731,7 @@ void calculate_nonzeros_per_row_in_span( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); @@ -775,7 +778,7 @@ void calculate_nonzeros_per_row_in_index_set( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL); @@ -808,7 +811,7 @@ void compute_submatrix(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); @@ -868,7 +871,7 @@ void compute_submatrix_from_index_set( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL); @@ -881,7 +884,7 @@ void inv_symm_permute(std::shared_ptr exec, inv_nonsymm_permute(exec, perm, perm, orig, permuted); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); @@ -921,7 +924,7 @@ void inv_nonsymm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); @@ -959,7 +962,7 @@ void row_permute(std::shared_ptr exec, const IndexType* perm, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); @@ -998,7 +1001,7 @@ void inv_row_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); @@ -1011,7 +1014,7 @@ void inv_symm_scale_permute(std::shared_ptr exec, inv_nonsymm_scale_permute(exec, scale, perm, scale, perm, orig, permuted); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); @@ -1055,7 +1058,7 @@ void inv_nonsymm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); @@ -1096,7 +1099,7 @@ void row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); @@ -1137,7 +1140,7 @@ void inv_row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); @@ -1160,7 +1163,7 @@ void sort_by_column_index(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); @@ -1188,7 +1191,7 @@ void is_sorted_by_column_index( *is_sorted = local_is_sorted; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); @@ -1214,7 +1217,8 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_EXTRACT_DIAGONAL); template @@ -1241,7 +1245,7 @@ void check_diagonal_entries_exist(std::shared_ptr exec, has_all_diags = l_has_all_diags; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); @@ -1270,7 +1274,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/omp/matrix/dense_kernels.cpp b/omp/matrix/dense_kernels.cpp index d1c0f2f8949..4ca5aa0c075 100644 --- a/omp/matrix/dense_kernels.cpp +++ b/omp/matrix/dense_kernels.cpp @@ -46,7 +46,7 @@ void compute_dot_dispatch(std::shared_ptr exec, compute_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); @@ -60,7 +60,7 @@ void compute_conj_dot_dispatch(std::shared_ptr exec, compute_conj_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); @@ -73,7 +73,7 @@ void compute_norm2_dispatch(std::shared_ptr exec, compute_norm2(exec, x, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); @@ -100,7 +100,8 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); template @@ -136,7 +137,7 @@ void apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL); template @@ -168,7 +169,7 @@ void convert_to_coo(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); @@ -199,7 +200,7 @@ void convert_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); @@ -232,7 +233,7 @@ void convert_to_ell(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); @@ -280,7 +281,7 @@ void convert_to_fbcsr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); @@ -326,7 +327,7 @@ void convert_to_hybrid(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); @@ -368,7 +369,7 @@ void convert_to_sellp(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); @@ -398,7 +399,7 @@ void convert_to_sparsity_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); @@ -415,7 +416,8 @@ void transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); template @@ -431,7 +433,8 @@ void conj_transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); template @@ -461,7 +464,7 @@ void count_nonzero_blocks_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL); diff --git a/omp/matrix/diagonal_kernels.cpp b/omp/matrix/diagonal_kernels.cpp index 71363c7bc6e..c16e740dc45 100644 --- a/omp/matrix/diagonal_kernels.cpp +++ b/omp/matrix/diagonal_kernels.cpp @@ -43,7 +43,7 @@ void apply_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); diff --git a/omp/matrix/ell_kernels.cpp b/omp/matrix/ell_kernels.cpp index c35a3654b86..dc200ae0f93 100644 --- a/omp/matrix/ell_kernels.cpp +++ b/omp/matrix/ell_kernels.cpp @@ -185,7 +185,7 @@ void spmv(std::shared_ptr exec, spmv_blocked<4>(exec, a, b, c, out); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_SPMV_KERNEL); @@ -228,7 +228,7 @@ void advanced_spmv(std::shared_ptr exec, spmv_blocked<4>(exec, a, b, c, out); } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); diff --git a/omp/matrix/fbcsr_kernels.cpp b/omp/matrix/fbcsr_kernels.cpp index d17d47a7467..14dcb1db77a 100644 --- a/omp/matrix/fbcsr_kernels.cpp +++ b/omp/matrix/fbcsr_kernels.cpp @@ -74,7 +74,8 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_SPMV_KERNEL); template @@ -118,7 +119,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); @@ -176,7 +177,7 @@ void fill_in_matrix_data(std::shared_ptr exec, std::copy(col_idx_vec.begin(), col_idx_vec.end(), col_idxs.get_data()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); @@ -209,7 +210,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); @@ -255,7 +256,7 @@ void convert_to_csr(const std::shared_ptr exec, row_ptrs[result->get_size()[0]] = source->get_num_stored_elements(); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); @@ -330,7 +331,7 @@ void transpose(std::shared_ptr exec, [](const ValueType x) { return x; }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); @@ -343,7 +344,7 @@ void conj_transpose(std::shared_ptr exec, [](const ValueType x) { return conj(x); }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); @@ -371,7 +372,7 @@ void is_sorted_by_column_index( *is_sorted = local_is_sorted; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); @@ -426,7 +427,7 @@ void sort_by_column_index(const std::shared_ptr exec, syn::value_list(), syn::type_list<>(), to_sort); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); @@ -463,7 +464,7 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); diff --git a/omp/matrix/sellp_kernels.cpp b/omp/matrix/sellp_kernels.cpp index 7f8b16264ce..6306093b36d 100644 --- a/omp/matrix/sellp_kernels.cpp +++ b/omp/matrix/sellp_kernels.cpp @@ -155,7 +155,8 @@ void spmv(std::shared_ptr exec, spmv_blocked<4>(exec, a, b, c, out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_SPMV_KERNEL); template @@ -194,7 +195,7 @@ void advanced_spmv(std::shared_ptr exec, spmv_blocked<4>(exec, a, b, c, out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); diff --git a/omp/matrix/sparsity_csr_kernels.cpp b/omp/matrix/sparsity_csr_kernels.cpp index 35bb42c70a6..560ee6d4890 100644 --- a/omp/matrix/sparsity_csr_kernels.cpp +++ b/omp/matrix/sparsity_csr_kernels.cpp @@ -58,7 +58,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); @@ -95,7 +95,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_ADVANCED_SPMV_KERNEL); @@ -149,7 +149,7 @@ void transpose(std::shared_ptr exec, transpose_and_transform(exec, trans, orig); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); @@ -168,7 +168,7 @@ void sort_by_column_index(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); @@ -197,7 +197,7 @@ void is_sorted_by_column_index( *is_sorted = local_is_sorted; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_IS_SORTED_BY_COLUMN_INDEX); diff --git a/reference/matrix/coo_kernels.cpp b/reference/matrix/coo_kernels.cpp index f9bf9f5f33d..ebb8c1dfce6 100644 --- a/reference/matrix/coo_kernels.cpp +++ b/reference/matrix/coo_kernels.cpp @@ -38,7 +38,8 @@ void spmv(std::shared_ptr exec, spmv2(exec, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV_KERNEL); template @@ -53,7 +54,7 @@ void advanced_spmv(std::shared_ptr exec, advanced_spmv2(exec, alpha, a, b, c); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV_KERNEL); @@ -73,7 +74,8 @@ void spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_COO_SPMV2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_COO_SPMV2_KERNEL); template @@ -96,7 +98,7 @@ void advanced_spmv2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_ADVANCED_SPMV2_KERNEL); @@ -113,7 +115,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_FILL_IN_DENSE_KERNEL); @@ -136,7 +138,7 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_COO_EXTRACT_DIAGONAL_KERNEL); diff --git a/reference/matrix/csr_kernels.cpp b/reference/matrix/csr_kernels.cpp index a0607110b79..679844084d2 100644 --- a/reference/matrix/csr_kernels.cpp +++ b/reference/matrix/csr_kernels.cpp @@ -76,7 +76,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SPMV_KERNEL); @@ -94,8 +94,8 @@ void advanced_spmv(std::shared_ptr exec, auto row_ptrs = a->get_const_row_ptrs(); auto col_idxs = a->get_const_col_idxs(); - arithmetic_type valpha = alpha->at(0, 0); - arithmetic_type vbeta = beta->at(0, 0); + auto valpha = static_cast(alpha->at(0, 0)); + auto vbeta = static_cast(beta->at(0, 0)); const auto a_vals = acc::helper::build_const_rrm_accessor(a); @@ -116,7 +116,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPMV_KERNEL); @@ -240,7 +240,8 @@ void spgemm(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEMM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEMM_KERNEL); template @@ -295,7 +296,7 @@ void advanced_spgemm(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADVANCED_SPGEMM_KERNEL); @@ -345,7 +346,8 @@ void spgeam(std::shared_ptr exec, [](IndexType, IndexType) {}); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SPGEAM_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SPGEAM_KERNEL); template @@ -367,7 +369,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_FILL_IN_DENSE_KERNEL); @@ -414,7 +416,7 @@ void convert_to_sellp(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_SELLP_KERNEL); @@ -445,7 +447,7 @@ void convert_to_ell(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_ELL_KERNEL); @@ -515,7 +517,7 @@ void convert_to_fbcsr(std::shared_ptr exec, std::copy(col_idx_vec.begin(), col_idx_vec.end(), col_idxs.get_data()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_FBCSR_KERNEL); @@ -574,7 +576,8 @@ void transpose(std::shared_ptr exec, [](const ValueType x) { return x; }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_TRANSPOSE_KERNEL); template @@ -586,7 +589,7 @@ void conj_transpose(std::shared_ptr exec, [](const ValueType x) { return conj(x); }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL); @@ -610,7 +613,7 @@ void calculate_nonzeros_per_row_in_span( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_SPAN_KERNEL); @@ -657,7 +660,7 @@ void calculate_nonzeros_per_row_in_index_set( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CALC_NNZ_PER_ROW_IN_INDEX_SET_KERNEL); @@ -691,7 +694,7 @@ void compute_submatrix(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_KERNEL); @@ -749,7 +752,7 @@ void compute_submatrix_from_index_set( } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_COMPUTE_SUB_MATRIX_FROM_INDEX_SET_KERNEL); @@ -800,7 +803,7 @@ void convert_to_hybrid(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CONVERT_TO_HYBRID_KERNEL); @@ -813,7 +816,7 @@ void inv_symm_permute(std::shared_ptr exec, inv_nonsymm_permute(exec, perm, perm, orig, permuted); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_PERMUTE_KERNEL); @@ -851,7 +854,7 @@ void inv_nonsymm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_PERMUTE_KERNEL); @@ -886,7 +889,7 @@ void row_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL); @@ -921,7 +924,7 @@ void inv_row_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_PERMUTE_KERNEL); @@ -951,7 +954,7 @@ void inv_col_permute(std::shared_ptr exec, cp_row_ptrs[num_rows] = in_row_ptrs[num_rows]; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_COL_PERMUTE_KERNEL); @@ -964,7 +967,7 @@ void inv_symm_scale_permute(std::shared_ptr exec, inv_nonsymm_scale_permute(exec, scale, perm, scale, perm, orig, permuted); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_SYMM_SCALE_PERMUTE_KERNEL); @@ -1006,7 +1009,7 @@ void inv_nonsymm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_NONSYMM_SCALE_PERMUTE_KERNEL); @@ -1043,7 +1046,7 @@ void row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ROW_SCALE_PERMUTE_KERNEL); @@ -1080,7 +1083,7 @@ void inv_row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_ROW_SCALE_PERMUTE_KERNEL); @@ -1111,7 +1114,7 @@ void inv_col_scale_permute(std::shared_ptr exec, cp_row_ptrs[num_rows] = in_row_ptrs[num_rows]; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_INV_COL_SCALE_PERMUTE_KERNEL); @@ -1133,7 +1136,7 @@ void sort_by_column_index(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX); @@ -1157,7 +1160,7 @@ void is_sorted_by_column_index( return; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_IS_SORTED_BY_COLUMN_INDEX); @@ -1182,7 +1185,8 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_EXTRACT_DIAGONAL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_EXTRACT_DIAGONAL); template @@ -1198,7 +1202,8 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_SCALE_KERNEL); template @@ -1214,7 +1219,8 @@ void inv_scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_CSR_INV_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_CSR_INV_SCALE_KERNEL); template @@ -1240,7 +1246,7 @@ void check_diagonal_entries_exist(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_CHECK_DIAGONAL_ENTRIES_EXIST); @@ -1263,7 +1269,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_CSR_ADD_SCALED_IDENTITY_KERNEL); diff --git a/reference/matrix/dense_kernels.cpp b/reference/matrix/dense_kernels.cpp index 921a49998b7..561073c8c2d 100644 --- a/reference/matrix/dense_kernels.cpp +++ b/reference/matrix/dense_kernels.cpp @@ -56,7 +56,8 @@ void simple_apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SIMPLE_APPLY_KERNEL); template @@ -89,7 +90,7 @@ void apply(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_APPLY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_APPLY_KERNEL); template @@ -105,7 +106,7 @@ void copy(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY( +GKO_INSTANTIATE_FOR_EACH_VALUE_CONVERSION_OR_COPY_WITH_HALF( GKO_DECLARE_DENSE_COPY_KERNEL); @@ -120,7 +121,7 @@ void fill(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_FILL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_FILL_KERNEL); template @@ -142,7 +143,8 @@ void scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE(GKO_DECLARE_DENSE_SCALE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SCALE_KERNEL); template @@ -165,7 +167,7 @@ void inv_scale(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SCALE_KERNEL); @@ -189,7 +191,7 @@ void add_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ADD_SCALED_KERNEL); @@ -213,7 +215,7 @@ void sub_scaled(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SUB_SCALED_KERNEL); @@ -229,7 +231,8 @@ void add_scaled_diag(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_ADD_SCALED_DIAG_KERNEL); template @@ -244,7 +247,8 @@ void sub_scaled_diag(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_SUB_SCALED_DIAG_KERNEL); template @@ -263,7 +267,8 @@ void compute_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_DOT_KERNEL); template @@ -275,7 +280,7 @@ void compute_dot_dispatch(std::shared_ptr exec, compute_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_DOT_DISPATCH_KERNEL); @@ -295,7 +300,8 @@ void compute_conj_dot(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_KERNEL); template @@ -308,7 +314,7 @@ void compute_conj_dot_dispatch(std::shared_ptr exec, compute_conj_dot(exec, x, y, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_CONJ_DOT_DISPATCH_KERNEL); @@ -331,7 +337,8 @@ void compute_norm2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_NORM2_KERNEL); template @@ -343,7 +350,7 @@ void compute_norm2_dispatch(std::shared_ptr exec, compute_norm2(exec, x, result, tmp); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_NORM2_DISPATCH_KERNEL); @@ -363,7 +370,8 @@ void compute_norm1(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_NORM1_KERNEL); template @@ -386,7 +394,8 @@ void compute_mean(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_COMPUTE_MEAN_KERNEL); template @@ -400,7 +409,7 @@ void fill_in_matrix_data(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_FILL_IN_MATRIX_DATA_KERNEL); @@ -420,7 +429,7 @@ void compute_squared_norm2(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_SQUARED_NORM2_KERNEL); @@ -435,7 +444,7 @@ void compute_sqrt(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_SQRT_KERNEL); @@ -466,7 +475,7 @@ void convert_to_coo(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_COO_KERNEL); @@ -498,7 +507,7 @@ void convert_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_CSR_KERNEL); @@ -530,7 +539,7 @@ void convert_to_ell(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_ELL_KERNEL); @@ -577,7 +586,7 @@ void convert_to_fbcsr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_FBCSR_KERNEL); @@ -626,7 +635,7 @@ void convert_to_hybrid(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_HYBRID_KERNEL); @@ -662,7 +671,7 @@ void convert_to_sellp(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SELLP_KERNEL); @@ -692,7 +701,7 @@ void convert_to_sparsity_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_CONVERT_TO_SPARSITY_CSR_KERNEL); @@ -713,7 +722,7 @@ void compute_max_nnz_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_MAX_NNZ_PER_ROW_KERNEL); @@ -745,7 +754,7 @@ void compute_slice_sets(std::shared_ptr exec, components::prefix_sum_nonnegative(exec, slice_sets, num_slices + 1); } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COMPUTE_SLICE_SETS_KERNEL); @@ -765,9 +774,9 @@ void count_nonzeros_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL); -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZEROS_PER_ROW_KERNEL_SIZE_T); @@ -797,7 +806,7 @@ void count_nonzero_blocks_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COUNT_NONZERO_BLOCKS_PER_ROW_KERNEL); @@ -813,7 +822,8 @@ void transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_TRANSPOSE_KERNEL); template @@ -828,7 +838,8 @@ void conj_transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_CONJ_TRANSPOSE_KERNEL); template @@ -844,7 +855,7 @@ void symm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SYMM_PERMUTE_KERNEL); @@ -862,7 +873,7 @@ void inv_symm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SYMM_PERMUTE_KERNEL); @@ -879,7 +890,7 @@ void nonsymm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_NONSYMM_PERMUTE_KERNEL); @@ -896,7 +907,7 @@ void inv_nonsymm_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_NONSYMM_PERMUTE_KERNEL); @@ -912,7 +923,7 @@ void row_gather(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( GKO_DECLARE_DENSE_ROW_GATHER_KERNEL); @@ -937,7 +948,7 @@ void advanced_row_gather(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2_WITH_HALF( GKO_DECLARE_DENSE_ADVANCED_ROW_GATHER_KERNEL); @@ -953,7 +964,7 @@ void col_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COL_PERMUTE_KERNEL); @@ -970,7 +981,7 @@ void inv_row_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_ROW_PERMUTE_KERNEL); @@ -987,7 +998,7 @@ void inv_col_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_COL_PERMUTE_KERNEL); @@ -1006,7 +1017,7 @@ void symm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_SYMM_SCALE_PERMUTE_KERNEL); @@ -1025,7 +1036,7 @@ void inv_symm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_SYMM_SCALE_PERMUTE_KERNEL); @@ -1048,7 +1059,7 @@ void nonsymm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_NONSYMM_SCALE_PERMUTE_KERNEL); @@ -1071,7 +1082,7 @@ void inv_nonsymm_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_NONSYMM_SCALE_PERMUTE_KERNEL); @@ -1089,7 +1100,7 @@ void row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ROW_SCALE_PERMUTE_KERNEL); @@ -1107,7 +1118,7 @@ void inv_row_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_ROW_SCALE_PERMUTE_KERNEL); @@ -1125,7 +1136,7 @@ void col_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_COL_SCALE_PERMUTE_KERNEL); @@ -1143,7 +1154,7 @@ void inv_col_scale_permute(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DENSE_INV_COL_SCALE_PERMUTE_KERNEL); @@ -1158,7 +1169,8 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DENSE_EXTRACT_DIAGONAL_KERNEL); template @@ -1173,7 +1185,8 @@ void inplace_absolute_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_INPLACE_ABSOLUTE_DENSE_KERNEL); template @@ -1189,7 +1202,8 @@ void outplace_absolute_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_OUTPLACE_ABSOLUTE_DENSE_KERNEL); template @@ -1205,7 +1219,7 @@ void make_complex(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_MAKE_COMPLEX_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_MAKE_COMPLEX_KERNEL); template @@ -1221,7 +1235,7 @@ void get_real(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GET_REAL_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_REAL_KERNEL); template @@ -1237,7 +1251,7 @@ void get_imag(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_GET_IMAG_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_GET_IMAG_KERNEL); template @@ -1257,7 +1271,7 @@ void add_scaled_identity(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_SCALAR_TYPE_WITH_HALF( GKO_DECLARE_DENSE_ADD_SCALED_IDENTITY_KERNEL); diff --git a/reference/matrix/diagonal_kernels.cpp b/reference/matrix/diagonal_kernels.cpp index 028b7685c2b..47d59728ab0 100644 --- a/reference/matrix/diagonal_kernels.cpp +++ b/reference/matrix/diagonal_kernels.cpp @@ -35,7 +35,8 @@ void apply_to_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_APPLY_TO_DENSE_KERNEL); template @@ -52,7 +53,7 @@ void right_apply_to_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_DENSE_KERNEL); @@ -77,7 +78,7 @@ void apply_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_APPLY_TO_CSR_KERNEL); @@ -101,7 +102,7 @@ void right_apply_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_RIGHT_APPLY_TO_CSR_KERNEL); @@ -118,7 +119,7 @@ void fill_in_matrix_data(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_FILL_IN_MATRIX_DATA_KERNEL); @@ -141,7 +142,7 @@ void convert_to_csr(std::shared_ptr exec, row_ptrs[size] = size; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DIAGONAL_CONVERT_TO_CSR_KERNEL); @@ -159,7 +160,8 @@ void conj_transpose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF( + GKO_DECLARE_DIAGONAL_CONJ_TRANSPOSE_KERNEL); } // namespace diagonal diff --git a/reference/matrix/ell_kernels.cpp b/reference/matrix/ell_kernels.cpp index 1fa37c4e250..ece95b38a39 100644 --- a/reference/matrix/ell_kernels.cpp +++ b/reference/matrix/ell_kernels.cpp @@ -68,7 +68,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_SPMV_KERNEL); @@ -107,7 +107,7 @@ void advanced_spmv(std::shared_ptr exec, for (size_type j = 0; j < c->get_size()[1]; j++) { for (size_type row = 0; row < a->get_size()[0]; row++) { - arithmetic_type result = c->at(row, j); + auto result = static_cast(c->at(row, j)); result *= beta_val; for (size_type i = 0; i < num_stored_elements_per_row; i++) { arithmetic_type val = a_vals(row + i * stride); @@ -121,7 +121,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL); @@ -161,7 +161,7 @@ void fill_in_matrix_data(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_FILL_IN_MATRIX_DATA_KERNEL); @@ -185,7 +185,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_FILL_IN_DENSE_KERNEL); @@ -203,7 +203,8 @@ void copy(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_COPY_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_ELL_COPY_KERNEL); template @@ -234,7 +235,7 @@ void convert_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_CONVERT_TO_CSR_KERNEL); @@ -258,7 +259,7 @@ void count_nonzeros_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_COUNT_NONZEROS_PER_ROW_KERNEL); @@ -283,7 +284,7 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_ELL_EXTRACT_DIAGONAL_KERNEL); diff --git a/reference/matrix/fbcsr_kernels.cpp b/reference/matrix/fbcsr_kernels.cpp index 4c170a973a7..048158136be 100644 --- a/reference/matrix/fbcsr_kernels.cpp +++ b/reference/matrix/fbcsr_kernels.cpp @@ -74,7 +74,8 @@ void spmv(const std::shared_ptr, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_FBCSR_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_FBCSR_SPMV_KERNEL); template @@ -118,7 +119,7 @@ void advanced_spmv(const std::shared_ptr, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_ADVANCED_SPMV_KERNEL); @@ -176,7 +177,7 @@ void fill_in_matrix_data(std::shared_ptr exec, std::copy(col_idx_vec.begin(), col_idx_vec.end(), col_idxs.get_data()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_MATRIX_DATA_KERNEL); @@ -212,7 +213,7 @@ void fill_in_dense(const std::shared_ptr, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_FILL_IN_DENSE_KERNEL); @@ -271,7 +272,7 @@ void convert_to_csr(const std::shared_ptr, static_cast(source->get_num_stored_elements()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONVERT_TO_CSR_KERNEL); @@ -353,7 +354,7 @@ void transpose(std::shared_ptr exec, [](const ValueType x) { return x; }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_TRANSPOSE_KERNEL); @@ -366,7 +367,7 @@ void conj_transpose(std::shared_ptr exec, [](const ValueType x) { return conj(x); }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_CONJ_TRANSPOSE_KERNEL); @@ -391,7 +392,7 @@ void is_sorted_by_column_index( return; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_IS_SORTED_BY_COLUMN_INDEX); @@ -448,7 +449,7 @@ void sort_by_column_index(const std::shared_ptr exec, syn::value_list(), syn::type_list<>(), to_sort); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_SORT_BY_COLUMN_INDEX); @@ -487,7 +488,7 @@ void extract_diagonal(std::shared_ptr, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_FBCSR_EXTRACT_DIAGONAL); diff --git a/reference/matrix/hybrid_kernels.cpp b/reference/matrix/hybrid_kernels.cpp index f2a06c321f2..5fe013297f3 100644 --- a/reference/matrix/hybrid_kernels.cpp +++ b/reference/matrix/hybrid_kernels.cpp @@ -86,7 +86,7 @@ void fill_in_matrix_data(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_HYBRID_FILL_IN_MATRIX_DATA_KERNEL); @@ -130,7 +130,7 @@ void convert_to_csr(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_HYBRID_CONVERT_TO_CSR_KERNEL); diff --git a/reference/matrix/scaled_permutation_kernels.cpp b/reference/matrix/scaled_permutation_kernels.cpp index b00e06f72f2..a352c0f777d 100644 --- a/reference/matrix/scaled_permutation_kernels.cpp +++ b/reference/matrix/scaled_permutation_kernels.cpp @@ -26,7 +26,7 @@ void invert(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SCALED_PERMUTATION_INVERT_KERNEL); @@ -51,7 +51,7 @@ void compose(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SCALED_PERMUTATION_COMPOSE_KERNEL); diff --git a/reference/matrix/sellp_kernels.cpp b/reference/matrix/sellp_kernels.cpp index 120194d6952..70cfc3cac3a 100644 --- a/reference/matrix/sellp_kernels.cpp +++ b/reference/matrix/sellp_kernels.cpp @@ -55,7 +55,8 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_SELLP_SPMV_KERNEL); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_SELLP_SPMV_KERNEL); template @@ -96,7 +97,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_ADVANCED_SPMV_KERNEL); @@ -163,7 +164,7 @@ void fill_in_matrix_data(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_FILL_IN_MATRIX_DATA_KERNEL); @@ -198,7 +199,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_FILL_IN_DENSE_KERNEL); @@ -234,7 +235,7 @@ void count_nonzeros_per_row(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_COUNT_NONZEROS_PER_ROW_KERNEL); @@ -280,7 +281,7 @@ void convert_to_csr(std::shared_ptr exec, result_row_ptrs[num_rows] = cur_ptr; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_CONVERT_TO_CSR_KERNEL); @@ -317,7 +318,7 @@ void extract_diagonal(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SELLP_EXTRACT_DIAGONAL_KERNEL); diff --git a/reference/matrix/sparsity_csr_kernels.cpp b/reference/matrix/sparsity_csr_kernels.cpp index c511a16a292..b773d3b9a50 100644 --- a/reference/matrix/sparsity_csr_kernels.cpp +++ b/reference/matrix/sparsity_csr_kernels.cpp @@ -55,7 +55,7 @@ void spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SPMV_KERNEL); @@ -92,7 +92,7 @@ void advanced_spmv(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_ADVANCED_SPMV_KERNEL); @@ -113,7 +113,7 @@ void fill_in_dense(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_FILL_IN_DENSE_KERNEL); @@ -138,7 +138,7 @@ void diagonal_element_prefix_sum( prefix_sum[num_rows] = num_diag; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_DIAGONAL_ELEMENT_PREFIX_SUM_KERNEL); @@ -173,7 +173,7 @@ void remove_diagonal_elements(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_REMOVE_DIAGONAL_ELEMENTS_KERNEL); @@ -227,7 +227,7 @@ void transpose(std::shared_ptr exec, transpose_and_transform(exec, orig, trans); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_TRANSPOSE_KERNEL); @@ -245,7 +245,7 @@ void sort_by_column_index(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_SORT_BY_COLUMN_INDEX); @@ -269,7 +269,7 @@ void is_sorted_by_column_index( return; } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_SPARSITY_CSR_IS_SORTED_BY_COLUMN_INDEX); diff --git a/reference/test/base/combination.cpp b/reference/test/base/combination.cpp index aea578f4e7e..149aaa33256 100644 --- a/reference/test/base/combination.cpp +++ b/reference/test/base/combination.cpp @@ -34,7 +34,8 @@ class Combination : public ::testing::Test { std::vector> operators; }; -TYPED_TEST_SUITE(Combination, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Combination, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Combination, CopiesOnSameExecutor) @@ -114,7 +115,7 @@ TYPED_TEST(Combination, AppliesToMixedVector) cmb = [ 8 7 ] [ 5 4 ] */ - using value_type = gko::next_precision; + using value_type = gko::next_precision_with_half; using Mtx = gko::matrix::Dense; auto cmb = gko::Combination::create( this->coefficients[0], this->operators[0], this->coefficients[1], @@ -156,7 +157,8 @@ TYPED_TEST(Combination, AppliesToMixedComplexVector) cmb = [ 8 7 ] [ 5 4 ] */ - using value_type = gko::to_complex>; + using value_type = + gko::to_complex>; using Mtx = gko::matrix::Dense; auto cmb = gko::Combination::create( this->coefficients[0], this->operators[0], this->coefficients[1], @@ -200,7 +202,7 @@ TYPED_TEST(Combination, AppliesLinearCombinationToMixedVector) cmb = [ 8 7 ] [ 5 4 ] */ - using value_type = gko::next_precision; + using value_type = gko::next_precision_with_half; using Mtx = gko::matrix::Dense; auto cmb = gko::Combination::create( this->coefficients[0], this->operators[0], this->coefficients[1], @@ -248,7 +250,8 @@ TYPED_TEST(Combination, AppliesLinearCombinationToMixedComplexVector) cmb = [ 8 7 ] [ 5 4 ] */ - using MixedDense = gko::matrix::Dense>; + using MixedDense = + gko::matrix::Dense>; using MixedDenseComplex = gko::to_complex; using value_type = typename MixedDenseComplex::value_type; auto cmb = gko::Combination::create( diff --git a/reference/test/matrix/coo_kernels.cpp b/reference/test/matrix/coo_kernels.cpp index fcca61a33d4..53efc588e1c 100644 --- a/reference/test/matrix/coo_kernels.cpp +++ b/reference/test/matrix/coo_kernels.cpp @@ -79,16 +79,17 @@ TYPED_TEST(Coo, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto tmp = OtherCoo::create(this->exec); auto res = Coo::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx->convert_to(tmp); tmp->convert_to(res); @@ -101,7 +102,7 @@ TYPED_TEST(Coo, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto tmp = OtherCoo::create(this->exec); @@ -214,7 +215,7 @@ TYPED_TEST(Coo, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto empty = OtherCoo::create(this->exec); @@ -231,7 +232,7 @@ TYPED_TEST(Coo, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto empty = OtherCoo::create(this->exec); diff --git a/reference/test/matrix/csr_kernels.cpp b/reference/test/matrix/csr_kernels.cpp index 2dd68bd9239..b84ac958f02 100644 --- a/reference/test/matrix/csr_kernels.cpp +++ b/reference/test/matrix/csr_kernels.cpp @@ -788,7 +788,7 @@ TYPED_TEST(Csr, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto tmp = OtherCsr::create(this->exec); @@ -814,7 +814,7 @@ TYPED_TEST(Csr, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto tmp = OtherCsr::create(this->exec); @@ -992,7 +992,7 @@ TYPED_TEST(Csr, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto empty = OtherCsr::create(this->exec); @@ -1011,7 +1011,7 @@ TYPED_TEST(Csr, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto empty = OtherCsr::create(this->exec); diff --git a/reference/test/matrix/dense_kernels.cpp b/reference/test/matrix/dense_kernels.cpp index 51b0aa148fd..a8d37ce5a09 100644 --- a/reference/test/matrix/dense_kernels.cpp +++ b/reference/test/matrix/dense_kernels.cpp @@ -75,8 +75,7 @@ class Dense : public ::testing::Test { return gko::test::generate_random_matrix( num_rows, num_cols, std::uniform_int_distribution(num_cols, num_cols), - std::normal_distribution>(0.0, 1.0), - rand_engine, exec); + std::normal_distribution<>(0.0, 1.0), rand_engine, exec); } }; @@ -751,9 +750,11 @@ TYPED_TEST(Dense, ConvertsToPrecision) auto tmp = OtherDense::create(this->exec); auto res = Dense::create(this->exec); // If OtherT is more precise: 0, otherwise r - auto residual = r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = + r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>(r::value)}; this->mtx1->convert_to(tmp); tmp->convert_to(res); @@ -771,9 +772,11 @@ TYPED_TEST(Dense, MovesToPrecision) auto tmp = OtherDense::create(this->exec); auto res = Dense::create(this->exec); // If OtherT is more precise: 0, otherwise r - auto residual = r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = + r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>(r::value)}; this->mtx1->move_to(tmp); tmp->move_to(res); @@ -3549,7 +3552,7 @@ class DenseComplex : public ::testing::Test { }; -TYPED_TEST_SUITE(DenseComplex, gko::test::ComplexValueTypes, +TYPED_TEST_SUITE(DenseComplex, gko::test::ComplexValueTypesWithHalf, TypenameNameGenerator); diff --git a/reference/test/matrix/diagonal_kernels.cpp b/reference/test/matrix/diagonal_kernels.cpp index b0932c7eb66..e2ac67190d0 100644 --- a/reference/test/matrix/diagonal_kernels.cpp +++ b/reference/test/matrix/diagonal_kernels.cpp @@ -85,16 +85,17 @@ TYPED_TEST_SUITE(Diagonal, gko::test::ValueTypes, TypenameNameGenerator); TYPED_TEST(Diagonal, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Diagonal = typename TestFixture::Diag; using OtherDiagonal = gko::matrix::Diagonal; auto tmp = OtherDiagonal::create(this->exec); auto res = Diagonal::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->diag1->convert_to(tmp); tmp->convert_to(res); @@ -106,7 +107,7 @@ TYPED_TEST(Diagonal, ConvertsToPrecision) TYPED_TEST(Diagonal, MovesToPrecision) { using ValueType = typename TestFixture::value_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Diagonal = typename TestFixture::Diag; using OtherDiagonal = gko::matrix::Diagonal; auto tmp = OtherDiagonal::create(this->exec); @@ -672,7 +673,7 @@ class DiagonalComplex : public ::testing::Test { using Diag = gko::matrix::Diagonal; }; -TYPED_TEST_SUITE(DiagonalComplex, gko::test::ComplexValueTypes, +TYPED_TEST_SUITE(DiagonalComplex, gko::test::ComplexValueTypesWithHalf, TypenameNameGenerator); diff --git a/reference/test/matrix/ell_kernels.cpp b/reference/test/matrix/ell_kernels.cpp index e1eef9f087c..6214db82d1c 100644 --- a/reference/test/matrix/ell_kernels.cpp +++ b/reference/test/matrix/ell_kernels.cpp @@ -443,16 +443,17 @@ TYPED_TEST(Ell, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto tmp = OtherEll::create(this->exec); auto res = Ell::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx1->convert_to(tmp); tmp->convert_to(res); @@ -465,7 +466,7 @@ TYPED_TEST(Ell, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto tmp = OtherEll::create(this->exec); @@ -734,7 +735,7 @@ TYPED_TEST(Ell, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto empty = Ell::create(this->exec); @@ -751,7 +752,7 @@ TYPED_TEST(Ell, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto empty = Ell::create(this->exec); diff --git a/reference/test/matrix/fbcsr_kernels.cpp b/reference/test/matrix/fbcsr_kernels.cpp index f7c6d2197ef..665df4ace31 100644 --- a/reference/test/matrix/fbcsr_kernels.cpp +++ b/reference/test/matrix/fbcsr_kernels.cpp @@ -271,16 +271,17 @@ TYPED_TEST(Fbcsr, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto tmp = OtherFbcsr::create(this->exec); auto res = Fbcsr::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx->convert_to(tmp); tmp->convert_to(res); @@ -293,7 +294,7 @@ TYPED_TEST(Fbcsr, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto tmp = OtherFbcsr::create(this->exec); @@ -391,7 +392,7 @@ TYPED_TEST(Fbcsr, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto empty = OtherFbcsr::create(this->exec); @@ -410,7 +411,7 @@ TYPED_TEST(Fbcsr, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto empty = OtherFbcsr::create(this->exec); diff --git a/reference/test/matrix/hybrid_kernels.cpp b/reference/test/matrix/hybrid_kernels.cpp index 754e599b8fe..87fd4c02811 100644 --- a/reference/test/matrix/hybrid_kernels.cpp +++ b/reference/test/matrix/hybrid_kernels.cpp @@ -233,16 +233,17 @@ TYPED_TEST(Hybrid, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto tmp = OtherHybrid::create(this->exec); auto res = Hybrid::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx1->convert_to(tmp); tmp->convert_to(res); @@ -255,7 +256,7 @@ TYPED_TEST(Hybrid, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto tmp = OtherHybrid::create(this->exec); @@ -366,7 +367,7 @@ TYPED_TEST(Hybrid, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto other = Hybrid::create(this->exec); @@ -383,7 +384,7 @@ TYPED_TEST(Hybrid, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto other = Hybrid::create(this->exec); diff --git a/reference/test/matrix/scaled_permutation.cpp b/reference/test/matrix/scaled_permutation.cpp index ba65705bf29..6d8d49f5662 100644 --- a/reference/test/matrix/scaled_permutation.cpp +++ b/reference/test/matrix/scaled_permutation.cpp @@ -145,8 +145,7 @@ TYPED_TEST(ScaledPermutation, CombineWithInverse) using index_type = typename TestFixture::index_type; const gko::size_type size = 20; auto rng = std::default_random_engine{3754}; - auto dist = std::uniform_real_distribution>{ - 1.0, 2.0}; + auto dist = std::uniform_real_distribution<>{1.0, 2.0}; auto perm = gko::matrix::ScaledPermutation::create( this->exec, size); std::iota(perm->get_permutation(), perm->get_permutation() + size, 0); diff --git a/reference/test/matrix/sellp_kernels.cpp b/reference/test/matrix/sellp_kernels.cpp index a39d8e16832..3208b8c42be 100644 --- a/reference/test/matrix/sellp_kernels.cpp +++ b/reference/test/matrix/sellp_kernels.cpp @@ -189,16 +189,17 @@ TYPED_TEST(Sellp, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto tmp = OtherSellp::create(this->exec); auto res = Sellp::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx1->convert_to(tmp); tmp->convert_to(res); @@ -211,16 +212,17 @@ TYPED_TEST(Sellp, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto tmp = OtherSellp::create(this->exec); auto res = Sellp::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = - r::value < r::value - ? gko::remove_complex{0} - : static_cast>(r::value); + auto residual = r::value < r::value + ? gko::remove_complex{0} + : gko::remove_complex{ + static_cast>( + r::value)}; this->mtx1->move_to(tmp); tmp->move_to(res); @@ -308,7 +310,7 @@ TYPED_TEST(Sellp, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto empty = OtherSellp::create(this->exec); @@ -327,7 +329,7 @@ TYPED_TEST(Sellp, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = typename gko::next_precision; + using OtherType = gko::next_precision; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto empty = OtherSellp::create(this->exec); diff --git a/test/matrix/fbcsr_kernels.cpp b/test/matrix/fbcsr_kernels.cpp index 8cff04c28a0..4ff8e1fc36a 100644 --- a/test/matrix/fbcsr_kernels.cpp +++ b/test/matrix/fbcsr_kernels.cpp @@ -37,7 +37,7 @@ class Fbcsr : public CommonTestFixture { std::unique_ptr rsorted; - std::normal_distribution> distb; + std::normal_distribution<> distb; std::default_random_engine engine; value_type get_random_value() @@ -123,6 +123,9 @@ TYPED_TEST(Fbcsr, SpmvIsEquivalentToRefSorted) using Mtx = typename TestFixture::Mtx; using Dense = typename TestFixture::Dense; using value_type = typename Mtx::value_type; + if (this->exec->get_master() != this->exec) { + SKIP_IF_HALF(value_type); + } auto drand = gko::clone(this->exec, this->rsorted); auto x = Dense::create(this->ref, gko::dim<2>(this->rsorted->get_size()[1], 1)); @@ -145,6 +148,9 @@ TYPED_TEST(Fbcsr, SpmvMultiIsEquivalentToRefSorted) using Mtx = typename TestFixture::Mtx; using Dense = typename TestFixture::Dense; using value_type = typename Mtx::value_type; + if (this->exec->get_master() != this->exec) { + SKIP_IF_HALF(value_type); + } auto drand = gko::clone(this->exec, this->rsorted); auto x = Dense::create(this->ref, gko::dim<2>(this->rsorted->get_size()[1], 3)); @@ -168,6 +174,9 @@ TYPED_TEST(Fbcsr, AdvancedSpmvIsEquivalentToRefSorted) using Dense = typename TestFixture::Dense; using value_type = typename TestFixture::value_type; using real_type = typename TestFixture::real_type; + if (this->exec->get_master() != this->exec) { + SKIP_IF_HALF(value_type); + } auto drand = gko::clone(this->exec, this->rsorted); auto x = Dense::create(this->ref, gko::dim<2>(this->rsorted->get_size()[1], 1)); @@ -198,6 +207,9 @@ TYPED_TEST(Fbcsr, AdvancedSpmvMultiIsEquivalentToRefSorted) using Dense = typename TestFixture::Dense; using value_type = typename TestFixture::value_type; using real_type = typename TestFixture::real_type; + if (this->exec->get_master() != this->exec) { + SKIP_IF_HALF(value_type); + } auto drand = gko::clone(this->exec, this->rsorted); auto x = Dense::create(this->ref, gko::dim<2>(this->rsorted->get_size()[1], 3)); diff --git a/test/matrix/matrix.cpp b/test/matrix/matrix.cpp index eea1a67ef5f..0b06f76df85 100644 --- a/test/matrix/matrix.cpp +++ b/test/matrix/matrix.cpp @@ -586,10 +586,7 @@ class Matrix : public CommonTestFixture { template gko::matrix_data gen_dense_data(gko::dim<2> size) { - return { - size, - std::normal_distribution>(0.0, 1.0), - rand_engine}; + return {size, std::normal_distribution<>(0.0, 1.0), rand_engine}; } template @@ -609,10 +606,7 @@ class Matrix : public CommonTestFixture { return {gko::initialize( {gko::test::detail::get_rand_value< typename VecType::value_type>( - std::normal_distribution< - gko::remove_complex>( - 0.0, 1.0), - rand_engine)}, + std::normal_distribution<>(0.0, 1.0), rand_engine)}, ref), exec}; } From 036485a1089785470ba9566a874416936d34f37a Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 11:59:30 +0200 Subject: [PATCH 09/16] device_matrix_data and mtx_io --- .../base/device_matrix_data_kernels.cpp | 16 +++++++-- .../base/device_matrix_data_kernels.cpp | 4 +-- core/base/device_matrix_data.cpp | 3 +- core/base/mtx_io.cpp | 35 ++++++++++++++----- core/device_hooks/common_kernels.inc.cpp | 12 ++++--- core/test/base/mtx_io.cpp | 20 +++++++++-- dpcpp/base/device_matrix_data_kernels.dp.cpp | 4 +-- omp/base/device_matrix_data_kernels.cpp | 6 ++-- reference/base/device_matrix_data_kernels.cpp | 10 +++--- test/base/device_matrix_data_kernels.cpp | 7 ++-- 10 files changed, 80 insertions(+), 37 deletions(-) diff --git a/common/cuda_hip/base/device_matrix_data_kernels.cpp b/common/cuda_hip/base/device_matrix_data_kernels.cpp index c5742653a93..ebfed84dba2 100644 --- a/common/cuda_hip/base/device_matrix_data_kernels.cpp +++ b/common/cuda_hip/base/device_matrix_data_kernels.cpp @@ -12,6 +12,7 @@ #include #include +#include "common/cuda_hip/base/math.hpp" #include "common/cuda_hip/base/thrust.hpp" #include "common/cuda_hip/base/types.hpp" @@ -22,6 +23,15 @@ namespace GKO_DEVICE_NAMESPACE { namespace components { +// __half `!=` operation is only available in __device__ +// Although gko::is_nonzero is constexpr, it still shows calling __device__ in +// __host__ +template +GKO_INLINE __device__ constexpr bool is_nonzero(T value) +{ + return value != zero(); +} + template void remove_zeros(std::shared_ptr exec, array& values, array& row_idxs, @@ -58,7 +68,7 @@ void remove_zeros(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); @@ -102,7 +112,7 @@ void sum_duplicates(std::shared_ptr exec, size_type, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SUM_DUPLICATES_KERNEL); @@ -117,7 +127,7 @@ void sort_row_major(std::shared_ptr exec, it + data.get_num_stored_elements(), vals); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SORT_ROW_MAJOR_KERNEL); diff --git a/common/unified/base/device_matrix_data_kernels.cpp b/common/unified/base/device_matrix_data_kernels.cpp index d801b47fcd5..b72c6bf3476 100644 --- a/common/unified/base/device_matrix_data_kernels.cpp +++ b/common/unified/base/device_matrix_data_kernels.cpp @@ -30,7 +30,7 @@ void soa_to_aos(std::shared_ptr exec, in.get_const_col_idxs(), in.get_const_values(), out); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SOA_TO_AOS_KERNEL); @@ -50,7 +50,7 @@ void aos_to_soa(std::shared_ptr exec, out.get_values()); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_AOS_TO_SOA_KERNEL); diff --git a/core/base/device_matrix_data.cpp b/core/base/device_matrix_data.cpp index 4c71fffe275..cb9d332f5ab 100644 --- a/core/base/device_matrix_data.cpp +++ b/core/base/device_matrix_data.cpp @@ -157,7 +157,8 @@ device_matrix_data::empty_out() #define GKO_DECLARE_DEVICE_MATRIX_DATA(ValueType, IndexType) \ class device_matrix_data -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DEVICE_MATRIX_DATA); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DEVICE_MATRIX_DATA); } // namespace gko diff --git a/core/base/mtx_io.cpp b/core/base/mtx_io.cpp index 33c3b07d487..0897349d08c 100644 --- a/core/base/mtx_io.cpp +++ b/core/base/mtx_io.cpp @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -757,19 +758,28 @@ static constexpr uint64 binary_format_magic() { constexpr auto is_int = std::is_same::value; constexpr auto is_long = std::is_same::value; + constexpr auto is_half = std::is_same::value; constexpr auto is_double = std::is_same::value; constexpr auto is_float = std::is_same::value; constexpr auto is_complex_double = std::is_same>::value; constexpr auto is_complex_float = std::is_same>::value; + constexpr auto is_complex_half = + std::is_same>::value; static_assert(is_int || is_long, "invalid storage index type"); - static_assert( - is_double || is_float || is_complex_double || is_complex_float, - "invalid storage value type"); + static_assert(is_half || is_complex_half || is_double || is_float || + is_complex_double || is_complex_float, + "invalid storage value type"); constexpr auto index_bit = is_int ? 'I' : 'L'; constexpr auto value_bit = - is_double ? 'D' : (is_float ? 'S' : (is_complex_double ? 'Z' : 'C')); + is_double + ? 'D' + : (is_float + ? 'S' + : (is_complex_double + ? 'Z' + : (is_complex_float ? 'C' : (is_half ? 'H' : 'X')))); constexpr uint64 shift = 256; constexpr uint64 type_bits = index_bit * shift + value_bit; return 'G' + @@ -879,12 +889,16 @@ matrix_data read_binary_raw(std::istream& is) } DECLARE_OVERLOAD(double, int32) DECLARE_OVERLOAD(float, int32) + DECLARE_OVERLOAD(half, int32) DECLARE_OVERLOAD(std::complex, int32) DECLARE_OVERLOAD(std::complex, int32) + DECLARE_OVERLOAD(std::complex, int32) DECLARE_OVERLOAD(double, int64) DECLARE_OVERLOAD(float, int64) + DECLARE_OVERLOAD(half, int64) DECLARE_OVERLOAD(std::complex, int64) DECLARE_OVERLOAD(std::complex, int64) + DECLARE_OVERLOAD(std::complex, int64) #undef DECLARE_OVERLOAD else { @@ -970,11 +984,14 @@ void write_raw(std::ostream& os, const matrix_data& data, const matrix_data& data) #define GKO_DECLARE_READ_GENERIC_RAW(ValueType, IndexType) \ matrix_data read_generic_raw(std::istream& is) -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_READ_RAW); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_WRITE_RAW); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_READ_BINARY_RAW); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_WRITE_BINARY_RAW); -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_READ_GENERIC_RAW); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_READ_RAW); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF(GKO_DECLARE_WRITE_RAW); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_READ_BINARY_RAW); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_WRITE_BINARY_RAW); +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_READ_GENERIC_RAW); } // namespace gko diff --git a/core/device_hooks/common_kernels.inc.cpp b/core/device_hooks/common_kernels.inc.cpp index 78b80ec2859..439cda481a2 100644 --- a/core/device_hooks/common_kernels.inc.cpp +++ b/core/device_hooks/common_kernels.inc.cpp @@ -251,14 +251,16 @@ GKO_STUB_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_FILL_SEQ_ARRAY_KERNEL); GKO_STUB_TEMPLATE_TYPE_WITH_HALF(GKO_DECLARE_REDUCE_ADD_ARRAY_KERNEL); GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_INPLACE_ABSOLUTE_ARRAY_KERNEL); GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_OUTPLACE_ABSOLUTE_ARRAY_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SUM_DUPLICATES_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE( +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SORT_ROW_MAJOR_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DEVICE_MATRIX_DATA_AOS_TO_SOA_KERNEL); -GKO_STUB_VALUE_AND_INDEX_TYPE(GKO_DECLARE_DEVICE_MATRIX_DATA_SOA_TO_AOS_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DEVICE_MATRIX_DATA_AOS_TO_SOA_KERNEL); +GKO_STUB_VALUE_AND_INDEX_TYPE_WITH_HALF( + GKO_DECLARE_DEVICE_MATRIX_DATA_SOA_TO_AOS_KERNEL); template GKO_DECLARE_CONVERT_PTRS_TO_IDXS(IndexType, RowPtrType) diff --git a/core/test/base/mtx_io.cpp b/core/test/base/mtx_io.cpp index 8ac1ced0e50..14d44335b85 100644 --- a/core/test/base/mtx_io.cpp +++ b/core/test/base/mtx_io.cpp @@ -7,6 +7,7 @@ #include +#include #include #include #include @@ -570,6 +571,12 @@ TEST(MtxReader, ReadsBinary) test_read(gko::matrix_data{}); test_read(gko::matrix_data, gko::int64>{}); test_read(gko::matrix_data, gko::int64>{}); +#if GINKGO_ENABLE_HALF + test_read(gko::matrix_data{}); + test_read(gko::matrix_data, gko::int32>{}); + test_read(gko::matrix_data{}); + test_read(gko::matrix_data, gko::int64>{}); +#endif } @@ -625,6 +632,12 @@ TEST(MtxReader, ReadsComplexBinary) test_read_fail(gko::matrix_data{}); test_read(gko::matrix_data, gko::int64>{}); test_read(gko::matrix_data, gko::int64>{}); +#if GINKGO_ENABLE_HALF + test_read_fail(gko::matrix_data{}); + test_read(gko::matrix_data, gko::int32>{}); + test_read_fail(gko::matrix_data{}); + test_read(gko::matrix_data, gko::int64>{}); +#endif } @@ -960,7 +973,7 @@ class RealDummyLinOpTest : public ::testing::Test { typename std::tuple_element<1, decltype(ValueIndexType())>::type; }; -TYPED_TEST_SUITE(RealDummyLinOpTest, gko::test::RealValueIndexTypes, +TYPED_TEST_SUITE(RealDummyLinOpTest, gko::test::RealValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -1165,7 +1178,7 @@ class DenseTest : public ::testing::Test { using index_type = typename std::tuple_element<1, ValueIndexType>::type; }; -TYPED_TEST_SUITE(DenseTest, gko::test::RealValueIndexTypes, +TYPED_TEST_SUITE(DenseTest, gko::test::RealValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -1209,7 +1222,8 @@ class ComplexDummyLinOpTest : public ::testing::Test { typename std::tuple_element<1, decltype(ValueIndexType())>::type; }; -TYPED_TEST_SUITE(ComplexDummyLinOpTest, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(ComplexDummyLinOpTest, + gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/dpcpp/base/device_matrix_data_kernels.dp.cpp b/dpcpp/base/device_matrix_data_kernels.dp.cpp index f39615613fe..a5f58831a27 100644 --- a/dpcpp/base/device_matrix_data_kernels.dp.cpp +++ b/dpcpp/base/device_matrix_data_kernels.dp.cpp @@ -49,7 +49,7 @@ void remove_zeros(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); @@ -112,7 +112,7 @@ void sort_row_major(std::shared_ptr exec, }); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SORT_ROW_MAJOR_KERNEL); diff --git a/omp/base/device_matrix_data_kernels.cpp b/omp/base/device_matrix_data_kernels.cpp index bce89e2f409..cb2dabd3010 100644 --- a/omp/base/device_matrix_data_kernels.cpp +++ b/omp/base/device_matrix_data_kernels.cpp @@ -69,7 +69,7 @@ void remove_zeros(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); @@ -127,7 +127,7 @@ void sum_duplicates(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SUM_DUPLICATES_KERNEL); @@ -142,7 +142,7 @@ void sort_row_major(std::shared_ptr exec, aos_to_soa(exec, tmp, data); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SORT_ROW_MAJOR_KERNEL); diff --git a/reference/base/device_matrix_data_kernels.cpp b/reference/base/device_matrix_data_kernels.cpp index f9a23b35e69..78a2e25a712 100644 --- a/reference/base/device_matrix_data_kernels.cpp +++ b/reference/base/device_matrix_data_kernels.cpp @@ -29,7 +29,7 @@ void soa_to_aos(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SOA_TO_AOS_KERNEL); @@ -46,7 +46,7 @@ void aos_to_soa(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_AOS_TO_SOA_KERNEL); @@ -78,7 +78,7 @@ void remove_zeros(std::shared_ptr exec, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_REMOVE_ZEROS_KERNEL); @@ -127,7 +127,7 @@ void sum_duplicates(std::shared_ptr exec, size_type, } } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SUM_DUPLICATES_KERNEL); @@ -142,7 +142,7 @@ void sort_row_major(std::shared_ptr exec, aos_to_soa(exec, tmp, data); } -GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( +GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE_WITH_HALF( GKO_DECLARE_DEVICE_MATRIX_DATA_SORT_ROW_MAJOR_KERNEL); diff --git a/test/base/device_matrix_data_kernels.cpp b/test/base/device_matrix_data_kernels.cpp index 6ddc926b76c..d2543ae7cbb 100644 --- a/test/base/device_matrix_data_kernels.cpp +++ b/test/base/device_matrix_data_kernels.cpp @@ -35,8 +35,7 @@ class DeviceMatrixData : public CommonTestFixture { 0, host_data.size[0] - 1); std::uniform_int_distribution col_distr( 0, host_data.size[1] - 1); - std::uniform_real_distribution> - val_distr(1.0, 2.0); + std::uniform_real_distribution<> val_distr(1.0, 2.0); // add random entries for (int i = 0; i < 1000; i++) { host_data.nonzeros.emplace_back( @@ -85,7 +84,7 @@ class DeviceMatrixData : public CommonTestFixture { gko::matrix_data deduplicated_data; }; -TYPED_TEST_SUITE(DeviceMatrixData, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(DeviceMatrixData, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -339,7 +338,7 @@ TYPED_TEST(DeviceMatrixData, SumsDuplicates) arrays.values.set_executor(this->exec->get_master()); for (int i = 0; i < arrays.values.get_size(); i++) { max_error = std::max( - max_error, std::abs(arrays.values.get_const_data()[i] - + max_error, gko::abs(arrays.values.get_const_data()[i] - ref_arrays.values.get_const_data()[i])); } // when Hip with GNU < 7, it will give a little difference. From b29a8f6795bc1a5dfbcd44319913da0b5b1d128c Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 15:39:04 +0200 Subject: [PATCH 10/16] components such as array/iterator/segmented_array test with half --- core/test/base/array.cpp | 3 ++- core/test/base/iterator_factory.cpp | 4 ++-- core/test/base/segmented_array.cpp | 3 ++- core/test/components/addressable_pq.cpp | 4 ++-- cuda/test/base/array.cpp | 3 ++- reference/test/base/array.cpp | 3 ++- reference/test/components/absolute_array_kernels.cpp | 3 ++- reference/test/components/fill_array_kernels.cpp | 2 +- reference/test/components/reduce_array_kernels.cpp | 2 +- test/components/fill_array_kernels.cpp | 2 +- test/components/reduce_array_kernels.cpp | 11 ++++++++--- 11 files changed, 25 insertions(+), 15 deletions(-) diff --git a/core/test/base/array.cpp b/core/test/base/array.cpp index f7e03855d06..23515d70fc4 100644 --- a/core/test/base/array.cpp +++ b/core/test/base/array.cpp @@ -40,7 +40,8 @@ class Array : public ::testing::Test { gko::array x; }; -TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Array, CanBeCreatedWithoutAnExecutor) diff --git a/core/test/base/iterator_factory.cpp b/core/test/base/iterator_factory.cpp index bbc3bbfd04f..3685242f78a 100644 --- a/core/test/base/iterator_factory.cpp +++ b/core/test/base/iterator_factory.cpp @@ -78,7 +78,7 @@ class ZipIterator : public ::testing::Test { const std::vector ordered_value; }; -TYPED_TEST_SUITE(ZipIterator, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(ZipIterator, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -366,7 +366,7 @@ class PermuteIterator : public ::testing::Test { using value_type = ValueType; }; -TYPED_TEST_SUITE(PermuteIterator, gko::test::ComplexAndPODTypes, +TYPED_TEST_SUITE(PermuteIterator, gko::test::ComplexAndPODTypesWithHalf, TypenameNameGenerator); diff --git a/core/test/base/segmented_array.cpp b/core/test/base/segmented_array.cpp index 2741990036f..31444d71d18 100644 --- a/core/test/base/segmented_array.cpp +++ b/core/test/base/segmented_array.cpp @@ -27,7 +27,8 @@ class SegmentedArray : public ::testing::Test { std::shared_ptr exec = gko::ReferenceExecutor::create(); }; -TYPED_TEST_SUITE(SegmentedArray, gko::test::PODTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(SegmentedArray, gko::test::PODTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(SegmentedArray, CanConstructFromExecutor) diff --git a/core/test/components/addressable_pq.cpp b/core/test/components/addressable_pq.cpp index 6301cd44fb4..87fcb289a77 100644 --- a/core/test/components/addressable_pq.cpp +++ b/core/test/components/addressable_pq.cpp @@ -91,8 +91,8 @@ class AddressablePriorityQueue : public ::testing::Test { std::shared_ptr exec; }; -TYPED_TEST_SUITE(AddressablePriorityQueue, gko::test::RealValueIndexTypes, - TypenameNameGenerator); +TYPED_TEST_SUITE(AddressablePriorityQueue, + gko::test::RealValueIndexTypesWithHalf, TypenameNameGenerator); TYPED_TEST(AddressablePriorityQueue, InitializesCorrectly) diff --git a/cuda/test/base/array.cpp b/cuda/test/base/array.cpp index db7d4c54536..7294cbff29f 100644 --- a/cuda/test/base/array.cpp +++ b/cuda/test/base/array.cpp @@ -32,7 +32,8 @@ class Array : public CudaTestFixture { gko::array x; }; -TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Array, CanCreateTemporaryCloneOnDifferentExecutor) diff --git a/reference/test/base/array.cpp b/reference/test/base/array.cpp index 666ab13063c..2c69f1afc8e 100644 --- a/reference/test/base/array.cpp +++ b/reference/test/base/array.cpp @@ -28,7 +28,8 @@ class Array : public ::testing::Test { gko::array x; }; -TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Array, gko::test::ComplexAndPODTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Array, CanBeFilledWithValue) diff --git a/reference/test/components/absolute_array_kernels.cpp b/reference/test/components/absolute_array_kernels.cpp index c192d540032..5ad75440c88 100644 --- a/reference/test/components/absolute_array_kernels.cpp +++ b/reference/test/components/absolute_array_kernels.cpp @@ -43,7 +43,8 @@ class AbsoluteArray : public ::testing::Test { gko::array vals; }; -TYPED_TEST_SUITE(AbsoluteArray, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(AbsoluteArray, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(AbsoluteArray, InplaceEqualsExpected) diff --git a/reference/test/components/fill_array_kernels.cpp b/reference/test/components/fill_array_kernels.cpp index 3c7520c6847..0a9239ce1bd 100644 --- a/reference/test/components/fill_array_kernels.cpp +++ b/reference/test/components/fill_array_kernels.cpp @@ -40,7 +40,7 @@ class FillArray : public ::testing::Test { gko::array seqs; }; -TYPED_TEST_SUITE(FillArray, gko::test::ComplexAndPODTypes, +TYPED_TEST_SUITE(FillArray, gko::test::ComplexAndPODTypesWithHalf, TypenameNameGenerator); diff --git a/reference/test/components/reduce_array_kernels.cpp b/reference/test/components/reduce_array_kernels.cpp index 8286817c853..c8839bc178d 100644 --- a/reference/test/components/reduce_array_kernels.cpp +++ b/reference/test/components/reduce_array_kernels.cpp @@ -31,7 +31,7 @@ class ReduceArray : public ::testing::Test { gko::array vals; }; -TYPED_TEST_SUITE(ReduceArray, gko::test::ComplexAndPODTypes, +TYPED_TEST_SUITE(ReduceArray, gko::test::ComplexAndPODTypesWithHalf, TypenameNameGenerator); diff --git a/test/components/fill_array_kernels.cpp b/test/components/fill_array_kernels.cpp index 3d494b3f5f0..4237a75304a 100644 --- a/test/components/fill_array_kernels.cpp +++ b/test/components/fill_array_kernels.cpp @@ -36,7 +36,7 @@ class FillArray : public CommonTestFixture { gko::array seqs; }; -TYPED_TEST_SUITE(FillArray, gko::test::ComplexAndPODTypes, +TYPED_TEST_SUITE(FillArray, gko::test::ComplexAndPODTypesWithHalf, TypenameNameGenerator); diff --git a/test/components/reduce_array_kernels.cpp b/test/components/reduce_array_kernels.cpp index b7407801a32..7940feec661 100644 --- a/test/components/reduce_array_kernels.cpp +++ b/test/components/reduce_array_kernels.cpp @@ -20,14 +20,19 @@ template class ReduceArray : public CommonTestFixture { protected: using value_type = T; + static constexpr bool using_half = + std::is_same_v, gko::half>; + + // due to half accuracy, the summation ordering will affect the result + // easily ReduceArray() - : total_size(6355), + : total_size(using_half ? 1024 : 6355), out{ref, I{2}}, dout{exec, out}, vals{ref, total_size}, dvals{exec} { - std::fill_n(vals.get_data(), total_size, 3); + std::fill_n(vals.get_data(), total_size, using_half ? 1 : 3); dvals = vals; } @@ -38,7 +43,7 @@ class ReduceArray : public CommonTestFixture { gko::array dvals; }; -TYPED_TEST_SUITE(ReduceArray, gko::test::ComplexAndPODTypes, +TYPED_TEST_SUITE(ReduceArray, gko::test::ComplexAndPODTypesWithHalf, TypenameNameGenerator); From 2be0042c02ca4f9fbb34268366e7fbfbb42a82bf Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 24 Oct 2024 15:39:41 +0200 Subject: [PATCH 11/16] matrix test with half --- core/test/matrix/coo.cpp | 3 +- core/test/matrix/coo_builder.cpp | 2 +- core/test/matrix/csr.cpp | 3 +- core/test/matrix/csr_builder.cpp | 2 +- core/test/matrix/dense.cpp | 2 +- core/test/matrix/diagonal.cpp | 3 +- core/test/matrix/ell.cpp | 3 +- core/test/matrix/fbcsr.cpp | 7 ++- core/test/matrix/fbcsr_builder.cpp | 2 +- core/test/matrix/hybrid.cpp | 3 +- core/test/matrix/identity.cpp | 6 +- core/test/matrix/permutation.cpp | 2 +- core/test/matrix/row_gatherer.cpp | 2 +- core/test/matrix/sellp.cpp | 3 +- core/test/matrix/sparsity_csr.cpp | 2 +- core/test/utils/fb_matrix_generator.hpp | 13 ++--- core/test/utils/value_generator.hpp | 6 +- hip/test/matrix/fbcsr_kernels.cpp | 56 +++++++++++++------ reference/test/matrix/coo_kernels.cpp | 33 +++++------ reference/test/matrix/csr_kernels.cpp | 48 ++++++++-------- reference/test/matrix/dense_kernels.cpp | 19 ++++--- reference/test/matrix/diagonal_kernels.cpp | 14 +++-- reference/test/matrix/ell_kernels.cpp | 46 +++++++-------- reference/test/matrix/fbcsr_kernels.cpp | 15 ++--- reference/test/matrix/hybrid_kernels.cpp | 20 ++++--- reference/test/matrix/identity.cpp | 6 +- reference/test/matrix/permutation.cpp | 2 +- reference/test/matrix/scaled_permutation.cpp | 2 +- reference/test/matrix/sellp_kernels.cpp | 19 ++++--- reference/test/matrix/sparsity_csr.cpp | 2 +- .../test/matrix/sparsity_csr_kernels.cpp | 12 ++-- test/matrix/fbcsr_kernels.cpp | 23 ++++++-- 32 files changed, 219 insertions(+), 162 deletions(-) diff --git a/core/test/matrix/coo.cpp b/core/test/matrix/coo.cpp index ffb8d5aee9f..56735e792d5 100644 --- a/core/test/matrix/coo.cpp +++ b/core/test/matrix/coo.cpp @@ -77,7 +77,8 @@ class Coo : public ::testing::Test { } }; -TYPED_TEST_SUITE(Coo, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Coo, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Coo, KnowsItsSize) diff --git a/core/test/matrix/coo_builder.cpp b/core/test/matrix/coo_builder.cpp index 9bfae5cf3af..b1b22c5848a 100644 --- a/core/test/matrix/coo_builder.cpp +++ b/core/test/matrix/coo_builder.cpp @@ -32,7 +32,7 @@ class CooBuilder : public ::testing::Test { std::unique_ptr mtx; }; -TYPED_TEST_SUITE(CooBuilder, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(CooBuilder, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/matrix/csr.cpp b/core/test/matrix/csr.cpp index 4bbdc63851a..f199de423e8 100644 --- a/core/test/matrix/csr.cpp +++ b/core/test/matrix/csr.cpp @@ -82,7 +82,8 @@ class Csr : public ::testing::Test { } }; -TYPED_TEST_SUITE(Csr, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Csr, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Csr, KnowsItsSize) diff --git a/core/test/matrix/csr_builder.cpp b/core/test/matrix/csr_builder.cpp index 24cbe4718c5..2accb57770c 100644 --- a/core/test/matrix/csr_builder.cpp +++ b/core/test/matrix/csr_builder.cpp @@ -33,7 +33,7 @@ class CsrBuilder : public ::testing::Test { std::unique_ptr mtx; }; -TYPED_TEST_SUITE(CsrBuilder, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(CsrBuilder, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/matrix/dense.cpp b/core/test/matrix/dense.cpp index e7158a15aed..f1a673840ea 100644 --- a/core/test/matrix/dense.cpp +++ b/core/test/matrix/dense.cpp @@ -48,7 +48,7 @@ class Dense : public ::testing::Test { std::unique_ptr> mtx; }; -TYPED_TEST_SUITE(Dense, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Dense, gko::test::ValueTypesWithHalf, TypenameNameGenerator); TYPED_TEST(Dense, CanBeEmpty) diff --git a/core/test/matrix/diagonal.cpp b/core/test/matrix/diagonal.cpp index de03a9350bb..7e598d67a5e 100644 --- a/core/test/matrix/diagonal.cpp +++ b/core/test/matrix/diagonal.cpp @@ -47,7 +47,8 @@ class Diagonal : public ::testing::Test { } }; -TYPED_TEST_SUITE(Diagonal, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Diagonal, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Diagonal, KnowsItsSize) diff --git a/core/test/matrix/ell.cpp b/core/test/matrix/ell.cpp index bcc2b591a50..93fc73dde18 100644 --- a/core/test/matrix/ell.cpp +++ b/core/test/matrix/ell.cpp @@ -79,7 +79,8 @@ class Ell : public ::testing::Test { } }; -TYPED_TEST_SUITE(Ell, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Ell, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Ell, KnowsItsSize) diff --git a/core/test/matrix/fbcsr.cpp b/core/test/matrix/fbcsr.cpp index 3d3d4ee738d..fd024532a14 100644 --- a/core/test/matrix/fbcsr.cpp +++ b/core/test/matrix/fbcsr.cpp @@ -131,7 +131,7 @@ class FbcsrSample : public ::testing::Test { }; -TYPED_TEST_SUITE(FbcsrSample, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(FbcsrSample, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -183,7 +183,7 @@ template class FbcsrSampleComplex : public FbcsrSample {}; -TYPED_TEST_SUITE(FbcsrSampleComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(FbcsrSampleComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -282,7 +282,8 @@ class Fbcsr : public ::testing::Test { } }; -TYPED_TEST_SUITE(Fbcsr, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Fbcsr, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Fbcsr, GetNumBlocksCorrectlyThrows) diff --git a/core/test/matrix/fbcsr_builder.cpp b/core/test/matrix/fbcsr_builder.cpp index d91a0c7b70a..241c7ccc6eb 100644 --- a/core/test/matrix/fbcsr_builder.cpp +++ b/core/test/matrix/fbcsr_builder.cpp @@ -33,7 +33,7 @@ class FbcsrBuilder : public ::testing::Test { std::unique_ptr mtx; }; -TYPED_TEST_SUITE(FbcsrBuilder, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(FbcsrBuilder, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/matrix/hybrid.cpp b/core/test/matrix/hybrid.cpp index d1a69312755..6b1e2a4a747 100644 --- a/core/test/matrix/hybrid.cpp +++ b/core/test/matrix/hybrid.cpp @@ -96,7 +96,8 @@ class Hybrid : public ::testing::Test { } }; -TYPED_TEST_SUITE(Hybrid, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Hybrid, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Hybrid, KnowsItsSize) diff --git a/core/test/matrix/identity.cpp b/core/test/matrix/identity.cpp index bcf9c036992..80defae4441 100644 --- a/core/test/matrix/identity.cpp +++ b/core/test/matrix/identity.cpp @@ -31,7 +31,8 @@ class Identity : public ::testing::Test { std::shared_ptr exec; }; -TYPED_TEST_SUITE(Identity, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Identity, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Identity, CanBeEmpty) @@ -81,7 +82,8 @@ class IdentityFactory : public ::testing::Test { using value_type = T; }; -TYPED_TEST_SUITE(IdentityFactory, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(IdentityFactory, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(IdentityFactory, CanGenerateIdentityMatrix) diff --git a/core/test/matrix/permutation.cpp b/core/test/matrix/permutation.cpp index edb1532696b..fcd5aad789c 100644 --- a/core/test/matrix/permutation.cpp +++ b/core/test/matrix/permutation.cpp @@ -52,7 +52,7 @@ class Permutation : public ::testing::Test { std::unique_ptr> mtx; }; -TYPED_TEST_SUITE(Permutation, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(Permutation, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/matrix/row_gatherer.cpp b/core/test/matrix/row_gatherer.cpp index 801f639c206..b808828cc08 100644 --- a/core/test/matrix/row_gatherer.cpp +++ b/core/test/matrix/row_gatherer.cpp @@ -65,7 +65,7 @@ class RowGatherer : public ::testing::Test { std::unique_ptr out; }; -TYPED_TEST_SUITE(RowGatherer, gko::test::TwoValueIndexType, +TYPED_TEST_SUITE(RowGatherer, gko::test::TwoValueIndexTypeWithHalf, TupleTypenameNameGenerator); diff --git a/core/test/matrix/sellp.cpp b/core/test/matrix/sellp.cpp index 123d7bae773..a79fcf2bbd3 100644 --- a/core/test/matrix/sellp.cpp +++ b/core/test/matrix/sellp.cpp @@ -107,7 +107,8 @@ class Sellp : public ::testing::Test { } }; -TYPED_TEST_SUITE(Sellp, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Sellp, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Sellp, KnowsItsSize) diff --git a/core/test/matrix/sparsity_csr.cpp b/core/test/matrix/sparsity_csr.cpp index e929f960f1e..67f8237adb6 100644 --- a/core/test/matrix/sparsity_csr.cpp +++ b/core/test/matrix/sparsity_csr.cpp @@ -74,7 +74,7 @@ class SparsityCsr : public ::testing::Test { } }; -TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/utils/fb_matrix_generator.hpp b/core/test/utils/fb_matrix_generator.hpp index 034dd95fce1..786f836e10a 100644 --- a/core/test/utils/fb_matrix_generator.hpp +++ b/core/test/utils/fb_matrix_generator.hpp @@ -131,16 +131,15 @@ std::unique_ptr> generate_fbcsr_from_csr( const IndexType* const row_ptrs = fmtx->get_const_row_ptrs(); const IndexType* const col_idxs = fmtx->get_const_col_idxs(); ValueType* const vals = fmtx->get_values(); - std::uniform_real_distribution> - off_diag_dist(-1.0, 1.0); + std::uniform_real_distribution<> off_diag_dist(-1.0, 1.0); for (IndexType ibrow = 0; ibrow < nbrows; ibrow++) { if (row_diag_dominant) { const IndexType nrownz = (row_ptrs[ibrow + 1] - row_ptrs[ibrow]) * block_size; - std::uniform_real_distribution> - diag_dist(1.01 * nrownz, 2 * nrownz); + std::uniform_real_distribution<> diag_dist(1.01 * nrownz, + 2 * nrownz); for (IndexType ibz = row_ptrs[ibrow]; ibz < row_ptrs[ibrow + 1]; ibz++) { @@ -205,13 +204,11 @@ std::unique_ptr> generate_random_fbcsr( matrix::Csr>( nbrows, nbcols, std::uniform_int_distribution(0, nbcols - 1), - std::normal_distribution(0.0, 1.0), - std::move(engine), ref) + std::normal_distribution<>(0.0, 1.0), std::move(engine), ref) : generate_random_matrix>( nbrows, nbcols, std::uniform_int_distribution(0, nbcols - 1), - std::normal_distribution(0.0, 1.0), - std::move(engine), ref); + std::normal_distribution<>(0.0, 1.0), std::move(engine), ref); if (unsort && rand_csr_ref->is_sorted_by_column_index()) { unsort_matrix(rand_csr_ref, engine); } diff --git a/core/test/utils/value_generator.hpp b/core/test/utils/value_generator.hpp index f18f2170c96..19e01b33356 100644 --- a/core/test/utils/value_generator.hpp +++ b/core/test/utils/value_generator.hpp @@ -33,7 +33,7 @@ template typename std::enable_if::value, ValueType>::type get_rand_value(ValueDistribution&& value_dist, Engine&& gen) { - return value_dist(gen); + return static_cast(value_dist(gen)); } /** @@ -45,7 +45,9 @@ template typename std::enable_if::value, ValueType>::type get_rand_value(ValueDistribution&& value_dist, Engine&& gen) { - return ValueType(value_dist(gen), value_dist(gen)); + using real_type = remove_complex; + return ValueType(static_cast(value_dist(gen)), + static_cast(value_dist(gen))); } diff --git a/hip/test/matrix/fbcsr_kernels.cpp b/hip/test/matrix/fbcsr_kernels.cpp index 0b4b16086ca..536ff3dc01c 100644 --- a/hip/test/matrix/fbcsr_kernels.cpp +++ b/hip/test/matrix/fbcsr_kernels.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -41,7 +42,7 @@ class Fbcsr : public HipTestFixture { std::unique_ptr rsorted_ref; - std::normal_distribution> distb; + std::normal_distribution<> distb; std::default_random_engine engine; value_type get_random_value() @@ -60,7 +61,8 @@ class Fbcsr : public HipTestFixture { } }; -TYPED_TEST_SUITE(Fbcsr, gko::test::RealValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Fbcsr, gko::test::RealValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Fbcsr, CanWriteFromMatrixOnDevice) @@ -145,11 +147,15 @@ TYPED_TEST(Fbcsr, SpmvIsEquivalentToRefSorted) this->ref, gko::dim<2>(this->rsorted_ref->get_size()[0], 1)); auto prod_hip = Dense::create(this->exec, prod_ref->get_size()); - rand_hip->apply(x_hip, prod_hip); - this->rsorted_ref->apply(x_ref, prod_ref); + if (std::is_same::value) { + ASSERT_THROW(rand_hip->apply(x_hip, prod_hip), gko::NotImplemented); + } else { + rand_hip->apply(x_hip, prod_hip); + this->rsorted_ref->apply(x_ref, prod_ref); - const double tol = r::value; - GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + const double tol = r::value; + GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + } } @@ -169,11 +175,15 @@ TYPED_TEST(Fbcsr, SpmvMultiIsEquivalentToRefSorted) this->ref, gko::dim<2>(this->rsorted_ref->get_size()[0], 3)); auto prod_hip = Dense::create(this->exec, prod_ref->get_size()); - rand_hip->apply(x_hip, prod_hip); - this->rsorted_ref->apply(x_ref, prod_ref); + if (std::is_same::value) { + ASSERT_THROW(rand_hip->apply(x_hip, prod_hip), gko::NotImplemented); + } else { + rand_hip->apply(x_hip, prod_hip); + this->rsorted_ref->apply(x_ref, prod_ref); - const double tol = r::value; - GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + const double tol = r::value; + GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + } } @@ -205,11 +215,16 @@ TYPED_TEST(Fbcsr, AdvancedSpmvIsEquivalentToRefSorted) auto beta = Dense::create(this->exec); beta->copy_from(beta_ref); - rand_hip->apply(alpha, x_hip, beta, prod_hip); - this->rsorted_ref->apply(alpha_ref, x_ref, beta_ref, prod_ref); + if (std::is_same::value) { + ASSERT_THROW(rand_hip->apply(alpha, x_hip, beta, prod_hip), + gko::NotImplemented); + } else { + rand_hip->apply(alpha, x_hip, beta, prod_hip); + this->rsorted_ref->apply(alpha_ref, x_ref, beta_ref, prod_ref); - const double tol = r::value; - GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + const double tol = r::value; + GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + } } @@ -241,11 +256,16 @@ TYPED_TEST(Fbcsr, AdvancedSpmvMultiIsEquivalentToRefSorted) auto beta = Dense::create(this->exec); beta->copy_from(beta_ref); - rand_hip->apply(alpha, x_hip, beta, prod_hip); - this->rsorted_ref->apply(alpha_ref, x_ref, beta_ref, prod_ref); + if (std::is_same::value) { + ASSERT_THROW(rand_hip->apply(alpha, x_hip, beta, prod_hip), + gko::NotImplemented); + } else { + rand_hip->apply(alpha, x_hip, beta, prod_hip); + this->rsorted_ref->apply(alpha_ref, x_ref, beta_ref, prod_ref); - const double tol = r::value; - GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + const double tol = r::value; + GKO_ASSERT_MTX_NEAR(prod_ref, prod_hip, 5 * tol); + } } diff --git a/reference/test/matrix/coo_kernels.cpp b/reference/test/matrix/coo_kernels.cpp index 53efc588e1c..6ffea5d0e7d 100644 --- a/reference/test/matrix/coo_kernels.cpp +++ b/reference/test/matrix/coo_kernels.cpp @@ -32,7 +32,8 @@ class Coo : public ::testing::Test { using Csr = gko::matrix::Csr; using Mtx = gko::matrix::Coo; using Vec = gko::matrix::Dense; - using MixedVec = gko::matrix::Dense>; + using MixedVec = + gko::matrix::Dense>; Coo() : exec(gko::ReferenceExecutor::create()), mtx(Mtx::create(exec)) { @@ -72,24 +73,24 @@ class Coo : public ::testing::Test { std::unique_ptr uns_mtx; }; -TYPED_TEST_SUITE(Coo, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Coo, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Coo, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto tmp = OtherCoo::create(this->exec); auto res = Coo::create(this->exec); // If OtherType is more precise: 0, otherwise r - auto residual = r::value < r::value - ? gko::remove_complex{0} - : gko::remove_complex{ - static_cast>( - r::value)}; + auto residual = + r::value < r::value + ? gko::remove_complex{0} + : static_cast>(r::value); this->mtx->convert_to(tmp); tmp->convert_to(res); @@ -102,7 +103,7 @@ TYPED_TEST(Coo, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto tmp = OtherCoo::create(this->exec); @@ -215,7 +216,7 @@ TYPED_TEST(Coo, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto empty = OtherCoo::create(this->exec); @@ -232,7 +233,7 @@ TYPED_TEST(Coo, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Coo = typename TestFixture::Mtx; using OtherCoo = gko::matrix::Coo; auto empty = OtherCoo::create(this->exec); @@ -703,7 +704,7 @@ TYPED_TEST(Coo, AppliesToComplex) TYPED_TEST(Coo, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -759,7 +760,7 @@ TYPED_TEST(Coo, AdvancedAppliesToComplex) TYPED_TEST(Coo, AdvancedAppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -817,7 +818,7 @@ TYPED_TEST(Coo, ApplyAddsToComplex) TYPED_TEST(Coo, ApplyAddsToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedVec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -874,7 +875,7 @@ TYPED_TEST(Coo, ApplyAddsScaledToComplex) TYPED_TEST(Coo, ApplyAddsScaledToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -911,7 +912,7 @@ class CooComplex : public ::testing::Test { using Mtx = gko::matrix::Coo; }; -TYPED_TEST_SUITE(CooComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(CooComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/csr_kernels.cpp b/reference/test/matrix/csr_kernels.cpp index b84ac958f02..b417eb93f52 100644 --- a/reference/test/matrix/csr_kernels.cpp +++ b/reference/test/matrix/csr_kernels.cpp @@ -46,7 +46,8 @@ class Csr : public ::testing::Test { using Ell = gko::matrix::Ell; using Hybrid = gko::matrix::Hybrid; using Vec = gko::matrix::Dense; - using MixedVec = gko::matrix::Dense>; + using MixedVec = + gko::matrix::Dense>; using Perm = gko::matrix::Permutation; using ScaledPerm = gko::matrix::ScaledPermutation; @@ -347,7 +348,8 @@ class Csr : public ::testing::Test { index_type invalid_index = gko::invalid_index(); }; -TYPED_TEST_SUITE(Csr, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Csr, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Csr, AppliesToDenseVector) @@ -368,7 +370,7 @@ TYPED_TEST(Csr, MixedAppliesToDenseVector1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = typename gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec::create(this->exec, gko::dim<2>{2, 1}); @@ -383,7 +385,7 @@ TYPED_TEST(Csr, MixedAppliesToDenseVector2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); @@ -399,9 +401,9 @@ TYPED_TEST(Csr, MixedAppliesToDenseVector3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; - using Vec2 = gko::matrix::Dense>; + using Vec2 = gko::matrix::Dense>; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec1::create(this->exec, gko::dim<2>{2, 1}); @@ -432,7 +434,7 @@ TYPED_TEST(Csr, MixedAppliesToDenseMatrix1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; // clang-format off auto x = gko::initialize( @@ -456,7 +458,7 @@ TYPED_TEST(Csr, MixedAppliesToDenseMatrix2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; // clang-format off @@ -481,7 +483,7 @@ TYPED_TEST(Csr, MixedAppliesToDenseMatrix3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; // clang-format off @@ -522,7 +524,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseVector1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -539,7 +541,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseVector2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -557,7 +559,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseVector3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -595,7 +597,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseMatrix1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -619,7 +621,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseMatrix2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -639,7 +641,7 @@ TYPED_TEST(Csr, MixedAppliesLinearCombinationToDenseMatrix3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -788,7 +790,7 @@ TYPED_TEST(Csr, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto tmp = OtherCsr::create(this->exec); @@ -814,7 +816,7 @@ TYPED_TEST(Csr, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto tmp = OtherCsr::create(this->exec); @@ -992,7 +994,7 @@ TYPED_TEST(Csr, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto empty = OtherCsr::create(this->exec); @@ -1011,7 +1013,7 @@ TYPED_TEST(Csr, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Csr = typename TestFixture::Mtx; using OtherCsr = gko::matrix::Csr; auto empty = OtherCsr::create(this->exec); @@ -2048,7 +2050,7 @@ TYPED_TEST(Csr, AppliesToComplex) TYPED_TEST(Csr, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -2104,7 +2106,7 @@ TYPED_TEST(Csr, AdvancedAppliesToComplex) TYPED_TEST(Csr, AdvancedAppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -2245,7 +2247,7 @@ class CsrComplex : public ::testing::Test { using Mtx = gko::matrix::Csr; }; -TYPED_TEST_SUITE(CsrComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(CsrComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -2590,7 +2592,7 @@ class CsrLookup : public ::testing::Test { index_type invalid_index = gko::invalid_index(); }; -TYPED_TEST_SUITE(CsrLookup, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(CsrLookup, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets) diff --git a/reference/test/matrix/dense_kernels.cpp b/reference/test/matrix/dense_kernels.cpp index a8d37ce5a09..3854cd56dff 100644 --- a/reference/test/matrix/dense_kernels.cpp +++ b/reference/test/matrix/dense_kernels.cpp @@ -37,7 +37,8 @@ class Dense : public ::testing::Test { protected: using value_type = T; using Mtx = gko::matrix::Dense; - using MixedMtx = gko::matrix::Dense>; + using MixedMtx = + gko::matrix::Dense>; using ComplexMtx = gko::to_complex; using RealMtx = gko::remove_complex; Dense() @@ -80,7 +81,7 @@ class Dense : public ::testing::Test { }; -TYPED_TEST_SUITE(Dense, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Dense, gko::test::ValueTypesWithHalf, TypenameNameGenerator); TYPED_TEST(Dense, CopyRespectsStride) @@ -745,7 +746,7 @@ TYPED_TEST(Dense, ConvertsToPrecision) { using Dense = typename TestFixture::Mtx; using T = typename TestFixture::value_type; - using OtherT = typename gko::next_precision; + using OtherT = typename gko::next_precision_with_half; using OtherDense = typename gko::matrix::Dense; auto tmp = OtherDense::create(this->exec); auto res = Dense::create(this->exec); @@ -767,7 +768,7 @@ TYPED_TEST(Dense, MovesToPrecision) { using Dense = typename TestFixture::Mtx; using T = typename TestFixture::value_type; - using OtherT = typename gko::next_precision; + using OtherT = typename gko::next_precision_with_half; using OtherDense = typename gko::matrix::Dense; auto tmp = OtherDense::create(this->exec); auto res = Dense::create(this->exec); @@ -1066,7 +1067,7 @@ TYPED_TEST(Dense, AppliesToComplex) TYPED_TEST(Dense, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -1120,7 +1121,7 @@ TYPED_TEST(Dense, AdvancedAppliesToComplex) TYPED_TEST(Dense, AdvancedAppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -1359,7 +1360,7 @@ class DenseWithIndexType std::unique_ptr scale_perm0; }; -TYPED_TEST_SUITE(DenseWithIndexType, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(DenseWithIndexType, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -2013,7 +2014,7 @@ TYPED_TEST(Dense, ConvertsEmptyToPrecision) { using Dense = typename TestFixture::Mtx; using T = typename TestFixture::value_type; - using OtherT = typename gko::next_precision; + using OtherT = typename gko::next_precision_with_half; using OtherDense = typename gko::matrix::Dense; auto empty = OtherDense::create(this->exec); auto res = Dense::create(this->exec); @@ -2028,7 +2029,7 @@ TYPED_TEST(Dense, MovesEmptyToPrecision) { using Dense = typename TestFixture::Mtx; using T = typename TestFixture::value_type; - using OtherT = typename gko::next_precision; + using OtherT = typename gko::next_precision_with_half; using OtherDense = typename gko::matrix::Dense; auto empty = OtherDense::create(this->exec); auto res = Dense::create(this->exec); diff --git a/reference/test/matrix/diagonal_kernels.cpp b/reference/test/matrix/diagonal_kernels.cpp index e2ac67190d0..d1208e96178 100644 --- a/reference/test/matrix/diagonal_kernels.cpp +++ b/reference/test/matrix/diagonal_kernels.cpp @@ -30,7 +30,8 @@ class Diagonal : public ::testing::Test { using Csr = gko::matrix::Csr; using Diag = gko::matrix::Diagonal; using Dense = gko::matrix::Dense; - using MixedDense = gko::matrix::Dense>; + using MixedDense = + gko::matrix::Dense>; Diagonal() : exec(gko::ReferenceExecutor::create()), @@ -79,13 +80,14 @@ class Diagonal : public ::testing::Test { std::unique_ptr dense3; }; -TYPED_TEST_SUITE(Diagonal, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Diagonal, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Diagonal, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Diagonal = typename TestFixture::Diag; using OtherDiagonal = gko::matrix::Diagonal; auto tmp = OtherDiagonal::create(this->exec); @@ -107,7 +109,7 @@ TYPED_TEST(Diagonal, ConvertsToPrecision) TYPED_TEST(Diagonal, MovesToPrecision) { using ValueType = typename TestFixture::value_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Diagonal = typename TestFixture::Diag; using OtherDiagonal = gko::matrix::Diagonal; auto tmp = OtherDiagonal::create(this->exec); @@ -574,7 +576,7 @@ TYPED_TEST(Diagonal, AppliesToComplex) TYPED_TEST(Diagonal, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -634,7 +636,7 @@ TYPED_TEST(Diagonal, AppliesLinearCombinationToComplex) TYPED_TEST(Diagonal, AppliesLinearCombinationToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; using Scalar = gko::matrix::Dense; diff --git a/reference/test/matrix/ell_kernels.cpp b/reference/test/matrix/ell_kernels.cpp index 6214db82d1c..7f3c770c603 100644 --- a/reference/test/matrix/ell_kernels.cpp +++ b/reference/test/matrix/ell_kernels.cpp @@ -30,7 +30,8 @@ class Ell : public ::testing::Test { using Mtx = gko::matrix::Ell; using Csr = gko::matrix::Csr; using Vec = gko::matrix::Dense; - using MixedVec = gko::matrix::Dense>; + using MixedVec = + gko::matrix::Dense>; Ell() : exec(gko::ReferenceExecutor::create()), @@ -72,7 +73,8 @@ class Ell : public ::testing::Test { std::unique_ptr mtx2; }; -TYPED_TEST_SUITE(Ell, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Ell, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Ell, AppliesToDenseVector) @@ -91,7 +93,7 @@ TYPED_TEST(Ell, MixedAppliesToDenseVector1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = typename gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec::create(this->exec, gko::dim<2>{2, 1}); @@ -106,7 +108,7 @@ TYPED_TEST(Ell, MixedAppliesToDenseVector2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); @@ -122,9 +124,9 @@ TYPED_TEST(Ell, MixedAppliesToDenseVector3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; - using Vec2 = gko::matrix::Dense>; + using Vec2 = gko::matrix::Dense>; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec1::create(this->exec, gko::dim<2>{2, 1}); @@ -160,7 +162,7 @@ TYPED_TEST(Ell, MixedAppliesToDenseMatrix1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; // clang-format off auto x = gko::initialize( @@ -184,7 +186,7 @@ TYPED_TEST(Ell, MixedAppliesToDenseMatrix2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; // clang-format off @@ -209,7 +211,7 @@ TYPED_TEST(Ell, MixedAppliesToDenseMatrix3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; // clang-format off @@ -248,7 +250,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseVector1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -265,7 +267,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseVector2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -283,7 +285,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseVector3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -327,7 +329,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseMatrix1) { // Both vectors have the same value type which differs from the matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -355,7 +357,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseMatrix2) { // Input vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -384,7 +386,7 @@ TYPED_TEST(Ell, MixedAppliesLinearCombinationToDenseMatrix3) { // Output vector has same value type as matrix using T = typename TestFixture::value_type; - using next_T = gko::next_precision; + using next_T = gko::next_precision_with_half; using Vec1 = typename TestFixture::Vec; using Vec2 = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); @@ -443,7 +445,7 @@ TYPED_TEST(Ell, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto tmp = OtherEll::create(this->exec); @@ -466,7 +468,7 @@ TYPED_TEST(Ell, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto tmp = OtherEll::create(this->exec); @@ -735,7 +737,7 @@ TYPED_TEST(Ell, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto empty = Ell::create(this->exec); @@ -752,7 +754,7 @@ TYPED_TEST(Ell, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Ell = typename TestFixture::Mtx; using OtherEll = gko::matrix::Ell; auto empty = Ell::create(this->exec); @@ -897,7 +899,7 @@ TYPED_TEST(Ell, AppliesToComplex) TYPED_TEST(Ell, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -954,7 +956,7 @@ TYPED_TEST(Ell, AdvancedAppliesToComplex) TYPED_TEST(Ell, AdvancedAppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -992,7 +994,7 @@ class EllComplex : public ::testing::Test { using Mtx = gko::matrix::Ell; }; -TYPED_TEST_SUITE(EllComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(EllComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/fbcsr_kernels.cpp b/reference/test/matrix/fbcsr_kernels.cpp index 665df4ace31..9d9e2144cc3 100644 --- a/reference/test/matrix/fbcsr_kernels.cpp +++ b/reference/test/matrix/fbcsr_kernels.cpp @@ -104,7 +104,8 @@ class Fbcsr : public ::testing::Test { const std::unique_ptr mtxsq; }; -TYPED_TEST_SUITE(Fbcsr, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Fbcsr, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); template @@ -114,7 +115,7 @@ std::unique_ptr> get_some_vectors( { using RT = gko::remove_complex; std::default_random_engine engine(39); - std::normal_distribution dist(0.0, 5.0); + std::normal_distribution<> dist(0.0, 5.0); std::uniform_int_distribution<> nnzdist(1, nrhs); return gko::test::generate_random_matrix>( nrows, nrhs, nnzdist, dist, engine, exec); @@ -271,7 +272,7 @@ TYPED_TEST(Fbcsr, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto tmp = OtherFbcsr::create(this->exec); @@ -294,7 +295,7 @@ TYPED_TEST(Fbcsr, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto tmp = OtherFbcsr::create(this->exec); @@ -392,7 +393,7 @@ TYPED_TEST(Fbcsr, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto empty = OtherFbcsr::create(this->exec); @@ -411,7 +412,7 @@ TYPED_TEST(Fbcsr, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Fbcsr = typename TestFixture::Mtx; using OtherFbcsr = gko::matrix::Fbcsr; auto empty = OtherFbcsr::create(this->exec); @@ -619,7 +620,7 @@ class FbcsrComplex : public ::testing::Test { using Csr = gko::matrix::Csr; }; -TYPED_TEST_SUITE(FbcsrComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(FbcsrComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/hybrid_kernels.cpp b/reference/test/matrix/hybrid_kernels.cpp index 87fd4c02811..c5e6496dce1 100644 --- a/reference/test/matrix/hybrid_kernels.cpp +++ b/reference/test/matrix/hybrid_kernels.cpp @@ -32,7 +32,8 @@ class Hybrid : public ::testing::Test { using Mtx = gko::matrix::Hybrid; using Vec = gko::matrix::Dense; using Csr = gko::matrix::Csr; - using MixedVec = gko::matrix::Dense>; + using MixedVec = + gko::matrix::Dense>; Hybrid() : exec(gko::ReferenceExecutor::create()), @@ -96,7 +97,8 @@ class Hybrid : public ::testing::Test { std::unique_ptr mtx3; }; -TYPED_TEST_SUITE(Hybrid, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Hybrid, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Hybrid, AppliesToDenseVector) @@ -233,7 +235,7 @@ TYPED_TEST(Hybrid, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto tmp = OtherHybrid::create(this->exec); @@ -256,7 +258,7 @@ TYPED_TEST(Hybrid, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto tmp = OtherHybrid::create(this->exec); @@ -367,7 +369,7 @@ TYPED_TEST(Hybrid, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto other = Hybrid::create(this->exec); @@ -384,7 +386,7 @@ TYPED_TEST(Hybrid, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Hybrid = typename TestFixture::Mtx; using OtherHybrid = gko::matrix::Hybrid; auto other = Hybrid::create(this->exec); @@ -699,7 +701,7 @@ TYPED_TEST(Hybrid, AppliesToComplex) TYPED_TEST(Hybrid, AppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using Vec = gko::matrix::Dense; auto exec = gko::ReferenceExecutor::create(); @@ -756,7 +758,7 @@ TYPED_TEST(Hybrid, AdvancedAppliesToComplex) TYPED_TEST(Hybrid, AdvancedAppliesToMixedComplex) { using mixed_value_type = - gko::next_precision; + gko::next_precision_with_half; using mixed_complex_type = gko::to_complex; using MixedDense = gko::matrix::Dense; using MixedDenseComplex = gko::matrix::Dense; @@ -795,7 +797,7 @@ class HybridComplex : public ::testing::Test { using Mtx = gko::matrix::Hybrid; }; -TYPED_TEST_SUITE(HybridComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(HybridComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/identity.cpp b/reference/test/matrix/identity.cpp index 11953de338a..82704145978 100644 --- a/reference/test/matrix/identity.cpp +++ b/reference/test/matrix/identity.cpp @@ -19,7 +19,8 @@ class Identity : public ::testing::Test { using value_type = T; using Id = gko::matrix::Identity; using Vec = gko::matrix::Dense; - using MixedVec = gko::matrix::Dense>; + using MixedVec = + gko::matrix::Dense>; using ComplexVec = gko::to_complex; using MixedComplexVec = gko::to_complex; @@ -29,7 +30,8 @@ class Identity : public ::testing::Test { }; -TYPED_TEST_SUITE(Identity, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Identity, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Identity, AppliesToVector) diff --git a/reference/test/matrix/permutation.cpp b/reference/test/matrix/permutation.cpp index 5418f97353b..b646a6fc67f 100644 --- a/reference/test/matrix/permutation.cpp +++ b/reference/test/matrix/permutation.cpp @@ -51,7 +51,7 @@ class Permutation : public ::testing::Test { std::shared_ptr exec; }; -TYPED_TEST_SUITE(Permutation, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(Permutation, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/scaled_permutation.cpp b/reference/test/matrix/scaled_permutation.cpp index 6d8d49f5662..f2b3e66b4cd 100644 --- a/reference/test/matrix/scaled_permutation.cpp +++ b/reference/test/matrix/scaled_permutation.cpp @@ -58,7 +58,7 @@ class ScaledPermutation : public ::testing::Test { std::unique_ptr perm2; }; -TYPED_TEST_SUITE(ScaledPermutation, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(ScaledPermutation, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/sellp_kernels.cpp b/reference/test/matrix/sellp_kernels.cpp index 3208b8c42be..23251c63b8f 100644 --- a/reference/test/matrix/sellp_kernels.cpp +++ b/reference/test/matrix/sellp_kernels.cpp @@ -50,7 +50,8 @@ class Sellp : public ::testing::Test { std::unique_ptr mtx2; }; -TYPED_TEST_SUITE(Sellp, gko::test::ValueIndexTypes, PairTypenameNameGenerator); +TYPED_TEST_SUITE(Sellp, gko::test::ValueIndexTypesWithHalf, + PairTypenameNameGenerator); TYPED_TEST(Sellp, AppliesToDenseVector) @@ -67,7 +68,8 @@ TYPED_TEST(Sellp, AppliesToDenseVector) TYPED_TEST(Sellp, AppliesToMixedDenseVector) { - using value_type = gko::next_precision; + using value_type = + gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec::create(this->exec, gko::dim<2>{2, 1}); @@ -116,7 +118,8 @@ TYPED_TEST(Sellp, AppliesLinearCombinationToDenseVector) TYPED_TEST(Sellp, AppliesLinearCombinationToMixedDenseVector) { - using value_type = gko::next_precision; + using value_type = + gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -189,7 +192,7 @@ TYPED_TEST(Sellp, ConvertsToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto tmp = OtherSellp::create(this->exec); @@ -212,7 +215,7 @@ TYPED_TEST(Sellp, MovesToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto tmp = OtherSellp::create(this->exec); @@ -310,7 +313,7 @@ TYPED_TEST(Sellp, ConvertsEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto empty = OtherSellp::create(this->exec); @@ -329,7 +332,7 @@ TYPED_TEST(Sellp, MovesEmptyToPrecision) { using ValueType = typename TestFixture::value_type; using IndexType = typename TestFixture::index_type; - using OtherType = gko::next_precision; + using OtherType = gko::next_precision_with_half; using Sellp = typename TestFixture::Mtx; using OtherSellp = gko::matrix::Sellp; auto empty = OtherSellp::create(this->exec); @@ -751,7 +754,7 @@ class SellpComplex : public ::testing::Test { using Mtx = gko::matrix::Sellp; }; -TYPED_TEST_SUITE(SellpComplex, gko::test::ComplexValueIndexTypes, +TYPED_TEST_SUITE(SellpComplex, gko::test::ComplexValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/sparsity_csr.cpp b/reference/test/matrix/sparsity_csr.cpp index d8ed6147e30..8db0dee144f 100644 --- a/reference/test/matrix/sparsity_csr.cpp +++ b/reference/test/matrix/sparsity_csr.cpp @@ -47,7 +47,7 @@ class SparsityCsr : public ::testing::Test { std::unique_ptr mtx; }; -TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/reference/test/matrix/sparsity_csr_kernels.cpp b/reference/test/matrix/sparsity_csr_kernels.cpp index f08d6c352ca..30805d033ab 100644 --- a/reference/test/matrix/sparsity_csr_kernels.cpp +++ b/reference/test/matrix/sparsity_csr_kernels.cpp @@ -125,7 +125,7 @@ class SparsityCsr : public ::testing::Test { std::unique_ptr mtx3_unsorted; }; -TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(SparsityCsr, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); @@ -145,7 +145,7 @@ TYPED_TEST(SparsityCsr, AppliesToDenseVector) TYPED_TEST(SparsityCsr, AppliesToMixedDenseVector) { - using T = gko::next_precision; + using T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto x = gko::initialize({2.0, 1.0, 4.0}, this->exec); auto y = Vec::create(this->exec, gko::dim<2>{2, 1}); @@ -192,7 +192,7 @@ TYPED_TEST(SparsityCsr, AppliesLinearCombinationToDenseVector) TYPED_TEST(SparsityCsr, AppliesLinearCombinationToMixedDenseVector) { - using T = gko::next_precision; + using T = gko::next_precision_with_half; using Vec = gko::matrix::Dense; auto alpha = gko::initialize({-1.0}, this->exec); auto beta = gko::initialize({2.0}, this->exec); @@ -243,8 +243,8 @@ TYPED_TEST(SparsityCsr, AppliesToComplex) TYPED_TEST(SparsityCsr, AppliesToMixedComplex) { - using T = - gko::next_precision>; + using T = gko::next_precision_with_half< + gko::to_complex>; using Vec = gko::matrix::Dense; auto x = gko::initialize({T{2.0, 4.0}, T{1.0, 2.0}, T{4.0, 8.0}}, this->exec); @@ -279,7 +279,7 @@ TYPED_TEST(SparsityCsr, AppliesLinearCombinationToComplex) TYPED_TEST(SparsityCsr, AppliesLinearCombinationToMixedComplex) { using Vec = gko::matrix::Dense< - gko::next_precision>; + gko::next_precision_with_half>; using ComplexVec = gko::to_complex; using T = typename ComplexVec::value_type; auto alpha = gko::initialize({-1.0}, this->exec); diff --git a/test/matrix/fbcsr_kernels.cpp b/test/matrix/fbcsr_kernels.cpp index 4ff8e1fc36a..5e3d4b1a112 100644 --- a/test/matrix/fbcsr_kernels.cpp +++ b/test/matrix/fbcsr_kernels.cpp @@ -48,18 +48,23 @@ class Fbcsr : public CommonTestFixture { void generate_sin(gko::ptr_param x) { value_type* const xarr = x->get_values(); + // we do not have sin for half, so we compute sin in double or + // complex + using working_type = std::conditional_t(), + std::complex, double>; for (index_type i = 0; i < x->get_size()[0] * x->get_size()[1]; i++) { - xarr[i] = - static_cast(2.0) * - std::sin(static_cast(i / 2.0) + get_random_value()); + xarr[i] = static_cast( + 2.0 * std::sin(i / 2.0 + + static_cast(get_random_value()))); } } }; #ifdef GKO_COMPILING_HIP -TYPED_TEST_SUITE(Fbcsr, gko::test::RealValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Fbcsr, gko::test::RealValueTypesWithHalf, + TypenameNameGenerator); #else -TYPED_TEST_SUITE(Fbcsr, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Fbcsr, gko::test::ValueTypesWithHalf, TypenameNameGenerator); #endif TYPED_TEST(Fbcsr, CanWriteFromMatrixOnDevice) @@ -124,6 +129,8 @@ TYPED_TEST(Fbcsr, SpmvIsEquivalentToRefSorted) using Dense = typename TestFixture::Dense; using value_type = typename Mtx::value_type; if (this->exec->get_master() != this->exec) { + // FBCSR on accelerator does not have half precision apply through + // vendor libraries. SKIP_IF_HALF(value_type); } auto drand = gko::clone(this->exec, this->rsorted); @@ -149,6 +156,8 @@ TYPED_TEST(Fbcsr, SpmvMultiIsEquivalentToRefSorted) using Dense = typename TestFixture::Dense; using value_type = typename Mtx::value_type; if (this->exec->get_master() != this->exec) { + // FBCSR on accelerator does not have half precision apply through + // vendor libraries. SKIP_IF_HALF(value_type); } auto drand = gko::clone(this->exec, this->rsorted); @@ -175,6 +184,8 @@ TYPED_TEST(Fbcsr, AdvancedSpmvIsEquivalentToRefSorted) using value_type = typename TestFixture::value_type; using real_type = typename TestFixture::real_type; if (this->exec->get_master() != this->exec) { + // FBCSR on accelerator does not have half precision apply through + // vendor libraries. SKIP_IF_HALF(value_type); } auto drand = gko::clone(this->exec, this->rsorted); @@ -208,6 +219,8 @@ TYPED_TEST(Fbcsr, AdvancedSpmvMultiIsEquivalentToRefSorted) using value_type = typename TestFixture::value_type; using real_type = typename TestFixture::real_type; if (this->exec->get_master() != this->exec) { + // FBCSR on accelerator does not have half precision apply through + // vendor libraries. SKIP_IF_HALF(value_type); } auto drand = gko::clone(this->exec, this->rsorted); From 8d3e4b58676b660d3cb16b3cc1ccb3072485a7b8 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 16:29:56 +0200 Subject: [PATCH 12/16] base such as composition/combination with half and corr. test --- core/base/block_operator.cpp | 8 ++++++-- core/base/combination.cpp | 2 +- core/base/composition.cpp | 2 +- core/base/dense_cache.cpp | 2 +- core/base/perturbation.cpp | 2 +- core/test/base/combination.cpp | 3 ++- core/test/base/composition.cpp | 3 ++- core/test/base/dense_cache.cpp | 3 ++- reference/test/base/composition.cpp | 13 ++++++++----- reference/test/base/perturbation.cpp | 13 ++++++++----- 10 files changed, 32 insertions(+), 19 deletions(-) diff --git a/core/base/block_operator.cpp b/core/base/block_operator.cpp index f53375301a8..68c00aeee70 100644 --- a/core/base/block_operator.cpp +++ b/core/base/block_operator.cpp @@ -19,8 +19,12 @@ namespace { template auto dispatch_dense(Fn&& fn, LinOp* v) { - return run, - std::complex>(v, std::forward(fn)); + return run, +#endif + std::complex, std::complex>(v, + std::forward(fn)); } diff --git a/core/base/combination.cpp b/core/base/combination.cpp index 3b30b77d38c..53af6742f6e 100644 --- a/core/base/combination.cpp +++ b/core/base/combination.cpp @@ -168,7 +168,7 @@ void Combination::apply_impl(const LinOp* alpha, const LinOp* b, #define GKO_DECLARE_COMBINATION(_type) class Combination<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_COMBINATION); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_COMBINATION); } // namespace gko diff --git a/core/base/composition.cpp b/core/base/composition.cpp index 82c8152300b..f6a7df21e45 100644 --- a/core/base/composition.cpp +++ b/core/base/composition.cpp @@ -222,7 +222,7 @@ void Composition::apply_impl(const LinOp* alpha, const LinOp* b, #define GKO_DECLARE_COMPOSITION(_type) class Composition<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_COMPOSITION); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_COMPOSITION); } // namespace gko diff --git a/core/base/dense_cache.cpp b/core/base/dense_cache.cpp index 38a0decfa46..096ad1f761a 100644 --- a/core/base/dense_cache.cpp +++ b/core/base/dense_cache.cpp @@ -33,7 +33,7 @@ void DenseCache::init_from( #define GKO_DECLARE_DENSE_CACHE(_type) struct DenseCache<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DENSE_CACHE); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_DENSE_CACHE); } // namespace detail diff --git a/core/base/perturbation.cpp b/core/base/perturbation.cpp index 87501361c05..b17cba209e1 100644 --- a/core/base/perturbation.cpp +++ b/core/base/perturbation.cpp @@ -182,7 +182,7 @@ void Perturbation::apply_impl(const LinOp* alpha, const LinOp* b, #define GKO_DECLARE_PERTURBATION(_type) class Perturbation<_type> -GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_PERTURBATION); +GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_PERTURBATION); } // namespace gko diff --git a/core/test/base/combination.cpp b/core/test/base/combination.cpp index 73c30ffe11c..63c73cfa168 100644 --- a/core/test/base/combination.cpp +++ b/core/test/base/combination.cpp @@ -43,7 +43,8 @@ class Combination : public ::testing::Test { std::vector> coefficients; }; -TYPED_TEST_SUITE(Combination, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Combination, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Combination, CanBeEmpty) diff --git a/core/test/base/composition.cpp b/core/test/base/composition.cpp index 122755b8f92..58c86894fc8 100644 --- a/core/test/base/composition.cpp +++ b/core/test/base/composition.cpp @@ -41,7 +41,8 @@ class Composition : public ::testing::Test { std::vector> operators; }; -TYPED_TEST_SUITE(Composition, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Composition, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Composition, CanBeEmpty) diff --git a/core/test/base/dense_cache.cpp b/core/test/base/dense_cache.cpp index 526187610a4..54d904617db 100644 --- a/core/test/base/dense_cache.cpp +++ b/core/test/base/dense_cache.cpp @@ -31,7 +31,8 @@ class DenseCache : public ::testing::Test { }; -TYPED_TEST_SUITE(DenseCache, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(DenseCache, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(DenseCache, CanDefaultConstruct) diff --git a/reference/test/base/composition.cpp b/reference/test/base/composition.cpp index f736edb53f9..d17b8602ce8 100644 --- a/reference/test/base/composition.cpp +++ b/reference/test/base/composition.cpp @@ -75,7 +75,8 @@ class Composition : public ::testing::Test { std::shared_ptr product; }; -TYPED_TEST_SUITE(Composition, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Composition, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Composition, CopiesOnSameExecutor) @@ -142,7 +143,7 @@ TYPED_TEST(Composition, AppliesSingleToMixedVector) cmp = [ -9 -2 ] [ 27 26 ] */ - using Mtx = gko::matrix::Dense>; + using Mtx = gko::matrix::Dense>; using value_type = typename Mtx::value_type; auto cmp = gko::Composition::create(this->product); auto x = gko::initialize({1.0, 2.0}, this->exec); @@ -182,7 +183,8 @@ TYPED_TEST(Composition, AppliesSingleToMixedComplexVector) cmp = [ -9 -2 ] [ 27 26 ] */ - using value_type = gko::next_precision>; + using value_type = + gko::next_precision_with_half>; using Mtx = gko::matrix::Dense; auto cmp = gko::Composition::create(this->product); auto x = gko::initialize( @@ -222,7 +224,7 @@ TYPED_TEST(Composition, AppliesSingleLinearCombinationToMixedVector) cmp = [ -9 -2 ] [ 27 26 ] */ - using value_type = gko::next_precision; + using value_type = gko::next_precision_with_half; using Mtx = gko::matrix::Dense; auto cmp = gko::Composition::create(this->product); auto alpha = gko::initialize({3.0}, this->exec); @@ -267,7 +269,8 @@ TYPED_TEST(Composition, AppliesSingleLinearCombinationToMixedComplexVector) cmp = [ -9 -2 ] [ 27 26 ] */ - using MixedDense = gko::matrix::Dense>; + using MixedDense = + gko::matrix::Dense>; using MixedDenseComplex = gko::to_complex; using value_type = typename MixedDenseComplex::value_type; auto cmp = gko::Composition::create(this->product); diff --git a/reference/test/base/perturbation.cpp b/reference/test/base/perturbation.cpp index b6be9ab1563..50a5fe7db20 100644 --- a/reference/test/base/perturbation.cpp +++ b/reference/test/base/perturbation.cpp @@ -33,7 +33,8 @@ class Perturbation : public ::testing::Test { std::shared_ptr scalar; }; -TYPED_TEST_SUITE(Perturbation, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(Perturbation, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(Perturbation, CopiesOnSameExecutor) @@ -101,7 +102,7 @@ TYPED_TEST(Perturbation, AppliesToMixedVector) cmp = I + 2 * [ 2 ] * [ 3 2 ] [ 1 ] */ - using Mtx = gko::matrix::Dense>; + using Mtx = gko::matrix::Dense>; using value_type = typename Mtx::value_type; auto cmp = gko::Perturbation::create(this->scalar, this->basis, this->projector); @@ -143,7 +144,8 @@ TYPED_TEST(Perturbation, AppliesToMixedComplexVector) cmp = I + 2 * [ 2 ] * [ 3 2 ] [ 1 ] */ - using value_type = gko::to_complex>; + using value_type = + gko::to_complex>; using Mtx = gko::matrix::Dense; auto cmp = gko::Perturbation::create(this->scalar, this->basis, this->projector); @@ -185,7 +187,7 @@ TYPED_TEST(Perturbation, AppliesLinearCombinationToMixedVector) cmp = I + 2 * [ 2 ] * [ 3 2 ] [ 1 ] */ - using value_type = gko::next_precision; + using value_type = gko::next_precision_with_half; using Mtx = gko::matrix::Dense; auto cmp = gko::Perturbation::create(this->scalar, this->basis, this->projector); @@ -232,7 +234,8 @@ TYPED_TEST(Perturbation, AppliesLinearCombinationToMixedComplexVector) cmp = I + 2 * [ 2 ] * [ 3 2 ] [ 1 ] */ - using MixedDense = gko::matrix::Dense>; + using MixedDense = + gko::matrix::Dense>; using MixedDenseComplex = gko::to_complex; using value_type = typename MixedDenseComplex::value_type; auto cmp = gko::Perturbation::create(this->scalar, this->basis, From b2fa55a6fe0f22cbf8fe78a8ba1ceb744c3ac0fa Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 4 Nov 2024 15:15:17 +0100 Subject: [PATCH 13/16] test_utils test --- core/test/utils/array_generator_test.cpp | 18 +++++---- core/test/utils/matrix_generator.hpp | 18 +++++++-- core/test/utils/matrix_generator_test.cpp | 49 ++++++++++++++--------- core/test/utils/matrix_utils_test.cpp | 11 ++--- core/test/utils/unsort_matrix_test.cpp | 2 +- core/test/utils/value_generator_test.cpp | 16 +++++--- reference/test/utils/assertions_test.cpp | 3 +- 7 files changed, 73 insertions(+), 44 deletions(-) diff --git a/core/test/utils/array_generator_test.cpp b/core/test/utils/array_generator_test.cpp index ae66e4686da..ca96761ea4e 100644 --- a/core/test/utils/array_generator_test.cpp +++ b/core/test/utils/array_generator_test.cpp @@ -18,11 +18,12 @@ template class ArrayGenerator : public ::testing::Test { protected: using value_type = T; + using check_type = double; ArrayGenerator() : exec(gko::ReferenceExecutor::create()) { array = gko::test::generate_random_array( - 500, std::normal_distribution>(20.0, 5.0), + 500, std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42), exec); } @@ -30,15 +31,17 @@ class ArrayGenerator : public ::testing::Test { gko::array array; template - ValueType get_nth_moment(int n, ValueType c, InputIterator sample_start, - InputIterator sample_end, Closure closure_op) + check_type get_nth_moment(int n, ValueType c, InputIterator sample_start, + InputIterator sample_end, Closure closure_op) { using std::pow; - ValueType res = 0; - ValueType num_elems = 0; + check_type res = 0; + check_type num_elems = 0; while (sample_start != sample_end) { auto tmp = *(sample_start++); - res += pow(closure_op(tmp) - c, n); + res += pow(static_cast(closure_op(tmp)) - + static_cast(c), + n); num_elems += 1; } return res / num_elems; @@ -62,7 +65,8 @@ class ArrayGenerator : public ::testing::Test { } }; -TYPED_TEST_SUITE(ArrayGenerator, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(ArrayGenerator, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(ArrayGenerator, OutputHasCorrectSize) diff --git a/core/test/utils/matrix_generator.hpp b/core/test/utils/matrix_generator.hpp index 56ff38c520d..01ee40cdadc 100644 --- a/core/test/utils/matrix_generator.hpp +++ b/core/test/utils/matrix_generator.hpp @@ -659,10 +659,20 @@ gko::matrix_data generate_tridiag_inverse_matrix_data( auto off_diag = i < j ? upper : lower; auto min_idx = std::min(i, j); auto max_idx = std::max(i, j); - auto val = sign * - static_cast( - std::pow(off_diag, max_idx - min_idx)) * - alpha[min_idx] * beta[max_idx + 1] / alpha.back(); + // NVHPC 23.3 with O3 gives wrong result with std::pow on + // complex. We use the float variant to help it, also for + // half. + using pow_type = std::conditional_t< + std::is_same, + gko::half>::value, + std::conditional_t(), + std::complex, float>, + ValueType>; + auto val = + sign * + static_cast(std::pow( + static_cast(off_diag), max_idx - min_idx)) * + alpha[min_idx] * beta[max_idx + 1] / alpha.back(); md.nonzeros.emplace_back(i, j, val); } } diff --git a/core/test/utils/matrix_generator_test.cpp b/core/test/utils/matrix_generator_test.cpp index 43756bc1709..61710540e24 100644 --- a/core/test/utils/matrix_generator_test.cpp +++ b/core/test/utils/matrix_generator_test.cpp @@ -20,31 +20,32 @@ template class MatrixGenerator : public ::testing::Test { protected: using value_type = T; + using check_type = double; using real_type = gko::remove_complex; using mtx_type = gko::matrix::Dense; MatrixGenerator() : exec(gko::ReferenceExecutor::create()), mtx(gko::test::generate_random_matrix( - 500, 100, std::normal_distribution(50, 5), - std::normal_distribution(20.0, 5.0), + 500, 100, std::normal_distribution<>(50, 5), + std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42), exec)), dense_mtx(gko::test::generate_random_dense_matrix( - 500, 100, std::normal_distribution(20.0, 5.0), + 500, 100, std::normal_distribution<>(20.0, 5.0), std::default_random_engine(41), exec)), l_mtx(gko::test::generate_random_lower_triangular_matrix( - 4, true, std::normal_distribution(50, 5), - std::normal_distribution(20.0, 5.0), + 4, true, std::normal_distribution<>(50, 5), + std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42), exec)), u_mtx(gko::test::generate_random_upper_triangular_matrix( - 4, true, std::normal_distribution(50, 5), - std::normal_distribution(20.0, 5.0), + 4, true, std::normal_distribution<>(50, 5), + std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42), exec)), lower_bandwidth(2), upper_bandwidth(3), band_mtx(gko::test::generate_random_band_matrix( 100, lower_bandwidth, upper_bandwidth, - std::normal_distribution(20.0, 5.0), + std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42), exec)), nnz_per_row_sample(500, 0), values_sample(0), @@ -96,15 +97,17 @@ class MatrixGenerator : public ::testing::Test { template - ValueType get_nth_moment(int n, ValueType c, InputIterator sample_start, - InputIterator sample_end, Closure closure_op) + check_type get_nth_moment(int n, ValueType c, InputIterator sample_start, + InputIterator sample_end, Closure closure_op) { using std::pow; - ValueType res = 0; - ValueType num_elems = 0; + check_type res = 0; + check_type num_elems = 0; while (sample_start != sample_end) { auto tmp = *(sample_start++); - res += pow(closure_op(tmp) - c, n); + res += pow(static_cast(closure_op(tmp)) - + static_cast(c), + n); num_elems += 1; } return res / num_elems; @@ -128,7 +131,8 @@ class MatrixGenerator : public ::testing::Test { } }; -TYPED_TEST_SUITE(MatrixGenerator, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(MatrixGenerator, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(MatrixGenerator, OutputHasCorrectSize) @@ -247,7 +251,7 @@ TYPED_TEST(MatrixGenerator, CanGenerateTridiagMatrix) { using T = typename TestFixture::value_type; using Dense = typename TestFixture::mtx_type; - auto dist = std::normal_distribution>(0, 1); + auto dist = std::normal_distribution<>(0, 1); auto engine = std::default_random_engine(42); auto lower = gko::test::detail::get_rand_value(dist, engine); auto diag = gko::test::detail::get_rand_value(dist, engine); @@ -271,18 +275,23 @@ TYPED_TEST(MatrixGenerator, CanGenerateTridiagInverseMatrix) { using T = typename TestFixture::value_type; using Dense = typename TestFixture::mtx_type; - auto dist = std::normal_distribution>(0, 1); + auto dist = std::normal_distribution<>(0, 1); auto engine = std::default_random_engine(42); auto lower = gko::test::detail::get_rand_value(dist, engine); auto upper = gko::test::detail::get_rand_value(dist, engine); // make diagonally dominant - auto diag = std::abs(gko::test::detail::get_rand_value(dist, engine)) + - std::abs(lower) + std::abs(upper); + auto diag = gko::abs(gko::test::detail::get_rand_value(dist, engine)) + + gko::abs(lower) + gko::abs(upper); + gko::size_type size = 50; + if (std::is_same_v, gko::half>) { + // half precision can only handle the inverse of small matrix. + size = 5; + } auto mtx = gko::test::generate_tridiag_matrix( - 50, {lower, diag, upper}, this->exec); + size, {lower, diag, upper}, this->exec); auto inv_mtx = gko::test::generate_tridiag_inverse_matrix( - 50, {lower, diag, upper}, this->exec); + size, {lower, diag, upper}, this->exec); auto result = Dense::create(this->exec, mtx->get_size()); inv_mtx->apply(mtx, result); diff --git a/core/test/utils/matrix_utils_test.cpp b/core/test/utils/matrix_utils_test.cpp index 3c67571e1b2..f742d4561a2 100644 --- a/core/test/utils/matrix_utils_test.cpp +++ b/core/test/utils/matrix_utils_test.cpp @@ -30,8 +30,8 @@ class MatrixUtils : public ::testing::Test { MatrixUtils() : exec(gko::ReferenceExecutor::create()), data(gko::test::generate_random_matrix_data( - 500, 500, std::normal_distribution(50, 5), - std::normal_distribution(20.0, 5.0), + 500, 500, std::normal_distribution<>(50, 5), + std::normal_distribution<>(20.0, 5.0), std::default_random_engine(42))), rectangular_data(gko::dim<2>(500, 100)) {} @@ -41,7 +41,8 @@ class MatrixUtils : public ::testing::Test { mtx_data rectangular_data; }; -TYPED_TEST_SUITE(MatrixUtils, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(MatrixUtils, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(MatrixUtils, MakeSymmetricThrowsError) @@ -241,7 +242,7 @@ TYPED_TEST(MatrixUtils, MakeHpdMatrixCorrectly) TYPED_TEST(MatrixUtils, MakeHpdMatrixWithRatioCorrectly) { using T = typename TestFixture::value_type; - gko::remove_complex ratio = 1.00001; + gko::remove_complex ratio = 1.01; auto cpy_data = this->data; gko::utils::make_hpd(this->data, ratio); @@ -276,7 +277,7 @@ TYPED_TEST(MatrixUtils, MakeSpdMatrixCorrectly) TYPED_TEST(MatrixUtils, MakeSpdMatrixWithRatioCorrectly) { using T = typename TestFixture::value_type; - gko::remove_complex ratio = 1.00001; + gko::remove_complex ratio = 1.01; auto cpy_data = this->data; gko::utils::make_spd(this->data, ratio); diff --git a/core/test/utils/unsort_matrix_test.cpp b/core/test/utils/unsort_matrix_test.cpp index 5d2f88f982a..40ec65b08db 100644 --- a/core/test/utils/unsort_matrix_test.cpp +++ b/core/test/utils/unsort_matrix_test.cpp @@ -119,7 +119,7 @@ class UnsortMatrix : public ::testing::Test { std::unique_ptr coo_empty; }; -TYPED_TEST_SUITE(UnsortMatrix, gko::test::ValueIndexTypes, +TYPED_TEST_SUITE(UnsortMatrix, gko::test::ValueIndexTypesWithHalf, PairTypenameNameGenerator); diff --git a/core/test/utils/value_generator_test.cpp b/core/test/utils/value_generator_test.cpp index 633565a66ef..57473c41b6e 100644 --- a/core/test/utils/value_generator_test.cpp +++ b/core/test/utils/value_generator_test.cpp @@ -20,19 +20,22 @@ template class ValueGenerator : public ::testing::Test { protected: using value_type = T; + using check_type = double; ValueGenerator() {} template - ValueType get_nth_moment(int n, ValueType c, InputIterator sample_start, - InputIterator sample_end, Closure closure_op) + check_type get_nth_moment(int n, ValueType c, InputIterator sample_start, + InputIterator sample_end, Closure closure_op) { using std::pow; - ValueType res = 0; - ValueType num_elems = 0; + check_type res = 0; + check_type num_elems = 0; while (sample_start != sample_end) { auto tmp = *(sample_start++); - res += pow(closure_op(tmp) - c, n); + res += pow(static_cast(closure_op(tmp)) - + static_cast(c), + n); num_elems += 1; } return res / num_elems; @@ -56,7 +59,8 @@ class ValueGenerator : public ::testing::Test { } }; -TYPED_TEST_SUITE(ValueGenerator, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(ValueGenerator, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(ValueGenerator, OutputHasCorrectAverageAndDeviation) diff --git a/reference/test/utils/assertions_test.cpp b/reference/test/utils/assertions_test.cpp index 98f1ec68e0d..9c6b544172e 100644 --- a/reference/test/utils/assertions_test.cpp +++ b/reference/test/utils/assertions_test.cpp @@ -17,7 +17,8 @@ namespace { template class MatricesNear : public ::testing::Test {}; -TYPED_TEST_SUITE(MatricesNear, gko::test::ValueTypes, TypenameNameGenerator); +TYPED_TEST_SUITE(MatricesNear, gko::test::ValueTypesWithHalf, + TypenameNameGenerator); TYPED_TEST(MatricesNear, CanPassAnyMatrixType) From 8910f831e32b3d0afdc93c0441b7efb94ca208ba Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Thu, 21 Nov 2024 11:14:50 +0100 Subject: [PATCH 14/16] constexpr restriction for nvc++ --- accessor/reference_helper.hpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/accessor/reference_helper.hpp b/accessor/reference_helper.hpp index a3a77352f8f..61e15bf8b22 100644 --- a/accessor/reference_helper.hpp +++ b/accessor/reference_helper.hpp @@ -12,10 +12,8 @@ #include "utils.hpp" -// CUDA TOOLKIT < 11 does not support constexpr in combination with -// thrust::complex, which is why constexpr is only present in later versions -#if defined(__CUDA_ARCH__) && defined(__CUDACC_VER_MAJOR__) && \ - (__CUDACC_VER_MAJOR__ < 11) +// NVC++ disallow a constexpr function has a nonliteral return type like half +#if defined(__NVCOMPILER) && GINKGO_ENABLE_HALF #define GKO_ACC_ENABLE_REFERENCE_CONSTEXPR @@ -23,7 +21,7 @@ #define GKO_ACC_ENABLE_REFERENCE_CONSTEXPR constexpr -#endif // __CUDA_ARCH__ && __CUDACC_VER_MAJOR__ && __CUDACC_VER_MAJOR__ < 11 +#endif namespace gko { From 042161502c578cae9dbe7629e176e992cb92dd19 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Mon, 25 Nov 2024 15:03:07 +0100 Subject: [PATCH 15/16] cuda with CC<70 and hip do not support 16 bit atomic. throw error or fallback to a working version if it is the case for matrix --- common/cuda_hip/components/atomic.hpp | 48 -------- common/cuda_hip/matrix/coo_kernels.cpp | 114 ++++++++++-------- .../cuda_hip/matrix/csr_kernels.template.cpp | 97 ++++++++------- common/cuda_hip/matrix/ell_kernels.cpp | 93 ++++++++------ hip/components/cooperative_groups.hip.hpp | 12 +- 5 files changed, 182 insertions(+), 182 deletions(-) diff --git a/common/cuda_hip/components/atomic.hpp b/common/cuda_hip/components/atomic.hpp index 954bc7476ed..cd59485dac9 100644 --- a/common/cuda_hip/components/atomic.hpp +++ b/common/cuda_hip/components/atomic.hpp @@ -96,52 +96,6 @@ __forceinline__ __device__ ResultType reinterpret(ValueType val) } \ }; - -#define GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(CONVERTER_TYPE) \ - template \ - struct atomic_helper< \ - ValueType, \ - std::enable_if_t<(sizeof(ValueType) == sizeof(CONVERTER_TYPE))>> { \ - __forceinline__ __device__ static ValueType atomic_add( \ - ValueType* __restrict__ addr, ValueType val) \ - { \ - assert(false); \ - using c_type = CONVERTER_TYPE; \ - return atomic_wrapper( \ - addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \ - old = *c_addr; \ - *c_addr = reinterpret( \ - val + reinterpret(assumed)); \ - }); \ - } \ - __forceinline__ __device__ static ValueType atomic_max( \ - ValueType* __restrict__ addr, ValueType val) \ - { \ - assert(false); \ - using c_type = CONVERTER_TYPE; \ - return atomic_wrapper( \ - addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \ - if (reinterpret(assumed) < val) { \ - old = *c_addr; \ - *c_addr = reinterpret(assumed); \ - } \ - }); \ - } \ - \ - private: \ - template \ - __forceinline__ __device__ static ValueType atomic_wrapper( \ - ValueType* __restrict__ addr, Callable set_old) \ - { \ - CONVERTER_TYPE* address_as_converter = \ - reinterpret_cast(addr); \ - CONVERTER_TYPE old = *address_as_converter; \ - CONVERTER_TYPE assumed = old; \ - set_old(old, assumed, address_as_converter); \ - return reinterpret(old); \ - } \ - }; - // Support 64-bit ATOMIC_ADD and ATOMIC_MAX GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned long long int); // Support 32-bit ATOMIC_ADD and ATOMIC_MAX @@ -152,8 +106,6 @@ GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned int); // Support 16-bit atomicCAS, atomicADD, and atomicMAX only on CUDA with CC // >= 7.0 GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned short int); -#else -GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(unsigned short int) #endif diff --git a/common/cuda_hip/matrix/coo_kernels.cpp b/common/cuda_hip/matrix/coo_kernels.cpp index 4609f9f7f95..88d6dced504 100644 --- a/common/cuda_hip/matrix/coo_kernels.cpp +++ b/common/cuda_hip/matrix/coo_kernels.cpp @@ -268,30 +268,38 @@ void spmv2(std::shared_ptr exec, const dim3 coo_block(config::warp_size, warps_in_block, 1); const auto nwarps = host_kernel::calculate_nwarps(exec, nnz); - if (nwarps > 0 && b_ncols > 0) { - // TODO: b_ncols needs to be tuned for ROCm. - if (b_ncols < 4) { - const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); - int num_lines = ceildiv(nnz, nwarps * config::warp_size); - - abstract_spmv<<get_stream()>>>( - nnz, num_lines, as_device_type(a->get_const_values()), - a->get_const_col_idxs(), - as_device_type(a->get_const_row_idxs()), - as_device_type(b->get_const_values()), b->get_stride(), - as_device_type(c->get_values()), c->get_stride()); - } else { - int num_elems = - ceildiv(nnz, nwarps * config::warp_size) * config::warp_size; - const dim3 coo_grid(ceildiv(nwarps, warps_in_block), - ceildiv(b_ncols, config::warp_size)); - - abstract_spmm<<get_stream()>>>( - nnz, num_elems, as_device_type(a->get_const_values()), - a->get_const_col_idxs(), - as_device_type(a->get_const_row_idxs()), b_ncols, - as_device_type(b->get_const_values()), b->get_stride(), - as_device_type(c->get_values()), c->get_stride()); +// not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + if constexpr (std::is_same_v, gko::half>) { + GKO_NOT_SUPPORTED(c); + } else +#endif + { + if (nwarps > 0 && b_ncols > 0) { + // TODO: b_ncols needs to be tuned for ROCm. + if (b_ncols < 4) { + const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); + int num_lines = ceildiv(nnz, nwarps * config::warp_size); + + abstract_spmv<<get_stream()>>>( + nnz, num_lines, as_device_type(a->get_const_values()), + a->get_const_col_idxs(), + as_device_type(a->get_const_row_idxs()), + as_device_type(b->get_const_values()), b->get_stride(), + as_device_type(c->get_values()), c->get_stride()); + } else { + int num_elems = ceildiv(nnz, nwarps * config::warp_size) * + config::warp_size; + const dim3 coo_grid(ceildiv(nwarps, warps_in_block), + ceildiv(b_ncols, config::warp_size)); + + abstract_spmm<<get_stream()>>>( + nnz, num_elems, as_device_type(a->get_const_values()), + a->get_const_col_idxs(), + as_device_type(a->get_const_row_idxs()), b_ncols, + as_device_type(b->get_const_values()), b->get_stride(), + as_device_type(c->get_values()), c->get_stride()); + } } } } @@ -312,30 +320,40 @@ void advanced_spmv2(std::shared_ptr exec, const dim3 coo_block(config::warp_size, warps_in_block, 1); const auto b_ncols = b->get_size()[1]; - if (nwarps > 0 && b_ncols > 0) { - // TODO: b_ncols needs to be tuned for ROCm. - if (b_ncols < 4) { - int num_lines = ceildiv(nnz, nwarps * config::warp_size); - const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); - - abstract_spmv<<get_stream()>>>( - nnz, num_lines, as_device_type(alpha->get_const_values()), - as_device_type(a->get_const_values()), a->get_const_col_idxs(), - as_device_type(a->get_const_row_idxs()), - as_device_type(b->get_const_values()), b->get_stride(), - as_device_type(c->get_values()), c->get_stride()); - } else { - int num_elems = - ceildiv(nnz, nwarps * config::warp_size) * config::warp_size; - const dim3 coo_grid(ceildiv(nwarps, warps_in_block), - ceildiv(b_ncols, config::warp_size)); - - abstract_spmm<<get_stream()>>>( - nnz, num_elems, as_device_type(alpha->get_const_values()), - as_device_type(a->get_const_values()), a->get_const_col_idxs(), - as_device_type(a->get_const_row_idxs()), b_ncols, - as_device_type(b->get_const_values()), b->get_stride(), - as_device_type(c->get_values()), c->get_stride()); + // not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + if constexpr (std::is_same_v, gko::half>) { + GKO_NOT_SUPPORTED(c); + } else +#endif + { + if (nwarps > 0 && b_ncols > 0) { + // TODO: b_ncols needs to be tuned for ROCm. + if (b_ncols < 4) { + int num_lines = ceildiv(nnz, nwarps * config::warp_size); + const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols); + + abstract_spmv<<get_stream()>>>( + nnz, num_lines, as_device_type(alpha->get_const_values()), + as_device_type(a->get_const_values()), + a->get_const_col_idxs(), + as_device_type(a->get_const_row_idxs()), + as_device_type(b->get_const_values()), b->get_stride(), + as_device_type(c->get_values()), c->get_stride()); + } else { + int num_elems = ceildiv(nnz, nwarps * config::warp_size) * + config::warp_size; + const dim3 coo_grid(ceildiv(nwarps, warps_in_block), + ceildiv(b_ncols, config::warp_size)); + + abstract_spmm<<get_stream()>>>( + nnz, num_elems, as_device_type(alpha->get_const_values()), + as_device_type(a->get_const_values()), + a->get_const_col_idxs(), + as_device_type(a->get_const_row_idxs()), b_ncols, + as_device_type(b->get_const_values()), b->get_stride(), + as_device_type(c->get_values()), c->get_stride()); + } } } } diff --git a/common/cuda_hip/matrix/csr_kernels.template.cpp b/common/cuda_hip/matrix/csr_kernels.template.cpp index f808e234670..bd2423d4306 100644 --- a/common/cuda_hip/matrix/csr_kernels.template.cpp +++ b/common/cuda_hip/matrix/csr_kernels.template.cpp @@ -2064,7 +2064,7 @@ GKO_ENABLE_IMPLEMENTATION_SELECTION(select_classical_spmv, classical_spmv); template -void load_balance_spmv(std::shared_ptr exec, +bool load_balance_spmv(std::shared_ptr exec, const matrix::Csr* a, const matrix::Dense* b, matrix::Dense* c, @@ -2074,42 +2074,54 @@ void load_balance_spmv(std::shared_ptr exec, using arithmetic_type = highest_precision; - if (beta) { - dense::scale(exec, beta, c); - } else { - dense::fill(exec, c, zero()); - } - const IndexType nwarps = a->get_num_srow_elements(); - if (nwarps > 0) { - const dim3 csr_block(config::warp_size, warps_in_block, 1); - const dim3 csr_grid(ceildiv(nwarps, warps_in_block), b->get_size()[1]); - const auto a_vals = - acc::helper::build_const_rrm_accessor(a); - const auto b_vals = - acc::helper::build_const_rrm_accessor(b); - auto c_vals = acc::helper::build_rrm_accessor(c); - if (alpha) { - if (csr_grid.x > 0 && csr_grid.y > 0) { - kernel::abstract_spmv<<get_stream()>>>( - nwarps, static_cast(a->get_size()[0]), - as_device_type(alpha->get_const_values()), - acc::as_device_range(a_vals), a->get_const_col_idxs(), - as_device_type(a->get_const_row_ptrs()), - as_device_type(a->get_const_srow()), - acc::as_device_range(b_vals), acc::as_device_range(c_vals)); - } + // not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + if constexpr (std::is_same_v, half>) { + return false; + } else +#endif + { + if (beta) { + dense::scale(exec, beta, c); } else { - if (csr_grid.x > 0 && csr_grid.y > 0) { - kernel::abstract_spmv<<get_stream()>>>( - nwarps, static_cast(a->get_size()[0]), - acc::as_device_range(a_vals), a->get_const_col_idxs(), - as_device_type(a->get_const_row_ptrs()), - as_device_type(a->get_const_srow()), - acc::as_device_range(b_vals), acc::as_device_range(c_vals)); + dense::fill(exec, c, zero()); + } + const IndexType nwarps = a->get_num_srow_elements(); + if (nwarps > 0) { + const dim3 csr_block(config::warp_size, warps_in_block, 1); + const dim3 csr_grid(ceildiv(nwarps, warps_in_block), + b->get_size()[1]); + const auto a_vals = + acc::helper::build_const_rrm_accessor(a); + const auto b_vals = + acc::helper::build_const_rrm_accessor(b); + auto c_vals = acc::helper::build_rrm_accessor(c); + if (alpha) { + if (csr_grid.x > 0 && csr_grid.y > 0) { + kernel::abstract_spmv<<get_stream()>>>( + nwarps, static_cast(a->get_size()[0]), + as_device_type(alpha->get_const_values()), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + as_device_type(a->get_const_row_ptrs()), + as_device_type(a->get_const_srow()), + acc::as_device_range(b_vals), + acc::as_device_range(c_vals)); + } + } else { + if (csr_grid.x > 0 && csr_grid.y > 0) { + kernel::abstract_spmv<<get_stream()>>>( + nwarps, static_cast(a->get_size()[0]), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + as_device_type(a->get_const_row_ptrs()), + as_device_type(a->get_const_srow()), + acc::as_device_range(b_vals), + acc::as_device_range(c_vals)); + } } } + return true; } } @@ -2257,8 +2269,6 @@ void spmv(std::shared_ptr exec, { if (c->get_size()[0] == 0 || c->get_size()[1] == 0) { // empty output: nothing to do - } else if (a->get_strategy()->get_name() == "load_balance") { - host_kernel::load_balance_spmv(exec, a, b, c); } else if (a->get_strategy()->get_name() == "merge_path") { using arithmetic_type = highest_precision; @@ -2273,8 +2283,10 @@ void spmv(std::shared_ptr exec, syn::value_list(), syn::type_list<>(), exec, a, b, c); } else { bool use_classical = true; - if (a->get_strategy()->get_name() == "sparselib" || - a->get_strategy()->get_name() == "cusparse") { + if (a->get_strategy()->get_name() == "load_balance") { + use_classical = !host_kernel::load_balance_spmv(exec, a, b, c); + } else if (a->get_strategy()->get_name() == "sparselib" || + a->get_strategy()->get_name() == "cusparse") { use_classical = !host_kernel::try_sparselib_spmv(exec, a, b, c); } if (use_classical) { @@ -2316,8 +2328,6 @@ void advanced_spmv(std::shared_ptr exec, { if (c->get_size()[0] == 0 || c->get_size()[1] == 0) { // empty output: nothing to do - } else if (a->get_strategy()->get_name() == "load_balance") { - host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta); } else if (a->get_strategy()->get_name() == "merge_path") { using arithmetic_type = highest_precision; @@ -2333,8 +2343,11 @@ void advanced_spmv(std::shared_ptr exec, beta); } else { bool use_classical = true; - if (a->get_strategy()->get_name() == "sparselib" || - a->get_strategy()->get_name() == "cusparse") { + if (a->get_strategy()->get_name() == "load_balance") { + use_classical = + !host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta); + } else if (a->get_strategy()->get_name() == "sparselib" || + a->get_strategy()->get_name() == "cusparse") { use_classical = !host_kernel::try_sparselib_spmv(exec, a, b, c, alpha, beta); } diff --git a/common/cuda_hip/matrix/ell_kernels.cpp b/common/cuda_hip/matrix/ell_kernels.cpp index 16371166662..23079092162 100644 --- a/common/cuda_hip/matrix/ell_kernels.cpp +++ b/common/cuda_hip/matrix/ell_kernels.cpp @@ -91,7 +91,7 @@ __device__ void spmv_kernel( using arithmetic_type = typename a_accessor::arithmetic_type; const auto tidx = thread::get_thread_id_flat(); const decltype(tidx) column_id = blockIdx.y; - if (num_thread_per_worker == 1) { + if constexpr (num_thread_per_worker == 1) { // Specialize the num_thread_per_worker = 1. It doesn't need the shared // memory, __syncthreads, and atomic_add if (tidx < num_rows) { @@ -137,7 +137,7 @@ __device__ void spmv_kernel( __syncthreads(); if (idx_in_worker == 0) { const auto c_ind = x * c_stride + column_id; - if (atomic) { + if constexpr (atomic) { atomic_add(&(c[c_ind]), op(storage[threadIdx.x], c[c_ind])); } else { c[c_ind] = op(storage[threadIdx.x], c[c_ind]); @@ -179,7 +179,7 @@ __global__ __launch_bounds__(default_block_size) void spmv( using arithmetic_type = typename a_accessor::arithmetic_type; const auto alpha_val = alpha(0); const OutputValueType beta_val = beta[0]; - if (atomic) { + if constexpr (atomic) { // Because the atomic operation changes the values of c during // computation, it can not directly do alpha * a * b + beta * c // operation. The beta * c needs to be done before calling this kernel. @@ -240,42 +240,59 @@ void abstract_spmv(syn::value_list, const dim3 grid_size(ceildiv(nrows * num_worker_per_row, block_size.x), b->get_size()[1], 1); - const auto a_vals = acc::range( - std::array{{static_cast( - num_stored_elements_per_row * stride)}}, - a->get_const_values()); - const auto b_vals = acc::range( - std::array{ - {static_cast(b->get_size()[0]), - static_cast(b->get_size()[1])}}, - b->get_const_values(), - std::array{ - {static_cast(b->get_stride())}}); - - if (alpha == nullptr && beta == nullptr) { - if (grid_size.x > 0 && grid_size.y > 0) { - kernel::spmv - <<get_stream()>>>( - nrows, num_worker_per_row, acc::as_device_range(a_vals), - a->get_const_col_idxs(), stride, - num_stored_elements_per_row, acc::as_device_range(b_vals), - as_device_type(c->get_values()), c->get_stride()); - } - } else if (alpha != nullptr && beta != nullptr) { - const auto alpha_val = acc::range( - std::array{1}, alpha->get_const_values()); - if (grid_size.x > 0 && grid_size.y > 0) { - kernel::spmv - <<get_stream()>>>( - nrows, num_worker_per_row, acc::as_device_range(alpha_val), - acc::as_device_range(a_vals), a->get_const_col_idxs(), - stride, num_stored_elements_per_row, - acc::as_device_range(b_vals), - as_device_type(beta->get_const_values()), - as_device_type(c->get_values()), c->get_stride()); - } - } else { +// not support 16 bit atomic +#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700)) + // We do atomic on shared memory when num_thread_per_worker is not 1. + // If atomic is also true, we also do atomic on out_vector. + constexpr bool shared_half = + std::is_same_v, half>; + constexpr bool atomic_half_out = + atomic && std::is_same_v, half>; + if constexpr (num_thread_per_worker != 1 && + (shared_half || atomic_half_out)) { GKO_KERNEL_NOT_FOUND; + } else +#endif + { + const auto a_vals = acc::range( + std::array{{static_cast( + num_stored_elements_per_row * stride)}}, + a->get_const_values()); + const auto b_vals = acc::range( + std::array{ + {static_cast(b->get_size()[0]), + static_cast(b->get_size()[1])}}, + b->get_const_values(), + std::array{ + {static_cast(b->get_stride())}}); + + if (alpha == nullptr && beta == nullptr) { + if (grid_size.x > 0 && grid_size.y > 0) { + kernel::spmv + <<get_stream()>>>( + nrows, num_worker_per_row, acc::as_device_range(a_vals), + a->get_const_col_idxs(), stride, + num_stored_elements_per_row, + acc::as_device_range(b_vals), + as_device_type(c->get_values()), c->get_stride()); + } + } else if (alpha != nullptr && beta != nullptr) { + const auto alpha_val = acc::range( + std::array{1}, alpha->get_const_values()); + if (grid_size.x > 0 && grid_size.y > 0) { + kernel::spmv + <<get_stream()>>>( + nrows, num_worker_per_row, + acc::as_device_range(alpha_val), + acc::as_device_range(a_vals), a->get_const_col_idxs(), + stride, num_stored_elements_per_row, + acc::as_device_range(b_vals), + as_device_type(beta->get_const_values()), + as_device_type(c->get_values()), c->get_stride()); + } + } else { + GKO_KERNEL_NOT_FOUND; + } } } diff --git a/hip/components/cooperative_groups.hip.hpp b/hip/components/cooperative_groups.hip.hpp index 46c2fb195bc..36618bb7f3e 100644 --- a/hip/components/cooperative_groups.hip.hpp +++ b/hip/components/cooperative_groups.hip.hpp @@ -306,7 +306,7 @@ class enable_extended_shuffle : public Group { SelectorType selector) const \ { \ return shuffle_impl( \ - [this](uint16 v, SelectorType s) { \ + [this](uint32 v, SelectorType s) { \ return static_cast(this)->_name(v, s); \ }, \ var, selector); \ @@ -326,12 +326,12 @@ class enable_extended_shuffle : public Group { shuffle_impl(ShuffleOperator intrinsic_shuffle, const ValueType var, SelectorType selector) { - static_assert(sizeof(ValueType) % sizeof(uint16) == 0, - "Unable to shuffle sizes which are not 2-byte multiples"); - constexpr auto value_size = sizeof(ValueType) / sizeof(uint16); + static_assert(sizeof(ValueType) % sizeof(uint32) == 0, + "Unable to shuffle sizes which are not 4-byte multiples"); + constexpr auto value_size = sizeof(ValueType) / sizeof(uint32); ValueType result; - auto var_array = reinterpret_cast(&var); - auto result_array = reinterpret_cast(&result); + auto var_array = reinterpret_cast(&var); + auto result_array = reinterpret_cast(&result); #pragma unroll for (std::size_t i = 0; i < value_size; ++i) { result_array[i] = intrinsic_shuffle(var_array[i], selector); From 8190bf643b28eb5f421c4c2e36dd85175676a9f3 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Tue, 26 Nov 2024 11:13:59 +0100 Subject: [PATCH 16/16] implement half shuffle via 32 bit impl --- hip/components/cooperative_groups.hip.hpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/hip/components/cooperative_groups.hip.hpp b/hip/components/cooperative_groups.hip.hpp index 36618bb7f3e..dce69421a31 100644 --- a/hip/components/cooperative_groups.hip.hpp +++ b/hip/components/cooperative_groups.hip.hpp @@ -319,6 +319,27 @@ class enable_extended_shuffle : public Group { #undef GKO_ENABLE_SHUFFLE_OPERATION +// hip does not support 16bit shuffle directly +#define GKO_ENABLE_SHUFFLE_OPERATION_HALF(_name, SelectorType) \ + __device__ __forceinline__ __half _name(const __half& var, \ + SelectorType selector) const \ + { \ + uint32 u; \ + memcpy(&u, &var, sizeof(__half)); \ + u = static_cast(this)->_name(u, selector); \ + __half result; \ + memcpy(&result, &u, sizeof(__half)); \ + return result; \ + } + + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl, int32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_up, uint32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_down, uint32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_xor, int32) + +#undef GKO_ENABLE_SHUFFLE_OPERATION_HALF + + private: template