diff --git a/core/test/utils/assertions.hpp b/core/test/utils/assertions.hpp index 174d4536657..87a4e878fc7 100644 --- a/core/test/utils/assertions.hpp +++ b/core/test/utils/assertions.hpp @@ -140,7 +140,6 @@ template void print_componentwise_error(Ostream& os, const MatrixData1& first, const MatrixData2& second) { - using std::abs; using vt = typename detail::biggest_valuetype< typename MatrixData1::value_type, typename MatrixData2::value_type>::type; diff --git a/cuda/base/types.hpp b/cuda/base/types.hpp index c7fe79b5a6f..367674ac163 100644 --- a/cuda/base/types.hpp +++ b/cuda/base/types.hpp @@ -100,10 +100,6 @@ __device__ __forceinline__ bool is_nan(const thrust::complex<__half>& val) #endif -namespace kernels { -namespace cuda { - - #ifdef __CUDACC__ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 @@ -138,6 +134,9 @@ __device__ __forceinline__ __half sqrt(const __half& val) #endif #endif + +namespace kernels { +namespace cuda { namespace detail { /** diff --git a/hip/base/types.hip.hpp b/hip/base/types.hip.hpp index a52dfe0b239..febead4f370 100644 --- a/hip/base/types.hip.hpp +++ b/hip/base/types.hip.hpp @@ -88,35 +88,8 @@ THRUST_HALF_FRIEND_OPERATOR(/, /=) namespace gko { -#if GINKGO_HIP_PLATFORM_NVCC -// from the cuda_fp16.hpp -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 -__device__ __forceinline__ bool is_nan(const __half& val) -{ - return __hisnan(val); -} -#if CUDA_VERSION >= 10020 -__device__ __forceinline__ __half abs(const __half& val) { return __habs(val); } -#else -__device__ __forceinline__ __half abs(const __half& val) -{ - return abs(static_cast(val)); -} -#endif -#else -__device__ __forceinline__ bool is_nan(const __half& val) -{ - return is_nan(static_cast(val)); -} -__device__ __forceinline__ __half abs(const __half& val) -{ - return abs(static_cast(val)); -} -#endif - -#else // Not nvidia device __device__ __forceinline__ bool is_nan(const __half& val) { return __hisnan(val); @@ -125,8 +98,6 @@ __device__ __forceinline__ bool is_nan(const __half& val) // rocm40 __habs is not constexpr __device__ __forceinline__ __half abs(const __half& val) { return __habs(val); } -#endif - namespace kernels { namespace hip { diff --git a/include/ginkgo/core/base/math.hpp b/include/ginkgo/core/base/math.hpp index fb73c9c3cd6..79802a08350 100644 --- a/include/ginkgo/core/base/math.hpp +++ b/include/ginkgo/core/base/math.hpp @@ -35,106 +35,6 @@ class complex; namespace gko { -using std::abs; -using std::sqrt; - -GKO_INLINE gko::half abs(gko::half a) { return gko::half((a > 0) ? a : -a); } - -GKO_INLINE gko::half abs(std::complex a) -{ - // Using float abs not sqrt on norm to avoid overflow - return gko::half(abs(std::complex(a))); -} - - -GKO_INLINE gko::half sqrt(gko::half a) -{ - return gko::half(std::sqrt(float(a))); -} - -GKO_INLINE std::complex sqrt(std::complex a) -{ - return std::complex(sqrt(std::complex( - static_cast(a.real()), static_cast(a.imag())))); -} - - -} // namespace gko - - -namespace gko { - - -// HIP should not see std::abs or std::sqrt, we want the custom implementation. -// Hence, provide the using declaration only for some cases -namespace kernels { -namespace reference { - - -using std::abs; - - -using std::sqrt; - - -} // namespace reference -} // namespace kernels - - -namespace kernels { -namespace omp { - - -using std::abs; - - -using std::sqrt; - - -} // namespace omp -} // namespace kernels - - -namespace kernels { -namespace cuda { - - -using std::abs; - - -using std::sqrt; - - -} // namespace cuda -} // namespace kernels - - -namespace kernels { -namespace dpcpp { - - -using std::abs; - - -using std::sqrt; - - -} // namespace dpcpp -} // namespace kernels - - -namespace test { - - -using std::abs; - - -using std::sqrt; - - -} // namespace test - - // type manipulations @@ -1030,6 +930,7 @@ GKO_INLINE constexpr auto squared_norm(const T& x) return real(conj(x) * x); } +using std::abs; /** * Returns the absolute value of the object. @@ -1055,6 +956,27 @@ abs(const T& x) return sqrt(squared_norm(x)); } +// increase the priority in function lookup +GKO_INLINE gko::half abs(const std::complex& x) +{ + // Using float abs not sqrt on norm to avoid overflow + return static_cast(abs(std::complex(x))); +} + + +using std::sqrt; + +GKO_INLINE gko::half sqrt(gko::half a) +{ + return gko::half(std::sqrt(float(a))); +} + +GKO_INLINE std::complex sqrt(std::complex a) +{ + return std::complex(sqrt(std::complex( + static_cast(a.real()), static_cast(a.imag())))); +} + /** * Returns the value of pi.