diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 319435575cc..8fed467bf6d 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -35,11 +35,9 @@ from cugraph.structure.number_map import NumberMap from cugraph.structure.symmetrize import symmetrize from cugraph.dask.common.part_utils import ( - get_persisted_df_worker_map, persist_dask_df_equal_parts_per_worker, ) from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster -from cugraph.dask import get_n_workers import cugraph.dask.comms.comms as Comms @@ -825,12 +823,13 @@ def get_two_hop_neighbors(self, start_vertices=None): _client = default_client() def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): - return pylibcugraph_get_two_hop_neighbors( + results_ = pylibcugraph_get_two_hop_neighbors( resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), graph=mg_graph_x, start_vertices=start_vertices, do_expensive_check=False, ) + return results_ if isinstance(start_vertices, int): start_vertices = [start_vertices] @@ -845,31 +844,31 @@ def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): else: start_vertices_type = self.input_df.dtypes[0] - if not isinstance(start_vertices, (dask_cudf.Series)): - start_vertices = dask_cudf.from_cudf( + start_vertices = start_vertices.astype(start_vertices_type) + + def create_iterable_args( + session_id, input_graph, start_vertices=None, npartitions=None + ): + session_id_it = [session_id] * npartitions + graph_it = input_graph.values() + start_vertices = cp.array_split(start_vertices.values, npartitions) + return [ + session_id_it, + graph_it, start_vertices, - npartitions=min(self._npartitions, len(start_vertices)), - ) - start_vertices = start_vertices.astype(start_vertices_type) + ] - n_workers = get_n_workers() - start_vertices = start_vertices.repartition(npartitions=n_workers) - start_vertices = persist_dask_df_equal_parts_per_worker( - start_vertices, _client + result = _client.map( + _call_plc_two_hop_neighbors, + *create_iterable_args( + Comms.get_session_id(), + self._plc_graph, + start_vertices, + self._npartitions, + ), + pure=False, ) - start_vertices = get_persisted_df_worker_map(start_vertices, _client) - result = [ - _client.submit( - _call_plc_two_hop_neighbors, - Comms.get_session_id(), - self._plc_graph[w], - start_vertices[w][0], - workers=[w], - allow_other_workers=False, - ) - for w in start_vertices.keys() - ] else: result = [ _client.submit( @@ -896,7 +895,8 @@ def convert_to_cudf(cp_arrays): return df cudf_result = [ - _client.submit(convert_to_cudf, cp_arrays) for cp_arrays in result + _client.submit(convert_to_cudf, cp_arrays, pure=False) + for cp_arrays in result ] wait(cudf_result)