Skip to content

Commit

Permalink
back porting ktruss CUDA 11.4 guard to 21.08 (#1813)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlratzel authored Sep 16, 2021
1 parent b681f73 commit fa6f0f1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 47 deletions.
39 changes: 2 additions & 37 deletions python/cugraph/community/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,42 +23,7 @@
)
from cugraph.community.subgraph_extraction import subgraph
from cugraph.community.triangle_count import triangles
from cugraph.community.ktruss_subgraph import ktruss_subgraph
from cugraph.community.ktruss_subgraph import k_truss
from cugraph.community.egonet import ego_graph
from cugraph.community.egonet import batched_ego_graphs

# FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to
# crash in that environment. Allow ktruss to import on non-11.4 systems, but
# replace ktruss with a __UnsupportedModule instance, which lazily raises an
# exception when referenced.
from numba import cuda
try:
__cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
__cuda_version = "n/a"

__ktruss_unsupported_cuda_version = (11, 4)

class __UnsupportedModule:
def __init__(self, exception):
self.__exception = exception

def __getattr__(self, attr):
raise self.__exception

def __call__(self, *args, **kwargs):
raise self.__exception


if __cuda_version != __ktruss_unsupported_cuda_version:
from cugraph.community.ktruss_subgraph import ktruss_subgraph
from cugraph.community.ktruss_subgraph import k_truss
else:
__kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version])
k_truss = __UnsupportedModule(
NotImplementedError("k_truss is not currently supported in CUDA"
f" {__kuvs} environments.")
)
ktruss_subgraph = __UnsupportedModule(
NotImplementedError("ktruss_subgraph is not currently supported in CUDA"
f" {__kuvs} environments.")
)
27 changes: 17 additions & 10 deletions python/cugraph/community/ktruss_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@
from cugraph.utilities import check_nx_graph
from cugraph.utilities import cugraph_to_nx

from numba import cuda


# FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to
# crash in that environment. Allow ktruss to import on non-11.4 systems, but
# raise an exception if ktruss is directly imported on 11.4.
from numba import cuda
try:
__cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
__cuda_version = "n/a"
def _ensure_compatible_cuda_version():
try:
cuda_version = cuda.runtime.get_version()
except cuda.cudadrv.runtime.CudaRuntimeAPIError:
cuda_version = "n/a"

__ktruss_unsupported_cuda_version = (11, 4)
unsupported_cuda_version = (11, 4)

if __cuda_version == __ktruss_unsupported_cuda_version:
__kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version])
raise NotImplementedError("k_truss is not currently supported in CUDA"
f" {__kuvs} environments.")
if cuda_version == unsupported_cuda_version:
ver_string = ".".join([str(n) for n in unsupported_cuda_version])
raise NotImplementedError("k_truss is not currently supported in CUDA"
f" {ver_string} environments.")


def k_truss(G, k):
Expand Down Expand Up @@ -62,6 +65,8 @@ def k_truss(G, k):
The networkx graph will NOT have all attributes copied over
"""

_ensure_compatible_cuda_version()

G, isNx = check_nx_graph(G)

if isNx is True:
Expand Down Expand Up @@ -137,6 +142,8 @@ def ktruss_subgraph(G, k, use_weights=True):
>>> k_subgraph = cugraph.ktruss_subgraph(G, 3)
"""

_ensure_compatible_cuda_version()

KTrussSubgraph = Graph()
if type(G) is not Graph:
raise Exception("input graph must be undirected")
Expand Down

0 comments on commit fa6f0f1

Please sign in to comment.