diff --git a/hip/components/cooperative_groups.hip.hpp b/hip/components/cooperative_groups.hip.hpp index 36618bb7f3e..dce69421a31 100644 --- a/hip/components/cooperative_groups.hip.hpp +++ b/hip/components/cooperative_groups.hip.hpp @@ -319,6 +319,27 @@ class enable_extended_shuffle : public Group { #undef GKO_ENABLE_SHUFFLE_OPERATION +// hip does not support 16bit shuffle directly +#define GKO_ENABLE_SHUFFLE_OPERATION_HALF(_name, SelectorType) \ + __device__ __forceinline__ __half _name(const __half& var, \ + SelectorType selector) const \ + { \ + uint32 u; \ + memcpy(&u, &var, sizeof(__half)); \ + u = static_cast(this)->_name(u, selector); \ + __half result; \ + memcpy(&result, &u, sizeof(__half)); \ + return result; \ + } + + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl, int32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_up, uint32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_down, uint32) + GKO_ENABLE_SHUFFLE_OPERATION_HALF(shfl_xor, int32) + +#undef GKO_ENABLE_SHUFFLE_OPERATION_HALF + + private: template