From 5174ba4842dc88bef2953b859016cfded7d9499c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Thu, 19 Dec 2024 13:53:24 +0100 Subject: [PATCH] Add numba-cuda debug info --- darts/src/darts/utils/cuda.py | 45 ++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/darts/src/darts/utils/cuda.py b/darts/src/darts/utils/cuda.py index 920ff89..820f7a7 100644 --- a/darts/src/darts/utils/cuda.py +++ b/darts/src/darts/utils/cuda.py @@ -8,23 +8,16 @@ def debug_info(): """Print debug information about the CUDA devices and library installations.""" + import os + import torch from xrspatial.utils import has_cuda_and_cupy + logger.debug("=== CUDA DEBUG INFO ===") logger.debug(f"PyTorch version: {torch.__version__}") logger.debug(f"PyTorch CUDA available: {torch.cuda.is_available()}") - logger.debug(f"Cupy+Numba CUDA available: {has_cuda_and_cupy()}") - - try: - import cupy - - logger.debug(f"Cupy version: {cupy.__version__}") - logger.debug(f"Cupy CUDA version: {cupy.cuda.get_local_runtime_version()}") - logger.debug(f"Cupy CUDA driver version: {cupy.cuda.runtime.driverGetVersion()}") - logger.debug(f"Cupy CUDA runtime version: {cupy.cuda.runtime.runtimeGetVersion()}") - except ImportError: - logger.debug("Module 'cupy' not found, darts is probably installed without CUDA support.") + logger.debug(f"LD_LIBRARY_PATH: {os.environ.get("LD_LIBRARY_PATH")}") try: from pynvml import ( @@ -56,6 +49,36 @@ def debug_info(): except ImportError: logger.debug("Module 'pynvml' not found, darts is probably installed without CUDA support.") + try: + import cupy + + logger.debug(f"Cupy version: {cupy.__version__}") + # This is the version which is installed (dynamically linked via PATH or LD_LIBRARY_PATH) in the environment + env_runtime_version = cupy.cuda.get_local_runtime_version() + # This is the version which is used by cupy (statically linked) + cupy_runtime_version = cupy.cuda.runtime.runtimeGetVersion() + if env_runtime_version != cupy_runtime_version: + logger.warning( + "Cupy CUDA runtime versions don't match!\n" + f"Got {env_runtime_version} as local (dynamically linked) runtime version.\n" + f"Got {cupy_runtime_version} as by cupy statically linked runtime version.\n" + "Cupy will use the statically linked runtime version!" + ) + else: + logger.debug(f"Cupy CUDA runtime version: {cupy_runtime_version}") + logger.debug(f"Cupy CUDA driver version: {cupy.cuda.runtime.driverGetVersion()}") + except ImportError: + logger.debug("Module 'cupy' not found, darts is probably installed without CUDA support.") + + try: + import numba.cuda + + logger.debug(f"Numba CUDA runtime: {numba.cuda.runtime.get_version()}") + logger.debug(f"Numba CUDA is available: {numba.cuda.is_available()}") + logger.debug(f"Numba CUDA has supported devices: {numba.cuda.detect()}") + except ImportError: + logger.debug("Module 'numba.cuda' not found, darts is probably installed without CUDA support.") + def decide_device(device: Literal["cuda", "cpu", "auto"] | int | None) -> Literal["cuda", "cpu"] | int: """Decide the device based on the input.