diff --git a/python/cugraph/community/__init__.py b/python/cugraph/community/__init__.py index e8bea0cbaa0..9cc92637e20 100644 --- a/python/cugraph/community/__init__.py +++ b/python/cugraph/community/__init__.py @@ -23,42 +23,7 @@ ) from cugraph.community.subgraph_extraction import subgraph from cugraph.community.triangle_count import triangles +from cugraph.community.ktruss_subgraph import ktruss_subgraph +from cugraph.community.ktruss_subgraph import k_truss from cugraph.community.egonet import ego_graph from cugraph.community.egonet import batched_ego_graphs - -# FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to -# crash in that environment. Allow ktruss to import on non-11.4 systems, but -# replace ktruss with a __UnsupportedModule instance, which lazily raises an -# exception when referenced. -from numba import cuda -try: - __cuda_version = cuda.runtime.get_version() -except cuda.cudadrv.runtime.CudaRuntimeAPIError: - __cuda_version = "n/a" - -__ktruss_unsupported_cuda_version = (11, 4) - -class __UnsupportedModule: - def __init__(self, exception): - self.__exception = exception - - def __getattr__(self, attr): - raise self.__exception - - def __call__(self, *args, **kwargs): - raise self.__exception - - -if __cuda_version != __ktruss_unsupported_cuda_version: - from cugraph.community.ktruss_subgraph import ktruss_subgraph - from cugraph.community.ktruss_subgraph import k_truss -else: - __kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version]) - k_truss = __UnsupportedModule( - NotImplementedError("k_truss is not currently supported in CUDA" - f" {__kuvs} environments.") - ) - ktruss_subgraph = __UnsupportedModule( - NotImplementedError("ktruss_subgraph is not currently supported in CUDA" - f" {__kuvs} environments.") - ) diff --git a/python/cugraph/community/ktruss_subgraph.py b/python/cugraph/community/ktruss_subgraph.py index c80f65c1608..afa7d66d31d 100644 --- a/python/cugraph/community/ktruss_subgraph.py +++ b/python/cugraph/community/ktruss_subgraph.py @@ -16,21 +16,24 @@ from cugraph.utilities import check_nx_graph from cugraph.utilities import cugraph_to_nx +from numba import cuda + + # FIXME: special case for ktruss on CUDA 11.4: an 11.4 bug causes ktruss to # crash in that environment. Allow ktruss to import on non-11.4 systems, but # raise an exception if ktruss is directly imported on 11.4. -from numba import cuda -try: - __cuda_version = cuda.runtime.get_version() -except cuda.cudadrv.runtime.CudaRuntimeAPIError: - __cuda_version = "n/a" +def _ensure_compatible_cuda_version(): + try: + cuda_version = cuda.runtime.get_version() + except cuda.cudadrv.runtime.CudaRuntimeAPIError: + cuda_version = "n/a" -__ktruss_unsupported_cuda_version = (11, 4) + unsupported_cuda_version = (11, 4) -if __cuda_version == __ktruss_unsupported_cuda_version: - __kuvs = ".".join([str(n) for n in __ktruss_unsupported_cuda_version]) - raise NotImplementedError("k_truss is not currently supported in CUDA" - f" {__kuvs} environments.") + if cuda_version == unsupported_cuda_version: + ver_string = ".".join([str(n) for n in unsupported_cuda_version]) + raise NotImplementedError("k_truss is not currently supported in CUDA" + f" {ver_string} environments.") def k_truss(G, k): @@ -62,6 +65,8 @@ def k_truss(G, k): The networkx graph will NOT have all attributes copied over """ + _ensure_compatible_cuda_version() + G, isNx = check_nx_graph(G) if isNx is True: @@ -137,6 +142,8 @@ def ktruss_subgraph(G, k, use_weights=True): >>> k_subgraph = cugraph.ktruss_subgraph(G, 3) """ + _ensure_compatible_cuda_version() + KTrussSubgraph = Graph() if type(G) is not Graph: raise Exception("input graph must be undirected")