Skip to content

Commit

Permalink
reduce abs/sqrt location
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 21, 2024
1 parent af5b228 commit 2a40a7d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 134 deletions.
1 change: 0 additions & 1 deletion core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ template <typename Ostream, typename MatrixData1, typename MatrixData2>
void print_componentwise_error(Ostream& os, const MatrixData1& first,
const MatrixData2& second)
{
using std::abs;
using vt = typename detail::biggest_valuetype<
typename MatrixData1::value_type,
typename MatrixData2::value_type>::type;
Expand Down
7 changes: 3 additions & 4 deletions cuda/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ __device__ __forceinline__ bool is_nan(const thrust::complex<__half>& val)
#endif


namespace kernels {
namespace cuda {


#ifdef __CUDACC__
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530

Expand Down Expand Up @@ -138,6 +134,9 @@ __device__ __forceinline__ __half sqrt(const __half& val)
#endif
#endif


namespace kernels {
namespace cuda {
namespace detail {

/**
Expand Down
29 changes: 0 additions & 29 deletions hip/base/types.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,35 +88,8 @@ THRUST_HALF_FRIEND_OPERATOR(/, /=)


namespace gko {
#if GINKGO_HIP_PLATFORM_NVCC
// from the cuda_fp16.hpp
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__ __forceinline__ bool is_nan(const __half& val)
{
return __hisnan(val);
}

#if CUDA_VERSION >= 10020
__device__ __forceinline__ __half abs(const __half& val) { return __habs(val); }
#else
__device__ __forceinline__ __half abs(const __half& val)
{
return abs(static_cast<float>(val));
}
#endif
#else
__device__ __forceinline__ bool is_nan(const __half& val)
{
return is_nan(static_cast<float>(val));
}

__device__ __forceinline__ __half abs(const __half& val)
{
return abs(static_cast<float>(val));
}
#endif

#else // Not nvidia device
__device__ __forceinline__ bool is_nan(const __half& val)
{
return __hisnan(val);
Expand All @@ -125,8 +98,6 @@ __device__ __forceinline__ bool is_nan(const __half& val)
// rocm40 __habs is not constexpr
__device__ __forceinline__ __half abs(const __half& val) { return __habs(val); }

#endif


namespace kernels {
namespace hip {
Expand Down
122 changes: 22 additions & 100 deletions include/ginkgo/core/base/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,106 +35,6 @@ class complex;
namespace gko {


using std::abs;
using std::sqrt;

GKO_INLINE gko::half abs(gko::half a) { return gko::half((a > 0) ? a : -a); }

GKO_INLINE gko::half abs(std::complex<gko::half> a)
{
// Using float abs not sqrt on norm to avoid overflow
return gko::half(abs(std::complex<float>(a)));
}


GKO_INLINE gko::half sqrt(gko::half a)
{
return gko::half(std::sqrt(float(a)));
}

GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
{
return std::complex<gko::half>(sqrt(std::complex<float>(
static_cast<float>(a.real()), static_cast<float>(a.imag()))));
}


} // namespace gko


namespace gko {


// HIP should not see std::abs or std::sqrt, we want the custom implementation.
// Hence, provide the using declaration only for some cases
namespace kernels {
namespace reference {


using std::abs;


using std::sqrt;


} // namespace reference
} // namespace kernels


namespace kernels {
namespace omp {


using std::abs;


using std::sqrt;


} // namespace omp
} // namespace kernels


namespace kernels {
namespace cuda {


using std::abs;


using std::sqrt;


} // namespace cuda
} // namespace kernels


namespace kernels {
namespace dpcpp {


using std::abs;


using std::sqrt;


} // namespace dpcpp
} // namespace kernels


namespace test {


using std::abs;


using std::sqrt;


} // namespace test


// type manipulations


Expand Down Expand Up @@ -1030,6 +930,7 @@ GKO_INLINE constexpr auto squared_norm(const T& x)
return real(conj(x) * x);
}

using std::abs;

/**
* Returns the absolute value of the object.
Expand All @@ -1055,6 +956,27 @@ abs(const T& x)
return sqrt(squared_norm(x));
}

// increase the priority in function lookup
GKO_INLINE gko::half abs(const std::complex<gko::half>& x)
{
// Using float abs not sqrt on norm to avoid overflow
return static_cast<gko::half>(abs(std::complex<float>(x)));
}


using std::sqrt;

GKO_INLINE gko::half sqrt(gko::half a)
{
return gko::half(std::sqrt(float(a)));
}

GKO_INLINE std::complex<gko::half> sqrt(std::complex<gko::half> a)
{
return std::complex<gko::half>(sqrt(std::complex<float>(
static_cast<float>(a.real()), static_cast<float>(a.imag()))));
}


/**
* Returns the value of pi.
Expand Down

0 comments on commit 2a40a7d

Please sign in to comment.