Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Dec 2, 2024
1 parent d7a05a5 commit 7b7c648
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
66 changes: 44 additions & 22 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,14 +631,20 @@ def __init__(
with_replacement: bool = False,
biased: bool = False,
heterogeneous: bool = False,
vertex_type_offsets: Optional[TensorType] = None,
num_edge_types: int = 1,
):

self.__fanout = fanout
self.__prior_sources_behavior = prior_sources_behavior
self.__deduplicate_sources = deduplicate_sources
self.__compress_per_hop = compress_per_hop
self.__compression = compression
self.__with_replacement = with_replacement
self.__func_kwargs = {
"h_fan_out": np.asarray(fanout, dtype="int32"),
"prior_sources_behavior": prior_sources_behavior,
"retain_seeds": retain_original_seeds,
"deduplicate_sources": deduplicate_sources,
"compress_per_hop": compress_per_hop,
"compression": compression,
"with_replacement": with_replacement,
}

# It is currently required that graphs are weighted for biased
# sampling. So setting the function here is safe. In the future,
Expand All @@ -647,11 +653,17 @@ def __init__(
# TODO allow func to be a call to a future remote sampling API
# if the provided graph is in another process (rapidsai/cugraph#4623).
if heterogeneous:
if vertex_type_offsets is None:
raise ValueError("Heterogeneous sampling requires vertex type offsets.")
self.__func = (
pylibcugraph.heterogeneous_biased_neighbor_sample
if biased
else pylibcugraph.heterogeneous_uniform_neighbor_sample
)
self.__func_kwargs["num_edge_types"] = num_edge_types
self.__func_kwargs["vertex_type_offsets"] = cupy.asarray(
vertex_type_offsets
)
else:
self.__func = (
pylibcugraph.homogeneous_biased_neighbor_sample
Expand All @@ -661,27 +673,45 @@ def __init__(

if num_edge_types > 1 and not heterogeneous:
raise ValueError(
"Heterogeneous sampling must be selected if there is > 1 edge type"
"Heterogeneous sampling must be selected if there is > 1 edge type."
)

self.__num_edge_types = num_edge_types

super().__init__(
graph,
writer,
local_seeds_per_call=self.__calc_local_seeds_per_call(local_seeds_per_call),
local_seeds_per_call=self.__calc_local_seeds_per_call(
local_seeds_per_call,
heterogeneous=heterogeneous,
num_edge_types=num_edge_types,
),
retain_original_seeds=retain_original_seeds,
)

def __calc_local_seeds_per_call(self, local_seeds_per_call: Optional[int] = None):
def __calc_local_seeds_per_call(
self,
local_seeds_per_call: Optional[int] = None,
heterogeneous: bool = False,
num_edge_types: int = 1,
):
torch = import_optional("torch")

fanout = self.__fanout

if local_seeds_per_call is None:
if len([x for x in self.__fanout if x <= 0]) > 0:
if len([x for x in fanout if x <= 0]) > 0:
return NeighborSampler.UNKNOWN_VERTICES_DEFAULT

if heterogeneous:
if len(fanout) % num_edge_types != 0:
raise ValueError(f"Illegal fanout for {num_edge_types} edge types.")
num_hops = len(fanout) // num_edge_types
fanout = [
sum([fanout[t * num_hops + h] for t in range(num_edge_types)])
for h in range(num_hops)
]

total_memory = torch.cuda.get_device_properties(0).total_memory
fanout_prod = reduce(lambda x, y: x * y, self.__fanout)
fanout_prod = reduce(lambda x, y: x * y, fanout)
return int(
NeighborSampler.BASE_VERTICES_PER_BYTE * total_memory / fanout_prod
)
Expand All @@ -702,25 +732,17 @@ def sample_batches(
"input_graph": self._graph,
"start_vertex_list": cupy.asarray(seeds),
"starting_vertex_label_offsets": cupy.asarray(batch_id_offsets),
"h_fan_out": np.asarray(self.__fanout, dtype="int32"),
"with_replacement": self.__with_replacement,
"renumber": True,
"return_hops": True,
"do_expensive_check": False,
"prior_sources_behavior": self.__prior_sources_behavior,
"deduplicate_sources": self.__deduplicate_sources,
"retain_seeds": self._retain_original_seeds,
"compression": self.__compression,
"compress_per_hop": self.__compress_per_hop,
"random_state": random_state + rank,
}

if self.__num_edge_types > 1:
kwargs["num_edge_types"] = self.__num_edge_types
kwargs.update(self.__func_kwargs)

print(kwargs)

sampling_results_dict = self.__func(**kwargs)
print(sampling_results_dict)

sampling_results_dict["fanout"] = cupy.array(self.__fanout, dtype="int32")
return sampling_results_dict
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ cdef extern from "cugraph_c/sampling_algorithms.h":
cugraph_graph_t* graph,
const cugraph_type_erased_device_array_view_t* start_vertices,
const cugraph_type_erased_device_array_view_t* starting_vertex_label_offsets,
const cugraph_type_erased_device_array_view_t* vertex_type_offsets,
const cugraph_type_erased_host_array_view_t* fan_out,
int num_edge_types,
const cugraph_sampling_options_t* options,
Expand All @@ -88,6 +89,7 @@ cdef extern from "cugraph_c/sampling_algorithms.h":
const cugraph_edge_property_view_t* edge_biases,
const cugraph_type_erased_device_array_view_t* start_vertices,
const cugraph_type_erased_device_array_view_t* starting_vertex_label_offsets,
const cugraph_type_erased_device_array_view_t* vertex_type_offsets,
const cugraph_type_erased_host_array_view_t* fan_out,
int num_edge_types,
const cugraph_sampling_options_t* options,
Expand Down

0 comments on commit 7b7c648

Please sign in to comment.