Skip to content

Commit

Permalink
Add CI to cugraph_dgl (#3312)
Browse files Browse the repository at this point in the history
This PR replaces #3293 .

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Rick Ratzel (https://github.com/rlratzel)

URL: #3312
  • Loading branch information
VibhuJawa authored Mar 22, 2023
1 parent 393db1d commit d98c366
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 2 deletions.
54 changes: 54 additions & 0 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,60 @@ pytest \
cugraph/pytest-based/bench_algos.py
popd

if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then
if [[ "${RUNNER_ARCH}" != "ARM64" ]]; then
# we are only testing in a single cuda version
# because of pytorch and rapids compatibilty problems
rapids-mamba-retry env create --force -f env.yaml -n test_cugraph_dgl

# activate test_cugraph_dgl environment for dgl
set +u
conda activate test_cugraph_dgl
set -u
rapids-mamba-retry install \
--channel "${CPP_CHANNEL}" \
--channel "${PYTHON_CHANNEL}" \
--channel pytorch \
--channel pytorch-nightly \
--channel dglteam/label/cu117 \
--channel nvidia \
libcugraph \
pylibcugraph \
cugraph \
cugraph-dgl \
'dgl>=1.0' \
'pytorch>=2.0' \
'pytorch-cuda>=11.8'

rapids-print-env

rapids-logger "pytest cugraph_dgl (single GPU)"
pushd python/cugraph-dgl/tests
pytest \
--cache-clear \
--ignore=mg \
--ignore=nn \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cugraph-dgl.xml" \
--cov-config=../../.coveragerc \
--cov=cugraph_dgl \
--cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cugraph-dgl-coverage.xml" \
--cov-report=term \
.
popd

# Reactivate the test environment back
set +u
conda deactivate
conda activate test
set -u
else
rapids-logger "skipping cugraph_dgl pytest on ARM64"
fi
else
rapids-logger "skipping cugraph_dgl pytest on CUDA!=11.8"
fi


rapids-logger "pytest cugraph_pyg (single GPU)"
pushd python/cugraph-pyg/cugraph_pyg
# rmat is not tested because of multi-GPU testing
Expand Down
1 change: 0 additions & 1 deletion conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies:
- notebook>=0.5.0
- numpydoc
- nvcc_linux-64=11.8
- ogb
- openmpi
- pip
- pre-commit
Expand Down
1 change: 0 additions & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ dependencies:
common:
- output_types: [conda, requirements]
packages:
- ogb
- py
- pytest
- pytest-cov
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@


def cast_to_tensor(ser: cudf.Series):
if len(ser) == 0:
# Empty series can not be converted to pytorch cuda tensor
t = torch.from_numpy(ser.values.get())
return t.to("cuda")

return torch.as_tensor(ser.values, device="cuda")


Expand Down
25 changes: 25 additions & 0 deletions python/cugraph-dgl/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cudf
import cupy as cp
import torch

from cugraph_dgl.dataloading.utils.sampling_helpers import cast_to_tensor


def test_casting_empty_array():
ar = cp.zeros(shape=0, dtype=cp.int32)
ser = cudf.Series(ar)
output_tensor = cast_to_tensor(ser)
assert output_tensor.dtype == torch.int32

0 comments on commit d98c366

Please sign in to comment.