diff --git a/cpp/include/raft/core/detail/copy.hpp b/cpp/include/raft/core/detail/copy.hpp index 04e74c4e58..4faded5041 100644 --- a/cpp/include/raft/core/detail/copy.hpp +++ b/cpp/include/raft/core/detail/copy.hpp @@ -32,6 +32,7 @@ #include #include #ifdef __CUDACC__ +#include #include #endif #endif @@ -449,38 +450,51 @@ mdspan_copyable_t copy(resources const& res, DstType&& dst, Sr #endif } else if constexpr (config::can_use_cublas) { #ifndef RAFT_DISABLE_CUDA - auto constexpr const alpha = typename std::remove_reference_t::value_type{1}; - auto constexpr const beta = typename std::remove_reference_t::value_type{0}; - if constexpr (std::is_same_v) { - CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res), - CUBLAS_OP_T, - CUBLAS_OP_N, - dst.extent(1), - dst.extent(0), - &alpha, - src.data_handle(), - src.extent(0), - &beta, - dst.data_handle(), - dst.extent(1), - dst.data_handle(), - dst.extent(1), - resource::get_cuda_stream(res))); + if constexpr (!((std::is_same_v::value_type, half>)&&( + std::is_same_v::value_type, half>))) { + auto constexpr const alpha = typename std::remove_reference_t::value_type{1}; + auto constexpr const beta = typename std::remove_reference_t::value_type{0}; + if constexpr (std::is_same_v) { + CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res), + CUBLAS_OP_T, + CUBLAS_OP_N, + dst.extent(1), + dst.extent(0), + &alpha, + src.data_handle(), + src.extent(0), + &beta, + dst.data_handle(), + dst.extent(1), + dst.data_handle(), + dst.extent(1), + resource::get_cuda_stream(res))); + } else { + CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res), + CUBLAS_OP_T, + CUBLAS_OP_N, + dst.extent(0), + dst.extent(1), + &alpha, + src.data_handle(), + src.extent(1), + &beta, + dst.data_handle(), + dst.extent(0), + dst.data_handle(), + dst.extent(0), + resource::get_cuda_stream(res))); + } } else { - CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res), - CUBLAS_OP_T, - CUBLAS_OP_N, - dst.extent(0), - dst.extent(1), - &alpha, - src.data_handle(), - src.extent(1), - &beta, - dst.data_handle(), - dst.extent(0), - dst.data_handle(), - dst.extent(0), - resource::get_cuda_stream(res))); +#ifdef __CUDACC__ + raft::linalg::transpose(res, dst, src); +#else + // Should never actually reach this because of enable_ifs. Included for + // safety. + RAFT_FAIL( + "raft::copy called in a way that requires custom kernel. Please use " + "raft/core/copy.cuh and include the header in a .cu file"); +#endif } #else // Not possible to reach this due to enable_ifs. Included for safety. diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp index e082aaf41a..c5de345082 100644 --- a/cpp/include/raft/core/math.hpp +++ b/cpp/include/raft/core/math.hpp @@ -106,7 +106,13 @@ template RAFT_INLINE_FUNCTION auto asin(T x) { #ifdef __CUDA_ARCH__ - return ::asin(x); + if constexpr (std::is_same::value) { + float x_float = __half2float(x); + float result_float = ::asin(x_float); + return __float2half(result_float); + } else { + return ::asin(x); + } #else return std::asin(x); #endif @@ -337,6 +343,12 @@ RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y) ((std::is_same_v || std::is_same_v)&&( std::is_same_v || std::is_same_v))) { return ::max(x, y); + } else if constexpr (std::is_same_v && std::is_same_v) { + const float f_y = __half2float(y); + return (x < f_y) ? f_y : x; + } else if constexpr (std::is_same_v && std::is_same_v) { + const float f_x = __half2float(x); + return (f_x < y) ? y : f_x; } // Else, check that the types are the same and provide a generic implementation else { diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index e42801fe32..6b10baa332 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -19,6 +19,8 @@ #include #include +#include + #include #include #include @@ -104,13 +106,27 @@ struct sq_op { { return in * in; } + + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const half& in, UnusedArgs...) const + { + return __half2float(in) * __half2float(in); + } }; struct add_op { template constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { - return a + b; + if constexpr (std::is_same_v && std::is_same_v) { + return __half2float(a) + __half2float(b); + } else if constexpr (std::is_same_v) { + return __half2float(a) + b; + } else if constexpr (std::is_same_v) { + return a + __half2float(b); + } else { + return a + b; + } } }; diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 55da634145..96b778f11f 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -266,7 +266,7 @@ struct MaskedDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + acc[i][j] = BaseClass::Zero(); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index c6b09be31e..a8a541bf53 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -200,7 +200,7 @@ struct PairwiseDistances : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + acc[i][j] = BaseClass::Zero(); } } } diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index cb6488bedf..b284bb3370 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -164,6 +164,12 @@ struct Policy4x4 { typedef KernelPolicy Policy; typedef ColKernelPolicy ColPolicy; }; + +template +struct Policy4x4 { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; /** @} */ /** @@ -204,6 +210,12 @@ struct Policy2x8 { // this is not used just for keeping compiler happy. typedef KernelPolicy Policy; }; + +template +struct Policy2x8 { + typedef KernelPolicy Policy; +}; + /** @} */ /** diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index b15cb222b4..3bdcc22c1f 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,9 @@ struct Contractions_NT { /** block of Y data loaded from global mem after `ldgXY()` */ DataT ldgDataY[P::LdgPerThY][P::Veclen]; - static constexpr DataT Zero = (DataT)0; + // static constexpr DataT Zero = DataT{0}; + + static constexpr DataT Zero() { return DataT{0}; } public: /** @@ -197,7 +199,7 @@ struct Contractions_NT { } else { #pragma unroll for (int j = 0; j < P::Veclen; ++j) { - ldgDataX[i][j] = Zero; + ldgDataX[i][j] = Zero(); } } } @@ -211,7 +213,7 @@ struct Contractions_NT { } else { #pragma unroll for (int j = 0; j < P::Veclen; ++j) { - ldgDataX[i][j] = Zero; + ldgDataX[i][j] = Zero(); } } } @@ -235,7 +237,7 @@ struct Contractions_NT { } else { #pragma unroll for (int j = 0; j < P::Veclen; ++j) { - ldgDataY[i][j] = Zero; + ldgDataY[i][j] = Zero(); } } } @@ -249,7 +251,7 @@ struct Contractions_NT { } else { #pragma unroll for (int j = 0; j < P::Veclen; ++j) { - ldgDataY[i][j] = Zero; + ldgDataY[i][j] = Zero(); } } } diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index 236c840040..af6a78638c 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -27,82 +27,83 @@ namespace raft::linalg::detail { -template +template void legacy_gemm(raft::resources const& res, const bool trans_a, const bool trans_b, const int m, const int n, const int k, - const T* alpha, - const T* A, + const S_T* alpha, + const A_T* A, const int lda, - const T* B, + const B_T* B, const int ldb, - const T* beta, - T* C, + const S_T* beta, + C_T* C, const int ldc, cudaStream_t stream) { - return legacy_matmul(res, - trans_a, - trans_b, - static_cast(m), - static_cast(n), - static_cast(k), - alpha, - A, - static_cast(lda), - B, - static_cast(ldb), - beta, - C, - static_cast(ldc), - stream); + return legacy_matmul(res, + trans_a, + trans_b, + static_cast(m), + static_cast(n), + static_cast(k), + alpha, + A, + static_cast(lda), + B, + static_cast(ldb), + beta, + C, + static_cast(ldc), + stream); } -template +template void legacy_gemm(raft::resources const& res, - const T* a, + const A_T* a, int n_rows_a, int n_cols_a, - const T* b, - T* c, + const B_T* b, + C_T* c, int n_rows_c, int n_cols_c, cublasOperation_t trans_a, cublasOperation_t trans_b, - T alpha, - T beta, + S_T alpha, + S_T beta, cudaStream_t stream) { int m = n_rows_c; int n = n_cols_c; auto k = trans_a == CUBLAS_OP_T ? n_rows_a : n_cols_a; - return legacy_matmul(res, - trans_a == CUBLAS_OP_T, - trans_b == CUBLAS_OP_T, - static_cast(n_rows_c), - static_cast(n_cols_c), - static_cast(k), - &alpha, - a, - static_cast(trans_a == CUBLAS_OP_T ? k : m), - b, - static_cast(trans_b == CUBLAS_OP_T ? n : k), - &beta, - c, - static_cast(m), - stream); + return legacy_matmul( + res, + trans_a == CUBLAS_OP_T, + trans_b == CUBLAS_OP_T, + static_cast(n_rows_c), + static_cast(n_cols_c), + static_cast(k), + &alpha, + a, + static_cast(trans_a == CUBLAS_OP_T ? k : m), + b, + static_cast(trans_b == CUBLAS_OP_T ? n : k), + &beta, + c, + static_cast(m), + stream); } -template +template void legacy_gemm(raft::resources const& res, - const T* a, + const A_T* a, int n_rows_a, int n_cols_a, - const T* b, - T* c, + const B_T* b, + C_T* c, int n_rows_c, int n_cols_c, cublasOperation_t trans_a, @@ -110,14 +111,14 @@ void legacy_gemm(raft::resources const& res, cudaStream_t stream) { return legacy_gemm( - res, a, n_rows_a, n_cols_a, b, c, n_rows_c, n_cols_c, trans_a, trans_b, T{1}, T{0}, stream); + res, a, n_rows_a, n_cols_a, b, c, n_rows_c, n_cols_c, trans_a, trans_b, C_T{1}, C_T{0}, stream); } -template +template void legacy_gemm(raft::resources const& res, - T* z, - T* x, - T* y, + z_T* z, + x_T* x, + y_T* y, int _M, int _N, int _K, @@ -125,11 +126,11 @@ void legacy_gemm(raft::resources const& res, bool isXColMajor, bool isYColMajor, cudaStream_t stream, - const T* alpha, - const T* beta) + const s_T* alpha, + const s_T* beta) { if (isZColMajor) { - return legacy_matmul( + return legacy_matmul( res, !isXColMajor, !isYColMajor, @@ -146,7 +147,7 @@ void legacy_gemm(raft::resources const& res, static_cast(_M), stream); } else { - return legacy_gemm( + return legacy_gemm( res, z, y, x, _N, _M, _K, true, !isYColMajor, !isXColMajor, stream, alpha, beta); } } diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index ed7e360848..24da634575 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ namespace raft { namespace linalg { namespace detail { -template -void rowNormCaller(Type* dots, +template +void rowNormCaller(OutType* dots, const Type* data, IdxType D, IdxType N, @@ -36,53 +36,53 @@ void rowNormCaller(Type* dots, { switch (type) { case L1Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - true, - stream, - false, - raft::abs_op(), - raft::add_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + true, + stream, + false, + raft::abs_op(), + raft::add_op(), + fin_op); break; case L2Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - true, - stream, - false, - raft::sq_op(), - raft::add_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + true, + stream, + false, + raft::sq_op(), + raft::add_op(), + fin_op); break; case LinfNorm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - true, - stream, - false, - raft::abs_op(), - raft::max_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + true, + stream, + false, + raft::abs_op(), + raft::max_op(), + fin_op); break; default: THROW("Unsupported norm type: %d", type); }; } -template -void colNormCaller(Type* dots, +template +void colNormCaller(OutType* dots, const Type* data, IdxType D, IdxType N, @@ -93,46 +93,46 @@ void colNormCaller(Type* dots, { switch (type) { case L1Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - false, - stream, - false, - raft::abs_op(), - raft::add_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + false, + stream, + false, + raft::abs_op(), + raft::add_op(), + fin_op); break; case L2Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - false, - stream, - false, - raft::sq_op(), - raft::add_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + false, + stream, + false, + raft::sq_op(), + raft::add_op(), + fin_op); break; case LinfNorm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - false, - stream, - false, - raft::abs_op(), - raft::max_op(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (OutType)0, + rowMajor, + false, + stream, + false, + raft::abs_op(), + raft::max_op(), + fin_op); break; default: THROW("Unsupported norm type: %d", type); }; diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 999e7f1974..ec60aacc9c 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -28,10 +28,84 @@ #include #include +#include + namespace raft { namespace linalg { namespace detail { +template +RAFT_KERNEL transpose_half_kernel(IndexType n_rows, + IndexType n_cols, + const half* __restrict__ in, + half* __restrict__ out) +{ + __shared__ half tile[TILE_DIM][TILE_DIM + 1]; + + for (int block_offset_y = 0; block_offset_y < n_rows; block_offset_y += gridDim.y * TILE_DIM) { + for (int block_offset_x = 0; block_offset_x < n_cols; block_offset_x += gridDim.x * TILE_DIM) { + auto x = block_offset_x + blockIdx.x * TILE_DIM + threadIdx.x; + auto y = block_offset_y + blockIdx.y * TILE_DIM + threadIdx.y; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if (x < n_cols && (y + j) < n_rows) { + tile[threadIdx.y + j][threadIdx.x] = __ldg(&in[(y + j) * n_cols + x]); + } + } + __syncthreads(); + + x = block_offset_y + blockIdx.y * TILE_DIM + threadIdx.x; + y = block_offset_x + blockIdx.x * TILE_DIM + threadIdx.y; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + if (x < n_rows && (y + j) < n_cols) { + out[(y + j) * n_rows + x] = tile[threadIdx.x][threadIdx.y + j]; + } + } + __syncthreads(); + } + } +} + +template +void transpose_half( + raft::resources const& handle, IndexType n_rows, IndexType n_cols, const half* in, half* out) +{ + if (n_cols == 0 || n_rows == 0) return; + auto stream = resource::get_cuda_stream(handle); + + int dev_id, sm_count; + + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + constexpr int tpb = 256; + constexpr int block_dim_x = 128 / sizeof(half); + constexpr int block_dim_y = tpb / block_dim_x; + + dim3 blocks(block_dim_x, block_dim_y); + + int max_active_blocks = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, transpose_half_kernel, tpb, 0)); + int num_blocks = max_active_blocks * sm_count; + + int grid_x = (n_cols + block_dim_x - 1) / block_dim_x; + int grid_y = (n_rows + block_dim_x - 1) / block_dim_x; + + float ratio = static_cast(grid_y) / static_cast(grid_x); + int adjusted_grid_y = + std::max(std::min(grid_y, static_cast(std::sqrt(num_blocks * ratio))), 1); + int adjusted_grid_x = std::max(std::min(grid_x, num_blocks / adjusted_grid_y), 1); + + dim3 grids(adjusted_grid_x, adjusted_grid_y); + + transpose_half_kernel + <<>>(n_rows, n_cols, in, out); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + template void transpose(raft::resources const& handle, math_t* in, @@ -40,28 +114,31 @@ void transpose(raft::resources const& handle, int n_cols, cudaStream_t stream) { - cublasHandle_t cublas_h = resource::get_cublas_handle(handle); - RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); - int out_n_rows = n_cols; int out_n_cols = n_rows; - const math_t alpha = 1.0; - const math_t beta = 0.0; - RAFT_CUBLAS_TRY(cublasgeam(cublas_h, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_n_rows, - out_n_cols, - &alpha, - in, - n_rows, - &beta, - out, - out_n_rows, - out, - out_n_rows, - stream)); + if constexpr (std::is_same_v) { + transpose_half(handle, out_n_rows, out_n_cols, in, out); + } else { + cublasHandle_t cublas_h = resource::get_cublas_handle(handle); + RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); + const math_t alpha = 1.0; + const math_t beta = 0.0; + RAFT_CUBLAS_TRY(cublasgeam(cublas_h, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_n_rows, + out_n_cols, + &alpha, + in, + n_rows, + &beta, + out, + out_n_rows, + out, + out_n_rows, + stream)); + } } template @@ -112,6 +189,17 @@ void transpose_row_major_impl( resource::get_cuda_stream(handle))); } +template +void transpose_row_major_impl( + raft::resources const& handle, + raft::mdspan, LayoutPolicy, AccessorPolicy> in, + raft::mdspan, LayoutPolicy, AccessorPolicy> out) +{ + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + transpose_half(handle, out_n_cols, out_n_rows, in.data_handle(), out.data_handle()); +} + template void transpose_col_major_impl( raft::resources const& handle, @@ -138,6 +226,18 @@ void transpose_col_major_impl( out.stride(1), resource::get_cuda_stream(handle))); } + +template +void transpose_col_major_impl( + raft::resources const& handle, + raft::mdspan, LayoutPolicy, AccessorPolicy> in, + raft::mdspan, LayoutPolicy, AccessorPolicy> out) +{ + auto out_n_rows = in.extent(1); + auto out_n_cols = in.extent(0); + transpose_half(handle, out_n_rows, out_n_cols, in.data_handle(), out.data_handle()); +} + }; // end namespace detail }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index 7b8d35706b..5444d0c861 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -41,7 +41,10 @@ namespace raft::linalg { * @brief the wrapper of cublas gemm function * It computes the following equation: C = alpha .* opA(A) * opB(B) + beta .* C * - * @tparam math_t the element type + * @tparam A_t the element type of A + * @tparam B_t the element type of B + * @tparam C_t the element type of C + * @tparam S_t the element type of alpha and beta * @tparam DevicePointerMode whether pointers alpha, beta point to device memory * @param [in] handle raft handle * @param [in] trans_a cublas transpose op for A @@ -59,20 +62,20 @@ namespace raft::linalg { * @param [in] ldc leading dimension of C * @param [in] stream */ -template +template void gemm(raft::resources const& handle, const bool trans_a, const bool trans_b, const int m, const int n, const int k, - const math_t* alpha, - const math_t* A, + const S_t* alpha, + const A_t* A, const int lda, - const math_t* B, + const B_t* B, const int ldb, - const math_t* beta, - math_t* C, + const S_t* beta, + C_t* C, const int ldc, cudaStream_t stream) { @@ -83,7 +86,10 @@ void gemm(raft::resources const& handle, /** * @brief the wrapper of cublas gemm function * It computes the following equation: D = alpha . opA(A) * opB(B) + beta . C - * @tparam math_t the type of input/output matrices + * @tparam A_t the element type of A + * @tparam B_t the element type of B + * @tparam C_t the element type of C + * @tparam S_t the element type of alpha and beta * @param handle raft handle * @param a input matrix * @param n_rows_a number of rows of A @@ -98,19 +104,19 @@ void gemm(raft::resources const& handle, * @param beta scalar * @param stream cuda stream */ -template +template void gemm(raft::resources const& handle, - const math_t* a, + const A_t* a, int n_rows_a, int n_cols_a, - const math_t* b, - math_t* c, + const B_t* b, + C_t* c, int n_rows_c, int n_cols_c, cublasOperation_t trans_a, cublasOperation_t trans_b, - math_t alpha, - math_t beta, + S_t alpha, + S_t beta, cudaStream_t stream) { detail::legacy_gemm( @@ -120,7 +126,9 @@ void gemm(raft::resources const& handle, /** * @brief the wrapper of cublas gemm function * It computes the following equation: D = alpha . opA(A) * opB(B) + beta . C - * @tparam math_t the type of input/output matrices + * @tparam A_t the element type of A + * @tparam B_t the element type of B + * @tparam C_t the element type of C * @param handle raft handle * @param a input matrix * @param n_rows_a number of rows of A @@ -133,13 +141,13 @@ void gemm(raft::resources const& handle, * @param trans_b cublas transpose op for B * @param stream cuda stream */ -template +template void gemm(raft::resources const& handle, - const math_t* a, + const A_t* a, int n_rows_a, int n_cols_a, - const math_t* b, - math_t* c, + const B_t* b, + C_t* c, int n_rows_c, int n_cols_c, cublasOperation_t trans_a, @@ -154,7 +162,10 @@ void gemm(raft::resources const& handle, * @brief A wrapper for CUBLS GEMM function designed for handling all possible * combinations of operand layouts. * It computes the following equation: Z = alpha . X * Y + beta . Z - * @tparam T Data type of input/output matrices (float/double) + * @tparam z_T the element type of z + * @tparam x_T the element type of x + * @tparam y_T the element type of y + * @tparam s_T the element type of alpha and beta, equal to z_T by default * @param handle raft handle * @param z output matrix of size M rows x N columns * @param x input matrix of size M rows x K columns @@ -169,11 +180,11 @@ void gemm(raft::resources const& handle, * @param alpha scalar * @param beta scalar */ -template +template void gemm(raft::resources const& handle, - T* z, - T* x, - T* y, + z_T* z, + x_T* x, + y_T* y, int _M, int _N, int _K, @@ -181,10 +192,10 @@ void gemm(raft::resources const& handle, bool isXColMajor, bool isYColMajor, cudaStream_t stream, - T alpha = T(1.0), - T beta = T(0.0)) + s_T alpha = s_T(1.0), + s_T beta = s_T(0.0)) { - return detail::legacy_gemm( + return detail::legacy_gemm( handle, z, x, y, _M, _N, _K, isZColMajor, isXColMajor, isYColMajor, stream, &alpha, &beta); } diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 97a5d6135d..4270149793 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -41,6 +41,7 @@ namespace linalg { * @tparam Type the data type * @tparam Lambda device final lambda * @tparam IdxType Integer type used to for addressing + * @tparam OutType output type, equal to Type by default * @param dots the output vector of row-wise dot products * @param data the input matrix * @param D number of columns of data @@ -50,8 +51,11 @@ namespace linalg { * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template -void rowNorm(Type* dots, +template +void rowNorm(OutType* dots, const Type* data, IdxType D, IdxType N, @@ -68,6 +72,7 @@ void rowNorm(Type* dots, * @tparam Type the data type * @tparam Lambda device final lambda * @tparam IdxType Integer type used to for addressing + * @tparam OutType output type, equal to Type by default * @param dots the output vector of column-wise dot products * @param data the input matrix * @param D number of columns of data @@ -77,8 +82,11 @@ void rowNorm(Type* dots, * @param stream cuda stream where to launch work * @param fin_op the final lambda op */ -template -void colNorm(Type* dots, +template +void colNorm(OutType* dots, const Type* data, IdxType D, IdxType N, @@ -97,7 +105,8 @@ void colNorm(Type* dots, /** * @brief Compute norm of the input matrix and perform fin_op - * @tparam ElementType Input/Output data type + * @tparam ElementType Input data type + * @tparam OutType output data type * @tparam LayoutPolicy the layout of input (raft::row_major or raft::col_major) * @tparam IdxType Integer type used to for addressing * @tparam Lambda device final lambda @@ -110,12 +119,13 @@ void colNorm(Type* dots, * @param[in] fin_op the final lambda op */ template void norm(raft::resources const& handle, raft::device_matrix_view in, - raft::device_vector_view out, + raft::device_vector_view out, NormType type, Apply apply, Lambda fin_op = raft::identity_op()) diff --git a/cpp/include/raft/random/detail/rng_device.cuh b/cpp/include/raft/random/detail/rng_device.cuh index 12c67679ba..ffbb87bd0c 100644 --- a/cpp/include/raft/random/detail/rng_device.cuh +++ b/cpp/include/raft/random/detail/rng_device.cuh @@ -22,6 +22,8 @@ #include +#include + #include #include @@ -504,6 +506,12 @@ struct PhiloxGenerator { return ret; } + DI half next_half() + { + float ret = next_float(); + return __float2half(ret); + } + DI void next(float& ret) { // ret = curand_uniform(&(this->philox_state)); @@ -516,6 +524,12 @@ struct PhiloxGenerator { ret = next_double(); } + DI void next(half& ret) + { + // ret = curand_uniform_double(&(this->philox_state)); + ret = next_half(); + } + DI void next(uint32_t& ret) { ret = next_u32(); } DI void next(uint64_t& ret) { ret = next_u64(); } DI void next(int32_t& ret) { ret = next_i32(); } @@ -636,6 +650,12 @@ struct PCGenerator { return ret; } + HDI half next_half() + { + float ret = next_float(); + return __float2half(ret); + } + HDI void next(uint32_t& ret) { ret = next_u32(); } HDI void next(uint64_t& ret) { ret = next_u64(); } HDI void next(int32_t& ret) { ret = next_i32(); } @@ -643,6 +663,7 @@ struct PCGenerator { HDI void next(float& ret) { ret = next_float(); } HDI void next(double& ret) { ret = next_double(); } + HDI void next(half& ret) { ret = next_half(); } /** @} */ diff --git a/cpp/include/raft/random/detail/rng_impl.cuh b/cpp/include/raft/random/detail/rng_impl.cuh index 61a944e9b6..88654dbe5d 100644 --- a/cpp/include/raft/random/detail/rng_impl.cuh +++ b/cpp/include/raft/random/detail/rng_impl.cuh @@ -30,6 +30,7 @@ #include #include +#include namespace raft { namespace random { @@ -85,7 +86,7 @@ template void uniform( RngState& rng_state, OutType* ptr, LenType len, OutType start, OutType end, cudaStream_t stream) { - static_assert(std::is_floating_point::value, + static_assert(std::is_floating_point::value || std::is_same_v, "Type for 'uniform' can only be floating point!"); UniformDistParams params; params.start = start; diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh index b1b0291a85..769d5de9be 100644 --- a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -66,24 +66,30 @@ RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(cons index_t e_bit = s_bit + num_cols; index_t l_sum = 0; + int s_gap = 0; + int e_gap = 0; + while (offset < num_cols) { index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; std::remove_const_t l_bitmap = 0; if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } - if (s_bit > bitmap_idx * BITS_PER_BITMAP) { - l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); - l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); - } + offset += BITS_PER_BITMAP * warpSize; - if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { - l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); - l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + s_gap = s_bit - bitmap_idx * BITS_PER_BITMAP; + if (s_gap > 0) { + l_bitmap >>= s_gap; + l_bitmap <<= s_gap; + offset -= s_gap; } + e_gap = (bitmap_idx + 1) * BITS_PER_BITMAP - e_bit; + if (e_gap > 0) { + l_bitmap <<= e_gap; + l_bitmap >>= e_gap; + } l_sum += static_cast(raft::detail::popc(l_bitmap)); - offset += BITS_PER_BITMAP * warpSize; } l_sum = cg::reduce(tile, l_sum, cg::plus()); diff --git a/cpp/include/raft/sparse/detail/cusparse_wrappers.h b/cpp/include/raft/sparse/detail/cusparse_wrappers.h index ae552cc687..53a78a8f56 100644 --- a/cpp/include/raft/sparse/detail/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/detail/cusparse_wrappers.h @@ -207,6 +207,27 @@ inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, CUDA_R_64F); } template <> +inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, + int64_t rows, + int64_t cols, + int64_t nnz, + int32_t* csrRowOffsets, + int32_t* csrColInd, + half* csrValues) +{ + return cusparseCreateCsr(spMatDescr, + rows, + cols, + nnz, + csrRowOffsets, + csrColInd, + csrValues, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, + CUDA_R_16F); +} +template <> inline cusparseStatus_t cusparsecreatecsr(cusparseSpMatDescr_t* spMatDescr, int64_t rows, int64_t cols, @@ -302,6 +323,16 @@ inline cusparseStatus_t cusparsecreatednmat(cusparseDnMatDescr_t* dnMatDescr, { return cusparseCreateDnMat(dnMatDescr, rows, cols, ld, values, CUDA_R_64F, order); } +template <> +inline cusparseStatus_t cusparsecreatednmat(cusparseDnMatDescr_t* dnMatDescr, + int64_t rows, + int64_t cols, + int64_t ld, + half* values, + cusparseOrder_t order) +{ + return cusparseCreateDnMat(dnMatDescr, rows, cols, ld, values, CUDA_R_16F, order); +} /** @} */ /** @@ -658,7 +689,7 @@ inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, const T* beta, cusparseSpMatDescr_t matC, cusparseSDDMMAlg_t alg, - T* externalBuffer, + void* externalBuffer, cudaStream_t stream); template <> inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, @@ -670,7 +701,7 @@ inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, const float* beta, cusparseSpMatDescr_t matC, cusparseSDDMMAlg_t alg, - float* externalBuffer, + void* externalBuffer, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); @@ -684,7 +715,7 @@ inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, matC, CUDA_R_32F, alg, - static_cast(externalBuffer)); + externalBuffer); } template <> inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, @@ -696,7 +727,7 @@ inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, const double* beta, cusparseSpMatDescr_t matC, cusparseSDDMMAlg_t alg, - double* externalBuffer, + void* externalBuffer, cudaStream_t stream) { CUSPARSE_CHECK(cusparseSetStream(handle, stream)); @@ -710,7 +741,34 @@ inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, matC, CUDA_R_64F, alg, - static_cast(externalBuffer)); + externalBuffer); +} + +template <> +inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle, + cusparseOperation_t opA, + cusparseOperation_t opB, + const half* alpha, + const cusparseDnMatDescr_t matA, + const cusparseDnMatDescr_t matB, + const half* beta, + cusparseSpMatDescr_t matC, + cusparseSDDMMAlg_t alg, + void* externalBuffer, + cudaStream_t stream) +{ + CUSPARSE_CHECK(cusparseSetStream(handle, stream)); + return cusparseSDDMM(handle, + opA, + opB, + static_cast(alpha), + matA, + matB, + static_cast(beta), + matC, + CUDA_R_16F, + alg, + externalBuffer); } /** @} */ diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 42b545180b..864d61ba2f 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -20,6 +20,7 @@ #include #include +#include #include namespace raft { @@ -41,8 +42,8 @@ inline int max_cols_per_block() sizeof(value_t); } -template -RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, +template +RAFT_KERNEL faster_dot_on_csr_kernel(dot_t* __restrict__ dot, const value_idx* __restrict__ indptr, const value_idx* __restrict__ cols, const value_t* __restrict__ A, @@ -74,25 +75,28 @@ RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot, cur_row = row; } - value_t l_dot_ = 0.0; + dot_t l_dot_ = 0.0; for (value_idx k = vec_id; k < dim; k += blockDim.x) { asm("prefetch.global.L2 [%0];" ::"l"(B_col + k + blockDim.x)); - l_dot_ += s_A[k] * __ldcg(B_col + k); + if constexpr ((std::is_same_v && std::is_same_v)) { + l_dot_ += __half2float(s_A[k]) * __half2float(__ldcg(B_col + k)); + } else { + l_dot_ += s_A[k] * __ldcg(B_col + k); + } } - l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, 16); - l_dot_ += __shfl_down_sync(0xffff, l_dot_, 8); - l_dot_ += __shfl_down_sync(0xff, l_dot_, 4); - l_dot_ += __shfl_down_sync(0xf, l_dot_, 2); - l_dot_ += __shfl_down_sync(0x3, l_dot_, 1); - if (lane_id == 0) { atomicAdd_block(dot + dot_id, l_dot_); } + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + dot_t warp_sum = WarpReduce(temp_storage).Sum(l_dot_); + + if (lane_id == 0) { atomicAdd_block(dot + dot_id, warp_sum); } } } } -template +template void faster_dot_on_csr(raft::resources const& handle, - value_t* dot, + dot_t* dot, const value_idx nnz, const value_idx* indptr, const value_idx* cols, @@ -115,47 +119,47 @@ void faster_dot_on_csr(raft::resources const& handle, if (dim < 128) { constexpr int tpb = 64; cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); auto block_y = (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; dim3 blocks(block_x, block_y, 1); - faster_dot_on_csr_kernel + faster_dot_on_csr_kernel <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } else if (dim < 256) { constexpr int tpb = 128; cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); auto block_y = (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; dim3 blocks(block_x, block_y, 1); - faster_dot_on_csr_kernel + faster_dot_on_csr_kernel <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } else if (dim < 512) { constexpr int tpb = 256; cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); auto block_y = (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; dim3 blocks(block_x, block_y, 1); - faster_dot_on_csr_kernel + faster_dot_on_csr_kernel <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } else { constexpr int tpb = 512; cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); + &blocks_per_sm, faster_dot_on_csr_kernel, tpb, smem_size); auto block_x = std::min(n_rows, MAX_ROW_PER_ITER); auto block_y = (std::min(value_idx(blocks_per_sm * sm_count * 16), nnz) + block_x - 1) / block_x; dim3 blocks(block_x, block_y, 1); - faster_dot_on_csr_kernel + faster_dot_on_csr_kernel <<>>(dot, indptr, cols, A, B, nnz, n_rows, dim); } diff --git a/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh index ef74316d04..276960628d 100644 --- a/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh +++ b/cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh @@ -37,14 +37,14 @@ namespace sparse { namespace linalg { namespace detail { -template +template void masked_matmul(raft::resources const& handle, raft::device_matrix_view& A, raft::device_matrix_view& B, raft::core::bitmap_view& mask, - raft::device_csr_matrix_view& C, - std::optional> alpha, - std::optional> beta) + raft::device_csr_matrix_view& C, + std::optional> alpha, + std::optional> beta) { index_t m = A.extent(0); index_t n = B.extent(0); @@ -60,24 +60,24 @@ void masked_matmul(raft::resources const& handle, auto stream = raft::resource::get_cuda_stream(handle); - auto C_matrix = raft::make_device_csr_matrix(handle, compressed_C_view); + auto C_matrix = raft::make_device_csr_matrix(handle, compressed_C_view); // fill C raft::sparse::convert::bitmap_to_csr(handle, mask, C_matrix); if (m > 10 || alpha.has_value() || beta.has_value()) { - auto C_view = raft::make_device_csr_matrix_view( + auto C_view = raft::make_device_csr_matrix_view( C.get_elements().data(), compressed_C_view); // create B col_major view auto B_col_major = raft::make_device_matrix_view( B.data_handle(), dim, n); - value_t default_alpha = static_cast(1.0f); - value_t default_beta = static_cast(0.0f); + output_t default_alpha = static_cast(1.0f); + output_t default_beta = static_cast(0.0f); - if (!alpha.has_value()) { alpha = raft::make_host_scalar_view(&default_alpha); } - if (!beta.has_value()) { beta = raft::make_host_scalar_view(&default_beta); } + if (!alpha.has_value()) { alpha = raft::make_host_scalar_view(&default_alpha); } + if (!beta.has_value()) { beta = raft::make_host_scalar_view(&default_beta); } raft::sparse::linalg::sddmm(handle, A, diff --git a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp index 5088a20f46..f2e4aba644 100644 --- a/cpp/include/raft/sparse/linalg/detail/sddmm.hpp +++ b/cpp/include/raft/sparse/linalg/detail/sddmm.hpp @@ -35,11 +35,7 @@ namespace detail { * It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C * where A,B are device matrix views and C is a CSR device matrix view * - * @tparam ValueType Data type of input/output matrices (float/double) - * @tparam IndexType Type of C - * @tparam LayoutPolicyA layout of A - * @tparam LayoutPolicyB layout of B - * @tparam NZType Type of C + * @tparam OutputType Data type of input/output matrices (float/double) * * @param[in] handle raft resource handle * @param[in] descr_a input dense descriptor @@ -50,15 +46,15 @@ namespace detail { * @param[in] alpha scalar pointer * @param[in] beta scalar pointer */ -template +template void sddmm(raft::resources const& handle, cusparseDnMatDescr_t& descr_a, cusparseDnMatDescr_t& descr_b, cusparseSpMatDescr_t& descr_c, cusparseOperation_t op_a, cusparseOperation_t op_b, - const ValueType* alpha, - const ValueType* beta) + const OutputType* alpha, + const OutputType* beta) { auto alg = CUSPARSE_SDDMM_ALG_DEFAULT; size_t bufferSize; @@ -78,7 +74,7 @@ void sddmm(raft::resources const& handle, resource::sync_stream(handle); - rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); + rmm::device_uvector tmp(bufferSize, resource::get_cuda_stream(handle)); RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle), op_a, @@ -89,7 +85,7 @@ void sddmm(raft::resources const& handle, beta, descr_c, alg, - tmp.data(), + reinterpret_cast(tmp.data()), resource::get_cuda_stream(handle))); } diff --git a/cpp/include/raft/sparse/linalg/masked_matmul.hpp b/cpp/include/raft/sparse/linalg/masked_matmul.hpp index 560cd3f715..6cf6e834b9 100644 --- a/cpp/include/raft/sparse/linalg/masked_matmul.hpp +++ b/cpp/include/raft/sparse/linalg/masked_matmul.hpp @@ -35,7 +35,8 @@ namespace linalg { * multiplication using the sparsity pattern provided by the mask. The result is scaled by alpha * and added to beta times the original matrix C. * - * @tparam value_t Data type of elements in the input/output matrices (e.g., float, double) + * @tparam value_t Data type of elements in the input matrices (e.g., half, float, double) + * @tparam output_t Data type of elements in the output matrices (e.g., float, double) * @tparam index_t Type used for matrix indices * @tparam nnz_t Type used for the number of non-zero entries in CSR format * @tparam bitmap_t Type of the bitmap used for the mask @@ -52,14 +53,14 @@ namespace linalg { * std::nullopt) * @param[in] beta Optional scalar multiplier for the original matrix C (default: 0 if std::nullopt) */ -template +template void masked_matmul(raft::resources const& handle, raft::device_matrix_view A, raft::device_matrix_view B, raft::core::bitmap_view mask, - raft::device_csr_matrix_view C, - std::optional> alpha = std::nullopt, - std::optional> beta = std::nullopt) + raft::device_csr_matrix_view C, + std::optional> alpha = std::nullopt, + std::optional> beta = std::nullopt) { detail::masked_matmul(handle, A, B, mask, C, alpha, beta); } diff --git a/cpp/include/raft/sparse/linalg/sddmm.hpp b/cpp/include/raft/sparse/linalg/sddmm.hpp index c19f1d9081..96387e6c8b 100644 --- a/cpp/include/raft/sparse/linalg/sddmm.hpp +++ b/cpp/include/raft/sparse/linalg/sddmm.hpp @@ -29,11 +29,12 @@ namespace linalg { * followed by an element-wise multiplication with the sparsity pattern of C. * It computes the following equation: C = alpha · (opA(A) * opB(B) ∘ spy(C)) + beta · C * where A,B are device matrix views and C is a CSR device matrix view - * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam ValueType Data type of input/output matrices (float/double/half) * @tparam IndexType Type of C * @tparam NZType Type of C * @tparam LayoutPolicyA layout of A * @tparam LayoutPolicyB layout of B + * @tparam OutputType output type, equal to ValueType by default * @param[in] handle raft handle * @param[in] A input raft::device_matrix_view * @param[in] B input raft::device_matrix_view @@ -47,21 +48,23 @@ template + typename LayoutPolicyB, + typename OutputType> void sddmm(raft::resources const& handle, raft::device_matrix_view A, raft::device_matrix_view B, - raft::device_csr_matrix_view C, + raft::device_csr_matrix_view C, const raft::linalg::Operation opA, const raft::linalg::Operation opB, - raft::host_scalar_view alpha, - raft::host_scalar_view beta) + raft::host_scalar_view alpha, + raft::host_scalar_view beta) { RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(B), "B is not contiguous"); - static_assert(std::is_same_v || std::is_same_v, - "The `ValueType` of sddmm only supports float/double."); + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "The `ValueType` of sddmm only supports float/double/half."); auto descrA = detail::create_descriptor(A); auto descrB = detail::create_descriptor(B); diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 7a5a217959..0cc3b864fa 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -69,7 +69,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; + acc[i][j] = BaseClass::Zero(); } } this->stsXY(); diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh index bb9ebbba59..26f48af68b 100644 --- a/cpp/include/raft/util/cuda_dev_essentials.cuh +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ #pragma once +#include + // This file provides a few essential functions for use in __device__ code. The // scope is necessarily limited to ensure that compilation times are minimized. // Please make sure not to include large / expensive files from here. @@ -114,4 +116,19 @@ HDI void swapVals(T& a, T& b) b = tmp; } +/** + * @brief Convert half to float + * @tparam T the datatype of the value + * @param a need to convert + */ +template +HDI auto to_float(T& a) +{ + if constexpr (std::is_same_v::type, half>) { + return __half2float(a); + } else { + return a; + } +} + } // namespace raft diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 2b334d1bbf..f9e7f521be 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -189,7 +189,11 @@ void print_host_vector(const char* variable_name, out << variable_name << "=["; for (size_t i = 0; i < componentsCount; ++i) { if (i != 0) out << ","; - out << host_mem[i]; + if constexpr (std::is_same_v) { + out << __half2float(host_mem[i]); + } else { + out << host_mem[i]; + } } out << "];" << std::endl; } diff --git a/cpp/test/core/mdspan_copy.cu b/cpp/test/core/mdspan_copy.cu index b68ba38914..419c1e0859 100644 --- a/cpp/test/core/mdspan_copy.cu +++ b/cpp/test/core/mdspan_copy.cu @@ -161,6 +161,59 @@ TEST(MDSpanCopy, Mdspan2DDeviceDeviceCuda) } } } + +TEST(MDSpanCopy, Mdspan2DDeviceDeviceCudaHalfWithTranspose) +{ + auto res = device_resources{}; + auto constexpr rows = std::uint32_t{30}; + auto constexpr cols = std::uint32_t{20}; + auto in_left = make_device_mdarray( + res, extents{}); + auto in_right = make_device_mdarray( + res, extents{}); + auto gen_unique_entry = [](auto&& x, auto&& y) { return x * 7 + y * 11; }; + + for (auto i = std::uint32_t{}; i < rows; ++i) { + for (auto j = std::uint32_t{}; j < cols; ++j) { + in_left(i, j) = gen_unique_entry(i, j); + in_right(i, j) = gen_unique_entry(i, j); + } + } + + auto out_left = make_device_mdarray( + res, extents{}); + auto out_right = make_device_mdarray( + res, extents{}); + + res.sync_stream(); + + // Test dtype conversion with transpose + static_assert( + detail::mdspan_copyable_with_kernel_v, + "Current implementation should use kernel for this copy"); + copy(res, out_right.view(), in_left.view()); + res.sync_stream(); + for (auto i = std::uint32_t{}; i < rows; ++i) { + for (auto j = std::uint32_t{}; j < cols; ++j) { + ASSERT_TRUE(match(__half2float(out_right(i, j)), + __half2float(gen_unique_entry(i, j)), + CompareApprox{0.0001})); + } + } + static_assert( + detail::mdspan_copyable_with_kernel_v, + "Current implementation should use kernel for this copy"); + copy(res, out_left.view(), in_right.view()); + res.sync_stream(); + for (auto i = std::uint32_t{}; i < rows; ++i) { + for (auto j = std::uint32_t{}; j < cols; ++j) { + ASSERT_TRUE(match(__half2float(out_left(i, j)), + __half2float(gen_unique_entry(i, j)), + CompareApprox{0.0001})); + } + } +} + TEST(MDSpanCopy, Mdspan3DDeviceHostCuda) { auto res = device_resources{}; diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index e4f064949c..f91350f222 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -24,6 +24,8 @@ #include #include +#include + #include namespace raft { @@ -48,39 +50,48 @@ template } ///// Row-wise norm test definitions -template +template RAFT_KERNEL naiveRowNormKernel( - Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) + OutType* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) { - Type acc = (Type)0; + OutType acc = (OutType)0; IdxT rowStart = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; if (rowStart < N) { for (IdxT i = 0; i < D; ++i) { - if (type == L2Norm) { - acc += data[rowStart * D + i] * data[rowStart * D + i]; + if constexpr (std::is_same_v) { + if (type == L2Norm) { + acc += __half2float(data[rowStart * D + i]) * __half2float(data[rowStart * D + i]); + } else { + acc += raft::abs(__half2float(data[rowStart * D + i])); + } } else { - acc += raft::abs(data[rowStart * D + i]); + if (type == L2Norm) { + acc += data[rowStart * D + i] * data[rowStart * D + i]; + } else { + acc += raft::abs(data[rowStart * D + i]); + } } } dots[rowStart] = do_sqrt ? raft::sqrt(acc) : acc; } } -template +template void naiveRowNorm( - Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) + OutType* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) { static const IdxT TPB = 64; IdxT nblks = raft::ceildiv(N, TPB); - naiveRowNormKernel<<>>(dots, data, D, N, type, do_sqrt); + naiveRowNormKernel + <<>>(dots, data, D, N, type, do_sqrt); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -class RowNormTest : public ::testing::TestWithParam> { +template +class RowNormTest : public ::testing::TestWithParam> { public: RowNormTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), data(params.rows * params.cols, stream), dots_exp(params.rows, stream), @@ -94,7 +105,7 @@ class RowNormTest : public ::testing::TestWithParam> { IdxT rows = params.rows, cols = params.cols, len = rows * cols; uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveRowNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); - auto output_view = raft::make_device_vector_view(dots_act.data(), params.rows); + auto output_view = raft::make_device_vector_view(dots_act.data(), params.rows); auto input_row_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); auto input_col_major = raft::make_device_matrix_view( @@ -119,42 +130,44 @@ class RowNormTest : public ::testing::TestWithParam> { raft::resources handle; cudaStream_t stream; - NormInputs params; - rmm::device_uvector data, dots_exp, dots_act; + NormInputs params; + rmm::device_uvector data; + rmm::device_uvector dots_exp, dots_act; }; ///// Column-wise norm test definitisons -template +template RAFT_KERNEL naiveColNormKernel( - Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) + OutType* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt) { IdxT colID = threadIdx.x + static_cast(blockIdx.x) * blockDim.x; if (colID >= D) return; // avoid out-of-bounds thread - Type acc = 0; + OutType acc = 0; for (IdxT i = 0; i < N; i++) { - Type v = data[colID + i * D]; + OutType v = data[colID + i * D]; acc += type == L2Norm ? v * v : raft::abs(v); } dots[colID] = do_sqrt ? raft::sqrt(acc) : acc; } -template +template void naiveColNorm( - Type* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) + OutType* dots, const Type* data, IdxT D, IdxT N, NormType type, bool do_sqrt, cudaStream_t stream) { static const IdxT TPB = 64; IdxT nblks = raft::ceildiv(D, TPB); - naiveColNormKernel<<>>(dots, data, D, N, type, do_sqrt); + naiveColNormKernel + <<>>(dots, data, D, N, type, do_sqrt); RAFT_CUDA_TRY(cudaPeekAtLastError()); } -template -class ColNormTest : public ::testing::TestWithParam> { +template +class ColNormTest : public ::testing::TestWithParam> { public: ColNormTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), data(params.rows * params.cols, stream), dots_exp(params.cols, stream), @@ -169,7 +182,7 @@ class ColNormTest : public ::testing::TestWithParam> { uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveColNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); - auto output_view = raft::make_device_vector_view(dots_act.data(), params.cols); + auto output_view = raft::make_device_vector_view(dots_act.data(), params.cols); auto input_row_major = raft::make_device_matrix_view( data.data(), params.rows, params.cols); auto input_col_major = raft::make_device_matrix_view( @@ -196,8 +209,9 @@ class ColNormTest : public ::testing::TestWithParam> { raft::resources handle; cudaStream_t stream; - NormInputs params; - rmm::device_uvector data, dots_exp, dots_act; + NormInputs params; + rmm::device_uvector data; + rmm::device_uvector dots_exp, dots_act; }; ///// Row- and column-wise tests @@ -246,6 +260,19 @@ const std::vector> inputscd_i64 = {true}, {1234ULL}); +const std::vector> inputsh_i32 = + raft::util::itertools::product>( + {0.00001f}, {11, 1234}, {7, 33, 128, 500}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputsh_i64 = + raft::util::itertools::product>( + {0.00001f}, {11, 1234}, {7, 33, 128, 500}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputsch_i32 = + raft::util::itertools::product>( + {0.00001f}, {7, 33, 128, 500}, {11, 1234}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); +const std::vector> inputsch_i64 = + raft::util::itertools::product>( + {0.00001f}, {7, 33, 128, 500}, {11, 1234}, {L1Norm, L2Norm}, {false, true}, {true}, {1234ULL}); + typedef RowNormTest RowNormTestF_i32; typedef RowNormTest RowNormTestD_i32; typedef RowNormTest RowNormTestF_i64; @@ -255,6 +282,11 @@ typedef ColNormTest ColNormTestD_i32; typedef ColNormTest ColNormTestF_i64; typedef ColNormTest ColNormTestD_i64; +typedef RowNormTest RowNormTestH_i32; +typedef RowNormTest RowNormTestH_i64; +typedef ColNormTest ColNormTestH_i32; +typedef ColNormTest ColNormTestH_i64; + #define ROWNORM_TEST(test_type, test_inputs) \ TEST_P(test_type, Result) \ { \ @@ -272,5 +304,10 @@ ROWNORM_TEST(ColNormTestD_i32, inputscd_i32); ROWNORM_TEST(ColNormTestF_i64, inputscf_i64); ROWNORM_TEST(ColNormTestD_i64, inputscd_i64); +ROWNORM_TEST(RowNormTestH_i32, inputsh_i32); +ROWNORM_TEST(RowNormTestH_i64, inputsh_i64); +ROWNORM_TEST(ColNormTestH_i32, inputsch_i32); +ROWNORM_TEST(ColNormTestH_i64, inputsch_i64); + } // end namespace linalg } // end namespace raft diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index f6857d3ffa..cbe869a9a5 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -25,6 +25,8 @@ #include +#include + #include namespace raft { @@ -84,6 +86,8 @@ const std::vector> inputsf2 = {{0.1f, 3 * 3, 3, 3, 1234ULL const std::vector> inputsd2 = {{0.1, 3 * 3, 3, 3, 1234ULL}}; +const std::vector> inputsh2 = {{0.1, 3 * 3, 3, 3, 1234ULL}}; + typedef TransposeTest TransposeTestValF; TEST_P(TransposeTestValF, Result) { @@ -112,10 +116,49 @@ TEST_P(TransposeTestValD, Result) raft::CompareApproxAbs(params.tolerance))); } +bool validate_half(const half* h_ref, const half* h_result, half tolerance, int len) +{ + bool success = true; + for (int i = 0; i < len; ++i) { + if (raft::abs(__half2float(h_result[i]) - __half2float(h_ref[i])) >= __half2float(tolerance)) { + success = false; + break; + } + if (!success) break; + } + return success; +} + +typedef TransposeTest TransposeTestValH; +TEST_P(TransposeTestValH, Result) +{ + half data_trans_ref_h[params.len]; + half data_trans_h[params.len]; + half data_h[params.len]; + + RAFT_CUDA_TRY(cudaMemcpyAsync(data_trans_ref_h, + data_trans_ref.data(), + params.len * sizeof(half), + cudaMemcpyDeviceToHost, + stream)); + + RAFT_CUDA_TRY(cudaMemcpyAsync( + data_trans_h, data_trans.data(), params.len * sizeof(half), cudaMemcpyDeviceToHost, stream)); + RAFT_CUDA_TRY(cudaMemcpyAsync( + data_h, data.data(), params.len * sizeof(half), cudaMemcpyDeviceToHost, stream)); + + resource::sync_stream(handle, stream); + + ASSERT_TRUE(validate_half(data_trans_ref_h, data_trans_h, params.tolerance, params.len)); + ASSERT_TRUE(validate_half(data_trans_ref_h, data_h, params.tolerance, params.len)); +} + INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValF, ::testing::ValuesIn(inputsf2)); INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValD, ::testing::ValuesIn(inputsd2)); +INSTANTIATE_TEST_SUITE_P(TransposeTests, TransposeTestValH, ::testing::ValuesIn(inputsh2)); + namespace { /** * We hide these functions in tests for now until we have a heterogeneous mdarray diff --git a/cpp/test/sparse/masked_matmul.cu b/cpp/test/sparse/masked_matmul.cu index 0ece716a1b..f883beae32 100644 --- a/cpp/test/sparse/masked_matmul.cu +++ b/cpp/test/sparse/masked_matmul.cu @@ -24,6 +24,7 @@ #include +#include #include #include @@ -32,15 +33,15 @@ namespace raft { namespace sparse { -template +template struct MaskedMatmulInputs { - value_t tolerance; + output_t tolerance; index_t m; index_t k; index_t n; - value_t sparsity; + float sparsity; unsigned long long int seed; }; @@ -53,8 +54,13 @@ struct sum_abs_op { } }; -template -::std::ostream& operator<<(::std::ostream& os, const MaskedMatmulInputs& params) +struct float_to_half { + __host__ __device__ __half operator()(const float x) const { return __float2half(x); } +}; + +template +::std::ostream& operator<<(::std::ostream& os, + const MaskedMatmulInputs& params) { os << " m: " << params.m << "\tk: " << params.k << "\tn: " << params.n << "\tsparsity: " << params.sparsity; @@ -62,15 +68,33 @@ template return os; } +bool isCuSparseVersionGreaterThan_12_0_1() +{ + int version; + cusparseHandle_t handle; + cusparseCreate(&handle); + cusparseGetVersion(handle, &version); + + int major = version / 1000; + int minor = (version % 1000) / 100; + int patch = version % 100; + + cusparseDestroy(handle); + + return (major > 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2); +} + template -class MaskedMatmulTest : public ::testing::TestWithParam> { +class MaskedMatmulTest + : public ::testing::TestWithParam> { public: MaskedMatmulTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), a_data_d(0, resource::get_cuda_stream(handle)), b_data_d(0, resource::get_cuda_stream(handle)), @@ -142,7 +166,7 @@ class MaskedMatmulTest : public ::testing::TestWithParam& A, const std::vector& B, - std::vector& vals, + std::vector& vals, const std::vector& cols, const std::vector& row_ptrs, bool is_row_major_A, @@ -156,11 +180,15 @@ class MaskedMatmulTest : public ::testing::TestWithParam && std::is_same_v)) { + sum += __half2float(A[a_index]) * __half2float(B[b_index]); + } else { + sum += A[a_index] * B[b_index]; + } } vals[j] = sum; } @@ -183,29 +211,54 @@ class MaskedMatmulTest : public ::testing::TestWithParam(handle, 1, a_size + b_size); + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); auto labels = raft::make_device_vector(handle, 1); - raft::random::make_blobs(blobs_a_b.data_handle(), - labels.data_handle(), - 1, - a_size + b_size, - 1, - stream, - false, - nullptr, - nullptr, - value_t(1.0), - false, - value_t(-1.0f), - value_t(1.0f), - uint64_t(2024)); - - raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); - raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); - - raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); - raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + output_t(1.0), + false, + output_t(-1.0f), + output_t(1.0f), + uint64_t(2024)); + + if constexpr ((std::is_same_v && std::is_same_v)) { + { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_a_b.data_handle()); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(a_data_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + a_size, + d_value_ptr, + float_to_half()); + } + { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_a_b.data_handle() + a_size); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(b_data_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + b_size, + d_value_ptr, + float_to_half()); + } + raft::copy(a_data_h.data(), a_data_d.data(), a_size, stream); + raft::copy(b_data_h.data(), b_data_d.data(), b_size, stream); + } else { + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + } resource::sync_stream(handle); @@ -213,7 +266,7 @@ class MaskedMatmulTest : public ::testing::TestWithParam c_indptr_h(params.m + 1); std::vector c_indices_h(c_true_nnz); - std::vector c_data_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); cpu_convert_to_csr(bitmap_h, params.m, params.n, c_indices_h, c_indptr_h); @@ -236,7 +289,13 @@ class MaskedMatmulTest : public ::testing::TestWithParam && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } + make_data(); + } void Run() { @@ -255,33 +314,33 @@ class MaskedMatmulTest : public ::testing::TestWithParam(c_indices_d.size())); - auto C = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + auto C = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); raft::sparse::linalg::masked_matmul(handle, A, B, mask, C); resource::sync_stream(handle); - ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), - C.get_elements().data(), - c_expected_data_d.size(), - raft::CompareApprox(params.tolerance), - stream)); + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + C.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); - thrust::device_ptr expected_data_ptr = + thrust::device_ptr expected_data_ptr = thrust::device_pointer_cast(c_expected_data_d.data()); - value_t sum_abs = thrust::reduce(thrust::cuda::par.on(stream), - expected_data_ptr, - expected_data_ptr + c_expected_data_d.size(), - value_t(0.0f), - sum_abs_op()); - value_t avg = sum_abs / (1.0f * c_expected_data_d.size()); - - ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + output_t sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + output_t(0.0f), + sum_abs_op()); + output_t avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); } raft::resources handle; cudaStream_t stream; - MaskedMatmulInputs params; + MaskedMatmulInputs params; rmm::device_uvector a_data_d; rmm::device_uvector b_data_d; @@ -289,40 +348,82 @@ class MaskedMatmulTest : public ::testing::TestWithParam c_indptr_d; rmm::device_uvector c_indices_d; - rmm::device_uvector c_data_d; + rmm::device_uvector c_data_d; - rmm::device_uvector c_expected_data_d; + rmm::device_uvector c_expected_data_d; }; -using MaskedMatmulTestF = MaskedMatmulTest; +using MaskedMatmulTestF = MaskedMatmulTest; TEST_P(MaskedMatmulTestF, Result) { Run(); } -using MaskedMatmulTestD = MaskedMatmulTest; +using MaskedMatmulTestD = MaskedMatmulTest; TEST_P(MaskedMatmulTestD, Result) { Run(); } -const std::vector> sddmm_inputs_f = { +using MaskedMatmulTestH = MaskedMatmulTest; +TEST_P(MaskedMatmulTestH, Result) { Run(); } + +const std::vector> sddmm_inputs_f = { + {0.001f, 2, 255, 1023, 0.19, 1234ULL}, + {0.001f, 2, 255, 1023 * 2, 0.19, 1234ULL}, + {0.001f, 2, 255, 1023 * 3, 0.38, 1234ULL}, + {0.0001f, 10, 255, 13000, 0.01, 1234ULL}, {0.0001f, 10, 5, 32, 0.1, 1234ULL}, - {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, + {0.001f, 11, 255, 1023, 0.19, 1234ULL}, + {0.001f, 11, 255, 1023 * 2, 0.19, 1234ULL}, + {0.001f, 11, 255, 1023 * 3, 0.38, 1234ULL}, {0.0003f, 32, 1024, 1024, 0.2, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, {0.001f, 1024, 1024, 1024, 0.19, 1234ULL}, + {0.001f, 1023, 1023, 1023 * 3, 0.38, 1234ULL}, + {0.001f, 1025, 1025, 1025 * 3, 0.31, 1234ULL}, {0.0001f, 1024, 1024, 32, 0.3, 1234ULL}, {0.0001f, 1024, 32, 1024, 0.4, 1234ULL}, - {0.0003f, 32, 1024, 1024, 0.19, 1234ULL}, + {0.0003f, 31, 1025, 1025, 0.19, 1234ULL}, {0.001f, 1024, 1024, 1024, 0.1, 1234ULL}}; -const std::vector> sddmm_inputs_d = { - {0.0001f, 10, 5, 32, 0.01, 1234ULL}, - {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, +const std::vector> sddmm_inputs_d = { + {0.0001f, 2, 255, 1023, 0.19, 1234ULL}, + {0.0001f, 2, 255, 1023 * 2, 0.19, 1234ULL}, + {0.0001f, 2, 255, 1023 * 3, 0.38, 1234ULL}, + {0.0001f, 10, 255, 13000, 0.01, 1234ULL}, + {0.0001f, 10, 5, 32, 0.1, 1234ULL}, + {0.0001f, 11, 255, 1023, 0.19, 1234ULL}, + {0.0001f, 11, 255, 1023 * 2, 0.19, 1234ULL}, + {0.0001f, 11, 255, 1023 * 3, 0.38, 1234ULL}, {0.0001f, 32, 1024, 1024, 0.2, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, {0.0001f, 1024, 1024, 1024, 0.19, 1234ULL}, + {0.0001f, 1023, 1023, 1023 * 3, 0.38, 1234ULL}, + {0.0001f, 1025, 1025, 1025 * 3, 0.31, 1234ULL}, {0.0001f, 1024, 1024, 32, 0.3, 1234ULL}, {0.0001f, 1024, 32, 1024, 0.4, 1234ULL}, - {0.0001f, 32, 1024, 1024, 0.19, 1234ULL}, + {0.0001f, 31, 1025, 1025, 0.19, 1234ULL}, {0.0001f, 1024, 1024, 1024, 0.1, 1234ULL}}; +const std::vector> sddmm_inputs_h = { + {0.001f, 2, 255, 1023, 0.19, 1234ULL}, + {0.001f, 2, 255, 1023 * 2, 0.19, 1234ULL}, + {0.001f, 2, 255, 1023 * 3, 0.38, 1234ULL}, + {0.0001f, 10, 255, 13000, 0.01, 1234ULL}, + {0.0001f, 10, 5, 32, 0.1, 1234ULL}, + {0.001f, 11, 255, 1023, 0.19, 1234ULL}, + {0.001f, 11, 255, 1023 * 2, 0.19, 1234ULL}, + {0.001f, 11, 255, 1023 * 3, 0.38, 1234ULL}, + {0.0003f, 32, 1024, 1024, 0.2, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.1, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.19, 1234ULL}, + {0.001f, 1023, 1023, 1023 * 3, 0.38, 1234ULL}, + {0.001f, 1025, 1025, 1025 * 3, 0.31, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.4, 1234ULL}, + {0.0003f, 31, 1025, 1025, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.1, 1234ULL}}; + INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestF, ::testing::ValuesIn(sddmm_inputs_f)); INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestD, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(MaskedMatmulTest, MaskedMatmulTestH, ::testing::ValuesIn(sddmm_inputs_h)); + } // namespace sparse } // namespace raft diff --git a/cpp/test/sparse/sddmm.cu b/cpp/test/sparse/sddmm.cu index 8ff20581c9..26c2c519dd 100644 --- a/cpp/test/sparse/sddmm.cu +++ b/cpp/test/sparse/sddmm.cu @@ -22,8 +22,12 @@ #include #include +#include +#include #include +#include +#include #include #include @@ -32,16 +36,16 @@ namespace raft { namespace sparse { -template +template struct SDDMMInputs { - ValueType tolerance; + OutputType tolerance; IndexType m; IndexType k; IndexType n; - ValueType alpha; - ValueType beta; + OutputType alpha; + OutputType beta; bool transpose_a; bool transpose_b; @@ -59,6 +63,10 @@ struct sum_abs_op { } }; +struct float_to_half { + __host__ __device__ __half operator()(const float x) const { return __float2half(x); } +}; + template ::std::ostream& operator<<(::std::ostream& os, const SDDMMInputs& params) { @@ -72,11 +80,12 @@ template template -class SDDMMTest : public ::testing::TestWithParam> { + typename LayoutPolicyB = raft::layout_c_contiguous, + typename OutputType = ValueType> +class SDDMMTest : public ::testing::TestWithParam> { public: SDDMMTest() - : params(::testing::TestWithParam>::GetParam()), + : params(::testing::TestWithParam>::GetParam()), stream(resource::get_cuda_stream(handle)), a_data_d(0, resource::get_cuda_stream(handle)), b_data_d(0, resource::get_cuda_stream(handle)), @@ -88,9 +97,25 @@ class SDDMMTest : public ::testing::TestWithParam 12) || (major == 12 && minor > 0) || (major == 12 && minor == 0 && patch >= 2); + } + IndexType create_sparse_matrix(IndexType m, IndexType n, - ValueType sparsity, + OutputType sparsity, std::vector& matrix) { IndexType total_elements = static_cast(m * n); @@ -119,7 +144,7 @@ class SDDMMTest : public ::testing::TestWithParam& matrix, IndexType rows, IndexType cols, - std::vector& values, + std::vector& values, std::vector& indices, std::vector& indptr) { @@ -130,7 +155,7 @@ class SDDMMTest : public ::testing::TestWithParam(1.0); + values[offset_values] = static_cast(1.0); indices[offset_values] = static_cast(j); offset_values++; } @@ -141,7 +166,7 @@ class SDDMMTest : public ::testing::TestWithParam& A, const std::vector& B, - std::vector& vals, + std::vector& vals, const std::vector& cols, const std::vector& row_ptrs, bool is_row_major_A, @@ -158,11 +183,15 @@ class SDDMMTest : public ::testing::TestWithParam && std::is_same_v)) { + sum += __half2float(A[a_index]) * __half2float(B[b_index]); + } else { + sum += A[a_index] * B[b_index]; + } } vals[j] = params.alpha * sum + params.beta * vals[j]; } @@ -181,29 +210,53 @@ class SDDMMTest : public ::testing::TestWithParam(handle, 1, a_size + b_size); + auto blobs_a_b = raft::make_device_matrix(handle, 1, a_size + b_size); auto labels = raft::make_device_vector(handle, 1); - raft::random::make_blobs(blobs_a_b.data_handle(), - labels.data_handle(), - 1, - a_size + b_size, - 1, - stream, - false, - nullptr, - nullptr, - ValueType(1.0), - false, - ValueType(-1.0f), - ValueType(1.0f), - uint64_t(2024)); - - raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); - raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); - - raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); - raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + raft::random::make_blobs(blobs_a_b.data_handle(), + labels.data_handle(), + 1, + a_size + b_size, + 1, + stream, + false, + nullptr, + nullptr, + OutputType(1.0), + false, + OutputType(-1.0f), + OutputType(1.0f), + uint64_t(2024)); + if constexpr ((std::is_same_v && std::is_same_v)) { + { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_a_b.data_handle()); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(a_data_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + a_size, + d_value_ptr, + float_to_half()); + } + { + thrust::device_ptr d_output_ptr = + thrust::device_pointer_cast(blobs_a_b.data_handle() + a_size); + thrust::device_ptr d_value_ptr = thrust::device_pointer_cast(b_data_d.data()); + thrust::transform(thrust::cuda::par.on(stream), + d_output_ptr, + d_output_ptr + b_size, + d_value_ptr, + float_to_half()); + } + raft::copy(a_data_h.data(), a_data_d.data(), a_size, stream); + raft::copy(b_data_h.data(), b_data_d.data(), b_size, stream); + } else { + raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_h.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + + raft::copy(a_data_d.data(), blobs_a_b.data_handle(), a_size, stream); + raft::copy(b_data_d.data(), blobs_a_b.data_handle() + a_size, b_size, stream); + } resource::sync_stream(handle); @@ -213,7 +266,7 @@ class SDDMMTest : public ::testing::TestWithParam c_indptr_h(params.m + 1); std::vector c_indices_h(c_true_nnz); - std::vector c_data_h(c_true_nnz); + std::vector c_data_h(c_true_nnz); convert_to_csr(c_dense_data_h, params.m, params.n, c_data_h, c_indices_h, c_indptr_h); @@ -238,7 +291,13 @@ class SDDMMTest : public ::testing::TestWithParam && !isCuSparseVersionGreaterThan_12_0_1()) { + GTEST_SKIP() << "Skipping all tests for half-float as cuSparse doesn't support it."; + } + make_data(); + } void Run() { @@ -258,7 +317,7 @@ class SDDMMTest : public ::testing::TestWithParam(c_indices_d.size())); - auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); + auto c = raft::make_device_csr_matrix_view(c_data_d.data(), c_structure); auto op_a = params.transpose_a ? raft::linalg::Operation::TRANSPOSE : raft::linalg::Operation::NON_TRANSPOSE; @@ -271,41 +330,41 @@ class SDDMMTest : public ::testing::TestWithParam(¶ms.alpha), - raft::make_host_scalar_view(¶ms.beta)); + raft::make_host_scalar_view(¶ms.alpha), + raft::make_host_scalar_view(¶ms.beta)); resource::sync_stream(handle); - ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), - c.get_elements().data(), - c_expected_data_d.size(), - raft::CompareApprox(params.tolerance), - stream)); + ASSERT_TRUE(raft::devArrMatch(c_expected_data_d.data(), + c.get_elements().data(), + c_expected_data_d.size(), + raft::CompareApprox(params.tolerance), + stream)); - thrust::device_ptr expected_data_ptr = + thrust::device_ptr expected_data_ptr = thrust::device_pointer_cast(c_expected_data_d.data()); - ValueType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), - expected_data_ptr, - expected_data_ptr + c_expected_data_d.size(), - ValueType(0.0f), - sum_abs_op()); - ValueType avg = sum_abs / (1.0f * c_expected_data_d.size()); - - ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); + OutputType sum_abs = thrust::reduce(thrust::cuda::par.on(stream), + expected_data_ptr, + expected_data_ptr + c_expected_data_d.size(), + OutputType(0.0f), + sum_abs_op()); + OutputType avg = sum_abs / (1.0f * c_expected_data_d.size()); + + ASSERT_GE(avg, (params.tolerance * static_cast(0.001f))); } raft::resources handle; cudaStream_t stream; - SDDMMInputs params; + SDDMMInputs params; rmm::device_uvector a_data_d; rmm::device_uvector b_data_d; rmm::device_uvector c_indptr_d; rmm::device_uvector c_indices_d; - rmm::device_uvector c_data_d; + rmm::device_uvector c_data_d; - rmm::device_uvector c_expected_data_d; + rmm::device_uvector c_expected_data_d; }; using SDDMMTestF_Row_Col = SDDMMTest; @@ -332,6 +391,18 @@ TEST_P(SDDMMTestD_Row_Row, Result) { Run(); } using SDDMMTestD_Col_Col = SDDMMTest; TEST_P(SDDMMTestD_Col_Col, Result) { Run(); } +using SDDMMTestHF_Row_Col = SDDMMTest; +TEST_P(SDDMMTestHF_Row_Col, Result) { Run(); } + +using SDDMMTestHF_Col_Row = SDDMMTest; +TEST_P(SDDMMTestHF_Col_Row, Result) { Run(); } + +using SDDMMTestHF_Row_Row = SDDMMTest; +TEST_P(SDDMMTestHF_Row_Row, Result) { Run(); } + +using SDDMMTestHF_Col_Col = SDDMMTest; +TEST_P(SDDMMTestHF_Col_Col, Result) { Run(); } + const std::vector> sddmm_inputs_f = { {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, @@ -352,6 +423,16 @@ const std::vector> sddmm_inputs_d = { {0.0001f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, {0.0001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; +const std::vector> sddmm_inputs_h_f = { + {0.0001f, 10, 5, 32, 1.0, 0.0, false, false, 0.01, 1234ULL}, + {0.0001f, 1024, 32, 1024, 0.3, 0.0, true, false, 0.1, 1234ULL}, + {0.0003f, 32, 1024, 1024, 1.0, 0.3, false, true, 0.2, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.2, 0.2, true, true, 0.19, 1234ULL}, + {0.0001f, 1024, 1024, 32, 0.1, 0.2, false, false, 0.3, 1234ULL}, + {0.0001f, 1024, 32, 1024, 1.0, 0.3, true, false, 0.4, 1234ULL}, + {0.0003f, 32, 1024, 1024, 2.0, 0.2, false, true, 0.19, 1234ULL}, + {0.001f, 1024, 1024, 1024, 0.0, 1.2, true, true, 0.1, 1234ULL}}; + INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Col, ::testing::ValuesIn(sddmm_inputs_f)); INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Col_Row, ::testing::ValuesIn(sddmm_inputs_f)); INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestF_Row_Row, ::testing::ValuesIn(sddmm_inputs_f)); @@ -362,5 +443,10 @@ INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Row, ::testing::ValuesIn(sddmm INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Row_Row, ::testing::ValuesIn(sddmm_inputs_d)); INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestD_Col_Col, ::testing::ValuesIn(sddmm_inputs_d)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestHF_Row_Col, ::testing::ValuesIn(sddmm_inputs_h_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestHF_Col_Row, ::testing::ValuesIn(sddmm_inputs_h_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestHF_Row_Row, ::testing::ValuesIn(sddmm_inputs_h_f)); +INSTANTIATE_TEST_CASE_P(SDDMMTest, SDDMMTestHF_Col_Col, ::testing::ValuesIn(sddmm_inputs_h_f)); + } // namespace sparse } // namespace raft