Skip to content

Commit

Permalink
[FEA] add the support of masked_matmul (#2362)
Browse files Browse the repository at this point in the history
#2336

Authors:
  - rhdong (https://github.com/rhdong)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2362
  • Loading branch information
rhdong authored Jul 24, 2024
1 parent fa7c193 commit ffceee2
Show file tree
Hide file tree
Showing 6 changed files with 776 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/bench/prims/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ if(BUILD_PRIMS_BENCH)
PATH
linalg/add.cu
linalg/map_then_reduce.cu
linalg/masked_matmul.cu
linalg/matrix_vector_op.cu
linalg/norm.cu
linalg/normalize.cu
Expand Down
268 changes: 268 additions & 0 deletions cpp/bench/prims/linalg/masked_matmul.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <common/benchmark.hpp>

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/linalg/masked_matmul.hpp>
#include <raft/util/itertools.hpp>

#include <cusparse_v2.h>

#include <random>
#include <sstream>
#include <vector>

namespace raft::bench::linalg {

template <typename value_t>
struct MaskedMatmulBenchParams {
size_t m;
size_t k;
size_t n;
float sparsity;
value_t alpha = 1.0;
value_t beta = 0.0;
};

template <typename value_t>
inline auto operator<<(std::ostream& os, const MaskedMatmulBenchParams<value_t>& params)
-> std::ostream&
{
os << " m*k*n=" << params.m << "*" << params.k << "*" << params.n
<< "\tsparsity=" << params.sparsity;
if (params.sparsity == 1.0) { os << "<-inner product for comparison"; }
return os;
}

template <typename value_t, typename index_t = int64_t, typename bitmap_t = uint32_t>
struct MaskedMatmulBench : public fixture {
MaskedMatmulBench(const MaskedMatmulBenchParams<value_t>& p)
: fixture(true),
params(p),
handle(stream),
a_data_d(0, stream),
b_data_d(0, stream),
c_indptr_d(0, stream),
c_indices_d(0, stream),
c_data_d(0, stream),
bitmap_d(0, stream),
c_dense_data_d(0, stream)
{
index_t element = raft::ceildiv(index_t(params.m * params.n), index_t(sizeof(bitmap_t) * 8));
std::vector<bitmap_t> bitmap_h(element);

a_data_d.resize(params.m * params.k, stream);
b_data_d.resize(params.k * params.n, stream);
bitmap_d.resize(element, stream);

raft::random::RngState rng(2024ULL);
raft::random::uniform(
handle, rng, a_data_d.data(), params.m * params.k, value_t(-1.0), value_t(1.0));
raft::random::uniform(
handle, rng, b_data_d.data(), params.k * params.n, value_t(-1.0), value_t(1.0));

std::vector<bool> c_dense_data_h(params.m * params.n);

c_true_nnz = create_sparse_matrix(params.m, params.n, params.sparsity, bitmap_h);

std::vector<value_t> values(c_true_nnz);
std::vector<index_t> indices(c_true_nnz);
std::vector<index_t> indptr(params.m + 1);

c_data_d.resize(c_true_nnz, stream);
c_indptr_d.resize(params.m + 1, stream);
c_indices_d.resize(c_true_nnz, stream);
c_dense_data_d.resize(params.m * params.n, stream);

cpu_convert_to_csr(bitmap_h, params.m, params.n, indices, indptr);
RAFT_EXPECTS(c_true_nnz == c_indices_d.size(),
"Something wrong. The c_true_nnz != c_indices_d.size()!");

update_device(c_data_d.data(), values.data(), c_true_nnz, stream);
update_device(c_indices_d.data(), indices.data(), c_true_nnz, stream);
update_device(c_indptr_d.data(), indptr.data(), params.m + 1, stream);
update_device(bitmap_d.data(), bitmap_h.data(), element, stream);
}

index_t create_sparse_matrix(index_t m, index_t n, float sparsity, std::vector<bitmap_t>& bitmap)
{
index_t total = static_cast<index_t>(m * n);
index_t num_ones = static_cast<index_t>((total * 1.0f) * sparsity);
index_t res = num_ones;

for (auto& item : bitmap) {
item = static_cast<bitmap_t>(0);
}

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<index_t> dis(0, total - 1);

while (num_ones > 0) {
index_t index = dis(gen);

bitmap_t& element = bitmap[index / (8 * sizeof(bitmap_t))];
index_t bit_position = index % (8 * sizeof(bitmap_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
num_ones--;
}
}
return res;
}

void cpu_convert_to_csr(std::vector<bitmap_t>& bitmap,
index_t rows,
index_t cols,
std::vector<index_t>& indices,
std::vector<index_t>& indptr)
{
index_t offset_indptr = 0;
index_t offset_values = 0;
indptr[offset_indptr++] = 0;

index_t index = 0;
bitmap_t element = 0;
index_t bit_position = 0;

for (index_t i = 0; i < rows; ++i) {
for (index_t j = 0; j < cols; ++j) {
index = i * cols + j;
element = bitmap[index / (8 * sizeof(bitmap_t))];
bit_position = index % (8 * sizeof(bitmap_t));

if (((element >> bit_position) & 1)) {
indices[offset_values] = static_cast<index_t>(j);
offset_values++;
}
}
indptr[offset_indptr++] = static_cast<index_t>(offset_values);
}
}

~MaskedMatmulBench() {}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

auto a = raft::make_device_matrix_view<const value_t, index_t, row_major>(
a_data_d.data(), params.m, params.k);

auto b = raft::make_device_matrix_view<const value_t, index_t, row_major>(
b_data_d.data(), params.n, params.k);

auto c_structure = raft::make_device_compressed_structure_view<int64_t, int64_t, int64_t>(
c_indptr_d.data(),
c_indices_d.data(),
params.m,
params.n,
static_cast<index_t>(c_indices_d.size()));

auto mask =
raft::core::bitmap_view<const bitmap_t, index_t>(bitmap_d.data(), params.m, params.n);

auto c = raft::make_device_csr_matrix_view<value_t>(c_data_d.data(), c_structure);

if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
b_data_d.data(),
c_dense_data_d.data(),
static_cast<int>(params.m),
static_cast<int>(params.n),
static_cast<int>(params.k),
raft::distance::DistanceType::InnerProduct,
true);
}
resource::sync_stream(handle);

raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
resource::sync_stream(handle);

loop_on_state(state, [this, &a, &b, &mask, &c]() {
if (params.sparsity < 1.0) {
raft::sparse::linalg::masked_matmul(handle, a, b, mask, c);
} else {
raft::distance::pairwise_distance(handle,
a_data_d.data(),
b_data_d.data(),
c_dense_data_d.data(),
static_cast<int>(params.m),
static_cast<int>(params.n),
static_cast<int>(params.k),
raft::distance::DistanceType::InnerProduct,
true);
}
resource::sync_stream(handle);
});
}

private:
const raft::device_resources handle;
MaskedMatmulBenchParams<value_t> params;

rmm::device_uvector<value_t> a_data_d;
rmm::device_uvector<value_t> b_data_d;
rmm::device_uvector<bitmap_t> bitmap_d;

rmm::device_uvector<value_t> c_dense_data_d;

size_t c_true_nnz = 0;
rmm::device_uvector<index_t> c_indptr_d;
rmm::device_uvector<index_t> c_indices_d;
rmm::device_uvector<value_t> c_data_d;
};

template <typename value_t>
static std::vector<MaskedMatmulBenchParams<value_t>> getInputs()
{
std::vector<MaskedMatmulBenchParams<value_t>> param_vec;
struct TestParams {
size_t m;
size_t k;
size_t n;
float sparsity;
};

const std::vector<TestParams> params_group =
raft::util::itertools::product<TestParams>({size_t(10), size_t(1024)},
{size_t(128), size_t(1024)},
{size_t(1024 * 1024)},
{0.01f, 0.1f, 0.2f, 0.5f, 1.0f});

param_vec.reserve(params_group.size());
for (TestParams params : params_group) {
param_vec.push_back(
MaskedMatmulBenchParams<value_t>({params.m, params.k, params.n, params.sparsity}));
}
return param_vec;
}

RAFT_BENCH_REGISTER((MaskedMatmulBench<float>), "", getInputs<float>());

} // namespace raft::bench::linalg
107 changes: 107 additions & 0 deletions cpp/include/raft/sparse/linalg/detail/masked_matmul.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain A copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/bitmap.cuh>
#include <raft/core/detail/popc.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/sparse/convert/csr.cuh>
#include <raft/sparse/distance/detail/utils.cuh>
#include <raft/sparse/linalg/sddmm.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/device_uvector.hpp>

namespace raft {
namespace sparse {
namespace linalg {
namespace detail {

template <typename value_t, typename index_t, typename nnz_t, typename bitmap_t>
void masked_matmul(raft::resources const& handle,
raft::device_matrix_view<const value_t, index_t, raft::row_major>& A,
raft::device_matrix_view<const value_t, index_t, raft::row_major>& B,
raft::core::bitmap_view<const bitmap_t, index_t>& mask,
raft::device_csr_matrix_view<value_t, index_t, index_t, nnz_t>& C,
std::optional<raft::host_scalar_view<value_t>> alpha,
std::optional<raft::host_scalar_view<value_t>> beta)
{
index_t m = A.extent(0);
index_t n = B.extent(0);
index_t dim = A.extent(1);

auto compressed_C_view = C.structure_view();

RAFT_EXPECTS(A.extent(1) == B.extent(1), "The dim of A must be equal to the dim of B.");
RAFT_EXPECTS(A.extent(0) == compressed_C_view.get_n_rows(),
"Number of rows in C must match the number of rows in A.");
RAFT_EXPECTS(B.extent(0) == compressed_C_view.get_n_cols(),
"Number of columns in C must match the number of columns in B.");

auto stream = raft::resource::get_cuda_stream(handle);

auto C_matrix = raft::make_device_csr_matrix<value_t, index_t>(handle, compressed_C_view);

// fill C
raft::sparse::convert::bitmap_to_csr(handle, mask, C_matrix);

if (m > 10 || alpha.has_value() || beta.has_value()) {
auto C_view = raft::make_device_csr_matrix_view<value_t, index_t, index_t, index_t>(
C.get_elements().data(), compressed_C_view);

// create B col_major view
auto B_col_major = raft::make_device_matrix_view<const value_t, index_t, raft::col_major>(
B.data_handle(), dim, n);

value_t default_alpha = static_cast<value_t>(1.0f);
value_t default_beta = static_cast<value_t>(0.0f);

if (!alpha.has_value()) { alpha = raft::make_host_scalar_view<value_t>(&default_alpha); }
if (!beta.has_value()) { beta = raft::make_host_scalar_view<value_t>(&default_beta); }

raft::sparse::linalg::sddmm(handle,
A,
B_col_major,
C_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
*alpha,
*beta);
} else {
raft::sparse::distance::detail::faster_dot_on_csr(handle,
C.get_elements().data(),
compressed_C_view.get_nnz(),
compressed_C_view.get_indptr().data(),
compressed_C_view.get_indices().data(),
A.data_handle(),
B.data_handle(),
compressed_C_view.get_n_rows(),
dim);
}
}

} // namespace detail
} // namespace linalg
} // namespace sparse
} // namespace raft
Loading

0 comments on commit ffceee2

Please sign in to comment.