Skip to content

Commit

Permalink
cuda with CC<70 and hip do not support 16 bit atomic. throw error or …
Browse files Browse the repository at this point in the history
…fallback to a working version if it is the case for matrix
  • Loading branch information
yhmtsai committed Nov 30, 2024
1 parent 1c4a700 commit 56c0661
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 182 deletions.
48 changes: 0 additions & 48 deletions common/cuda_hip/components/atomic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,52 +96,6 @@ __forceinline__ __device__ ResultType reinterpret(ValueType val)
} \
};


#define GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(CONVERTER_TYPE) \
template <typename ValueType> \
struct atomic_helper< \
ValueType, \
std::enable_if_t<(sizeof(ValueType) == sizeof(CONVERTER_TYPE))>> { \
__forceinline__ __device__ static ValueType atomic_add( \
ValueType* __restrict__ addr, ValueType val) \
{ \
assert(false); \
using c_type = CONVERTER_TYPE; \
return atomic_wrapper( \
addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \
old = *c_addr; \
*c_addr = reinterpret<c_type>( \
val + reinterpret<ValueType>(assumed)); \
}); \
} \
__forceinline__ __device__ static ValueType atomic_max( \
ValueType* __restrict__ addr, ValueType val) \
{ \
assert(false); \
using c_type = CONVERTER_TYPE; \
return atomic_wrapper( \
addr, [&val](c_type& old, c_type assumed, c_type* c_addr) { \
if (reinterpret<ValueType>(assumed) < val) { \
old = *c_addr; \
*c_addr = reinterpret<c_type>(assumed); \
} \
}); \
} \
\
private: \
template <typename Callable> \
__forceinline__ __device__ static ValueType atomic_wrapper( \
ValueType* __restrict__ addr, Callable set_old) \
{ \
CONVERTER_TYPE* address_as_converter = \
reinterpret_cast<CONVERTER_TYPE*>(addr); \
CONVERTER_TYPE old = *address_as_converter; \
CONVERTER_TYPE assumed = old; \
set_old(old, assumed, address_as_converter); \
return reinterpret<ValueType>(old); \
} \
};

// Support 64-bit ATOMIC_ADD and ATOMIC_MAX
GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned long long int);
// Support 32-bit ATOMIC_ADD and ATOMIC_MAX
Expand All @@ -152,8 +106,6 @@ GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned int);
// Support 16-bit atomicCAS, atomicADD, and atomicMAX only on CUDA with CC
// >= 7.0
GKO_BIND_ATOMIC_HELPER_STRUCTURE(unsigned short int);
#else
GKO_BIND_ATOMIC_HELPER_FAKE_STRUCTURE(unsigned short int)
#endif


Expand Down
114 changes: 66 additions & 48 deletions common/cuda_hip/matrix/coo_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,30 +268,38 @@ void spmv2(std::shared_ptr<const DefaultExecutor> exec,
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto nwarps = host_kernel::calculate_nwarps(exec, nnz);

if (nwarps > 0 && b_ncols > 0) {
// TODO: b_ncols needs to be tuned for ROCm.
if (b_ncols < 4) {
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
int num_lines = ceildiv(nnz, nwarps * config::warp_size);

abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_lines, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));

abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_elems, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
GKO_NOT_SUPPORTED(c);
} else
#endif
{
if (nwarps > 0 && b_ncols > 0) {
// TODO: b_ncols needs to be tuned for ROCm.
if (b_ncols < 4) {
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);
int num_lines = ceildiv(nnz, nwarps * config::warp_size);

abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_lines, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
} else {
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));

abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_elems, as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
}
}
}
}
Expand All @@ -312,30 +320,40 @@ void advanced_spmv2(std::shared_ptr<const DefaultExecutor> exec,
const dim3 coo_block(config::warp_size, warps_in_block, 1);
const auto b_ncols = b->get_size()[1];

if (nwarps > 0 && b_ncols > 0) {
// TODO: b_ncols needs to be tuned for ROCm.
if (b_ncols < 4) {
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);

abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_lines, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
} else {
int num_elems =
ceildiv(nnz, nwarps * config::warp_size) * config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));

abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_elems, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()), a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<ValueType>, gko::half>) {
GKO_NOT_SUPPORTED(c);
} else
#endif
{
if (nwarps > 0 && b_ncols > 0) {
// TODO: b_ncols needs to be tuned for ROCm.
if (b_ncols < 4) {
int num_lines = ceildiv(nnz, nwarps * config::warp_size);
const dim3 coo_grid(ceildiv(nwarps, warps_in_block), b_ncols);

abstract_spmv<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_lines, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()),
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
} else {
int num_elems = ceildiv(nnz, nwarps * config::warp_size) *
config::warp_size;
const dim3 coo_grid(ceildiv(nwarps, warps_in_block),
ceildiv(b_ncols, config::warp_size));

abstract_spmm<<<coo_grid, coo_block, 0, exec->get_stream()>>>(
nnz, num_elems, as_device_type(alpha->get_const_values()),
as_device_type(a->get_const_values()),
a->get_const_col_idxs(),
as_device_type(a->get_const_row_idxs()), b_ncols,
as_device_type(b->get_const_values()), b->get_stride(),
as_device_type(c->get_values()), c->get_stride());
}
}
}
}
Expand Down
97 changes: 55 additions & 42 deletions common/cuda_hip/matrix/csr_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2064,7 +2064,7 @@ GKO_ENABLE_IMPLEMENTATION_SELECTION(select_classical_spmv, classical_spmv);

template <typename MatrixValueType, typename InputValueType,
typename OutputValueType, typename IndexType>
void load_balance_spmv(std::shared_ptr<const DefaultExecutor> exec,
bool load_balance_spmv(std::shared_ptr<const DefaultExecutor> exec,
const matrix::Csr<MatrixValueType, IndexType>* a,
const matrix::Dense<InputValueType>* b,
matrix::Dense<OutputValueType>* c,
Expand All @@ -2074,42 +2074,54 @@ void load_balance_spmv(std::shared_ptr<const DefaultExecutor> exec,
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;

if (beta) {
dense::scale(exec, beta, c);
} else {
dense::fill(exec, c, zero<OutputValueType>());
}
const IndexType nwarps = a->get_num_srow_elements();
if (nwarps > 0) {
const dim3 csr_block(config::warp_size, warps_in_block, 1);
const dim3 csr_grid(ceildiv(nwarps, warps_in_block), b->get_size()[1]);
const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
const auto b_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(b);
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
if (alpha) {
if (csr_grid.x > 0 && csr_grid.y > 0) {
kernel::abstract_spmv<<<csr_grid, csr_block, 0,
exec->get_stream()>>>(
nwarps, static_cast<IndexType>(a->get_size()[0]),
as_device_type(alpha->get_const_values()),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
as_device_type(a->get_const_row_ptrs()),
as_device_type(a->get_const_srow()),
acc::as_device_range(b_vals), acc::as_device_range(c_vals));
}
// not support 16 bit atomic
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
if constexpr (std::is_same_v<remove_complex<OutputValueType>, half>) {
return false;
} else
#endif
{
if (beta) {
dense::scale(exec, beta, c);
} else {
if (csr_grid.x > 0 && csr_grid.y > 0) {
kernel::abstract_spmv<<<csr_grid, csr_block, 0,
exec->get_stream()>>>(
nwarps, static_cast<IndexType>(a->get_size()[0]),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
as_device_type(a->get_const_row_ptrs()),
as_device_type(a->get_const_srow()),
acc::as_device_range(b_vals), acc::as_device_range(c_vals));
dense::fill(exec, c, zero<OutputValueType>());
}
const IndexType nwarps = a->get_num_srow_elements();
if (nwarps > 0) {
const dim3 csr_block(config::warp_size, warps_in_block, 1);
const dim3 csr_grid(ceildiv(nwarps, warps_in_block),
b->get_size()[1]);
const auto a_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(a);
const auto b_vals =
acc::helper::build_const_rrm_accessor<arithmetic_type>(b);
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
if (alpha) {
if (csr_grid.x > 0 && csr_grid.y > 0) {
kernel::abstract_spmv<<<csr_grid, csr_block, 0,
exec->get_stream()>>>(
nwarps, static_cast<IndexType>(a->get_size()[0]),
as_device_type(alpha->get_const_values()),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
as_device_type(a->get_const_row_ptrs()),
as_device_type(a->get_const_srow()),
acc::as_device_range(b_vals),
acc::as_device_range(c_vals));
}
} else {
if (csr_grid.x > 0 && csr_grid.y > 0) {
kernel::abstract_spmv<<<csr_grid, csr_block, 0,
exec->get_stream()>>>(
nwarps, static_cast<IndexType>(a->get_size()[0]),
acc::as_device_range(a_vals), a->get_const_col_idxs(),
as_device_type(a->get_const_row_ptrs()),
as_device_type(a->get_const_srow()),
acc::as_device_range(b_vals),
acc::as_device_range(c_vals));
}
}
}
return true;
}
}

Expand Down Expand Up @@ -2257,8 +2269,6 @@ void spmv(std::shared_ptr<const DefaultExecutor> exec,
{
if (c->get_size()[0] == 0 || c->get_size()[1] == 0) {
// empty output: nothing to do
} else if (a->get_strategy()->get_name() == "load_balance") {
host_kernel::load_balance_spmv(exec, a, b, c);
} else if (a->get_strategy()->get_name() == "merge_path") {
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
Expand All @@ -2273,8 +2283,10 @@ void spmv(std::shared_ptr<const DefaultExecutor> exec,
syn::value_list<int>(), syn::type_list<>(), exec, a, b, c);
} else {
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if (a->get_strategy()->get_name() == "load_balance") {
use_classical = !host_kernel::load_balance_spmv(exec, a, b, c);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical = !host_kernel::try_sparselib_spmv(exec, a, b, c);
}
if (use_classical) {
Expand Down Expand Up @@ -2316,8 +2328,6 @@ void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec,
{
if (c->get_size()[0] == 0 || c->get_size()[1] == 0) {
// empty output: nothing to do
} else if (a->get_strategy()->get_name() == "load_balance") {
host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta);
} else if (a->get_strategy()->get_name() == "merge_path") {
using arithmetic_type =
highest_precision<InputValueType, OutputValueType, MatrixValueType>;
Expand All @@ -2333,8 +2343,11 @@ void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec,
beta);
} else {
bool use_classical = true;
if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
if (a->get_strategy()->get_name() == "load_balance") {
use_classical =
!host_kernel::load_balance_spmv(exec, a, b, c, alpha, beta);
} else if (a->get_strategy()->get_name() == "sparselib" ||
a->get_strategy()->get_name() == "cusparse") {
use_classical =
!host_kernel::try_sparselib_spmv(exec, a, b, c, alpha, beta);
}
Expand Down
Loading

0 comments on commit 56c0661

Please sign in to comment.