Skip to content

Commit

Permalink
does not work for the other executor
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Jan 14, 2023
1 parent 6ae9695 commit a695ef0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 11 deletions.
4 changes: 3 additions & 1 deletion core/base/extended_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,9 @@ class complex<gko::half> {
const value_type& imag = value_type(0.f))
: real_(real), imag_(imag)
{}
template <typename T, typename U>
template <typename T, typename U,
typename = std::enable_if_t<std::is_scalar<T>::value &&
std::is_scalar<U>::value>>
explicit complex(const T& real, const U& imag)
: complex(static_cast<value_type>(real), static_cast<value_type>(imag))
{}
Expand Down
8 changes: 4 additions & 4 deletions core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,10 @@ ::testing::AssertionResult values_near<std::complex<half>, std::complex<half>>(
std::complex<half> val2, double abs_error)
{
using T = std::complex<float32>;
T Tval1;
T Tval2;
Tval1 = val1;
Tval2 = val2;
// T{val1} calls the constructor of complex<float>() -> which gives the
// complex<float>(double/float) ambiguous
T Tval1 = val1;
T Tval2 = val2;
const double diff = abs(Tval1 - Tval2);
if (diff <= abs_error) return ::testing::AssertionSuccess();

Expand Down
14 changes: 9 additions & 5 deletions cuda/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <ginkgo/core/base/matrix_data.hpp>

// namespace std {

// thrust calls the c function not the function from std
// Maybe override the function from thrust directlry
GKO_ATTRIBUTES GKO_INLINE __half hypot(__half a, __half b)
{
return hypot(static_cast<float>(a), static_cast<float>(b));
Expand All @@ -61,32 +63,34 @@ GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> sqrt(
return sqrt(static_cast<thrust::complex<float>>(a));
}

// } // namespace std

namespace thrust {


// Dircetly call float versrion from here?
template <>
GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z)
{
return hypot(z.real(), z.imag());
}


} // namespace thrust


#define THRUST_HALF_FRIEND_OPERATOR(_op, _opeq) \
GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> operator _op( \
const thrust::complex<__half> lhs, const thrust::complex<__half> rhs) \
{ \
auto result = lhs; \
result _opeq rhs; \
return result; \
return thrust::complex<float>{lhs} + thrust::complex<float>(rhs); \
}

THRUST_HALF_FRIEND_OPERATOR(+, +=)
THRUST_HALF_FRIEND_OPERATOR(-, -=)
THRUST_HALF_FRIEND_OPERATOR(*, *=)
THRUST_HALF_FRIEND_OPERATOR(/, /=)


namespace gko {


Expand Down
42 changes: 41 additions & 1 deletion hip/base/types.hip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,46 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/matrix_data.hpp>


// thrust calls the c function not the function from std
// Maybe override the function from thrust directlry
GKO_ATTRIBUTES GKO_INLINE __half hypot(__half a, __half b)
{
return hypot(static_cast<float>(a), static_cast<float>(b));
}

GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> sqrt(
thrust::complex<__half> a)
{
return sqrt(static_cast<thrust::complex<float>>(a));
}


namespace thrust {


// Dircetly call float versrion from here?
template <>
GKO_ATTRIBUTES GKO_INLINE __half abs<__half>(const complex<__half>& z)
{
return hypot(static_cast<float>(z.real()), static_cast<float>(z.imag()));
}


} // namespace thrust

#define THRUST_HALF_FRIEND_OPERATOR(_op, _opeq) \
GKO_ATTRIBUTES GKO_INLINE thrust::complex<__half> operator _op( \
const thrust::complex<__half> lhs, const thrust::complex<__half> rhs) \
{ \
return thrust::complex<float>{lhs} + thrust::complex<float>(rhs); \
}

THRUST_HALF_FRIEND_OPERATOR(+, +=)
THRUST_HALF_FRIEND_OPERATOR(-, -=)
THRUST_HALF_FRIEND_OPERATOR(*, *=)
THRUST_HALF_FRIEND_OPERATOR(/, /=)


namespace gko {
#if defined(__CUDA_ARCH__)
#if __CUDA_ARCH__ >= 700
Expand Down Expand Up @@ -323,7 +363,7 @@ struct hip_struct_member_type_impl {

template <typename T>
struct hip_struct_member_type_impl<std::complex<T>> {
using type = fake_complex<T>;
using type = fake_complex<typename hip_struct_member_type_impl<T>::type>;
};

template <>
Expand Down

0 comments on commit a695ef0

Please sign in to comment.