diff --git a/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh new file mode 100644 index 0000000000..38c0d7405d --- /dev/null +++ b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh @@ -0,0 +1,539 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::sparse::matrix::detail { + +/** + * @brief Calculates the BM25 values for a target matrix. + * @param num_feats: The total number of features in the matrix + * @param avg_feat_len: The avg length of all features combined. + * @param k_param: K value required by BM25 algorithm. + * @param b_param: B value required by BM25 algorithm. + */ +template +struct bm25 { + bm25(IndexType num_feats, ValueType avg_feat_len, ValueType k_param, ValueType b_param) + { + total_feats = num_feats; + avg_feat_length = avg_feat_len; + k = k_param; + b = b_param; + } + + float __device__ operator()(const ValueType& value, + const ValueType& num_feats_id_occ, + const ValueType& feat_length) + { + ValueType tf = ValueType(value / feat_length); + ValueType idf = raft::log(total_feats / num_feats_id_occ); + ValueType bm = ((k + 1) * tf) / (k * ((1.0f - b) + b * (feat_length / avg_feat_length)) + tf); + + return idf * bm; + } + ValueType avg_feat_length; + IndexType total_feats; + ValueType k; + ValueType b; +}; + +/** + * @brief Calculates the tfidf values for a target matrix. Term frequency is calculate using + * logrithmically scaled frequency. + * @param total_feats_param: The total number of features in the matrix + */ +template +struct tfidf { + tfidf(IndexType total_feats_param) { total_feats = total_feats_param; } + + float __device__ operator()(const ValueType& value, + const ValueType& num_feats_id_occ, + const ValueType& feat_length) + { + ValueType tf = ValueType(value / feat_length); + ValueType idf = raft::log(total_feats / num_feats_id_occ); + return tf * idf; + } + IndexType total_feats; +}; + +template +struct mapper { + mapper(raft::device_vector_view map) : map(map) {} + + float __device__ operator()(const ValueType& value) + { + ValueType new_value = map[value]; + if (new_value) { + return new_value; + } else { + return 0.0f; + } + } + + raft::device_vector_view map; +}; + +template +struct map_to { + map_to(raft::device_vector_view map) : map(map) {} + + float __device__ operator()(const IndexType& key, const ValueType& count) + { + map[key] = count; + return 0.0f; + } + + raft::device_vector_view map; +}; + +/** + * @brief Get unique counts + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param sort_vector: Input COO array that contains the keys. + * @param secondary_vector: Input with secondary keys of COO, (columns or rows). + * @param data: Input COO values array. + * @param itr_vals: Input array used to calculate counts. + * @param keys_out: Output array with one entry for each key. (same size as counts_out) + * @param counts_out: Output array with cumulative sum for each key. (same size as keys_out) + */ +template +void get_uniques_counts(raft::resources& handle, + raft::device_vector_view sort_vector, + raft::device_vector_view secondary_vector, + raft::device_vector_view data, + raft::device_vector_view itr_vals, + raft::device_vector_view keys_out, + raft::device_vector_view counts_out) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + raft::sparse::op::coo_sort(int(sort_vector.size()), + int(secondary_vector.size()), + int(data.size()), + sort_vector.data_handle(), + secondary_vector.data_handle(), + data.data_handle(), + stream); + // replace this call with raft version when available + // (https://github.com/rapidsai/raft/issues/2477) + RAFT_CHECK_CUDA(stream); + thrust::reduce_by_key(raft::resource::get_thrust_policy(handle), + sort_vector.data_handle(), + sort_vector.data_handle() + sort_vector.size(), + itr_vals.data_handle(), + keys_out.data_handle(), + counts_out.data_handle()); +} + +/** + * @brief Broadcasts values to target indices of vector based on key/value look up + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param origin: Input array that has values to use for computation + * @param keys: Output array that has keys, should be the size of unique + * @param counts: Output array that contains the computed counts + * @param results: Output array that scatters the counts to origin value positions. Same size as + * origin array. + */ +template +void create_mapped_vector(raft::resources& handle, + const raft::device_vector_view origin, + const raft::device_vector_view keys, + const raft::device_vector_view counts, + raft::device_vector_view result, + IndexType key_size) +{ + // index into the last element and then add 1 to it. + auto origin_map = raft::make_device_vector(handle, key_size + 1); + raft::matrix::fill(handle, origin_map.view(), 0.0f); + + auto dummy_vec = raft::make_device_vector(handle, keys.size()); + raft::linalg::map(handle, + dummy_vec.view(), + map_to(origin_map.view()), + raft::make_const_mdspan(keys), + raft::make_const_mdspan(counts)); + + raft::linalg::map(handle, result, raft::cast_op{}, raft::make_const_mdspan(origin)); + raft::linalg::map( + handle, result, mapper(origin_map.view()), raft::make_const_mdspan(result)); +} + +/** + * @brief Compute row(id) counts + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param rows: Input COO rows array + * @param columns: Input COO columns array + * @param values: Input COO values array + * @param id_counts: Output array that stores counts per row, scattered to same shape as rows. + * @param n_rows: Number of rows in matrix + */ +template +void get_id_counts(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + raft::device_vector_view id_counts, + IndexType n_rows) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + auto rows_counts = raft::make_device_vector(handle, n_rows); + raft::matrix::fill(handle, rows_counts.view(), 0); + + raft::sparse::linalg::coo_degree(raft::make_const_mdspan(rows).data_handle(), + int(rows.size()), + rows_counts.data_handle(), + stream); + + raft::linalg::map( + handle, id_counts, mapper(rows_counts.view()), raft::make_const_mdspan(rows)); +} + +/** + * @brief Gather per feature mean values, returns the cumulative avg feature length. + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param rows: Input COO rows array + * @param columns: Input COO columns array + * @param values: Input COO values array + * @param feat_lengths: Output array that stores mean per feature value + * @param n_cols: Number of columns in matrix + */ +template +float get_feature_data(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + raft::device_vector_view feat_lengths, + IndexType n_cols) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto preserved_columns = raft::make_device_vector(handle, columns.size()); + + int uniq_cols = + raft::sparse::neighbors::get_n_components(columns.data_handle(), columns.size(), stream); + raft::copy(preserved_columns.data_handle(), columns.data_handle(), columns.size(), stream); + + auto col_keys = raft::make_device_vector(handle, uniq_cols); + auto col_counts = raft::make_device_vector(handle, uniq_cols); + + get_uniques_counts(handle, columns, rows, values, values, col_keys.view(), col_counts.view()); + + auto total_feature_lengths = raft::make_device_scalar(handle, 0); + raft::linalg::mapReduce(total_feature_lengths.data_handle(), + col_counts.size(), + 0, + raft::identity_op(), + raft::add_op(), + stream, + col_counts.data_handle()); + auto total_feature_lengths_host = raft::make_host_scalar(handle, 0); + raft::copy(total_feature_lengths_host.data_handle(), + total_feature_lengths.data_handle(), + total_feature_lengths.size(), + stream); + ValueType avg_feat_length = ValueType(total_feature_lengths_host(0)) / n_cols; + create_mapped_vector( + handle, preserved_columns.view(), col_keys.view(), col_counts.view(), feat_lengths, n_cols); + return avg_feat_length; +} + +/** + * @brief Gather per feature mean values and id counts, returns the cumulative avg feature length. + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param rows: Input COO rows array + * @param columns: Input COO columns array + * @param values: Input COO values array + * @param feat_lengths: Output array that stores mean per feature value + * @param id_counts: Output array that stores id(row) counts for nz values + * @param n_rows: Number of rows in matrix + * @param n_cols: Number of columns in matrix + */ +template +float sparse_search_preprocess(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + raft::device_vector_view feat_lengths, + raft::device_vector_view id_counts, + IndexType n_rows, + IndexType n_cols) +{ + auto avg_feature_len = get_feature_data(handle, rows, columns, values, feat_lengths, n_cols); + + get_id_counts(handle, rows, columns, values, id_counts, n_rows); + + return avg_feature_len; +} + +/** + * @brief Use TFIDF algorithm to encode features in COO sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param rows: Input COO rows array + * @param columns: Input COO columns array + * @param values: Input COO values array + * @param values_out: Output COO values array + * @param n_rows: Number of rows in matrix + * @param n_cols: Number of columns in matrix + */ +template +void base_encode_tfidf(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + raft::device_vector_view values_out, + IndexType n_rows, + IndexType n_cols) +{ + auto feat_lengths = raft::make_device_vector(handle, values.size()); + auto id_counts = raft::make_device_vector(handle, values.size()); + auto col_counts = raft::make_device_vector(handle, n_cols); + auto avg_feat_length = sparse_search_preprocess( + handle, rows, columns, values, feat_lengths.view(), id_counts.view(), n_rows, n_cols); + + raft::linalg::map(handle, + values_out, + tfidf(n_cols), + raft::make_const_mdspan(values), + raft::make_const_mdspan(id_counts.view()), + raft::make_const_mdspan(feat_lengths.view())); +} + +/** + * @brief Use TFIDF algorithm to encode features in COO sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param coo_in: Input COO matrix + * @param values_out: Output COO values array + */ +template +void encode_tfidf(raft::resources& handle, + raft::device_coo_matrix_view coo_in, + raft::device_vector_view values_out) +{ + auto rows = raft::make_device_vector_view( + coo_in.structure_view().get_rows().data(), coo_in.structure_view().get_rows().size()); + auto columns = raft::make_device_vector_view( + coo_in.structure_view().get_cols().data(), coo_in.structure_view().get_cols().size()); + auto values = raft::make_device_vector_view(coo_in.get_elements().data(), + coo_in.get_elements().size()); + + base_encode_tfidf(handle, + rows, + columns, + values, + values_out, + coo_in.structure_view().get_n_rows(), + coo_in.structure_view().get_n_cols()); +} + +/** + * @brief Use TFIDF algorithm to encode features in CSR sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param csr_in: Input CSR matrix + * @param values_out: Output values array + */ +template +void encode_tfidf(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view values_out) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto indptr = raft::make_device_vector_view( + csr_in.structure_view().get_indptr().data(), csr_in.structure_view().get_indptr().size()); + auto indices = raft::make_device_vector_view( + csr_in.structure_view().get_indices().data(), csr_in.structure_view().get_indices().size()); + auto values = raft::make_device_vector_view(csr_in.get_elements().data(), + csr_in.get_elements().size()); + + auto rows = raft::make_device_vector(handle, values.size()); + + raft::sparse::convert::csr_to_coo(indptr.data_handle(), + csr_in.structure_view().get_n_rows(), + rows.data_handle(), + rows.size(), + stream); + + base_encode_tfidf(handle, + rows.view(), + indices, + values, + values_out, + csr_in.structure_view().get_n_rows(), + csr_in.structure_view().get_n_cols()); +} + +/** + * @brief Use BM25 algorithm to encode features in COO sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param rows: Input COO rows array + * @param columns: Input COO columns array + * @param values: Input COO values array + * @param values_out: Output COO values array + * @param n_rows: Number of rows in matrix + * @param n_cols: Number of columns in matrix + * @param k_param: K value to use for BM25 algorithm + * @param b_param: B value to use for BM25 algorithm + */ +template +void base_encode_bm25(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + raft::device_vector_view values_out, + IndexType n_rows, + IndexType n_cols, + ValueType k_param = 1.6f, + ValueType b_param = 0.75f) +{ + auto feat_lengths = raft::make_device_vector(handle, values.size()); + auto id_counts = raft::make_device_vector(handle, values.size()); + auto col_counts = raft::make_device_vector(handle, n_cols); + + auto avg_feat_length = sparse_search_preprocess( + handle, rows, columns, values, feat_lengths.view(), id_counts.view(), n_rows, n_cols); + + raft::linalg::map(handle, + values_out, + bm25(n_cols, avg_feat_length, k_param, b_param), + raft::make_const_mdspan(values), + raft::make_const_mdspan(id_counts.view()), + raft::make_const_mdspan(feat_lengths.view())); +} + +/** + * @brief Use BM25 algorithm to encode features in COO sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param coo_in: Input COO matrix + * @param values_out: Output values array + * @param k_param: K value to use for BM25 algorithm + * @param b_param: B value to use for BM25 algorithm + */ +template +void encode_bm25(raft::resources& handle, + raft::device_coo_matrix_view coo_in, + raft::device_vector_view values_out, + ValueType k_param = 1.6f, + ValueType b_param = 0.75f) +{ + auto rows = raft::make_device_vector_view( + coo_in.structure_view().get_rows().data(), coo_in.structure_view().get_rows().size()); + auto columns = raft::make_device_vector_view( + coo_in.structure_view().get_cols().data(), coo_in.structure_view().get_cols().size()); + auto values = raft::make_device_vector_view(coo_in.get_elements().data(), + coo_in.get_elements().size()); + + base_encode_bm25(handle, + rows, + columns, + values, + values_out, + coo_in.structure_view().get_n_rows(), + coo_in.structure_view().get_n_cols()); +} + +/** + * @brief Use BM25 algorithm to encode features in CSR sparse matrix + * @tparam IndexType: the type of the edge indexes in the matrix + * @tparam ValueType: the type of the values for edges + * @tparam IdxT: the type of the index values + * @param handle: raft resource handle + * @param csr_in: Input CSR matrix + * @param values_out: Output values array + * @param k_param: K value to use for BM25 algorithm + * @param b_param: B value to use for BM25 algorithm + */ +template +void encode_bm25(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view values_out, + ValueType k_param = 1.6f, + ValueType b_param = 0.75f) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto indptr = raft::make_device_vector_view( + csr_in.structure_view().get_indptr().data(), csr_in.structure_view().get_indptr().size()); + auto indices = raft::make_device_vector_view( + csr_in.structure_view().get_indices().data(), csr_in.structure_view().get_indices().size()); + auto values = raft::make_device_vector_view(csr_in.get_elements().data(), + csr_in.get_elements().size()); + + auto rows = raft::make_device_vector(handle, values.size()); + + raft::sparse::convert::csr_to_coo(indptr.data_handle(), + csr_in.structure_view().get_n_rows(), + rows.data_handle(), + rows.size(), + stream); + + base_encode_bm25(handle, + rows.view(), + indices, + values, + values_out, + csr_in.structure_view().get_n_rows(), + csr_in.structure_view().get_n_cols()); +} + +} // namespace raft::sparse::matrix::detail \ No newline at end of file diff --git a/cpp/include/raft/sparse/matrix/preprocessing.cuh b/cpp/include/raft/sparse/matrix/preprocessing.cuh new file mode 100644 index 0000000000..58335fb5c7 --- /dev/null +++ b/cpp/include/raft/sparse/matrix/preprocessing.cuh @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace raft::sparse::matrix { + +/** + * @brief Use BM25 algorithm to encode features in COO sparse matrix + * @tparam IndexType is the type of the edges index in the coo matrix + * @tparam ValueType is the type of the values array in the coo matrix + * @tparam IdxT is the type of the indices of arrays in matrix + * @param handle: raft resource handle + * @param coo_in: Input COO matrix + * @param values_out: Output values array + * @param k_param: K value to use for BM25 algorithm + * @param b_param: B value to use for BM25 algorithm + */ +template +void encode_bm25(raft::resources& handle, + raft::device_coo_matrix_view coo_in, + raft::device_vector_view values_out, + float k_param = 1.6f, + float b_param = 0.75) +{ + return matrix::detail::encode_bm25( + handle, coo_in, values_out, k_param, b_param); +} + +/** + * @brief Use BM25 algorithm to encode features in CSR sparse matrix + * @param handle: raft resource handle + * @tparam IndexType is the type of the edges index in the csr matrix + * @tparam ValueType is the type of the values array in the csr matrix + * @tparam IdxT is the type of the indices of arrays in matrix + * @param csr_in: Input CSR matrix + * @param values_out: Output values array + * @param k_param: K value to use for BM25 algorithm + * @param b_param: B value to use for BM25 algorithm + */ +template +void encode_bm25(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view values_out, + float k_param = 1.6f, + float b_param = 0.75) +{ + return matrix::detail::encode_bm25( + handle, csr_in, values_out, k_param, b_param); +} + +/** + * @brief Use TFIDF algorithm to encode features in COO sparse matrix + * @tparam IndexType is the type of the edges index in the coo matrix + * @tparam ValueType is the type of the values array in the coo matrix + * @tparam IdxT is the type of the indices of arrays in matrix + * @param handle: raft resource handle + * @param coo_in: Input COO matrix + * @param values_out: Output COO values array + */ +template +void encode_tfidf(raft::resources& handle, + raft::device_coo_matrix_view coo_in, + raft::device_vector_view values_out) +{ + return matrix::detail::encode_tfidf(handle, coo_in, values_out); +} + +/** + * @brief Use TFIDF algorithm to encode features in CSR sparse matrix + * @tparam IndexType is the type of the edges index in the csr matrix + * @tparam ValueType is the type of the values array in the csr matrix + * @tparam IdxT is the type of the indices of arrays in matrix + * @param handle: raft resource handle + * @param csr_in: Input CSR matrix + * @param values_out: Output values array + */ +template +void encode_tfidf(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view values_out) +{ + return matrix::detail::encode_tfidf(handle, csr_in, values_out); +} + +} // namespace raft::sparse::matrix diff --git a/cpp/include/raft/sparse/neighbors/brute_force.cuh b/cpp/include/raft/sparse/neighbors/brute_force.cuh index 47e00a012f..8e8f36c2c3 100644 --- a/cpp/include/raft/sparse/neighbors/brute_force.cuh +++ b/cpp/include/raft/sparse/neighbors/brute_force.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ namespace raft::sparse::neighbors::brute_force { /** * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors * using some distance implementation + * template parameter value_idx is the type of the Indptr and Indices arrays. + * template parameter value_t is the type of the Data array. * @param[in] idxIndptr csr indptr of the index matrix (size n_idx_rows + 1) * @param[in] idxIndices csr column indices array of the index matrix (size n_idx_nnz) * @param[in] idxData csr data array of the index matrix (size idxNNZ) diff --git a/cpp/include/raft/sparse/neighbors/knn.cuh b/cpp/include/raft/sparse/neighbors/knn.cuh index 2cf68818aa..ec278e12e4 100644 --- a/cpp/include/raft/sparse/neighbors/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/knn.cuh @@ -30,8 +30,11 @@ " Please use the sparse/spatial version instead.") #endif +#include #include +#include #include +#include namespace raft::sparse::neighbors { @@ -59,7 +62,7 @@ namespace raft::sparse::neighbors { * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) */ -template +template void brute_force_knn(const value_idx* idxIndptr, const value_idx* idxIndices, const value_t* idxData, @@ -103,4 +106,183 @@ void brute_force_knn(const value_idx* idxIndptr, metricArg); } +/** + * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors + * using some distance implementation + * @tparam value_idx is the type of the edges index in the csr matrix + * @tparam value_t is the type of the values array in the csr matrix + * @param[in] csr_idx index csr matrix + * @param[in] csr_query query csr matrix + * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) + * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) + * @param[in] k the number of neighbors to query + * @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to + * @param[in] batch_size_index maximum number of rows to use from index matrix per batch + * @param[in] batch_size_query maximum number of rows to use from query matrix per batch + * @param[in] metric distance metric/measure to use + * @param[in] metricArg potential argument for metric (currently unused) + */ +template +void brute_force_knn(raft::device_csr_matrix csr_idx, + raft::device_csr_matrix csr_query, + device_vector_view output_indices, + device_vector_view output_dists, + int k, + raft::resources const& handle, + size_t batch_size_index = 2 << 14, // approx 1M + size_t batch_size_query = 2 << 14, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + float metricArg = 0) +{ + auto idxIndptr = csr_idx.structure_view().get_indptr(); + auto idxIndices = csr_idx.structure_view().get_indices(); + auto idxData = csr_idx.view().get_elements(); + + RAFT_EXPECTS(idxData.size() > 0, "No Values were detected in the Index CSR Matrix."); + + auto queryIndptr = csr_query.structure_view().get_indptr(); + auto queryIndices = csr_query.structure_view().get_indices(); + auto queryData = csr_query.view().get_elements(); + + RAFT_EXPECTS(queryData.size() > 0, "No Values were detected in the Query CSR Matrix."); + + brute_force::knn(idxIndptr.data(), + idxIndices.data(), + idxData.data(), + idxIndices.size(), + idxIndptr.size() - 1, + csr_idx.structure_view().get_n_cols(), + queryIndptr.data(), + queryIndices.data(), + queryData.data(), + queryIndices.size(), + queryIndptr.size() - 1, + csr_query.structure_view().get_n_cols(), + output_indices.data_handle(), + output_dists.data_handle(), + k, + handle, + batch_size_index, + batch_size_query, + metric, + metricArg); +} + +/** + * Search the sparse kNN for the k-nearest neighbors of a set of sparse query vectors + * using some distance implementation + * @tparam value_idx is the type of the edges index in the coo matrix + * @tparam value_t is the type of the values array in the coo matrix + * @param[in] coo_idx index coo matrix + * @param[in] coo_query query coo matrix + * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) + * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) + * @param[in] k the number of neighbors to query + * @param[in] handle CUDA resource::get_cuda_stream(handle) to order operations with respect to + * @param[in] batch_size_index maximum number of rows to use from index matrix per batch + * @param[in] batch_size_query maximum number of rows to use from query matrix per batch + * @param[in] metric distance metric/measure to use + * @param[in] metricArg potential argument for metric (currently unused) + */ +template +void brute_force_knn(raft::device_coo_matrix coo_idx, + raft::device_coo_matrix coo_query, + device_vector_view output_indices, + device_vector_view output_dists, + int k, + raft::resources const& handle, + size_t batch_size_index = 2 << 14, // approx 1M + size_t batch_size_query = 2 << 14, + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, + float metricArg = 0) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto idxRows = coo_idx.structure_view().get_rows(); + auto idxCols = coo_idx.structure_view().get_cols(); + auto idxData = coo_idx.view().get_elements(); + + RAFT_EXPECTS(idxData.size() > 0, "No Values were detected in the Index COO Matrix."); + + auto queryRows = coo_query.structure_view().get_rows(); + auto queryCols = coo_query.structure_view().get_cols(); + auto queryData = coo_query.view().get_elements(); + + RAFT_EXPECTS(queryData.size() > 0, "No Values were detected in the Query COO Matrix."); + + raft::sparse::op::coo_sort(int(idxRows.size()), + int(idxCols.size()), + int(idxData.size()), + idxRows.data(), + idxCols.data(), + idxRows.data(), + stream); + + raft::sparse::op::coo_sort(int(queryRows.size()), + int(queryCols.size()), + int(queryData.size()), + queryRows.data(), + queryCols.data(), + queryData.data(), + stream); + // + 1 is to account for the 0 at the beginning of the csr representation + auto idxRowsCsr = raft::make_device_vector( + handle, coo_query.structure_view().get_n_rows() + 1); + auto queryRowsCsr = raft::make_device_vector( + handle, coo_query.structure_view().get_n_rows() + 1); + + raft::sparse::convert::sorted_coo_to_csr(idxRows.data(), + int(idxRows.size()), + idxRowsCsr.data_handle(), + coo_idx.structure_view().get_n_rows() + 1, + stream); + + raft::sparse::convert::sorted_coo_to_csr(queryRows.data(), + int(queryRows.size()), + queryRowsCsr.data_handle(), + coo_query.structure_view().get_n_rows() + 1, + stream); + + brute_force::knn(idxRowsCsr.data_handle(), + idxCols.data(), + idxData.data(), + idxCols.size(), + idxRowsCsr.size() - 1, + coo_idx.structure_view().get_n_cols(), + queryRowsCsr.data_handle(), + queryCols.data(), + queryData.data(), + queryCols.size(), + queryRowsCsr.size() - 1, + coo_query.structure_view().get_n_cols(), + output_indices.data_handle(), + output_dists.data_handle(), + k, + handle, + batch_size_index, + batch_size_query, + metric, + metricArg); +} + }; // namespace raft::sparse::neighbors diff --git a/cpp/template/build.sh b/cpp/template/build.sh new file mode 100755 index 0000000000..d7e011e366 --- /dev/null +++ b/cpp/template/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Copyright (c) 2023-2024, NVIDIA CORPORATION. + +# raft empty project template build script + +# Abort script on first error +set -e + +PARALLEL_LEVEL=${PARALLEL_LEVEL:=`nproc`} + +BUILD_TYPE=Release +BUILD_DIR=build/ + +RAFT_REPO_REL="" +EXTRA_CMAKE_ARGS="" +set -e + + +if [[ ${RAFT_REPO_REL} != "" ]]; then + RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`" + EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}" +fi + +if [ "$1" == "clean" ]; then + rm -rf build + exit 0 +fi + +mkdir -p $BUILD_DIR +cd $BUILD_DIR + +cmake \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DRAFT_NVTX=OFF \ + -DCMAKE_CUDA_ARCHITECTURES="RAPIDS" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + ${EXTRA_CMAKE_ARGS} \ + ../ + +cmake --build . -j${PARALLEL_LEVEL} diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 621ee6c160..f9c1a95685 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -259,6 +259,8 @@ if(BUILD_TESTS) sparse/spgemmi.cu sparse/spmm.cu sparse/symmetrize.cu + sparse/preprocess_csr.cu + sparse/preprocess_coo.cu ) ConfigureTest( diff --git a/cpp/test/preprocess_utils.cu b/cpp/test/preprocess_utils.cu new file mode 100644 index 0000000000..a240e2e4e2 --- /dev/null +++ b/cpp/test/preprocess_utils.cu @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::util { + +template +struct check_zeroes { + float __device__ operator()(const T1& value, const T2& idx) + { + if (value == 0) { + return 0.f; + } else { + return 1.f; + } + } +}; + +template +void preproc(raft::resources& handle, + raft::device_vector_view dense_values, + raft::device_vector_view results, + int num_rows, + int num_cols, + bool tf_idf) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto host_dense_vals = raft::make_host_vector(handle, dense_values.size()); + raft::copy( + host_dense_vals.data_handle(), dense_values.data_handle(), dense_values.size(), stream); + + auto host_matrix = + raft::make_host_matrix_view(host_dense_vals.data_handle(), num_rows, num_cols); + auto device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); + + raft::copy(device_matrix.data_handle(), host_matrix.data_handle(), host_matrix.size(), stream); + + auto output_cols_lengths = raft::make_device_matrix(handle, 1, num_cols); + raft::linalg::reduce(output_cols_lengths.data_handle(), + device_matrix.data_handle(), + num_rows, + num_cols, + 0.0f, + false, + true, + stream); + auto h_output_cols_lengths = raft::make_host_matrix(handle, 1, num_cols); + raft::copy(h_output_cols_lengths.data_handle(), + output_cols_lengths.data_handle(), + output_cols_lengths.size(), + stream); + + auto output_cols_length_sum = raft::make_device_scalar(handle, 0); + raft::linalg::mapReduce(output_cols_length_sum.data_handle(), + num_cols, + 0, + raft::identity_op(), + raft::add_op(), + stream, + output_cols_lengths.data_handle()); + auto h_output_cols_length_sum = raft::make_host_scalar(handle, 0); + raft::copy(h_output_cols_length_sum.data_handle(), + output_cols_length_sum.data_handle(), + output_cols_length_sum.size(), + stream); + + T2 avg_col_length = T2(h_output_cols_length_sum(0)) / num_cols; + + auto output_rows_freq = raft::make_device_matrix(handle, 1, num_rows); + raft::linalg::reduce(output_rows_freq.data_handle(), + device_matrix.data_handle(), + num_rows, + num_cols, + 0.0f, + false, + false, + stream); + + auto output_rows_cnt = raft::make_device_matrix(handle, 1, num_rows); + raft::linalg::reduce(output_rows_cnt.data_handle(), + device_matrix.data_handle(), + num_rows, + num_cols, + 0.0f, + false, + false, + stream, + false, + check_zeroes()); + auto h_output_rows_cnt = raft::make_host_matrix(handle, 1, num_rows); + raft::copy( + h_output_rows_cnt.data_handle(), output_rows_cnt.data_handle(), output_rows_cnt.size(), stream); + + auto out_device_matrix = raft::make_device_matrix(handle, num_rows, num_cols); + raft::matrix::fill(handle, out_device_matrix.view(), 0.0f); + auto out_host_matrix = raft::make_host_matrix(handle, num_rows, num_cols); + auto out_host_vector = raft::make_host_vector(handle, results.size()); + + float k1 = 1.6f; + float b = 0.75f; + int count = 0; + float result; + for (int row = 0; row < num_rows; row++) { + for (int col = 0; col < num_cols; col++) { + float val = host_matrix(row, col); + if (val == 0) { + out_host_matrix(row, col) = 0.0f; + } else { + float tf = float(val / h_output_cols_lengths(0, col)); + float idf = raft::log(num_cols / h_output_rows_cnt(0, row)); + if (tf_idf) { + result = tf * idf; + } else { + float bm25 = ((k1 + 1) * tf) / + (k1 * ((1 - b) + b * (h_output_cols_lengths(0, col) / avg_col_length)) + tf); + result = idf * bm25; + } + out_host_matrix(row, col) = result; + out_host_vector(count) = result; + count++; + } + } + } + + raft::copy(results.data_handle(), out_host_vector.data_handle(), out_host_vector.size(), stream); +} + +template +void calc_tfidf_bm25(raft::resources& handle, + raft::device_csr_matrix_view csr_in, + raft::device_vector_view results, + bool tf_idf = false) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + int num_rows = csr_in.structure_view().get_n_rows(); + int num_cols = csr_in.structure_view().get_n_cols(); + int rows_size = csr_in.structure_view().get_indptr().size(); + int cols_size = csr_in.structure_view().get_indices().size(); + int elements_size = csr_in.get_elements().size(); + + auto indptr = raft::make_device_vector_view( + csr_in.structure_view().get_indptr().data(), rows_size); + auto indices = raft::make_device_vector_view( + csr_in.structure_view().get_indices().data(), cols_size); + auto values = + raft::make_device_vector_view(csr_in.get_elements().data(), elements_size); + auto dense_values = raft::make_device_vector(handle, num_rows * num_cols); + + cusparseHandle_t cu_handle; + RAFT_CUSPARSE_TRY(cusparseCreate(&cu_handle)); + + raft::sparse::convert::csr_to_dense(cu_handle, + num_rows, + num_cols, + elements_size, + indptr.data_handle(), + indices.data_handle(), + values.data_handle(), + num_rows, + dense_values.data_handle(), + stream, + true); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + preproc(handle, dense_values.view(), results, num_rows, num_cols, tf_idf); +} + +template +void create_dataset(raft::resources& handle, + raft::device_vector_view rows, + raft::device_vector_view columns, + raft::device_vector_view values, + int max_term_occurence_doc = 5, + int num_rows_unique = 7, + int num_cols_unique = 7, + int seed = 12345) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::random::RngState rng(seed); + + auto d_out = raft::make_device_vector(handle, rows.size() * 2); + + int theta_guide = max(num_rows_unique, num_cols_unique); + auto theta = raft::make_device_vector(handle, theta_guide * 4); + + raft::random::uniform(handle, rng, theta.view(), 0.0f, 1.0f); + + raft::random::rmat_rectangular_gen(d_out.data_handle(), + rows.data_handle(), + columns.data_handle(), + theta.data_handle(), + num_rows_unique, + num_cols_unique, + int(values.size()), + stream, + rng); + + auto vals = raft::make_device_vector(handle, rows.size()); + raft::random::uniformInt(handle, rng, vals.view(), 1, max_term_occurence_doc); + raft::linalg::map(handle, values, raft::cast_op{}, raft::make_const_mdspan(vals.view())); +} + +}; // namespace raft::util \ No newline at end of file diff --git a/cpp/test/sparse/neighbors/brute_force_coo.cu b/cpp/test/sparse/neighbors/brute_force_coo.cu new file mode 100644 index 0000000000..f1ebd6b578 --- /dev/null +++ b/cpp/test/sparse/neighbors/brute_force_coo.cu @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { +namespace selection { + +using namespace raft; +using namespace raft::sparse; + +template +struct SparseKNNInputs { + value_idx n_cols; + + std::vector indptr_h; + std::vector indices_h; + std::vector data_h; + + std::vector out_dists_ref_h; + std::vector out_indices_ref_h; + + int k; + + int batch_size_index = 2; + int batch_size_query = 2; + + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SparseKNNInputs& dims) +{ + return os; +} + +template +class SparseKNNCOOTest : public ::testing::TestWithParam> { + public: + SparseKNNCOOTest() + : params(::testing::TestWithParam>::GetParam()), + indptr(0, resource::get_cuda_stream(handle)), + indices(0, resource::get_cuda_stream(handle)), + data(0, resource::get_cuda_stream(handle)), + out_indices(0, resource::get_cuda_stream(handle)), + out_dists(0, resource::get_cuda_stream(handle)), + out_indices_ref(0, resource::get_cuda_stream(handle)), + out_dists_ref(0, resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override + { + n_rows = params.indptr_h.size() - 1; + nnz = params.indices_h.size(); + k = params.k; + + auto out_indices_dev = raft::make_device_vector(handle, n_rows * k); + auto out_dists_dev = raft::make_device_vector(handle, n_rows * k); + + auto rows = raft::make_device_vector(handle, nnz); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + make_data(); + + raft::sparse::convert::csr_to_coo( + indptr.data(), int(indptr.size()), rows.data_handle(), nnz, stream); + + auto coo_struct_view = raft::make_device_coordinate_structure_view( + rows.data_handle(), indices.data(), n_rows, params.n_cols, int(data.size())); + auto c_matrix = raft::make_device_coo_matrix( + handle, coo_struct_view); + raft::update_device( + c_matrix.view().get_elements().data(), data.data(), data.size(), stream); + + raft::sparse::neighbors::brute_force_knn(c_matrix, + c_matrix, + out_indices_dev.view(), + out_dists_dev.view(), + k, + handle, + params.batch_size_index, + params.batch_size_query, + params.metric); + + raft::copy(out_indices.data(), out_indices_dev.data_handle(), out_indices_dev.size(), stream); + raft::copy(out_dists.data(), out_dists_dev.data_handle(), out_dists_dev.size(), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + } + + void compare() + { + ASSERT_TRUE(devArrMatch( + out_dists_ref.data(), out_dists.data(), n_rows * k, CompareApprox(1e-4))); + ASSERT_TRUE( + devArrMatch(out_indices_ref.data(), out_indices.data(), n_rows * k, Compare())); + } + + protected: + void make_data() + { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + auto stream = resource::get_cuda_stream(handle); + indptr.resize(indptr_h.size(), stream); + indices.resize(indices_h.size(), stream); + data.resize(data_h.size(), stream); + + update_device(indptr.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(indices.data(), indices_h.data(), indices_h.size(), stream); + update_device(data.data(), data_h.data(), data_h.size(), stream); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + std::vector out_indices_ref_h = params.out_indices_ref_h; + + out_indices_ref.resize(out_indices_ref_h.size(), stream); + out_dists_ref.resize(out_dists_ref_h.size(), stream); + + update_device( + out_indices_ref.data(), out_indices_ref_h.data(), out_indices_ref_h.size(), stream); + update_device(out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), stream); + + out_dists.resize(n_rows * k, stream); + out_indices.resize(n_rows * k, stream); + } + + raft::resources handle; + + int n_rows, nnz, k; + + // input data + rmm::device_uvector indptr, indices; + rmm::device_uvector data; + + // output data + rmm::device_uvector out_indices; + rmm::device_uvector out_dists; + + rmm::device_uvector out_indices_ref; + rmm::device_uvector out_dists_ref; + + SparseKNNInputs params; +}; + +const std::vector> inputs_i32_f = { + {9, // ncols + {0, 2, 4, 6, 8}, // indptr + {0, 4, 0, 3, 0, 2, 0, 8}, // indices + {0.0f, 1.0f, 5.0f, 6.0f, 5.0f, 6.0f, 0.0f, 1.0f}, // data + {0, 1.41421, 0, 7.87401, 0, 7.87401, 0, 1.41421}, // dists + {0, 3, 1, 0, 2, 0, 3, 0}, // inds + 2, + 2, + 2, + raft::distance::DistanceType::L2SqrtExpanded}}; +typedef SparseKNNCOOTest SparseKNNCOOTestF; +TEST_P(SparseKNNCOOTestF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseKNNCOOTest, SparseKNNCOOTestF, ::testing::ValuesIn(inputs_i32_f)); + +}; // end namespace selection +}; // end namespace sparse +}; // end namespace raft diff --git a/cpp/test/sparse/neighbors/brute_force_csr.cu b/cpp/test/sparse/neighbors/brute_force_csr.cu new file mode 100644 index 0000000000..dec1914e09 --- /dev/null +++ b/cpp/test/sparse/neighbors/brute_force_csr.cu @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../../test_utils.cuh" + +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { +namespace selection { + +using namespace raft; +using namespace raft::sparse; + +template +struct SparseKNNInputs { + value_idx n_cols; + + std::vector indptr_h; + std::vector indices_h; + std::vector data_h; + + std::vector out_dists_ref_h; + std::vector out_indices_ref_h; + + int k; + + int batch_size_index = 2; + int batch_size_query = 2; + + raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const SparseKNNInputs& dims) +{ + return os; +} + +template +class SparseKNNCSRTest : public ::testing::TestWithParam> { + public: + SparseKNNCSRTest() + : params(::testing::TestWithParam>::GetParam()), + indptr(0, resource::get_cuda_stream(handle)), + indices(0, resource::get_cuda_stream(handle)), + data(0, resource::get_cuda_stream(handle)), + out_indices(0, resource::get_cuda_stream(handle)), + out_dists(0, resource::get_cuda_stream(handle)), + out_indices_ref(0, resource::get_cuda_stream(handle)), + out_dists_ref(0, resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override + { + n_rows = params.indptr_h.size() - 1; + nnz = params.indices_h.size(); + k = params.k; + auto out_indices_dev = raft::make_device_vector(handle, n_rows * k); + auto out_dists_dev = raft::make_device_vector(handle, n_rows * k); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + make_data(); + auto csr_struct_view = raft::make_device_compressed_structure_view( + indptr.data(), indices.data(), n_rows, params.n_cols, int(data.size())); + auto c_matrix = raft::make_device_csr_matrix(handle, csr_struct_view); + + raft::update_device( + c_matrix.view().get_elements().data(), data.data(), data.size(), stream); + + raft::sparse::neighbors::brute_force_knn(c_matrix, + c_matrix, + out_indices_dev.view(), + out_dists_dev.view(), + k, + handle, + params.batch_size_index, + params.batch_size_query, + params.metric); + + raft::copy(out_indices.data(), out_indices_dev.data_handle(), out_indices_dev.size(), stream); + raft::copy(out_dists.data(), out_dists_dev.data_handle(), out_dists_dev.size(), stream); + std::cout << "finished copy" << std::endl; + + RAFT_CUDA_TRY(cudaStreamSynchronize(resource::get_cuda_stream(handle))); + } + + void compare() + { + ASSERT_TRUE(devArrMatch( + out_dists_ref.data(), out_dists.data(), n_rows * k, CompareApprox(1e-4))); + ASSERT_TRUE( + devArrMatch(out_indices_ref.data(), out_indices.data(), n_rows * k, Compare())); + } + + protected: + void make_data() + { + std::vector indptr_h = params.indptr_h; + std::vector indices_h = params.indices_h; + std::vector data_h = params.data_h; + + auto stream = resource::get_cuda_stream(handle); + indptr.resize(indptr_h.size(), stream); + indices.resize(indices_h.size(), stream); + data.resize(data_h.size(), stream); + + update_device(indptr.data(), indptr_h.data(), indptr_h.size(), stream); + update_device(indices.data(), indices_h.data(), indices_h.size(), stream); + update_device(data.data(), data_h.data(), data_h.size(), stream); + + std::vector out_dists_ref_h = params.out_dists_ref_h; + std::vector out_indices_ref_h = params.out_indices_ref_h; + + out_indices_ref.resize(out_indices_ref_h.size(), stream); + out_dists_ref.resize(out_dists_ref_h.size(), stream); + + update_device( + out_indices_ref.data(), out_indices_ref_h.data(), out_indices_ref_h.size(), stream); + update_device(out_dists_ref.data(), out_dists_ref_h.data(), out_dists_ref_h.size(), stream); + + out_dists.resize(n_rows * k, stream); + out_indices.resize(n_rows * k, stream); + } + + raft::resources handle; + + int n_rows, nnz, k; + + // input data + rmm::device_uvector indptr, indices; + rmm::device_uvector data; + + // output data + rmm::device_uvector out_indices; + rmm::device_uvector out_dists; + + rmm::device_uvector out_indices_ref; + rmm::device_uvector out_dists_ref; + + SparseKNNInputs params; +}; + +const std::vector> inputs_i32_f = { + {9, // ncols + {0, 2, 4, 6, 8}, // indptr + {0, 4, 0, 3, 0, 2, 0, 8}, // indices + {0.0f, 1.0f, 5.0f, 6.0f, 5.0f, 6.0f, 0.0f, 1.0f}, // data + {0, 1.41421, 0, 7.87401, 0, 7.87401, 0, 1.41421}, // dists + {0, 3, 1, 0, 2, 0, 3, 0}, // inds + 2, + 2, + 2, + raft::distance::DistanceType::L2SqrtExpanded}}; +typedef SparseKNNCSRTest SparseKNNCSRTestF; +TEST_P(SparseKNNCSRTestF, Result) { compare(); } +INSTANTIATE_TEST_CASE_P(SparseKNNCSRTest, SparseKNNCSRTestF, ::testing::ValuesIn(inputs_i32_f)); + +}; // end namespace selection +}; // end namespace sparse +}; // end namespace raft diff --git a/cpp/test/sparse/preprocess_coo.cu b/cpp/test/sparse/preprocess_coo.cu new file mode 100644 index 0000000000..c8ddfc920b --- /dev/null +++ b/cpp/test/sparse/preprocess_coo.cu @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../preprocess_utils.cu" +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace raft { +namespace sparse { + +template +struct SparsePreprocessInputs { + int n_rows; + int n_cols; + int nnz_edges; +}; + +template +class SparsePreprocessCoo + : public ::testing::TestWithParam> { + public: + SparsePreprocessCoo() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override {} + + void Run(bool bm25_on) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + int num_rows = pow(2, params.n_rows); + int num_cols = pow(2, params.n_cols); + + auto rows = raft::make_device_vector(handle, params.nnz_edges); + auto columns = raft::make_device_vector(handle, params.nnz_edges); + auto values = raft::make_device_vector(handle, params.nnz_edges); + + rmm::device_uvector rows_uvec(rows.size(), stream); + rmm::device_uvector cols_uvec(rows.size(), stream); + rmm::device_uvector vals_uvec(rows.size(), stream); + + raft::util::create_dataset( + handle, rows.view(), columns.view(), values.view(), 5, params.n_rows, params.n_cols); + + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + raft::copy(rows_uvec.data(), rows.data_handle(), rows.size(), stream); + raft::copy(cols_uvec.data(), columns.data_handle(), columns.size(), stream); + raft::copy(vals_uvec.data(), values.data_handle(), values.size(), stream); + + raft::sparse::COO coo(stream); + raft::sparse::op::max_duplicates(handle, + coo, + rows_uvec.data(), + cols_uvec.data(), + vals_uvec.data(), + params.nnz_edges, + num_rows, + num_cols); + + auto rows_csr = raft::make_device_vector(handle, num_rows + 1); + + raft::sparse::convert::sorted_coo_to_csr( + coo.rows(), coo.nnz, rows_csr.data_handle(), num_rows + 1, stream); + + auto csr_struct_view = raft::make_device_compressed_structure_view( + rows_csr.data_handle(), coo.cols(), num_rows, num_cols, coo.nnz); + + auto csr_matrix = + raft::make_device_csr_matrix(handle, csr_struct_view); + raft::update_device( + csr_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); + + auto coo_struct_view = raft::make_device_coordinate_structure_view( + coo.rows(), coo.cols(), num_rows, num_cols, int(coo.nnz)); + auto c_matrix = + raft::make_device_coo_matrix(handle, coo_struct_view); + raft::update_device(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); + + auto result = raft::make_device_vector(handle, coo.nnz); + + if (bm25_on) { + auto bm25_vals = raft::make_device_vector(handle, coo.nnz); + sparse::matrix::encode_bm25(handle, c_matrix.view(), result.view()); + raft::util::calc_tfidf_bm25(handle, csr_matrix.view(), bm25_vals.view()); + ASSERT_TRUE(raft::devArrMatch(bm25_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } else { + auto tfidf_vals = raft::make_device_vector(handle, coo.nnz); + sparse::matrix::encode_tfidf(handle, c_matrix.view(), result.view()); + raft::util::calc_tfidf_bm25( + handle, csr_matrix.view(), tfidf_vals.view(), true); + ASSERT_TRUE(raft::devArrMatch(tfidf_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + SparsePreprocessInputs params; +}; + +using SparsePreprocessTfidfCoo = SparsePreprocessCoo; +TEST_P(SparsePreprocessTfidfCoo, Result) { Run(false); } + +using SparsePreprocessBm25Coo = SparsePreprocessCoo; +TEST_P(SparsePreprocessBm25Coo, Result) { Run(true); } + +using SparsePreprocessTfidfCooBig = SparsePreprocessCoo; +TEST_P(SparsePreprocessTfidfCooBig, Result) { Run(false); } + +using SparsePreprocessBm25CooBig = SparsePreprocessCoo; +TEST_P(SparsePreprocessBm25CooBig, Result) { Run(true); } + +const std::vector> sparse_preprocess_inputs = { + { + 10, // n_rows_factor + 10, // n_cols_factor + 1000 // nnz_edges + }, +}; + +const std::vector> sparse_preprocess_inputs_big = { + { + 15, // n_rows_factor + 15, // n_cols_factor + 1000000 // nnz_edges + }, +}; + +INSTANTIATE_TEST_CASE_P(SparsePreprocessCoo, + SparsePreprocessTfidfCoo, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCoo, + SparsePreprocessBm25Coo, + ::testing::ValuesIn(sparse_preprocess_inputs)); + +INSTANTIATE_TEST_CASE_P(SparsePreprocessCoo, + SparsePreprocessTfidfCooBig, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCoo, + SparsePreprocessBm25CooBig, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); + +} // namespace sparse +} // namespace raft \ No newline at end of file diff --git a/cpp/test/sparse/preprocess_csr.cu b/cpp/test/sparse/preprocess_csr.cu new file mode 100644 index 0000000000..ff45051f67 --- /dev/null +++ b/cpp/test/sparse/preprocess_csr.cu @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../preprocess_utils.cu" +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace raft { +namespace sparse { + +template +struct SparsePreprocessInputs { + int n_rows; + int n_cols; + int nnz_edges; +}; + +template +class SparsePreprocessCSR + : public ::testing::TestWithParam> { + public: + SparsePreprocessCSR() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)) + { + } + + protected: + void SetUp() override {} + + void Run(bool bm25_on) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + int num_rows = pow(2, params.n_rows); + int num_cols = pow(2, params.n_cols); + + auto rows = raft::make_device_vector(handle, params.nnz_edges); + auto columns = raft::make_device_vector(handle, params.nnz_edges); + auto values = raft::make_device_vector(handle, params.nnz_edges); + + rmm::device_uvector rows_uvec(rows.size(), stream); + rmm::device_uvector cols_uvec(rows.size(), stream); + rmm::device_uvector vals_uvec(rows.size(), stream); + + raft::util::create_dataset( + handle, rows.view(), columns.view(), values.view(), 5, params.n_rows, params.n_cols); + + raft::sparse::op::coo_sort(int(rows.size()), + int(columns.size()), + int(values.size()), + rows.data_handle(), + columns.data_handle(), + values.data_handle(), + stream); + + raft::copy(rows_uvec.data(), rows.data_handle(), rows.size(), stream); + raft::copy(cols_uvec.data(), columns.data_handle(), columns.size(), stream); + raft::copy(vals_uvec.data(), values.data_handle(), values.size(), stream); + + raft::sparse::COO coo(stream); + raft::sparse::op::max_duplicates(handle, + coo, + rows_uvec.data(), + cols_uvec.data(), + vals_uvec.data(), + params.nnz_edges, + num_rows, + num_cols); + + auto rows_csr = raft::make_device_vector(handle, num_rows + 1); + + raft::sparse::convert::sorted_coo_to_csr( + coo.rows(), coo.nnz, rows_csr.data_handle(), num_rows + 1, stream); + auto csr_struct_view = raft::make_device_compressed_structure_view( + rows_csr.data_handle(), coo.cols(), num_rows, num_cols, coo.nnz); + auto c_matrix = + raft::make_device_csr_matrix(handle, csr_struct_view); + + raft::update_device(c_matrix.view().get_elements().data(), coo.vals(), coo.nnz, stream); + + auto result = raft::make_device_vector(handle, coo.nnz); + auto bm25_vals = raft::make_device_vector(handle, coo.nnz); + auto tfidf_vals = raft::make_device_vector(handle, coo.nnz); + + if (bm25_on) { + sparse::matrix::encode_bm25(handle, c_matrix.view(), result.view()); + raft::util::calc_tfidf_bm25(handle, c_matrix.view(), bm25_vals.view()); + ASSERT_TRUE(raft::devArrMatch(bm25_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } else { + sparse::matrix::encode_tfidf(handle, c_matrix.view(), result.view()); + raft::util::calc_tfidf_bm25(handle, c_matrix.view(), tfidf_vals.view(), true); + ASSERT_TRUE(raft::devArrMatch(tfidf_vals.data_handle(), + result.data_handle(), + result.size(), + raft::CompareApprox(2e-5), + stream)); + } + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + SparsePreprocessInputs params; +}; + +using SparsePreprocessTfidfCsr = SparsePreprocessCSR; +TEST_P(SparsePreprocessTfidfCsr, Result) { Run(false); } + +using SparsePreprocessBm25Csr = SparsePreprocessCSR; +TEST_P(SparsePreprocessBm25Csr, Result) { Run(true); } + +using SparsePreprocessTfidfCsrBig = SparsePreprocessCSR; +TEST_P(SparsePreprocessTfidfCsrBig, Result) { Run(false); } + +using SparsePreprocessBm25CsrBig = SparsePreprocessCSR; +TEST_P(SparsePreprocessBm25CsrBig, Result) { Run(true); } + +const std::vector> sparse_preprocess_inputs = { + { + 7, // n_rows_factor + 5, // n_cols_factor + 10 // num nnz values + }, +}; + +const std::vector> sparse_preprocess_inputs_big = { + { + 12, // n_rows_factor + 12, // n_cols_factor + 1000000 // nnz_edges - 6475 + }, +}; + +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessTfidfCsr, + ::testing::ValuesIn(sparse_preprocess_inputs)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBm25Csr, + ::testing::ValuesIn(sparse_preprocess_inputs)); + +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessTfidfCsrBig, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); +INSTANTIATE_TEST_CASE_P(SparsePreprocessCSR, + SparsePreprocessBm25CsrBig, + ::testing::ValuesIn(sparse_preprocess_inputs_big)); + +} // namespace sparse +} // namespace raft \ No newline at end of file