diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index 78846bc5766..8501eedce5c 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -464,51 +464,6 @@ k_truss_subgraph(raft::handle_t const& handle, size_t number_of_vertices, int k); -// FIXME: Internally distances is of int (signed 32-bit) data type, but current -// template uses data from VT, ET, WT from the legacy::GraphCSR View even if weights -// are not considered -/** - * @Synopsis Performs a breadth first search traversal of a graph starting from a vertex. - * - * @throws cugraph::logic_error with a custom message when an error occurs. - * - * @tparam VT Type of vertex identifiers. Supported value : int (signed, - * 32-bit) - * @tparam ET Type of edge identifiers. Supported value : int (signed, - * 32-bit) - * @tparam WT Type of edge weights. Supported values : int (signed, 32-bit) - * - * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, - the multi GPU version will be selected. - * @param[in] graph cuGraph graph descriptor, should contain the connectivity - * information as a CSR - * - * @param[out] distances If set to a valid pointer, this is populated by distance of - * every vertex in the graph from the starting vertex - * - * @param[out] predecessors If set to a valid pointer, this is populated by bfs traversal - * predecessor of every vertex - * - * @param[out] sp_counters If set to a valid pointer, this is populated by bfs traversal - * shortest_path counter of every vertex - * - * @param[in] start_vertex The starting vertex for breadth first search traversal - * - * @param[in] directed Treat the input graph as directed - * - * @param[in] mg_batch If set to true use SG BFS path when comms are initialized. - * - */ -template -void bfs(raft::handle_t const& handle, - legacy::GraphCSRView const& graph, - VT* distances, - VT* predecessors, - double* sp_counters, - const VT start_vertex, - bool directed = true, - bool mg_batch = false); - /** * @brief Compute Hungarian algorithm on a weighted bipartite graph * diff --git a/cpp/include/cugraph/utilities/device_comm.hpp b/cpp/include/cugraph/utilities/device_comm.hpp index 7087724921a..990074e781b 100644 --- a/cpp/include/cugraph/utilities/device_comm.hpp +++ b/cpp/include/cugraph/utilities/device_comm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -806,9 +806,6 @@ device_sendrecv(raft::comms::comms_t const& comm, size_t constexpr tuple_size = thrust::tuple_size::value_type>::value; - // FIXME: NCCL 2.7 supports only one ncclSend and one ncclRecv for a source rank and destination - // rank inside ncclGroupStart/ncclGroupEnd, so we cannot place this inside - // ncclGroupStart/ncclGroupEnd, this restriction will be lifted in NCCL 2.8 detail::device_sendrecv_tuple_iterator_element_impl::value_type>::value; - // FIXME: NCCL 2.7 supports only one ncclSend and one ncclRecv for a source rank and destination - // rank inside ncclGroupStart/ncclGroupEnd, so we cannot place this inside - // ncclGroupStart/ncclGroupEnd, this restriction will be lifted in NCCL 2.8 detail::device_multicast_sendrecv_tuple_iterator_element_impl std::enable_if_t::value, std::vector> host_scalar_allgather( raft::comms::comms_t const& comm, T input, cudaStream_t stream) { - std::vector rx_counts(comm.get_size(), size_t{1}); - std::vector displacements(rx_counts.size(), size_t{0}); - std::iota(displacements.begin(), displacements.end(), size_t{0}); - rmm::device_uvector d_outputs(rx_counts.size(), stream); + rmm::device_uvector d_outputs(comm.get_size(), stream); raft::update_device(d_outputs.data() + comm.get_rank(), &input, 1, stream); - // FIXME: better use allgather - comm.allgatherv(d_outputs.data() + comm.get_rank(), - d_outputs.data(), - rx_counts.data(), - displacements.data(), - stream); - std::vector h_outputs(rx_counts.size()); - raft::update_host(h_outputs.data(), d_outputs.data(), rx_counts.size(), stream); + comm.allgather(d_outputs.data() + comm.get_rank(), d_outputs.data(), size_t{1}, stream); + std::vector h_outputs(d_outputs.size()); + raft::update_host(h_outputs.data(), d_outputs.data(), d_outputs.size(), stream); auto status = comm.sync_stream(stream); CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); return h_outputs; @@ -277,11 +269,6 @@ std::enable_if_t::value, std::vector::value; - std::vector rx_counts(comm.get_size(), tuple_size); - std::vector displacements(rx_counts.size(), size_t{0}); - for (size_t i = 0; i < displacements.size(); ++i) { - displacements[i] = i * tuple_size; - } std::vector h_tuple_scalar_elements(tuple_size); rmm::device_uvector d_allgathered_tuple_scalar_elements(comm.get_size() * tuple_size, stream); @@ -292,12 +279,10 @@ host_scalar_allgather(raft::comms::comms_t const& comm, T input, cudaStream_t st h_tuple_scalar_elements.data(), tuple_size, stream); - // FIXME: better use allgather - comm.allgatherv(d_allgathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, - d_allgathered_tuple_scalar_elements.data(), - rx_counts.data(), - displacements.data(), - stream); + comm.allgather(d_allgathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, + d_allgathered_tuple_scalar_elements.data(), + tuple_size, + stream); std::vector h_allgathered_tuple_scalar_elements(comm.get_size() * tuple_size); raft::update_host(h_allgathered_tuple_scalar_elements.data(), d_allgathered_tuple_scalar_elements.data(), @@ -318,6 +303,71 @@ host_scalar_allgather(raft::comms::comms_t const& comm, T input, cudaStream_t st return ret; } +template +std::enable_if_t::value, T> host_scalar_scatter( + raft::comms::comms_t const& comm, + std::vector const& inputs, // relevant only in root + int root, + cudaStream_t stream) +{ + CUGRAPH_EXPECTS( + ((comm.get_rank() == root) && (inputs.size() == static_cast(comm.get_size()))) || + ((comm.get_rank() != root) && (inputs.size() == 0)), + "inputs.size() should match with comm.get_size() in root and should be 0 otherwise."); + rmm::device_uvector d_outputs(comm.get_size(), stream); + if (comm.get_rank() == root) { + raft::update_device(d_outputs.data(), inputs.data(), inputs.size(), stream); + } + comm.bcast(d_outputs.data(), d_outputs.size(), root, stream); + T h_output{}; + raft::update_host(&h_output, d_outputs.data() + comm.get_rank(), 1, stream); + auto status = comm.sync_stream(stream); + CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); + return h_output; +} + +template +std::enable_if_t::value, T> host_scalar_scatter( + raft::comms::comms_t const& comm, + std::vector const& inputs, // relevant only in root + int root, + cudaStream_t stream) +{ + CUGRAPH_EXPECTS( + ((comm.get_rank() == root) && (inputs.size() == static_cast(comm.get_size()))) || + ((comm.get_rank() != root) && (inputs.size() == 0)), + "inputs.size() should match with comm.get_size() in root and should be 0 otherwise."); + size_t constexpr tuple_size = thrust::tuple_size::value; + rmm::device_uvector d_scatter_tuple_scalar_elements(comm.get_size() * tuple_size, + stream); + if (comm.get_rank() == root) { + for (int i = 0; i < comm.get_size(); ++i) { + std::vector h_tuple_scalar_elements(tuple_size); + detail::update_vector_of_tuple_scalar_elements_from_tuple_impl() + .update(h_tuple_scalar_elements, inputs[i]); + raft::update_device(d_scatter_tuple_scalar_elements.data() + i * tuple_size, + h_tuple_scalar_elements.data(), + tuple_size, + stream); + } + } + comm.bcast( + d_scatter_tuple_scalar_elements.data(), d_scatter_tuple_scalar_elements.size(), root, stream); + std::vector h_tuple_scalar_elements(tuple_size); + raft::update_host(h_tuple_scalar_elements.data(), + d_scatter_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, + tuple_size, + stream); + auto status = comm.sync_stream(stream); + CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); + + T ret{}; + detail::update_tuple_from_vector_of_tuple_scalar_elements_impl().update( + ret, h_tuple_scalar_elements); + + return ret; +} + // Return value is valid only in root (return value may better be std::optional in C++17 or later) template std::enable_if_t::value, std::vector> host_scalar_gather( diff --git a/cpp/include/cugraph/utilities/shuffle_comm.cuh b/cpp/include/cugraph/utilities/shuffle_comm.cuh index 6a260144324..ab6a54cc1c0 100644 --- a/cpp/include/cugraph/utilities/shuffle_comm.cuh +++ b/cpp/include/cugraph/utilities/shuffle_comm.cuh @@ -80,7 +80,6 @@ compute_tx_rx_counts_offsets_ranks(raft::comms::comms_t const& comm, rmm::device_uvector d_rx_value_counts(comm_size, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released. std::vector tx_counts(comm_size, size_t{1}); std::vector tx_offsets(comm_size); std::iota(tx_offsets.begin(), tx_offsets.end(), size_t{0}); @@ -835,7 +834,6 @@ auto shuffle_values(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_offsets.size() > 0 ? rx_offsets.back() + rx_counts.back() : size_t{0}, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, @@ -889,7 +887,6 @@ auto groupby_gpu_id_and_shuffle_values(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_offsets.size() > 0 ? rx_offsets.back() + rx_counts.back() : size_t{0}, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, @@ -946,7 +943,6 @@ auto groupby_gpu_id_and_shuffle_kv_pairs(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_keys.size(), stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_key_first, @@ -959,7 +955,6 @@ auto groupby_gpu_id_and_shuffle_kv_pairs(raft::comms::comms_t const& comm, rx_src_ranks, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, diff --git a/cpp/src/centrality/katz_centrality_impl.cuh b/cpp/src/centrality/katz_centrality_impl.cuh index 202d00a5771..ac31043d862 100644 --- a/cpp/src/centrality/katz_centrality_impl.cuh +++ b/cpp/src/centrality/katz_centrality_impl.cuh @@ -74,8 +74,6 @@ void katz_centrality( CUGRAPH_EXPECTS(epsilon >= 0.0, "Invalid input argument: epsilon should be non-negative."); if (do_expensive_check) { - // FIXME: should I check for betas? - if (has_initial_guess) { auto num_negative_values = count_if_v(handle, pull_graph_view, katz_centralities, [] __device__(auto, auto val) { diff --git a/cpp/src/community/detail/common_methods.cuh b/cpp/src/community/detail/common_methods.cuh index b388ba53e81..f67d4d939ad 100644 --- a/cpp/src/community/detail/common_methods.cuh +++ b/cpp/src/community/detail/common_methods.cuh @@ -52,7 +52,7 @@ struct is_bitwise_comparable> : std::true_type {}; namespace cugraph { namespace detail { -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct key_aggregated_edge_op_t { weight_t total_edge_weight{}; @@ -80,7 +80,7 @@ struct key_aggregated_edge_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct reduce_op_t { using type = thrust::tuple; @@ -100,7 +100,28 @@ struct reduce_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +template +struct count_updown_moves_op_t { + bool up_down{}; + __device__ auto operator()(thrust::tuple> p) const + { + vertex_t old_cluster = thrust::get<0>(p); + auto new_cluster_gain_pair = thrust::get<1>(p); + vertex_t new_cluster = thrust::get<0>(new_cluster_gain_pair); + weight_t delta_modularity = thrust::get<1>(new_cluster_gain_pair); + + auto result_assignment = + (delta_modularity > weight_t{0}) + ? (((new_cluster > old_cluster) != up_down) ? old_cluster : new_cluster) + : old_cluster; + + return (delta_modularity > weight_t{0}) + ? (((new_cluster > old_cluster) != up_down) ? false : true) + : false; + } +}; +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct cluster_update_op_t { bool up_down{}; @@ -115,7 +136,7 @@ struct cluster_update_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct return_edge_weight_t { __device__ auto operator()( @@ -125,7 +146,7 @@ struct return_edge_weight_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct return_one_t { __device__ auto operator()( @@ -394,6 +415,21 @@ rmm::device_uvector update_clustering_by_delta_modularity( detail::reduce_op_t{}, cugraph::get_dataframe_buffer_begin(output_buffer)); + int nr_moves = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_zip_iterator(thrust::make_tuple( + next_clusters_v.begin(), cugraph::get_dataframe_buffer_begin(output_buffer))), + thrust::make_zip_iterator( + thrust::make_tuple(next_clusters_v.end(), cugraph::get_dataframe_buffer_end(output_buffer))), + detail::count_updown_moves_op_t{up_down}); + + if (multi_gpu) { + nr_moves = host_scalar_allreduce( + handle.get_comms(), nr_moves, raft::comms::op_t::SUM, handle.get_stream()); + } + + if (nr_moves == 0) { up_down = !up_down; } + thrust::transform(handle.get_thrust_policy(), next_clusters_v.begin(), next_clusters_v.end(), diff --git a/cpp/src/community/detail/refine_impl.cuh b/cpp/src/community/detail/refine_impl.cuh index 6b6470991bb..ebaae498d04 100644 --- a/cpp/src/community/detail/refine_impl.cuh +++ b/cpp/src/community/detail/refine_impl.cuh @@ -89,8 +89,9 @@ struct leiden_key_aggregated_edge_op_t { // E(Cr, S-Cr) > ||Cr||*(||S|| -||Cr||) bool is_dst_leiden_cluster_well_connected = - dst_leiden_cut_to_louvain > - resolution * dst_leiden_volume * (louvain_cluster_volume - dst_leiden_volume); + dst_leiden_cut_to_louvain > resolution * dst_leiden_volume * + (louvain_cluster_volume - dst_leiden_volume) / + total_edge_weight; // E(v, Cr-v) - ||v||* ||Cr-v||/||V(G)|| // aggregated_weight_to_neighboring_leiden_cluster == E(v, Cr-v)? @@ -98,11 +99,11 @@ struct leiden_key_aggregated_edge_op_t { weight_t mod_gain = -1.0; if (is_src_active > 0) { if ((louvain_of_dst_leiden_cluster == src_louvain_cluster) && - is_dst_leiden_cluster_well_connected) { + (dst_leiden_cluster_id != src_leiden_cluster) && is_dst_leiden_cluster_well_connected) { mod_gain = aggregated_weight_to_neighboring_leiden_cluster - - resolution * src_weighted_deg * (dst_leiden_volume - src_weighted_deg) / - total_edge_weight; - + resolution * src_weighted_deg * dst_leiden_volume / total_edge_weight; +// FIXME: Disable random moves in refinement phase for now. +#if 0 weight_t random_number{0.0}; if (mod_gain > 0.0) { auto flat_id = uint64_t{threadIdx.x + blockIdx.x * blockDim.x}; @@ -117,6 +118,8 @@ struct leiden_key_aggregated_edge_op_t { ? __expf(static_cast((2.0 * mod_gain) / (theta * total_edge_weight))) * random_number : -1.0; +#endif + mod_gain = mod_gain > 0.0 ? mod_gain : -1.0; } } @@ -240,11 +243,12 @@ refine_clustering( wcut_deg_and_cluster_vol_triple_begin, wcut_deg_and_cluster_vol_triple_end, singleton_and_connected_flags.begin(), - [resolution] __device__(auto wcut_wdeg_and_louvain_volume) { + [resolution, total_edge_weight] __device__(auto wcut_wdeg_and_louvain_volume) { auto wcut = thrust::get<0>(wcut_wdeg_and_louvain_volume); auto wdeg = thrust::get<1>(wcut_wdeg_and_louvain_volume); auto louvain_volume = thrust::get<2>(wcut_wdeg_and_louvain_volume); - return wcut > (resolution * wdeg * (louvain_volume - wdeg)); + return wcut > + (resolution * wdeg * (louvain_volume - wdeg) / total_edge_weight); }); edge_src_property_t src_louvain_cluster_weight_cache(handle); @@ -478,7 +482,7 @@ refine_clustering( auto values_for_leiden_cluster_keys = thrust::make_zip_iterator( thrust::make_tuple(refined_community_volumes.begin(), refined_community_cuts.begin(), - leiden_keys_used_in_edge_reduction.begin(), // redundant + leiden_keys_used_in_edge_reduction.begin(), louvain_of_leiden_keys_used_in_edge_reduction.begin())); using value_t = thrust::tuple; diff --git a/cpp/src/community/flatten_dendrogram.hpp b/cpp/src/community/flatten_dendrogram.hpp index 9a0c103c01f..eac20389765 100644 --- a/cpp/src/community/flatten_dendrogram.hpp +++ b/cpp/src/community/flatten_dendrogram.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,4 +59,31 @@ void partition_at_level(raft::handle_t const& handle, }); } +template +void leiden_partition_at_level(raft::handle_t const& handle, + Dendrogram const& dendrogram, + vertex_t* d_partition, + size_t level) +{ + vertex_t local_num_verts = dendrogram.get_level_size_nocheck(0); + raft::copy( + d_partition, dendrogram.get_level_ptr_nocheck(0), local_num_verts, handle.get_stream()); + + rmm::device_uvector local_vertex_ids_v(local_num_verts, handle.get_stream()); + + std::for_each( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator((level - 1) / 2), + [&handle, &dendrogram, &local_vertex_ids_v, &d_partition, local_num_verts](size_t l) { + cugraph::relabel( + handle, + std::tuple(dendrogram.get_level_ptr_nocheck(2 * l + 1), + dendrogram.get_level_ptr_nocheck(2 * l + 2)), + dendrogram.get_level_size_nocheck(2 * l + 1), + d_partition, + local_num_verts, + false); + }); +} + } // namespace cugraph diff --git a/cpp/src/community/leiden_impl.cuh b/cpp/src/community/leiden_impl.cuh index a9faf2f2d82..b6e20272de9 100644 --- a/cpp/src/community/leiden_impl.cuh +++ b/cpp/src/community/leiden_impl.cuh @@ -43,6 +43,34 @@ void check_clustering(graph_view_t const& gr if (graph_view.local_vertex_partition_range_size() > 0) CUGRAPH_EXPECTS(clustering != nullptr, "Invalid input argument: clustering is null"); } +template +vertex_t remove_duplicates(raft::handle_t const& handle, rmm::device_uvector& input_array) +{ + thrust::sort(handle.get_thrust_policy(), input_array.begin(), input_array.end()); + + auto nr_unique_elements = static_cast(thrust::distance( + input_array.begin(), + thrust::unique(handle.get_thrust_policy(), input_array.begin(), input_array.end()))); + + input_array.resize(nr_unique_elements, handle.get_stream()); + + if constexpr (multi_gpu) { + input_array = cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( + handle, std::move(input_array)); + + thrust::sort(handle.get_thrust_policy(), input_array.begin(), input_array.end()); + + nr_unique_elements = static_cast(thrust::distance( + input_array.begin(), + thrust::unique(handle.get_thrust_policy(), input_array.begin(), input_array.end()))); + + input_array.resize(nr_unique_elements, handle.get_stream()); + + nr_unique_elements = host_scalar_allreduce( + handle.get_comms(), nr_unique_elements, raft::comms::op_t::SUM, handle.get_stream()); + } + return nr_unique_elements; +} template >, weight_t> leiden( rmm::device_uvector louvain_of_refined_graph(0, handle.get_stream()); // #V - while (dendrogram->num_levels() < max_level) { + while (dendrogram->num_levels() < 2 * max_level + 1) { // // Initialize every cluster to reference each vertex to itself // @@ -353,40 +381,8 @@ std::pair>, weight_t> leiden( dendrogram->current_level_begin(), dendrogram->current_level_begin() + dendrogram->current_level_size(), copied_louvain_partition.begin()); - - thrust::sort( - handle.get_thrust_policy(), copied_louvain_partition.begin(), copied_louvain_partition.end()); - auto nr_unique_louvain_clusters = - static_cast(thrust::distance(copied_louvain_partition.begin(), - thrust::unique(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()))); - - copied_louvain_partition.resize(nr_unique_louvain_clusters, handle.get_stream()); - - if constexpr (graph_view_t::is_multi_gpu) { - copied_louvain_partition = - cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( - handle, std::move(copied_louvain_partition)); - - thrust::sort(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()); - - nr_unique_louvain_clusters = - static_cast(thrust::distance(copied_louvain_partition.begin(), - thrust::unique(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()))); - - copied_louvain_partition.resize(nr_unique_louvain_clusters, handle.get_stream()); - - nr_unique_louvain_clusters = host_scalar_allreduce(handle.get_comms(), - nr_unique_louvain_clusters, - raft::comms::op_t::SUM, - handle.get_stream()); - } + remove_duplicates(handle, copied_louvain_partition); terminate = terminate || (nr_unique_louvain_clusters == current_graph_view.number_of_vertices()); @@ -481,6 +477,15 @@ std::pair>, weight_t> leiden( (*cluster_assignment).data(), (*cluster_assignment).size(), false); + // louvain assignment of aggregated graph which is necessary to flatten dendrogram + dendrogram->add_level(current_graph_view.local_vertex_partition_range_first(), + current_graph_view.local_vertex_partition_range_size(), + handle.get_stream()); + + raft::copy(dendrogram->current_level_begin(), + (*cluster_assignment).begin(), + (*cluster_assignment).size(), + handle.get_stream()); louvain_of_refined_graph.resize(current_graph_view.local_vertex_partition_range_size(), handle.get_stream()); @@ -492,47 +497,6 @@ std::pair>, weight_t> leiden( } } - // Relabel dendrogram - vertex_t local_cluster_id_first{0}; - if constexpr (multi_gpu) { - auto unique_cluster_range_lasts = cugraph::partition_manager::compute_partition_range_lasts( - handle, static_cast(copied_louvain_partition.size())); - - auto& comm = handle.get_comms(); - auto const comm_size = comm.get_size(); - auto const comm_rank = comm.get_rank(); - auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); - auto const major_comm_size = major_comm.get_size(); - auto const major_comm_rank = major_comm.get_rank(); - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_size = minor_comm.get_size(); - auto const minor_comm_rank = minor_comm.get_rank(); - - auto vertex_partition_id = - partition_manager::compute_vertex_partition_id_from_graph_subcomm_ranks( - major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank); - - local_cluster_id_first = vertex_partition_id == 0 - ? vertex_t{0} - : unique_cluster_range_lasts[vertex_partition_id - 1]; - } - - rmm::device_uvector numbering_indices(copied_louvain_partition.size(), - handle.get_stream()); - detail::sequence_fill(handle.get_stream(), - numbering_indices.data(), - numbering_indices.size(), - local_cluster_id_first); - - relabel( - handle, - std::make_tuple(static_cast(copied_louvain_partition.begin()), - static_cast(numbering_indices.begin())), - copied_louvain_partition.size(), - dendrogram->current_level_begin(), - dendrogram->current_level_size(), - false); - copied_louvain_partition.resize(0, handle.get_stream()); copied_louvain_partition.shrink_to_fit(handle.get_stream()); @@ -550,23 +514,71 @@ std::pair>, weight_t> leiden( return std::make_pair(std::move(dendrogram), best_modularity); } -// FIXME: Can we have a common flatten_dendrogram to be used by both -// Louvain and Leiden, and possibly other clustering methods? +template +void relabel_cluster_ids(raft::handle_t const& handle, + rmm::device_uvector& unique_cluster_ids, + vertex_t* clustering, + size_t num_nodes) +{ + vertex_t local_cluster_id_first{0}; + if constexpr (multi_gpu) { + auto unique_cluster_range_lasts = cugraph::partition_manager::compute_partition_range_lasts( + handle, static_cast(unique_cluster_ids.size())); + + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto const major_comm_rank = major_comm.get_rank(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + auto const minor_comm_rank = minor_comm.get_rank(); + + auto vertex_partition_id = + partition_manager::compute_vertex_partition_id_from_graph_subcomm_ranks( + major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank); + + local_cluster_id_first = + vertex_partition_id == 0 ? vertex_t{0} : unique_cluster_range_lasts[vertex_partition_id - 1]; + } + + rmm::device_uvector numbering_indices(unique_cluster_ids.size(), handle.get_stream()); + detail::sequence_fill(handle.get_stream(), + numbering_indices.data(), + numbering_indices.size(), + local_cluster_id_first); + + relabel( + handle, + std::make_tuple(static_cast(unique_cluster_ids.begin()), + static_cast(numbering_indices.begin())), + unique_cluster_ids.size(), + clustering, + num_nodes, + false); +} + template -void flatten_dendrogram(raft::handle_t const& handle, - graph_view_t const& graph_view, - Dendrogram const& dendrogram, - vertex_t* clustering) +void flatten_leiden_dendrogram(raft::handle_t const& handle, + graph_view_t const& graph_view, + Dendrogram const& dendrogram, + vertex_t* clustering) { - rmm::device_uvector vertex_ids_v(graph_view.number_of_vertices(), handle.get_stream()); + leiden_partition_at_level( + handle, dendrogram, clustering, dendrogram.num_levels()); + + rmm::device_uvector unique_cluster_ids(graph_view.number_of_vertices(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + clustering, + clustering + graph_view.number_of_vertices(), + unique_cluster_ids.begin()); - thrust::sequence(handle.get_thrust_policy(), - vertex_ids_v.begin(), - vertex_ids_v.end(), - graph_view.local_vertex_partition_range_first()); + remove_duplicates(handle, unique_cluster_ids); - partition_at_level( - handle, dendrogram, vertex_ids_v.data(), clustering, dendrogram.num_levels()); + relabel_cluster_ids( + handle, unique_cluster_ids, clustering, graph_view.number_of_vertices()); } } // namespace detail @@ -588,14 +600,14 @@ std::pair>, weight_t> leiden( } template -void flatten_dendrogram(raft::handle_t const& handle, - graph_view_t const& graph_view, - Dendrogram const& dendrogram, - vertex_t* clustering) +void flatten_leiden_dendrogram(raft::handle_t const& handle, + graph_view_t const& graph_view, + Dendrogram const& dendrogram, + vertex_t* clustering) { CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - detail::flatten_dendrogram(handle, graph_view, dendrogram, clustering); + detail::flatten_leiden_dendrogram(handle, graph_view, dendrogram, clustering); } template @@ -620,7 +632,7 @@ std::pair leiden( std::tie(dendrogram, modularity) = detail::leiden(handle, rng_state, graph_view, edge_weight_view, max_level, resolution, theta); - detail::flatten_dendrogram(handle, graph_view, *dendrogram, clustering); + detail::flatten_leiden_dendrogram(handle, graph_view, *dendrogram, clustering); return std::make_pair(dendrogram->num_levels(), modularity); } diff --git a/cpp/src/components/weakly_connected_components_impl.cuh b/cpp/src/components/weakly_connected_components_impl.cuh index 615a50ded54..b7b6e139cfa 100644 --- a/cpp/src/components/weakly_connected_components_impl.cuh +++ b/cpp/src/components/weakly_connected_components_impl.cuh @@ -236,18 +236,16 @@ struct v_op_t { auto tag = thrust::get<1>(tagged_v); auto v_offset = vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(thrust::get<0>(tagged_v)); - // FIXME: better switch to atomic_ref after - // https://github.com/nvidia/libcudacxx/milestone/2 - auto old = - atomicCAS(level_components + v_offset, invalid_component_id::value, tag); - if (old != invalid_component_id::value && old != tag) { // conflict + cuda::atomic_ref v_component(*(level_components + v_offset)); + auto old = invalid_component_id::value; + bool success = v_component.compare_exchange_strong(old, tag, cuda::std::memory_order_relaxed); + if (!success && (old != tag)) { // conflict return thrust::make_tuple(thrust::optional{bucket_idx_conflict}, thrust::optional{std::byte{0}} /* dummy */); } else { - auto update = (old == invalid_component_id::value); return thrust::make_tuple( - update ? thrust::optional{bucket_idx_next} : thrust::nullopt, - update ? thrust::optional{std::byte{0}} /* dummy */ : thrust::nullopt); + success ? thrust::optional{bucket_idx_next} : thrust::nullopt, + success ? thrust::optional{std::byte{0}} /* dummy */ : thrust::nullopt); } } @@ -457,33 +455,11 @@ void weakly_connected_components_impl(raft::handle_t const& handle, std::numeric_limits::max()); } - // FIXME: we need to add host_scalar_scatter -#if 1 - rmm::device_uvector d_counts(comm_size, handle.get_stream()); - raft::update_device(d_counts.data(), - init_max_new_root_counts.data(), - init_max_new_root_counts.size(), - handle.get_stream()); - device_bcast( - comm, d_counts.data(), d_counts.data(), d_counts.size(), int{0}, handle.get_stream()); - raft::update_host( - &init_max_new_roots, d_counts.data() + comm_rank, size_t{1}, handle.get_stream()); -#else init_max_new_roots = - host_scalar_scatter(comm, init_max_new_root_counts.data(), int{0}, handle.get_stream()); -#endif + host_scalar_scatter(comm, init_max_new_root_counts, int{0}, handle.get_stream()); } else { - // FIXME: we need to add host_scalar_scatter -#if 1 - rmm::device_uvector d_counts(comm_size, handle.get_stream()); - device_bcast( - comm, d_counts.data(), d_counts.data(), d_counts.size(), int{0}, handle.get_stream()); - raft::update_host( - &init_max_new_roots, d_counts.data() + comm_rank, size_t{1}, handle.get_stream()); -#else init_max_new_roots = - host_scalar_scatter(comm, init_max_new_root_counts.data(), int{0}, handle.get_stream()); -#endif + host_scalar_scatter(comm, std::vector{}, int{0}, handle.get_stream()); } handle.sync_stream(); diff --git a/cpp/tests/c_api/leiden_test.c b/cpp/tests/c_api/leiden_test.c index 9e91adf9f89..df206ebd1ed 100644 --- a/cpp/tests/c_api/leiden_test.c +++ b/cpp/tests/c_api/leiden_test.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,7 +161,7 @@ int test_leiden_no_weights() vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; vertex_t h_result[] = {1, 1, 1, 2, 0, 0}; - weight_t expected_modularity = 0.0859375; + weight_t expected_modularity = 0.125; // Louvain wants store_transposed = FALSE return generic_leiden_test(h_src, diff --git a/cpp/tests/c_api/louvain_test.c b/cpp/tests/c_api/louvain_test.c index e9ac5c9ff06..41d777545b2 100644 --- a/cpp/tests/c_api/louvain_test.c +++ b/cpp/tests/c_api/louvain_test.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,22 +46,39 @@ int generic_louvain_test(vertex_t* h_src, cugraph_graph_t* p_graph = NULL; cugraph_hierarchical_clustering_result_t* p_result = NULL; - data_type_id_t vertex_tid = INT32; - data_type_id_t edge_tid = INT32; - data_type_id_t weight_tid = FLOAT32; + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; data_type_id_t edge_id_tid = INT32; data_type_id_t edge_type_tid = INT32; p_handle = cugraph_create_resource_handle(NULL); TEST_ASSERT(test_ret_value, p_handle != NULL, "resource handle creation failed."); - ret_code = create_sg_test_graph(p_handle, vertex_tid, edge_tid, h_src, h_dst, weight_tid, h_wgt, edge_type_tid, NULL, edge_id_tid, NULL, num_edges, store_transposed, FALSE, FALSE, FALSE, &p_graph, &ret_error); + ret_code = create_sg_test_graph(p_handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + NULL, + edge_id_tid, + NULL, + num_edges, + store_transposed, + FALSE, + FALSE, + FALSE, + &p_graph, + &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "create_test_graph failed."); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); - ret_code = - cugraph_louvain(p_handle, p_graph, max_level, threshold, resolution, FALSE, &p_result, &ret_error); + ret_code = cugraph_louvain( + p_handle, p_graph, max_level, threshold, resolution, FALSE, &p_result, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, "cugraph_louvain failed."); @@ -141,10 +158,10 @@ int test_louvain_no_weight() weight_t threshold = 1e-7; weight_t resolution = 1.0; - vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; - vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; - vertex_t h_result[] = {1, 1, 1, 2, 0, 0}; - weight_t expected_modularity = 0.0859375; + vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; + vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t h_result[] = {1, 1, 1, 1, 0, 0}; + weight_t expected_modularity = 0.125; // Louvain wants store_transposed = FALSE return generic_louvain_test(h_src, diff --git a/cpp/tests/community/louvain_test.cpp b/cpp/tests/community/louvain_test.cpp index 1e1fb6d4c33..284dcc94b8c 100644 --- a/cpp/tests/community/louvain_test.cpp +++ b/cpp/tests/community/louvain_test.cpp @@ -317,72 +317,6 @@ TEST(louvain_legacy, success) } } -TEST(louvain_legacy_renumbered, success) -{ - raft::handle_t handle; - - auto stream = handle.get_stream(); - - std::vector off_h = {0, 16, 25, 30, 34, 38, 42, 44, 46, 48, 50, 52, - 54, 56, 73, 85, 95, 101, 107, 112, 117, 121, 125, 129, - 132, 135, 138, 141, 144, 147, 149, 151, 153, 155, 156}; - std::vector ind_h = { - 1, 3, 7, 11, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 30, 33, 0, 5, 11, 15, 16, 19, 21, - 25, 30, 4, 13, 14, 22, 27, 0, 9, 20, 24, 2, 13, 15, 26, 1, 13, 14, 18, 13, 15, 0, 16, - 13, 14, 3, 20, 13, 14, 0, 1, 13, 22, 2, 4, 5, 6, 8, 10, 12, 14, 17, 18, 19, 22, 25, - 28, 29, 31, 32, 2, 5, 8, 10, 13, 15, 17, 18, 22, 29, 31, 32, 0, 1, 4, 6, 14, 16, 18, - 19, 21, 28, 0, 1, 7, 15, 19, 21, 0, 13, 14, 26, 27, 28, 0, 5, 13, 14, 15, 0, 1, 13, - 16, 16, 0, 3, 9, 23, 0, 1, 15, 16, 2, 12, 13, 14, 0, 20, 24, 0, 3, 23, 0, 1, 13, - 4, 17, 27, 2, 17, 26, 13, 15, 17, 13, 14, 0, 1, 13, 14, 13, 14, 0}; - - std::vector w_h = { - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; - - int num_verts = off_h.size() - 1; - int num_edges = ind_h.size(); - - rmm::device_uvector offsets_v(num_verts + 1, stream); - rmm::device_uvector indices_v(num_edges, stream); - rmm::device_uvector weights_v(num_edges, stream); - rmm::device_uvector result_v(num_verts, stream); - - raft::update_device(offsets_v.data(), off_h.data(), off_h.size(), stream); - raft::update_device(indices_v.data(), ind_h.data(), ind_h.size(), stream); - raft::update_device(weights_v.data(), w_h.data(), w_h.size(), stream); - - cugraph::legacy::GraphCSRView G( - offsets_v.data(), indices_v.data(), weights_v.data(), num_verts, num_edges); - - float modularity{0.0}; - size_t num_level = 40; - - // "FIXME": remove this check once we drop support for Pascal - // - // Calling louvain on Pascal will throw an exception, we'll check that - // this is the behavior while we still support Pascal (device_prop.major < 7) - // - if (handle.get_device_properties().major < 7) { - EXPECT_THROW(cugraph::louvain(handle, G, result_v.data()), cugraph::logic_error); - } else { - std::tie(num_level, modularity) = cugraph::louvain(handle, G, result_v.data()); - - auto cluster_id = cugraph::test::to_host(handle, result_v); - - int min = *min_element(cluster_id.begin(), cluster_id.end()); - - ASSERT_GE(min, 0); - ASSERT_FLOAT_EQ(modularity, 0.41880345); - } -} - using Tests_Louvain_File = Tests_Louvain; using Tests_Louvain_File32 = Tests_Louvain; using Tests_Louvain_File64 = Tests_Louvain; @@ -390,11 +324,15 @@ using Tests_Louvain_Rmat = Tests_Louvain; using Tests_Louvain_Rmat32 = Tests_Louvain; using Tests_Louvain_Rmat64 = Tests_Louvain; +#if 0 +// FIXME: Reenable legacy tests once threshold parameter is exposed +// by louvain legacy API. TEST_P(Tests_Louvain_File, CheckInt32Int32FloatFloatLegacy) { run_legacy_test( override_File_Usecase_with_cmd_line_arguments(GetParam())); } +#endif TEST_P(Tests_Louvain_File, CheckInt32Int32FloatFloat) { @@ -458,11 +396,12 @@ TEST_P(Tests_Louvain_Rmat64, CheckInt64Int64FloatFloat) INSTANTIATE_TEST_SUITE_P( simple_test, Tests_Louvain_File, - ::testing::Combine( - ::testing::Values(Louvain_Usecase{std::nullopt, std::nullopt, std::nullopt, true, 3, 0.408695}, - Louvain_Usecase{20, double{1e-4}, std::nullopt, true, 3, 0.408695}, - Louvain_Usecase{100, double{1e-4}, double{0.8}, true, 3, 0.48336622}), - ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + ::testing::Combine(::testing::Values( + Louvain_Usecase{ + std::nullopt, std::nullopt, std::nullopt, true, 3, 0.39907956}, + Louvain_Usecase{20, double{1e-3}, std::nullopt, true, 3, 0.39907956}, + Louvain_Usecase{100, double{1e-3}, double{0.8}, true, 3, 0.47547662}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); INSTANTIATE_TEST_SUITE_P( file_benchmark_test, /* note that the test filename can be overridden in benchmarking (with diff --git a/datasets/README.md b/datasets/README.md index e42413fc996..a23dc644081 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -120,9 +120,13 @@ The benchmark datasets are described below: | soc-twitter-2010 | 21,297,772 | 265,025,809 | No | No | **cit-Patents** : A citation graph that includes all citations made by patents granted between 1975 and 1999, totaling 16,522,438 citations. + **soc-LiveJournal** : A graph of the LiveJournal social network. + **europe_osm** : A graph of OpenStreetMap data for Europe. + **hollywood** : A graph of movie actors where vertices are actors, and two actors are joined by an edge whenever they appeared in a movie together. + **soc-twitter-2010** : A network of follower relationships from a snapshot of Twitter in 2010, where an edge from i to j indicates that j is a follower of i. _NOTE: the benchmark datasets were converted to a CSV format from their original format described in the reference URL below, and in doing so had edge weights and isolated vertices discarded._ diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index edeeface4c4..14dc5d84f90 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -27,11 +27,12 @@ import cugraph import warnings -from cugraph.utilities.utils import import_optional, MissingModule +import dask.array as dar +import dask.dataframe as dd +import dask.distributed as distributed +import dask_cudf -dd = import_optional("dask.dataframe") -distributed = import_optional("dask.distributed") -dask_cudf = import_optional("dask_cudf") +from cugraph.utilities.utils import import_optional, MissingModule torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") @@ -367,6 +368,13 @@ def __infer_offsets( } ) + def __dask_array_from_numpy(self, array: np.ndarray, npartitions: int): + return dar.from_array( + array, + meta=np.array([], dtype=array.dtype), + chunks=max(1, len(array) // npartitions), + ) + def __construct_graph( self, edge_info: Dict[Tuple[str, str, str], List[TensorType]], @@ -464,22 +472,32 @@ def __construct_graph( ] ) - df = pandas.DataFrame( - { - "src": pandas.Series(na_dst) - if order == "CSC" - else pandas.Series(na_src), - "dst": pandas.Series(na_src) - if order == "CSC" - else pandas.Series(na_dst), - "etp": pandas.Series(na_etp), - } - ) - vertex_dtype = df.src.dtype + vertex_dtype = na_src.dtype if multi_gpu: - nworkers = len(distributed.get_client().scheduler_info()["workers"]) - df = dd.from_pandas(df, npartitions=nworkers if len(df) > 32 else 1) + client = distributed.get_client() + nworkers = len(client.scheduler_info()["workers"]) + npartitions = nworkers * 4 + + src_dar = self.__dask_array_from_numpy(na_src, npartitions) + del na_src + + dst_dar = self.__dask_array_from_numpy(na_dst, npartitions) + del na_dst + + etp_dar = self.__dask_array_from_numpy(na_etp, npartitions) + del na_etp + + df = dd.from_dask_array(etp_dar, columns=["etp"]) + df["src"] = dst_dar if order == "CSC" else src_dar + df["dst"] = src_dar if order == "CSC" else dst_dar + + del src_dar + del dst_dar + del etp_dar + + if df.etp.dtype != "int32": + raise ValueError("Edge type must be int32!") # Ensure the dataframe is constructed on each partition # instead of adding additional synchronization head from potential @@ -487,9 +505,9 @@ def __construct_graph( def get_empty_df(): return cudf.DataFrame( { + "etp": cudf.Series([], dtype="int32"), "src": cudf.Series([], dtype=vertex_dtype), "dst": cudf.Series([], dtype=vertex_dtype), - "etp": cudf.Series([], dtype="int32"), } ) @@ -500,9 +518,23 @@ def get_empty_df(): if len(f) > 0 else get_empty_df(), meta=get_empty_df(), - ).reset_index(drop=True) + ).reset_index( + drop=True + ) # should be ok for dask else: - df = cudf.from_pandas(df).reset_index(drop=True) + df = pandas.DataFrame( + { + "src": pandas.Series(na_dst) + if order == "CSC" + else pandas.Series(na_src), + "dst": pandas.Series(na_src) + if order == "CSC" + else pandas.Series(na_dst), + "etp": pandas.Series(na_etp), + } + ) + df = cudf.from_pandas(df) + df.reset_index(drop=True, inplace=True) graph = cugraph.MultiGraph(directed=True) if multi_gpu: @@ -521,6 +553,7 @@ def get_empty_df(): edge_type="etp", ) + del df return graph @property diff --git a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py index 083c4a2b37b..1512901822a 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py @@ -24,7 +24,7 @@ import torch import numpy as np from cugraph.gnn import FeatureStore -from cugraph.experimental.datasets import karate +from cugraph.datasets import karate import tempfile diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py index 55aebf305da..f5035a38621 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py @@ -15,7 +15,6 @@ from cugraph_pyg.loader import CuGraphNeighborLoader from cugraph_pyg.data import CuGraphStore - from cugraph.utilities.utils import import_optional, MissingModule torch = import_optional("torch") diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py index 13c9c90c7c2..be8f8245807 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py @@ -386,3 +386,29 @@ def test_mg_frame_handle(graph, dask_client): F, G, N = graph cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) assert isinstance(cugraph_store._EXPERIMENTAL__CuGraphStore__graph._plc_graph, dict) + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +def test_cugraph_loader_large_index(dask_client): + large_index = ( + np.random.randint(0, 1_000_000, (100_000_000,)), + np.random.randint(0, 1_000_000, (100_000_000,)), + ) + + large_features = np.random.randint(0, 50, (1_000_000,)) + F = cugraph.gnn.FeatureStore(backend="torch") + F.add_data(large_features, "N", "f") + + store = CuGraphStore( + F, + {("N", "e", "N"): large_index}, + {"N": 1_000_000}, + multi_gpu=True, + ) + + graph = store._subgraph() + assert isinstance(graph, cugraph.Graph) + + el = graph.view_edge_list().compute() + assert (el["src"].values_host - large_index[0]).sum() == 0 + assert (el["dst"].values_host - large_index[1]).sum() == 0 diff --git a/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py b/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py index 5f9eac6b2a3..361226c8071 100644 --- a/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py +++ b/python/cugraph-service/server/cugraph_service_server/testing/benchmark_server_extension.py @@ -17,7 +17,7 @@ import cugraph from cugraph.experimental import PropertyGraph, MGPropertyGraph -from cugraph.experimental import datasets +from cugraph import datasets from cugraph.generators import rmat diff --git a/python/cugraph/cugraph/dask/community/leiden.py b/python/cugraph/cugraph/dask/community/leiden.py index 75582fa48f7..67bd0876ce6 100644 --- a/python/cugraph/cugraph/dask/community/leiden.py +++ b/python/cugraph/cugraph/dask/community/leiden.py @@ -125,7 +125,7 @@ def leiden( Examples -------- - >>> from cugraph.experimental.datasets import karate + >>> from cugraph.datasets import karate >>> G = karate.get_graph(fetch=True) >>> parts, modularity_score = cugraph.leiden(G) diff --git a/python/cugraph/cugraph/dask/community/louvain.py b/python/cugraph/cugraph/dask/community/louvain.py index 8efbbafaf7b..1b091817a1a 100644 --- a/python/cugraph/cugraph/dask/community/louvain.py +++ b/python/cugraph/cugraph/dask/community/louvain.py @@ -129,7 +129,7 @@ def louvain( Examples -------- - >>> from cugraph.experimental.datasets import karate + >>> from cugraph.datasets import karate >>> G = karate.get_graph(fetch=True) >>> parts = cugraph.louvain(G) diff --git a/python/cugraph/cugraph/datasets/__init__.py b/python/cugraph/cugraph/datasets/__init__.py index 65a820f108b..ac18274d354 100644 --- a/python/cugraph/cugraph/datasets/__init__.py +++ b/python/cugraph/cugraph/datasets/__init__.py @@ -39,3 +39,13 @@ small_tree = Dataset(meta_path / "small_tree.yaml") toy_graph = Dataset(meta_path / "toy_graph.yaml") toy_graph_undirected = Dataset(meta_path / "toy_graph_undirected.yaml") + +# Benchmarking datasets: be mindful of memory usage +# 250 MB +soc_livejournal = Dataset(meta_path / "soc-livejournal1.yaml") +# 965 MB +cit_patents = Dataset(meta_path / "cit-patents.yaml") +# 1.8 GB +europe_osm = Dataset(meta_path / "europe_osm.yaml") +# 1.5 GB +hollywood = Dataset(meta_path / "hollywood.yaml") diff --git a/python/cugraph/cugraph/datasets/dataset.py b/python/cugraph/cugraph/datasets/dataset.py index 877eade7708..dd7aa0df00a 100644 --- a/python/cugraph/cugraph/datasets/dataset.py +++ b/python/cugraph/cugraph/datasets/dataset.py @@ -14,44 +14,45 @@ import cudf import yaml import os +import pandas as pd from pathlib import Path from cugraph.structure.graph_classes import Graph class DefaultDownloadDir: """ - Maintains the path to the download directory used by Dataset instances. + Maintains a path to be used as a default download directory. + + All DefaultDownloadDir instances are based on RAPIDS_DATASET_ROOT_DIR if + set, or _default_base_dir if not set. + Instances of this class are typically shared by several Dataset instances in order to allow for the download directory to be defined and updated by a single object. """ - def __init__(self): - self._path = Path( - os.environ.get("RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets") - ) + _default_base_dir = Path.home() / ".cugraph/datasets" - @property - def path(self): + def __init__(self, *, subdir=""): """ - If `path` is not set, set it to the environment variable - RAPIDS_DATASET_ROOT_DIR. If the variable is not set, default to the - user's home directory. + subdir can be specified to provide a specialized dir under the base dir. """ - if self._path is None: - self._path = Path( - os.environ.get( - "RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets" - ) - ) - return self._path + self._subdir = Path(subdir) + self.reset() + + @property + def path(self): + return self._path.absolute() @path.setter def path(self, new): self._path = Path(new) - def clear(self): - self._path = None + def reset(self): + self._basedir = Path( + os.environ.get("RAPIDS_DATASET_ROOT_DIR", self._default_base_dir) + ) + self._path = self._basedir / self._subdir default_download_dir = DefaultDownloadDir() @@ -159,7 +160,7 @@ def unload(self): """ self._edgelist = None - def get_edgelist(self, download=False): + def get_edgelist(self, download=False, reader="cudf"): """ Return an Edgelist @@ -168,6 +169,9 @@ def get_edgelist(self, download=False): download : Boolean (default=False) Automatically download the dataset from the 'url' location within the YAML file. + + reader : 'cudf' or 'pandas' (default='cudf') + The library used to read a CSV and return an edgelist DataFrame. """ if self._edgelist is None: full_path = self.get_path() @@ -180,14 +184,29 @@ def get_edgelist(self, download=False): " exist. Try setting download=True" " to download the datafile" ) + header = None if isinstance(self.metadata["header"], int): header = self.metadata["header"] - self._edgelist = cudf.read_csv( - full_path, + + if reader == "cudf": + self.__reader = cudf.read_csv + elif reader == "pandas": + self.__reader = pd.read_csv + else: + raise ValueError( + "reader must be a module with a read_csv function compatible with \ + cudf.read_csv" + ) + + self._edgelist = self.__reader( + filepath_or_buffer=full_path, delimiter=self.metadata["delim"], names=self.metadata["col_names"], - dtype=self.metadata["col_types"], + dtype={ + self.metadata["col_names"][i]: self.metadata["col_types"][i] + for i in range(len(self.metadata["col_types"])) + }, header=header, ) @@ -219,6 +238,10 @@ def get_graph( dataset -if present- will be applied to the Graph. If the dataset does not contain weights, the Graph returned will be unweighted regardless of ignore_weights. + + store_transposed: Boolean (default=False) + If True, stores the transpose of the adjacency matrix. Required + for certain algorithms, such as pagerank. """ if self._edgelist is None: self.get_edgelist(download) @@ -237,20 +260,19 @@ def get_graph( "(or subclass) type or instance, got: " f"{type(create_using)}" ) - if len(self.metadata["col_names"]) > 2 and not (ignore_weights): G.from_cudf_edgelist( self._edgelist, - source="src", - destination="dst", - edge_attr="wgt", + source=self.metadata["col_names"][0], + destination=self.metadata["col_names"][1], + edge_attr=self.metadata["col_names"][2], store_transposed=store_transposed, ) else: G.from_cudf_edgelist( self._edgelist, - source="src", - destination="dst", + source=self.metadata["col_names"][0], + destination=self.metadata["col_names"][1], store_transposed=store_transposed, ) return G @@ -331,7 +353,7 @@ def download_all(force=False): def set_download_dir(path): """ - Set the download location fors datasets + Set the download location for datasets Parameters ---------- @@ -339,10 +361,10 @@ def set_download_dir(path): Location used to store datafiles """ if path is None: - default_download_dir.clear() + default_download_dir.reset() else: default_download_dir.path = path def get_download_dir(): - return default_download_dir.path.absolute() + return default_download_dir.path diff --git a/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml b/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml new file mode 100644 index 00000000000..d5c4cf195bd --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/cit-patents.yaml @@ -0,0 +1,22 @@ +name: cit-Patents +file_type: .csv +description: A citation graph that includes all citations made by patents granted between 1975 and 1999, totaling 16,522,438 citations. +author: NBER +refs: + J. Leskovec, J. Kleinberg and C. Faloutsos. Graphs over Time Densification Laws, Shrinking Diameters and Possible Explanations. + ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD), 2005. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: true +is_directed: true +is_multigraph: false +is_symmetric: false +number_of_edges: 16518948 +number_of_nodes: 3774768 +url: https://data.rapids.ai/cugraph/datasets/cit-Patents.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml b/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml new file mode 100644 index 00000000000..fe0e42a4b86 --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/europe_osm.yaml @@ -0,0 +1,21 @@ +name: europe_osm +file_type: .csv +description: A graph of OpenStreetMap data for Europe. +author: M. Kobitzsh / Geofabrik GmbH +refs: + Rossi, Ryan. Ahmed, Nesreen. The Network Data Respoistory with Interactive Graph Analytics and Visualization. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: true +number_of_edges: 54054660 +number_of_nodes: 50912018 +url: https://data.rapids.ai/cugraph/datasets/europe_osm.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/hollywood.yaml b/python/cugraph/cugraph/datasets/metadata/hollywood.yaml new file mode 100644 index 00000000000..2f09cf7679b --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/hollywood.yaml @@ -0,0 +1,26 @@ +name: hollywood +file_type: .csv +description: + A graph of movie actors where vertices are actors, and two actors are + joined by an edge whenever they appeared in a movie together. +author: Laboratory for Web Algorithmics (LAW) +refs: + The WebGraph Framework I Compression Techniques, Paolo Boldi + and Sebastiano Vigna, Proc. of the Thirteenth International + World Wide Web Conference (WWW 2004), 2004, Manhattan, USA, + pp. 595--601, ACM Press. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: true +number_of_edges: 57515616 +number_of_nodes: 1139905 +url: https://data.rapids.ai/cugraph/datasets/hollywood.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml b/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml new file mode 100644 index 00000000000..fafc68acb9b --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/soc-livejournal1.yaml @@ -0,0 +1,22 @@ +name: soc-LiveJournal1 +file_type: .csv +description: A graph of the LiveJournal social network. +author: L. Backstrom, D. Huttenlocher, J. Kleinberg, X. Lan +refs: + L. Backstrom, D. Huttenlocher, J. Kleinberg, X. Lan. Group Formation in + Large Social Networks Membership, Growth, and Evolution. KDD, 2006. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: true +is_directed: true +is_multigraph: false +is_symmetric: false +number_of_edges: 68993773 +number_of_nodes: 4847571 +url: https://data.rapids.ai/cugraph/datasets/soc-LiveJournal1.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml b/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml new file mode 100644 index 00000000000..df5df5735af --- /dev/null +++ b/python/cugraph/cugraph/datasets/metadata/soc-twitter-2010.yaml @@ -0,0 +1,22 @@ +name: soc-twitter-2010 +file_type: .csv +description: A network of follower relationships from a snapshot of Twitter in 2010, where an edge from i to j indicates that j is a follower of i. +author: H. Kwak, C. Lee, H. Park, S. Moon +refs: + J. Yang, J. Leskovec. Temporal Variation in Online Media. ACM Intl. + Conf. on Web Search and Data Mining (WSDM '11), 2011. +delim: " " +header: None +col_names: + - src + - dst +col_types: + - int32 + - int32 +has_loop: false +is_directed: false +is_multigraph: false +is_symmetric: false +number_of_edges: 530051354 +number_of_nodes: 21297772 +url: https://data.rapids.ai/cugraph/datasets/soc-twitter-2010.csv \ No newline at end of file diff --git a/python/cugraph/cugraph/experimental/datasets/__init__.py b/python/cugraph/cugraph/experimental/datasets/__init__.py deleted file mode 100644 index 18220243df1..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from cugraph.experimental.datasets.dataset import ( - Dataset, - load_all, - set_download_dir, - get_download_dir, - default_download_dir, -) -from cugraph.experimental.datasets import metadata -from pathlib import Path - -from cugraph.utilities.api_tools import promoted_experimental_warning_wrapper - - -Dataset = promoted_experimental_warning_wrapper(Dataset) -load_all = promoted_experimental_warning_wrapper(load_all) -set_download_dir = promoted_experimental_warning_wrapper(set_download_dir) -get_download_dir = promoted_experimental_warning_wrapper(get_download_dir) - -meta_path = Path(__file__).parent / "metadata" - - -# individual dataset objects -karate = Dataset(meta_path / "karate.yaml") -karate_data = Dataset(meta_path / "karate_data.yaml") -karate_undirected = Dataset(meta_path / "karate_undirected.yaml") -karate_asymmetric = Dataset(meta_path / "karate_asymmetric.yaml") -karate_disjoint = Dataset(meta_path / "karate-disjoint.yaml") -dolphins = Dataset(meta_path / "dolphins.yaml") -polbooks = Dataset(meta_path / "polbooks.yaml") -netscience = Dataset(meta_path / "netscience.yaml") -cyber = Dataset(meta_path / "cyber.yaml") -small_line = Dataset(meta_path / "small_line.yaml") -small_tree = Dataset(meta_path / "small_tree.yaml") -toy_graph = Dataset(meta_path / "toy_graph.yaml") -toy_graph_undirected = Dataset(meta_path / "toy_graph_undirected.yaml") -email_Eu_core = Dataset(meta_path / "email-Eu-core.yaml") -ktruss_polbooks = Dataset(meta_path / "ktruss_polbooks.yaml") - - -# batches of datasets -DATASETS_UNDIRECTED = [karate, dolphins] - -DATASETS_UNDIRECTED_WEIGHTS = [netscience] - -DATASETS_UNRENUMBERED = [karate_disjoint] - -DATASETS = [dolphins, netscience, karate_disjoint] - -DATASETS_SMALL = [karate, dolphins, polbooks] - -STRONGDATASETS = [dolphins, netscience, email_Eu_core] - -DATASETS_KTRUSS = [(polbooks, ktruss_polbooks)] - -MEDIUM_DATASETS = [polbooks] - -SMALL_DATASETS = [karate, dolphins, netscience] - -RLY_SMALL_DATASETS = [small_line, small_tree] - -ALL_DATASETS = [karate, dolphins, netscience, polbooks, small_line, small_tree] - -ALL_DATASETS_WGT = [karate, dolphins, netscience, polbooks, small_line, small_tree] - -TEST_GROUP = [dolphins, netscience] diff --git a/python/cugraph/cugraph/experimental/datasets/dataset.py b/python/cugraph/cugraph/experimental/datasets/dataset.py deleted file mode 100644 index 6b395d50fef..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/dataset.py +++ /dev/null @@ -1,312 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import cudf -import yaml -import os -from pathlib import Path -from cugraph.structure.graph_classes import Graph - - -class DefaultDownloadDir: - """ - Maintains the path to the download directory used by Dataset instances. - Instances of this class are typically shared by several Dataset instances - in order to allow for the download directory to be defined and updated by - a single object. - """ - - def __init__(self): - self._path = Path( - os.environ.get("RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets") - ) - - @property - def path(self): - """ - If `path` is not set, set it to the environment variable - RAPIDS_DATASET_ROOT_DIR. If the variable is not set, default to the - user's home directory. - """ - if self._path is None: - self._path = Path( - os.environ.get( - "RAPIDS_DATASET_ROOT_DIR", Path.home() / ".cugraph/datasets" - ) - ) - return self._path - - @path.setter - def path(self, new): - self._path = Path(new) - - def clear(self): - self._path = None - - -default_download_dir = DefaultDownloadDir() - - -class Dataset: - """ - A Dataset Object, used to easily import edgelist data and cuGraph.Graph - instances. - - Parameters - ---------- - meta_data_file_name : yaml file - The metadata file for the specific graph dataset, which includes - information on the name, type, url link, data loading format, graph - properties - """ - - def __init__( - self, - metadata_yaml_file=None, - csv_file=None, - csv_header=None, - csv_delim=" ", - csv_col_names=None, - csv_col_dtypes=None, - ): - self._metadata_file = None - self._dl_path = default_download_dir - self._edgelist = None - self._path = None - - if metadata_yaml_file is not None and csv_file is not None: - raise ValueError("cannot specify both metadata_yaml_file and csv_file") - - elif metadata_yaml_file is not None: - with open(metadata_yaml_file, "r") as file: - self.metadata = yaml.safe_load(file) - self._metadata_file = Path(metadata_yaml_file) - - elif csv_file is not None: - if csv_col_names is None or csv_col_dtypes is None: - raise ValueError( - "csv_col_names and csv_col_dtypes must both be " - "not None when csv_file is specified." - ) - self._path = Path(csv_file) - if self._path.exists() is False: - raise FileNotFoundError(csv_file) - self.metadata = { - "name": self._path.with_suffix("").name, - "file_type": ".csv", - "url": None, - "header": csv_header, - "delim": csv_delim, - "col_names": csv_col_names, - "col_types": csv_col_dtypes, - } - - else: - raise ValueError("must specify either metadata_yaml_file or csv_file") - - def __str__(self): - """ - Use the basename of the meta_data_file the instance was constructed with, - without any extension, as the string repr. - """ - # The metadata file is likely to have a more descriptive file name, so - # use that one first if present. - # FIXME: this may need to provide a more unique or descriptive string repr - if self._metadata_file is not None: - return self._metadata_file.with_suffix("").name - else: - return self.get_path().with_suffix("").name - - def __download_csv(self, url): - """ - Downloads the .csv file from url to the current download path - (self._dl_path), updates self._path with the full path to the - downloaded file, and returns the latest value of self._path. - """ - self._dl_path.path.mkdir(parents=True, exist_ok=True) - - filename = self.metadata["name"] + self.metadata["file_type"] - if self._dl_path.path.is_dir(): - df = cudf.read_csv(url) - self._path = self._dl_path.path / filename - df.to_csv(self._path, index=False) - - else: - raise RuntimeError( - f"The directory {self._dl_path.path.absolute()}" "does not exist" - ) - return self._path - - def unload(self): - - """ - Remove all saved internal objects, forcing them to be re-created when - accessed. - - NOTE: This will cause calls to get_*() to re-read the dataset file from - disk. The caller should ensure the file on disk has not moved/been - deleted/changed. - """ - self._edgelist = None - - def get_edgelist(self, fetch=False): - """ - Return an Edgelist - - Parameters - ---------- - fetch : Boolean (default=False) - Automatically fetch for the dataset from the 'url' location within - the YAML file. - """ - if self._edgelist is None: - full_path = self.get_path() - if not full_path.is_file(): - if fetch: - full_path = self.__download_csv(self.metadata["url"]) - else: - raise RuntimeError( - f"The datafile {full_path} does not" - " exist. Try get_edgelist(fetch=True)" - " to download the datafile" - ) - header = None - if isinstance(self.metadata["header"], int): - header = self.metadata["header"] - self._edgelist = cudf.read_csv( - full_path, - delimiter=self.metadata["delim"], - names=self.metadata["col_names"], - dtype=self.metadata["col_types"], - header=header, - ) - - return self._edgelist - - def get_graph( - self, - fetch=False, - create_using=Graph, - ignore_weights=False, - store_transposed=False, - ): - """ - Return a Graph object. - - Parameters - ---------- - fetch : Boolean (default=False) - Downloads the dataset from the web. - - create_using: cugraph.Graph (instance or class), optional - (default=Graph) - Specify the type of Graph to create. Can pass in an instance to - create a Graph instance with specified 'directed' attribute. - - ignore_weights : Boolean (default=False) - Ignores weights in the dataset if True, resulting in an - unweighted Graph. If False (the default), weights from the - dataset -if present- will be applied to the Graph. If the - dataset does not contain weights, the Graph returned will - be unweighted regardless of ignore_weights. - """ - if self._edgelist is None: - self.get_edgelist(fetch) - - if create_using is None: - G = Graph() - elif isinstance(create_using, Graph): - # what about BFS if trnaposed is True - attrs = {"directed": create_using.is_directed()} - G = type(create_using)(**attrs) - elif type(create_using) is type: - G = create_using() - else: - raise TypeError( - "create_using must be a cugraph.Graph " - "(or subclass) type or instance, got: " - f"{type(create_using)}" - ) - - if len(self.metadata["col_names"]) > 2 and not (ignore_weights): - G.from_cudf_edgelist( - self._edgelist, - source="src", - destination="dst", - edge_attr="wgt", - store_transposed=store_transposed, - ) - else: - G.from_cudf_edgelist( - self._edgelist, - source="src", - destination="dst", - store_transposed=store_transposed, - ) - return G - - def get_path(self): - """ - Returns the location of the stored dataset file - """ - if self._path is None: - self._path = self._dl_path.path / ( - self.metadata["name"] + self.metadata["file_type"] - ) - - return self._path.absolute() - - -def load_all(force=False): - """ - Looks in `metadata` directory and fetches all datafiles from the the URLs - provided in each YAML file. - - Parameters - force : Boolean (default=False) - Overwrite any existing copies of datafiles. - """ - default_download_dir.path.mkdir(parents=True, exist_ok=True) - - meta_path = Path(__file__).parent.absolute() / "metadata" - for file in meta_path.iterdir(): - meta = None - if file.suffix == ".yaml": - with open(meta_path / file, "r") as metafile: - meta = yaml.safe_load(metafile) - - if "url" in meta: - filename = meta["name"] + meta["file_type"] - save_to = default_download_dir.path / filename - if not save_to.is_file() or force: - df = cudf.read_csv(meta["url"]) - df.to_csv(save_to, index=False) - - -def set_download_dir(path): - """ - Set the download directory for fetching datasets - - Parameters - ---------- - path : String - Location used to store datafiles - """ - if path is None: - default_download_dir.clear() - else: - default_download_dir.path = path - - -def get_download_dir(): - return default_download_dir.path.absolute() diff --git a/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml b/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml deleted file mode 100644 index 69a79db9cd9..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/datasets_config.yaml +++ /dev/null @@ -1,5 +0,0 @@ ---- -fetch: "False" -force: "False" -# path where datasets will be downloaded to and stored -download_dir: "datasets" diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py b/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py deleted file mode 100644 index 081b2ae8260..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml deleted file mode 100644 index 93ab5345442..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/cyber.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: cyber -file_type: .csv -author: N/A -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/cyber.csv -refs: N/A -col_names: - - idx - - srcip - - dstip -col_types: - - int32 - - str - - str -delim: "," -header: 0 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 2546575 -number_of_nodes: 706529 -number_of_lines: 2546576 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml deleted file mode 100644 index e4951375321..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/dolphins.yaml +++ /dev/null @@ -1,25 +0,0 @@ -name: dolphins -file_type: .csv -author: D. Lusseau -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/dolphins.csv -refs: - D. Lusseau, K. Schneider, O. J. Boisseau, P. Haase, E. Slooten, and S. M. Dawson, - The bottlenose dolphin community of Doubtful Sound features a large proportion of - long-lasting associations, Behavioral Ecology and Sociobiology 54, 396-405 (2003). -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -delim: " " -header: None -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 318 -number_of_nodes: 62 -number_of_lines: 318 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml deleted file mode 100644 index 97d0dc82ee3..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/email-Eu-core.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: email-Eu-core -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/email-Eu-core.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 25571 -number_of_nodes: 1005 -number_of_lines: 25571 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml deleted file mode 100644 index 0c0eaf78b63..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate-disjoint.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate-disjoint -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-disjoint.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: True -is_multigraph: false -is_symmetric: true -number_of_edges: 312 -number_of_nodes: 68 -number_of_lines: 312 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml deleted file mode 100644 index 273381ed368..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: karate -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 156 -number_of_nodes: 34 -number_of_lines: 156 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml deleted file mode 100644 index 3616b8fb3a5..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_asymmetric.yaml +++ /dev/null @@ -1,24 +0,0 @@ -name: karate-asymmetric -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-asymmetric.csv -delim: " " -header: None -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: true -is_directed: false -is_multigraph: false -is_symmetric: false -number_of_edges: 78 -number_of_nodes: 34 -number_of_lines: 78 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml deleted file mode 100644 index 9a8b27f21ae..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_data.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate-data -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate-data.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: "\t" -header: None -col_names: - - src - - dst -col_types: - - int32 - - int32 -has_loop: true -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 156 -number_of_nodes: 34 -number_of_lines: 156 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml deleted file mode 100644 index 1b45f86caee..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/karate_undirected.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: karate_undirected -file_type: .csv -author: Zachary W. -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/karate_undirected.csv -refs: - W. W. Zachary, An information flow model for conflict and fission in small groups, - Journal of Anthropological Research 33, 452-473 (1977). -delim: "\t" -header: None -col_names: - - src - - dst -col_types: - - int32 - - int32 -has_loop: true -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 78 -number_of_nodes: 34 -number_of_lines: 78 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml deleted file mode 100644 index 1ef29b3917e..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/ktruss_polbooks.yaml +++ /dev/null @@ -1,23 +0,0 @@ -name: ktruss_polbooks -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/ref/ktruss/polbooks.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: false -number_of_edges: 233 -number_of_nodes: 58 -number_of_lines: 233 - diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml deleted file mode 100644 index 2dca702df3d..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/netscience.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: netscience -file_type: .csv -author: Newman, Mark EJ -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/netscience.csv -refs: Finding community structure in networks using the eigenvectors of matrices. -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 2742 -number_of_nodes: 1461 -number_of_lines: 5484 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml deleted file mode 100644 index 5816e5672fd..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/polbooks.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: polbooks -file_type: .csv -author: V. Krebs -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/polbooks.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -is_directed: true -has_loop: null -is_multigraph: null -is_symmetric: true -number_of_edges: 882 -number_of_nodes: 105 -number_of_lines: 882 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml deleted file mode 100644 index 5b724ac99fd..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/small_line.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: small_line -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/small_line.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 9 -number_of_nodes: 10 -number_of_lines: 8 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml deleted file mode 100644 index 8eeac346d2a..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/small_tree.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: small_tree -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/small_tree.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: true -is_multigraph: false -is_symmetric: true -number_of_edges: 11 -number_of_nodes: 9 -number_of_lines: 11 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml deleted file mode 100644 index 819aad06f6a..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: toy_graph -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/toy_graph.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 16 -number_of_nodes: 6 -number_of_lines: 16 diff --git a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml b/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml deleted file mode 100644 index c6e86bdf334..00000000000 --- a/python/cugraph/cugraph/experimental/datasets/metadata/toy_graph_undirected.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: toy_graph_undirected -file_type: .csv -author: null -url: https://raw.githubusercontent.com/rapidsai/cugraph/branch-22.08/datasets/toy_graph_undirected.csv -refs: null -delim: " " -header: None -col_names: - - src - - dst - - wgt -col_types: - - int32 - - int32 - - float32 -has_loop: false -is_directed: false -is_multigraph: false -is_symmetric: true -number_of_edges: 8 -number_of_nodes: 6 -number_of_lines: 8 diff --git a/python/cugraph/cugraph/testing/__init__.py b/python/cugraph/cugraph/testing/__init__.py index f5f0bcb06eb..2b4a4fd3ebf 100644 --- a/python/cugraph/cugraph/testing/__init__.py +++ b/python/cugraph/cugraph/testing/__init__.py @@ -19,7 +19,7 @@ Resultset, load_resultset, get_resultset, - results_dir_path, + default_resultset_download_dir, ) from cugraph.datasets import ( cyber, @@ -34,6 +34,11 @@ email_Eu_core, toy_graph, toy_graph_undirected, + soc_livejournal, + cit_patents, + europe_osm, + hollywood, + # twitter, ) # @@ -66,3 +71,4 @@ toy_graph_undirected, ] DEFAULT_DATASETS = [dolphins, netscience, karate_disjoint] +BENCHMARKING_DATASETS = [soc_livejournal, cit_patents, europe_osm, hollywood] diff --git a/python/cugraph/cugraph/testing/generate_resultsets.py b/python/cugraph/cugraph/testing/generate_resultsets.py index 9724aca32dc..2ae0f52d88b 100644 --- a/python/cugraph/cugraph/testing/generate_resultsets.py +++ b/python/cugraph/cugraph/testing/generate_resultsets.py @@ -20,8 +20,14 @@ import cudf import cugraph from cugraph.datasets import dolphins, netscience, karate_disjoint, karate -from cugraph.testing import utils, Resultset, SMALL_DATASETS, results_dir_path +# from cugraph.testing import utils, Resultset, SMALL_DATASETS, results_dir_path +from cugraph.testing import ( + utils, + Resultset, + SMALL_DATASETS, + default_resultset_download_dir, +) _resultsets = {} @@ -224,6 +230,7 @@ def add_resultset(result_data_dictionary, **kwargs): ] ) # Generating ALL results files + results_dir_path = default_resultset_download_dir.path if not results_dir_path.exists(): results_dir_path.mkdir(parents=True, exist_ok=True) diff --git a/python/cugraph/cugraph/testing/resultset.py b/python/cugraph/cugraph/testing/resultset.py index 490e3a7c4ff..9570d7f3e04 100644 --- a/python/cugraph/cugraph/testing/resultset.py +++ b/python/cugraph/cugraph/testing/resultset.py @@ -16,10 +16,12 @@ import urllib.request import cudf -from cugraph.testing import utils +from cugraph.datasets.dataset import ( + DefaultDownloadDir, + default_download_dir, +) - -results_dir_path = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" / "resultsets" +# results_dir_path = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" / "resultsets" class Resultset: @@ -48,6 +50,42 @@ def get_cudf_dataframe(self): _resultsets = {} +def get_resultset(resultset_name, **kwargs): + """ + Returns the golden results for a specific test. + + Parameters + ---------- + resultset_name : String + Name of the test's module (currently just 'traversal' is supported) + + kwargs : + All distinct test details regarding the choice of algorithm, dataset, + and graph + """ + arg_dict = dict(kwargs) + arg_dict["resultset_name"] = resultset_name + # Example: + # {'a': 1, 'z': 9, 'c': 5, 'b': 2} becomes 'a-1-b-2-c-5-z-9' + resultset_key = "-".join( + [ + str(val) + for arg_dict_pair in sorted(arg_dict.items()) + for val in arg_dict_pair + ] + ) + uuid = _resultsets.get(resultset_key) + if uuid is None: + raise KeyError(f"results for {arg_dict} not found") + + results_dir_path = default_resultset_download_dir.path + results_filename = results_dir_path / (uuid + ".csv") + return cudf.read_csv(results_filename) + + +default_resultset_download_dir = DefaultDownloadDir(subdir="tests/resultsets") + + def load_resultset(resultset_name, resultset_download_url): """ Read a mapping file (.csv) in the _results_dir and save the @@ -56,17 +94,21 @@ def load_resultset(resultset_name, resultset_download_url): _results_dir, use resultset_download_url to download a file to install/unpack/etc. to _results_dir first. """ - mapping_file_path = results_dir_path / (resultset_name + "_mappings.csv") + # curr_resultset_download_dir = get_resultset_download_dir() + curr_resultset_download_dir = default_resultset_download_dir.path + # curr_download_dir = path + curr_download_dir = default_download_dir.path + mapping_file_path = curr_resultset_download_dir / (resultset_name + "_mappings.csv") if not mapping_file_path.exists(): # Downloads a tar gz from s3 bucket, then unpacks the results files - compressed_file_dir = utils.RAPIDS_DATASET_ROOT_DIR_PATH / "tests" + compressed_file_dir = curr_download_dir / "tests" compressed_file_path = compressed_file_dir / "resultsets.tar.gz" - if not results_dir_path.exists(): - results_dir_path.mkdir(parents=True, exist_ok=True) + if not curr_resultset_download_dir.exists(): + curr_resultset_download_dir.mkdir(parents=True, exist_ok=True) if not compressed_file_path.exists(): urllib.request.urlretrieve(resultset_download_url, compressed_file_path) tar = tarfile.open(str(compressed_file_path), "r:gz") - tar.extractall(str(results_dir_path)) + tar.extractall(str(curr_resultset_download_dir)) tar.close() # FIXME: This assumes separator is " ", but should this be configurable? @@ -102,35 +144,3 @@ def load_resultset(resultset_name, resultset_download_url): ) _resultsets[resultset_key] = uuid - - -def get_resultset(resultset_name, **kwargs): - """ - Returns the golden results for a specific test. - - Parameters - ---------- - resultset_name : String - Name of the test's module (currently just 'traversal' is supported) - - kwargs : - All distinct test details regarding the choice of algorithm, dataset, - and graph - """ - arg_dict = dict(kwargs) - arg_dict["resultset_name"] = resultset_name - # Example: - # {'a': 1, 'z': 9, 'c': 5, 'b': 2} becomes 'a-1-b-2-c-5-z-9' - resultset_key = "-".join( - [ - str(val) - for arg_dict_pair in sorted(arg_dict.items()) - for val in arg_dict_pair - ] - ) - uuid = _resultsets.get(resultset_key) - if uuid is None: - raise KeyError(f"results for {arg_dict} not found") - - results_filename = results_dir_path / (uuid + ".csv") - return cudf.read_csv(results_filename) diff --git a/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py b/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py index 4277f94a396..478b7e655d5 100644 --- a/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py +++ b/python/cugraph/cugraph/tests/centrality/test_edge_betweenness_centrality_mg.py @@ -16,7 +16,7 @@ import dask_cudf from pylibcugraph.testing.utils import gen_fixture_params_product -from cugraph.experimental.datasets import DATASETS_UNDIRECTED +from cugraph.datasets import karate, dolphins import cugraph import cugraph.dask as dcg @@ -41,7 +41,7 @@ def setup_function(): # email_Eu_core is too expensive to test -datasets = DATASETS_UNDIRECTED +datasets = [karate, dolphins] # ============================================================================= diff --git a/python/cugraph/cugraph/tests/community/test_leiden.py b/python/cugraph/cugraph/tests/community/test_leiden.py index a06b0dd22c5..71117c4210f 100644 --- a/python/cugraph/cugraph/tests/community/test_leiden.py +++ b/python/cugraph/cugraph/tests/community/test_leiden.py @@ -22,8 +22,6 @@ from cugraph.testing import utils, UNDIRECTED_DATASETS from cugraph.datasets import karate_asymmetric -from cudf.testing.testing import assert_series_equal - # ============================================================================= # Test data @@ -43,8 +41,8 @@ "resolution": 1.0, "input_type": "COO", "expected_output": { - "partition": [1, 0, 1, 2, 2, 2], - "modularity_score": 0.1757322, + "partition": [0, 0, 0, 1, 1, 1], + "modularity_score": 0.215969, }, }, "data_2": { @@ -85,10 +83,10 @@ "input_type": "CSR", "expected_output": { # fmt: off - "partition": [6, 6, 3, 3, 1, 5, 5, 3, 0, 3, 1, 6, 3, 3, 4, 4, 5, 6, 4, 6, 4, - 6, 4, 4, 2, 2, 4, 4, 2, 4, 0, 2, 4, 4], + "partition": [3, 3, 3, 3, 2, 2, 2, 3, 1, 3, 2, 3, 3, 3, 1, 1, 2, 3, 1, 3, + 1, 3, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1], # fmt: on - "modularity_score": 0.3468113, + "modularity_score": 0.41880345, }, }, } @@ -138,7 +136,7 @@ def input_and_expected_output(request): # Create graph from csr offsets = src_or_offset_array indices = dst_or_index_array - G.from_cudf_adjlist(offsets, indices, weight) + G.from_cudf_adjlist(offsets, indices, weight, renumber=False) parts, mod = cugraph.leiden(G, max_level, resolution) @@ -223,9 +221,7 @@ def test_leiden_directed_graph(): @pytest.mark.sg def test_leiden_golden_results(input_and_expected_output): - expected_partition = cudf.Series( - input_and_expected_output["expected_output"]["partition"] - ) + expected_partition = input_and_expected_output["expected_output"]["partition"] expected_mod = input_and_expected_output["expected_output"]["modularity_score"] result_partition = input_and_expected_output["result_output"]["partition"] @@ -233,6 +229,10 @@ def test_leiden_golden_results(input_and_expected_output): assert abs(expected_mod - result_mod) < 0.0001 - assert_series_equal( - expected_partition, result_partition, check_dtype=False, check_names=False - ) + expected_to_result_map = {} + for e, r in zip(expected_partition, list(result_partition.to_pandas())): + if e in expected_to_result_map.keys(): + assert r == expected_to_result_map[e] + + else: + expected_to_result_map[e] = r diff --git a/python/cugraph/cugraph/tests/nx/test_compat_pr.py b/python/cugraph/cugraph/tests/nx/test_compat_pr.py index 9be3912a33f..45cab7a5674 100644 --- a/python/cugraph/cugraph/tests/nx/test_compat_pr.py +++ b/python/cugraph/cugraph/tests/nx/test_compat_pr.py @@ -24,7 +24,7 @@ import numpy as np from cugraph.testing import utils -from cugraph.experimental.datasets import karate +from cugraph.datasets import karate from pylibcugraph.testing.utils import gen_fixture_params_product diff --git a/python/cugraph/cugraph/tests/utils/test_dataset.py b/python/cugraph/cugraph/tests/utils/test_dataset.py index c2a4f7c6072..60bc6dbb45a 100644 --- a/python/cugraph/cugraph/tests/utils/test_dataset.py +++ b/python/cugraph/cugraph/tests/utils/test_dataset.py @@ -13,11 +13,10 @@ import os import gc -import sys -import warnings from pathlib import Path from tempfile import TemporaryDirectory +import pandas import pytest import cudf @@ -27,6 +26,7 @@ ALL_DATASETS, WEIGHTED_DATASETS, SMALL_DATASETS, + BENCHMARKING_DATASETS, ) from cugraph import datasets @@ -74,27 +74,14 @@ def setup(tmpdir): gc.collect() -@pytest.fixture() -def setup_deprecation_warning_tests(): - """ - Fixture used to set warning filters to 'default' and reload - experimental.datasets module if it has been previously - imported. Tests that import this fixture are expected to - import cugraph.experimental.datasets - """ - warnings.filterwarnings("default") - - if "cugraph.experimental.datasets" in sys.modules: - del sys.modules["cugraph.experimental.datasets"] - - yield - - ############################################################################### # Helpers # check if there is a row where src == dst -def has_loop(df): +def has_selfloop(dataset): + if not dataset.metadata["is_directed"]: + return False + df = dataset.get_edgelist(download=True) df.rename(columns={df.columns[0]: "src", df.columns[1]: "dst"}, inplace=True) res = df.where(df["src"] == df["dst"]) @@ -109,7 +96,13 @@ def is_symmetric(dataset): else: df = dataset.get_edgelist(download=True) df_a = df.sort_values("src") - df_b = df_a[["dst", "src", "wgt"]] + + # create df with swapped src/dst columns + df_b = None + if "wgt" in df_a.columns: + df_b = df_a[["dst", "src", "wgt"]] + else: + df_b = df_a[["dst", "src"]] df_b.rename(columns={"dst": "src", "src": "dst"}, inplace=True) # created a df by appending the two res = cudf.concat([df_a, df_b]) @@ -157,6 +150,27 @@ def test_download(dataset): assert dataset.get_path().is_file() +@pytest.mark.parametrize("dataset", SMALL_DATASETS) +def test_reader(dataset): + # defaults to using cudf.read_csv + E = dataset.get_edgelist(download=True) + + assert E is not None + assert isinstance(E, cudf.core.dataframe.DataFrame) + dataset.unload() + + # using pandas + E_pd = dataset.get_edgelist(download=True, reader="pandas") + + assert E_pd is not None + assert isinstance(E_pd, pandas.core.frame.DataFrame) + dataset.unload() + + with pytest.raises(ValueError): + dataset.get_edgelist(reader="fail") + dataset.get_edgelist(reader=None) + + @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_get_edgelist(dataset): E = dataset.get_edgelist(download=True) @@ -172,7 +186,6 @@ def test_get_graph(dataset): @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_metadata(dataset): M = dataset.metadata - assert M is not None @@ -310,10 +323,8 @@ def test_is_directed(dataset): @pytest.mark.parametrize("dataset", ALL_DATASETS) -def test_has_loop(dataset): - df = dataset.get_edgelist(download=True) - - assert has_loop(df) == dataset.metadata["has_loop"] +def test_has_selfloop(dataset): + assert has_selfloop(dataset) == dataset.metadata["has_loop"] @pytest.mark.parametrize("dataset", ALL_DATASETS) @@ -328,6 +339,25 @@ def test_is_multigraph(dataset): assert G.is_multigraph() == dataset.metadata["is_multigraph"] +# The datasets used for benchmarks are in their own test, since downloading them +# repeatedly would increase testing overhead significantly +@pytest.mark.parametrize("dataset", BENCHMARKING_DATASETS) +def test_benchmarking_datasets(dataset): + dataset_is_directed = dataset.metadata["is_directed"] + G = dataset.get_graph( + download=True, create_using=Graph(directed=dataset_is_directed) + ) + + assert G.is_directed() == dataset.metadata["is_directed"] + assert G.number_of_nodes() == dataset.metadata["number_of_nodes"] + assert G.number_of_edges() == dataset.metadata["number_of_edges"] + assert has_selfloop(dataset) == dataset.metadata["has_loop"] + assert is_symmetric(dataset) == dataset.metadata["is_symmetric"] + assert G.is_multigraph() == dataset.metadata["is_multigraph"] + + dataset.unload() + + @pytest.mark.parametrize("dataset", ALL_DATASETS) def test_object_getters(dataset): assert dataset.is_directed() == dataset.metadata["is_directed"] @@ -336,32 +366,3 @@ def test_object_getters(dataset): assert dataset.number_of_nodes() == dataset.metadata["number_of_nodes"] assert dataset.number_of_vertices() == dataset.metadata["number_of_nodes"] assert dataset.number_of_edges() == dataset.metadata["number_of_edges"] - - -# -# Test experimental for DeprecationWarnings -# -def test_experimental_dataset_import(setup_deprecation_warning_tests): - with pytest.deprecated_call(): - from cugraph.experimental.datasets import karate - - # unload() is called to pass flake8 - karate.unload() - - -def test_experimental_method_warnings(setup_deprecation_warning_tests): - from cugraph.experimental.datasets import ( - load_all, - set_download_dir, - get_download_dir, - ) - - warnings.filterwarnings("default") - tmpd = TemporaryDirectory() - - with pytest.deprecated_call(): - set_download_dir(tmpd.name) - get_download_dir() - load_all() - - tmpd.cleanup() diff --git a/python/cugraph/cugraph/tests/utils/test_resultset.py b/python/cugraph/cugraph/tests/utils/test_resultset.py new file mode 100644 index 00000000000..5c2298bedb7 --- /dev/null +++ b/python/cugraph/cugraph/tests/utils/test_resultset.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from pathlib import Path +from tempfile import TemporaryDirectory + +import cudf +from cugraph.datasets.dataset import ( + set_download_dir, + get_download_dir, +) +from cugraph.testing.resultset import load_resultset, default_resultset_download_dir + +############################################################################### + + +def test_load_resultset(): + with TemporaryDirectory() as tmpd: + + set_download_dir(Path(tmpd)) + default_resultset_download_dir.path = Path(tmpd) / "tests" / "resultsets" + default_resultset_download_dir.path.mkdir(parents=True, exist_ok=True) + + datasets_download_dir = get_download_dir() + resultsets_download_dir = default_resultset_download_dir.path + assert "tests" in os.listdir(datasets_download_dir) + assert "resultsets.tar.gz" not in os.listdir(datasets_download_dir / "tests") + assert "traversal_mappings.csv" not in os.listdir(resultsets_download_dir) + + load_resultset( + "traversal", "https://data.rapids.ai/cugraph/results/resultsets.tar.gz" + ) + + assert "resultsets.tar.gz" in os.listdir(datasets_download_dir / "tests") + assert "traversal_mappings.csv" in os.listdir(resultsets_download_dir) + + +def test_verify_resultset_load(): + # This test is more detailed than test_load_resultset, where for each module, + # we check that every single resultset file is included along with the + # corresponding mapping file. + with TemporaryDirectory() as tmpd: + set_download_dir(Path(tmpd)) + default_resultset_download_dir.path = Path(tmpd) / "tests" / "resultsets" + default_resultset_download_dir.path.mkdir(parents=True, exist_ok=True) + + resultsets_download_dir = default_resultset_download_dir.path + + load_resultset( + "traversal", "https://data.rapids.ai/cugraph/results/resultsets.tar.gz" + ) + + resultsets = os.listdir(resultsets_download_dir) + downloaded_results = cudf.read_csv( + resultsets_download_dir / "traversal_mappings.csv", sep=" " + ) + downloaded_uuids = downloaded_results["#UUID"].values + for resultset_uuid in downloaded_uuids: + assert str(resultset_uuid) + ".csv" in resultsets