diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 95361e19ca..9f23c44a5c 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -131,7 +131,14 @@ if(BUILD_PRIMS_BENCH) bench/prims/random/rng.cu bench/prims/random/subsample.cu bench/prims/main.cpp ) - ConfigureBench(NAME SPARSE_BENCH PATH bench/prims/sparse/convert_csr.cu bench/prims/main.cpp) + ConfigureBench( + NAME + SPARSE_BENCH + PATH + bench/prims/sparse/bitmap_to_csr.cu + bench/prims/sparse/convert_csr.cu + bench/prims/main.cpp + ) ConfigureBench( NAME diff --git a/cpp/bench/prims/sparse/bitmap_to_csr.cu b/cpp/bench/prims/sparse/bitmap_to_csr.cu new file mode 100644 index 0000000000..ed53df3265 --- /dev/null +++ b/cpp/bench/prims/sparse/bitmap_to_csr.cu @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + float sparsity; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << " rows*cols=" << params.n_rows << "*" << params.n_cols << "\tsparsity=" << params.sparsity; + return os; +} + +template +struct BitmapToCsrBench : public fixture { + BitmapToCsrBench(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + bitmap_d(0, stream), + nnz(0), + indptr_d(0, stream), + indices_d(0, stream), + values_d(0, stream) + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + values_d.resize(nnz, stream); + + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + + resource::sync_stream(handle); + loop_on_state(state, [this, &bitmap, &csr]() { + raft::sparse::convert::bitmap_to_csr(handle, bitmap, csr); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + + rmm::device_uvector bitmap_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + index_t nnz; +}; // struct BitmapToCsrBench + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + float sparsity; + }; + + const std::vector params_group = raft::util::itertools::product( + {index_t(10), index_t(1024)}, {index_t(1024 * 1024)}, {0.01f, 0.1f, 0.2f, 0.5f}); + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.sparsity})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); +RAFT_BENCH_REGISTER((BitmapToCsrBench), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/core/bitmap.cuh b/cpp/include/raft/core/bitmap.cuh new file mode 100644 index 0000000000..829c84ed25 --- /dev/null +++ b/cpp/include/raft/core/bitmap.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::core { +/** + * @defgroup bitmap Bitmap + * @{ + */ +/** + * @brief View of a RAFT Bitmap. + * + * This lightweight structure which represents and manipulates a two-dimensional bitmap matrix view + * with row major order. This class provides functionality for handling a matrix where each element + * is represented as a bit in a bitmap. + * + * @tparam bitmap_t Underlying type of the bitmap array. Default is uint32_t. + * @tparam index_t Indexing type used. Default is uint32_t. + */ +template +struct bitmap_view : public bitset_view { + static_assert((std::is_same::value || + std::is_same::value), + "The bitmap_t must be uint32_t or uint64_t."); + /** + * @brief Create a bitmap view from a device raw pointer. + * + * @param bitmap_ptr Device raw pointer + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t rows, index_t cols) + : bitset_view(bitmap_ptr, rows * cols), rows_(rows), cols_(cols) + { + } + + /** + * @brief Create a bitmap view from a device vector view of the bitset. + * + * @param bitmap_span Device vector view of the bitmap + * @param rows Number of row in the matrix. + * @param cols Number of col in the matrix. + */ + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t rows, + index_t cols) + : bitset_view(bitmap_span, rows * cols), rows_(rows), cols_(cols) + { + } + + private: + // Hide the constructors of bitset_view. + _RAFT_HOST_DEVICE bitmap_view(bitmap_t* bitmap_ptr, index_t bitmap_len) + : bitset_view(bitmap_ptr, bitmap_len) + { + } + + _RAFT_HOST_DEVICE bitmap_view(raft::device_vector_view bitmap_span, + index_t bitmap_len) + : bitset_view(bitmap_span, bitmap_len) + { + } + + public: + /** + * @brief Device function to test if a given row and col are set in the bitmap. + * + * @param row Row index of the bit to test + * @param col Col index of the bit to test + * @return bool True if index has not been unset in the bitset + */ + inline _RAFT_DEVICE auto test(const index_t row, const index_t col) const -> bool + { + return test(row * cols_ + col); + } + + /** + * @brief Device function to set a given row and col to set_value in the bitset. + * + * @param row Row index of the bit to set + * @param col Col index of the bit to set + * @param new_value Value to set the bit to (true or false) + */ + inline _RAFT_DEVICE void set(const index_t row, const index_t col, bool new_value) const + { + set(row * cols_ + col, &new_value); + } + + /** + * @brief Get the total number of rows + * @return index_t The total number of rows + */ + inline _RAFT_HOST_DEVICE index_t get_n_rows() const { return rows_; } + + /** + * @brief Get the total number of columns + * @return index_t The total number of columns + */ + inline _RAFT_HOST_DEVICE index_t get_n_cols() const { return cols_; } + + private: + index_t rows_; + index_t cols_; +}; + +/** @} */ +} // end namespace raft::core diff --git a/cpp/include/raft/sparse/convert/csr.cuh b/cpp/include/raft/sparse/convert/csr.cuh index 999e64cb0b..081192ed44 100644 --- a/cpp/include/raft/sparse/convert/csr.cuh +++ b/cpp/include/raft/sparse/convert/csr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,10 @@ #pragma once +#include +#include #include +#include #include #include @@ -102,6 +105,30 @@ void adj_to_csr(raft::resources const& handle, detail::adj_to_csr(handle, adj, row_ind, num_rows, num_cols, tmp, out_col_ind); } +/** + * @brief Converts a bitmap matrix to a Compressed Sparse Row (CSR) format matrix. + * + * @tparam bitmap_t The data type of the elements in the bitmap matrix. + * @tparam index_t The data type used for indexing the elements in the matrices. + * @tparam csr_matrix_t Specifies the CSR matrix type, constrained to + * raft::device_csr_matrix. + * + * @param[in] handle The RAFT handle containing the CUDA stream for operations. + * @param[in] bitmap The bitmap matrix view, to be converted to CSR format. + * @param[out] csr Output parameter where the resulting CSR matrix is stored. In the + * bitmap, each '1' bit corresponds to a non-zero element in the CSR matrix. + */ +template >> +void bitmap_to_csr(raft::resources const& handle, + raft::core::bitmap_view bitmap, + csr_matrix_t& csr) +{ + detail::bitmap_to_csr(handle, bitmap, csr); +} + }; // end NAMESPACE convert }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh new file mode 100644 index 0000000000..b0315486ff --- /dev/null +++ b/cpp/include/raft/sparse/convert/detail/bitmap_to_csr.cuh @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // detail::popc +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cg = cooperative_groups; + +namespace raft { +namespace sparse { +namespace convert { +namespace detail { + +// Threads per block in calc_nnz_by_rows_kernel. +static const constexpr int calc_nnz_by_rows_tpb = 32; + +template +RAFT_KERNEL __launch_bounds__(calc_nnz_by_rows_tpb) calc_nnz_by_rows_kernel(const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + index_t bitmap_num, + nnz_t* nnz_per_row) +{ + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + auto block = cg::this_thread_block(); + auto tile = cg::tiled_partition<32>(block); + + int lane_id = threadIdx.x & 0x1f; + + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t offset = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + index_t l_sum = 0; + + while (offset < num_cols) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + l_sum += static_cast(raft::detail::popc(l_bitmap)); + offset += BITS_PER_BITMAP * warpSize; + } + + l_sum = cg::reduce(tile, l_sum, cg::plus()); + + if (lane_id == 0) { *(nnz_per_row + row) += static_cast(l_sum); } + } +} + +template +void calc_nnz_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + index_t num_rows, + index_t num_cols, + nnz_t* nnz_per_row) +{ + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); + + int dev_id, sm_count, blocks_per_sm; + + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, calc_nnz_by_rows_kernel, calc_nnz_by_rows_tpb, 0); + + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, raft::ceildiv(bitmap_num, index_t(calc_nnz_by_rows_tpb))); + auto block = calc_nnz_by_rows_tpb; + + calc_nnz_by_rows_kernel + <<>>(bitmap, num_rows, num_cols, bitmap_num, nnz_per_row); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/* + Execute the exclusive_scan within one warp with no inter-warp communication. + This function calculates the exclusive prefix sum of `value` across threads within the same warp. + Each thread in the warp will end up with the sum of all the values of the threads with lower IDs + in the same warp, with the first thread always getting a sum of 0. +*/ +template +RAFT_DEVICE_INLINE_FUNCTION value_t warp_exclusive_scan(value_t value) +{ + int lane_id = threadIdx.x & 0x1f; + value_t shifted_value = __shfl_up_sync(0xffffffff, value, 1, warpSize); + if (lane_id == 0) shifted_value = 0; + + value_t sum = shifted_value; + + for (int i = 1; i < warpSize; i *= 2) { + value_t n = __shfl_up_sync(0xffffffff, sum, i, warpSize); + if (lane_id >= i) { sum += n; } + } + return sum; +} + +// Threads per block in fill_indices_by_rows_kernel. +static const constexpr int fill_indices_by_rows_tpb = 32; + +template +RAFT_KERNEL __launch_bounds__(fill_indices_by_rows_tpb) + fill_indices_by_rows_kernel(const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + nnz_t nnz, + index_t bitmap_num, + index_t* indices) +{ + constexpr bitmap_t FULL_MASK = ~bitmap_t(0u); + constexpr bitmap_t ONE = bitmap_t(1u); + constexpr index_t BITS_PER_BITMAP = sizeof(bitmap_t) * 8; + + int lane_id = threadIdx.x & 0x1f; + + // Ensure the HBM allocated for CSR values is sufficient to handle all non-zero bitmap bits. + // An assert will trigger if the allocated HBM is insufficient when `NDEBUG` isn't defined. + // Note: Assertion is active only if `NDEBUG` is undefined. + if constexpr (check_nnz) { + if (lane_id == 0) { assert(nnz < indptr[num_rows]); } + } + +#pragma unroll + for (index_t row = blockIdx.x; row < num_rows; row += gridDim.x) { + index_t g_sum = 0; + index_t s_bit = row * num_cols; + index_t e_bit = s_bit + num_cols; + index_t indptr_row = indptr[row]; + +#pragma unroll + for (index_t offset = 0; offset < num_cols; offset += BITS_PER_BITMAP * warpSize) { + index_t bitmap_idx = lane_id + (s_bit + offset) / BITS_PER_BITMAP; + bitmap_t l_bitmap = bitmap_t(0); + index_t l_offset = offset + lane_id * BITS_PER_BITMAP - (s_bit % BITS_PER_BITMAP); + + if (bitmap_idx * BITS_PER_BITMAP < e_bit) { l_bitmap = bitmap[bitmap_idx]; } + + if (s_bit > bitmap_idx * BITS_PER_BITMAP) { + l_bitmap >>= (s_bit - bitmap_idx * BITS_PER_BITMAP); + l_bitmap <<= (s_bit - bitmap_idx * BITS_PER_BITMAP); + } + + if ((bitmap_idx + 1) * BITS_PER_BITMAP > e_bit) { + l_bitmap <<= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + l_bitmap >>= ((bitmap_idx + 1) * BITS_PER_BITMAP - e_bit); + } + + index_t l_sum = + g_sum + warp_exclusive_scan(static_cast(raft::detail::popc(l_bitmap))); + + for (int i = 0; i < BITS_PER_BITMAP; i++) { + if (l_bitmap & (ONE << i)) { + indices[indptr_row + l_sum] = l_offset + i; + l_sum++; + } + } + g_sum = __shfl_sync(0xffffffff, l_sum, warpSize - 1); + } + } +} + +template +void fill_indices_by_rows(raft::resources const& handle, + const bitmap_t* bitmap, + const index_t* indptr, + index_t num_rows, + index_t num_cols, + nnz_t nnz, + index_t* indices) +{ + auto stream = resource::get_cuda_stream(handle); + const index_t total = num_rows * num_cols; + const index_t bitmap_num = raft::ceildiv(total, index_t(sizeof(bitmap_t) * 8)); + + int dev_id, sm_count, blocks_per_sm; + + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, + fill_indices_by_rows_kernel, + fill_indices_by_rows_tpb, + 0); + + index_t max_active_blocks = sm_count * blocks_per_sm; + auto grid = std::min(max_active_blocks, num_rows); + auto block = fill_indices_by_rows_tpb; + + fill_indices_by_rows_kernel + <<>>(bitmap, indptr, num_rows, num_cols, nnz, bitmap_num, indices); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template >> +void bitmap_to_csr(raft::resources const& handle, + raft::core::bitmap_view bitmap, + csr_matrix_t& csr) +{ + auto csr_view = csr.structure_view(); + + if (csr_view.get_n_rows() == 0 || csr_view.get_n_cols() == 0 || csr_view.get_nnz() == 0) { + return; + } + + RAFT_EXPECTS(bitmap.get_n_rows() == csr_view.get_n_rows(), + "Number of rows in bitmap must be equal to " + "number of rows in csr"); + + RAFT_EXPECTS(bitmap.get_n_cols() == csr_view.get_n_cols(), + "Number of columns in bitmap must be equal to " + "number of columns in csr"); + + auto thrust_policy = resource::get_thrust_policy(handle); + auto stream = resource::get_cuda_stream(handle); + + index_t* indptr = csr_view.get_indptr().data(); + index_t* indices = csr_view.get_indices().data(); + + RAFT_CUDA_TRY(cudaMemsetAsync(indptr, 0, (csr_view.get_n_rows() + 1) * sizeof(index_t), stream)); + + calc_nnz_by_rows(handle, bitmap.data(), csr_view.get_n_rows(), csr_view.get_n_cols(), indptr); + thrust::exclusive_scan(thrust_policy, indptr, indptr + csr_view.get_n_rows() + 1, indptr); + + if constexpr (is_device_csr_sparsity_owning_v) { + index_t nnz = 0; + RAFT_CUDA_TRY(cudaMemcpyAsync( + &nnz, indptr + csr_view.get_n_rows(), sizeof(index_t), cudaMemcpyDeviceToHost, stream)); + resource::sync_stream(handle); + csr.initialize_sparsity(nnz); + } + constexpr bool check_nnz = is_device_csr_sparsity_preserving_v; + fill_indices_by_rows( + handle, + bitmap.data(), + indptr, + csr_view.get_n_rows(), + csr_view.get_n_cols(), + csr_view.get_nnz(), + indices); + + thrust::fill_n(thrust_policy, + csr.get_elements().data(), + csr_view.get_nnz(), + typename csr_matrix_t::element_type(1)); +} + +}; // end NAMESPACE detail +}; // end NAMESPACE convert +}; // end NAMESPACE sparse +}; // end NAMESPACE raft diff --git a/cpp/test/sparse/convert_csr.cu b/cpp/test/sparse/convert_csr.cu index 4af792a9ea..1cd49b0bbd 100644 --- a/cpp/test/sparse/convert_csr.cu +++ b/cpp/test/sparse/convert_csr.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" +#include #include #include #include @@ -218,5 +219,247 @@ INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, CSRAdjGraphTestL, ::testing::ValuesIn(csradjgraph_inputs_l)); +/******************************** bitmap to csr ********************************/ + +template +struct BitmapToCSRInputs { + index_t n_rows; + index_t n_cols; + float sparsity; + bool owning; +}; + +template +class BitmapToCSRTest : public ::testing::TestWithParam> { + public: + BitmapToCSRTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + bitmap_d(0, stream), + indices_d(0, stream), + indptr_d(0, stream), + values_d(0, stream), + indptr_expected_d(0, stream), + indices_expected_d(0, stream), + values_expected_d(0, stream) + { + } + + protected: + index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector& bitmap) + { + index_t total = static_cast(m * n); + index_t num_ones = static_cast((total * 1.0f) * sparsity); + index_t res = num_ones; + + for (auto& item : bitmap) { + item = static_cast(0); + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, total - 1); + + while (num_ones > 0) { + index_t index = dis(gen); + + bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))]; + index_t bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1) == 0) { + element |= (static_cast(1) << bit_position); + num_ones--; + } + } + return res; + } + + void cpu_convert_to_csr(std::vector& bitmap, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + index_t index = 0; + bitmap_t element = 0; + index_t bit_position = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + index = i * cols + j; + element = bitmap[index / (8 * sizeof(bitmap_t))]; + bit_position = index % (8 * sizeof(bitmap_t)); + + if (((element >> bit_position) & 1)) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + bool csr_compare(const std::vector& row_ptrs1, + const std::vector& col_indices1, + const std::vector& row_ptrs2, + const std::vector& col_indices2) + { + if (row_ptrs1.size() != row_ptrs2.size()) { return false; } + + if (col_indices1.size() != col_indices2.size()) { return false; } + + if (!std::equal(row_ptrs1.begin(), row_ptrs1.end(), row_ptrs2.begin())) { return false; } + + for (size_t i = 0; i < row_ptrs1.size() - 1; ++i) { + size_t start_idx = row_ptrs1[i]; + size_t end_idx = row_ptrs1[i + 1]; + + std::vector cols1(col_indices1.begin() + start_idx, col_indices1.begin() + end_idx); + std::vector cols2(col_indices2.begin() + start_idx, col_indices2.begin() + end_idx); + + std::sort(cols1.begin(), cols1.end()); + std::sort(cols2.begin(), cols2.end()); + + if (cols1 != cols2) { return false; } + } + + return true; + } + + void SetUp() override + { + index_t element = raft::ceildiv(params.n_rows * params.n_cols, index_t(sizeof(bitmap_t) * 8)); + std::vector bitmap_h(element); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, bitmap_h); + + std::vector indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + cpu_convert_to_csr(bitmap_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + bitmap_d.resize(bitmap_h.size(), stream); + indptr_d.resize(params.n_rows + 1, stream); + indices_d.resize(nnz, stream); + + indptr_expected_d.resize(params.n_rows + 1, stream); + indices_expected_d.resize(nnz, stream); + values_expected_d.resize(nnz, stream); + + thrust::fill_n(resource::get_thrust_policy(handle), values_expected_d.data(), nnz, value_t{1}); + + values_d.resize(nnz, stream); + + update_device(indices_expected_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_expected_d.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(bitmap_d.data(), bitmap_h.data(), bitmap_h.size(), stream); + + resource::sync_stream(handle); + } + + void Run() + { + auto bitmap = + raft::core::bitmap_view(bitmap_d.data(), params.n_rows, params.n_cols); + + if (params.owning) { + auto csr = + raft::make_device_csr_matrix(handle, params.n_rows, params.n_cols, nnz); + auto csr_view = csr.structure_view(); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(indptr_d.data(), csr_view.get_indptr().data(), indptr_d.size(), stream); + raft::copy(indices_d.data(), csr_view.get_indices().data(), indices_d.size(), stream); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } else { + auto csr_view = raft::make_device_compressed_structure_view( + indptr_d.data(), indices_d.data(), params.n_rows, params.n_cols, nnz); + auto csr = raft::make_device_csr_matrix(handle, csr_view); + + convert::bitmap_to_csr(handle, bitmap, csr); + raft::copy(values_d.data(), csr.get_elements().data(), nnz, stream); + } + resource::sync_stream(handle); + + std::vector indices_h(indices_expected_d.size(), 0); + std::vector indices_expected_h(indices_expected_d.size(), 0); + update_host(indices_h.data(), indices_d.data(), indices_h.size(), stream); + update_host(indices_expected_h.data(), indices_expected_d.data(), indices_h.size(), stream); + + std::vector indptr_h(indptr_expected_d.size(), 0); + std::vector indptr_expected_h(indptr_expected_d.size(), 0); + update_host(indptr_h.data(), indptr_d.data(), indptr_h.size(), stream); + update_host(indptr_expected_h.data(), indptr_expected_d.data(), indptr_h.size(), stream); + + resource::sync_stream(handle); + + ASSERT_TRUE(csr_compare(indptr_h, indices_h, indptr_expected_h, indices_expected_h)); + ASSERT_TRUE(raft::devArrMatch( + values_expected_d.data(), values_d.data(), nnz, raft::Compare(), stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + BitmapToCSRInputs params; + + rmm::device_uvector bitmap_d; + + index_t nnz; + + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector values_d; + + rmm::device_uvector indptr_expected_d; + rmm::device_uvector indices_expected_d; + rmm::device_uvector values_expected_d; +}; + +using BitmapToCSRTestI = BitmapToCSRTest; +TEST_P(BitmapToCSRTestI, Result) { Run(); } + +using BitmapToCSRTestL = BitmapToCSRTest; +TEST_P(BitmapToCSRTestL, Result) { Run(); } + +template +const std::vector> bitmaptocsr_inputs = { + {0, 0, 0.2, false}, + {10, 32, 0.4, false}, + {10, 3, 0.2, false}, + {32, 1024, 0.4, false}, + {1024, 1048576, 0.01, false}, + {1024, 1024, 0.4, false}, + {64 * 1024 + 10, 2, 0.3, false}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, false}, // No peeling-remainder + {17, 16, 0.3, false}, // Check peeling-remainder + {18, 16, 0.3, false}, // Check peeling-remainder + {32 + 9, 33, 0.2, false}, // Check peeling-remainder + {2, 33, 0.2, false}, // Check peeling-remainder + {0, 0, 0.2, true}, + {10, 32, 0.4, true}, + {10, 3, 0.2, true}, + {32, 1024, 0.4, true}, + {1024, 1048576, 0.01, true}, + {1024, 1024, 0.4, true}, + {64 * 1024 + 10, 2, 0.3, true}, // 64K + 10 is slightly over maximum of blockDim.y + {16, 16, 0.3, true}, // No peeling-remainder + {17, 16, 0.3, true}, // Check peeling-remainder + {18, 16, 0.3, true}, // Check peeling-remainder + {32 + 9, 33, 0.2, true}, // Check peeling-remainder + {2, 33, 0.2, true}, // Check peeling-remainder +}; + +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestI, + ::testing::ValuesIn(bitmaptocsr_inputs)); +INSTANTIATE_TEST_CASE_P(SparseConvertCSRTest, + BitmapToCSRTestL, + ::testing::ValuesIn(bitmaptocsr_inputs)); + } // namespace sparse } // namespace raft diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 39e57fd69a..4122a18506 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -21,4 +21,5 @@ expose in public APIs. core_interruptible.rst core_operators.rst core_math.rst - core_bitset.rst \ No newline at end of file + core_bitset.rst + core_bitmap.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_bitmap.rst b/docs/source/cpp_api/core_bitmap.rst new file mode 100644 index 0000000000..6c1dc607bf --- /dev/null +++ b/docs/source/cpp_api/core_bitmap.rst @@ -0,0 +1,15 @@ +Bitmap +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: bitmap + :project: RAFT + :members: + :content-only: \ No newline at end of file