Skip to content

Commit

Permalink
support heterogeneous graph
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 7, 2024
1 parent 29be85c commit 270541a
Showing 1 changed file with 50 additions and 49 deletions.
99 changes: 50 additions & 49 deletions python/cugraph-dgl/cugraph_dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _vertex_offsets(self) -> Dict[str, int]:

return dict(self.__vertex_offsets)

def __get_edgelist(self) -> Dict[str, "torch.Tensor"]:
def __get_edgelist(self, prob_attr=None) -> Dict[str, "torch.Tensor"]:
"""
This function always returns src/dst labels with respect
to the out direction.
Expand Down Expand Up @@ -431,63 +431,69 @@ def __get_edgelist(self) -> Dict[str, "torch.Tensor"]:
)
)

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

num_edges_t = torch.tensor(
[self.__edge_indices[et].shape[1] for et in sorted_keys], device="cuda"
)
num_edges_all_t = torch.empty(
world_size, num_edges_t.numel(), dtype=torch.int64, device="cuda"
)
torch.distributed.all_gather_into_tensor(num_edges_all_t, num_edges_t)

if rank > 0:
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)
edge_id_array = torch.concat(
start_offsets = num_edges_all_t[:rank].T.sum(axis=1)

else:
rank = 0
start_offsets = torch.zeros(
(len(sorted_keys),), dtype=torch.int64, device="cuda"
)
num_edges_all_t = num_edges_t.reshape((1, num_edges_t.numel()))

# Use pinned memory here for fast access to CPU/WG storage
edge_id_array_per_type = [
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cpu",
).pin_memory()
for i in range(len(sorted_keys))
]

# Retrieve the weights from the appropriate feature(s)
# DGL implicitly requires all edge types use the same
# feature name.
if prob_attr is None:
weights = None
else:
if len(sorted_keys) > 1:
weights = torch.concat(
[
torch.arange(
start_offsets[i],
start_offsets[i] + num_edges_all_t[rank][i],
dtype=torch.int64,
device="cuda",
)
for i in range(len(sorted_keys))
self.edata[prob_attr][sorted_keys[i]][ix]
for i, ix in enumerate(edge_id_array_per_type)
]
)
else:
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
)
weights = self.edata[prob_attr][edge_id_array_per_type[0]]

else:
# single GPU
edge_id_array = torch.concat(
[
torch.arange(
self.__edge_indices[et].shape[1],
dtype=torch.int64,
device="cuda",
)
for et in sorted_keys
]
)
edge_id_array = torch.concat(edge_id_array_per_type).cuda()

return {
edgelist_dict = {
"src": edge_index[0],
"dst": edge_index[1],
"etp": edge_type_array,
"eid": edge_id_array,
}

if weights is not None:
edgelist_dict["wgt"] = weights

return edgelist_dict

@property
def is_homogeneous(self):
return len(self.__num_edges_dict) <= 1 and len(self.__num_nodes_dict) <= 1
Expand Down Expand Up @@ -533,16 +539,7 @@ def _graph(

if self.__graph is None:
src_col, dst_col = ("src", "dst") if direction == "out" else ("dst", "src")
edgelist_dict = self.__get_edgelist()

# FIXME this is invalid for a heterogeneous graph
# FIXME this should be part of the edgelist
weights = (
None
if prob_attr is None
else cupy.asarray(self.edata[prob_attr][edgelist_dict["eid"].cpu()])
)
weights = cupy.array([1, 1, 2, 0, 0, 0, 2, 1], dtype="float32")
edgelist_dict = self.__get_edgelist(prob_attr=prob_attr)

if self.is_multi_gpu:
rank = torch.distributed.get_rank()
Expand All @@ -559,7 +556,9 @@ def _graph(
vertices_array=[vertices_array],
edge_id_array=[cupy.asarray(edgelist_dict["eid"])],
edge_type_array=[cupy.asarray(edgelist_dict["etp"])],
weight_array=[weights],
weight_array=[cupy.asarray(edgelist_dict["wgt"])]
if "wgt" in edgelist_dict
else None,
)
else:
graph = pylibcugraph.SGGraph(
Expand All @@ -570,7 +569,9 @@ def _graph(
vertices_array=cupy.arange(self.num_nodes(), dtype="int64"),
edge_id_array=cupy.asarray(edgelist_dict["eid"]),
edge_type_array=cupy.asarray(edgelist_dict["etp"]),
weight_array=weights,
weight_array=cupy.asarray(edgelist_dict["wgt"])
if "wgt" in edgelist_dict
else None,
)

self.__graph = {"graph": graph, "direction": direction, "prob_attr": prob_attr}
Expand Down

0 comments on commit 270541a

Please sign in to comment.