Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 9, 2024
1 parent d0c6920 commit 6fcb1f0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def run_test_dataloader_basic_homogeneous(rank, world_size, uid):

@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available")
@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available")
@pytest.mark.skip(reason='blar')
def test_dataloader_basic_homogeneous():
uid = cugraph_comms_create_unique_id()
# Limit the number of GPUs this rest is run with
Expand All @@ -83,9 +82,18 @@ def test_dataloader_basic_homogeneous():
)


def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None,):
def sample_dgl_graphs(
g,
train_nid,
fanouts,
batch_size=1,
prob_attr=None,
):
# Single fanout to match cugraph
sampler = dgl.dataloading.NeighborSampler(fanouts, prob=prob_attr,)
sampler = dgl.dataloading.NeighborSampler(
fanouts,
prob=prob_attr,
)
dataloader = dgl.dataloading.DataLoader(
g,
train_nid,
Expand All @@ -106,8 +114,17 @@ def sample_dgl_graphs(g, train_nid, fanouts, batch_size=1, prob_attr=None,):
return dgl_output


def sample_cugraph_dgl_graphs(cugraph_g, train_nid, fanouts, batch_size=1,prob_attr=None,):
sampler = cugraph_dgl.dataloading.NeighborSampler(fanouts, prob=prob_attr,)
def sample_cugraph_dgl_graphs(
cugraph_g,
train_nid,
fanouts,
batch_size=1,
prob_attr=None,
):
sampler = cugraph_dgl.dataloading.NeighborSampler(
fanouts,
prob=prob_attr,
)

dataloader = cugraph_dgl.dataloading.FutureDataLoader(
cugraph_g,
Expand Down Expand Up @@ -170,7 +187,6 @@ def run_test_same_homogeneousgraph_results(rank, world_size, uid, ix, batch_size
@pytest.mark.skipif(isinstance(dgl, MissingModule), reason="dgl not available")
@pytest.mark.parametrize("ix", [[1], [1, 0]])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.skip(reason='blar')
def test_same_homogeneousgraph_results_mg(ix, batch_size):
uid = cugraph_comms_create_unique_id()
# Limit the number of GPUs this rest is run with
Expand All @@ -188,27 +204,27 @@ def run_test_dataloader_biased_homogeneous(rank, world_size, uid):

src = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + (rank * 9)
dst = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) + (rank * 9)
wgt = torch.tensor([.1, .1, .2, 0, 0, 0, .2, .1]*world_size, dtype=torch.float32)
wgt = torch.tensor(
[0.1, 0.1, 0.2, 0, 0, 0, 0.2, 0.1] * world_size, dtype=torch.float32
)

train_nid = torch.tensor([0,1]) + (rank*9)
train_nid = torch.tensor([0, 1]) + (rank * 9)
# Create a heterograph with 3 node types and 3 edge types.
dgl_g = dgl.graph((src, dst))
dgl_g.edata['wgt'] = wgt[:8]

print(src, dst, flush=True,)
dgl_g.edata["wgt"] = wgt[:8]

cugraph_g = cugraph_dgl.Graph(is_multi_gpu=True)
cugraph_g.add_nodes(9*world_size)
cugraph_g.add_edges(u=src, v=dst, data={'wgt': wgt})
cugraph_g.add_nodes(9 * world_size)
cugraph_g.add_edges(u=src, v=dst, data={"wgt": wgt})

dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr='wgt')
cugraph_output = sample_cugraph_dgl_graphs(cugraph_g, train_nid, [4], batch_size=2, prob_attr='wgt')
dgl_output = sample_dgl_graphs(dgl_g, train_nid, [4], batch_size=2, prob_attr="wgt")
cugraph_output = sample_cugraph_dgl_graphs(
cugraph_g, train_nid, [4], batch_size=2, prob_attr="wgt"
)

cugraph_output_nodes = cugraph_output[0]["output_nodes"].cpu().numpy()
dgl_output_nodes = dgl_output[0]["output_nodes"].cpu().numpy()

print(cugraph_output[0],flush=True,)

np.testing.assert_array_equal(
np.sort(cugraph_output_nodes), np.sort(dgl_output_nodes)
)
Expand All @@ -220,7 +236,7 @@ def run_test_dataloader_biased_homogeneous(rank, world_size, uid):
dgl_output[0]["blocks"][0].num_edges()
== cugraph_output[0]["blocks"][0].num_edges()
)

assert 5 == cugraph_output[0]["blocks"][0].num_edges()


Expand Down
9 changes: 2 additions & 7 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,11 +638,6 @@ def sample_from_nodes(
: len(current_seeds)
]

print(
current_seeds,
current_batches,
flush=True,
)
minibatch_dict = self.sample_batches(
seeds=current_seeds,
batch_ids=current_batches,
Expand Down Expand Up @@ -781,8 +776,8 @@ def sample_batches(
label_to_output_comm_rank=cupy.asarray(label_to_output_comm_rank),
h_fan_out=np.array(self.__fanout, dtype="int32"),
with_replacement=self.__with_replacement,
do_expensive_check=True,
with_edge_properties=False,
do_expensive_check=False,
with_edge_properties=True,
random_state=random_state + rank,
prior_sources_behavior=self.__prior_sources_behavior,
deduplicate_sources=self.__deduplicate_sources,
Expand Down

0 comments on commit 6fcb1f0

Please sign in to comment.