Skip to content

Commit

Permalink
Add numba-cuda debug info
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 19, 2024
1 parent 81cf102 commit 5174ba4
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions darts/src/darts/utils/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,16 @@

def debug_info():
"""Print debug information about the CUDA devices and library installations."""
import os

import torch
from xrspatial.utils import has_cuda_and_cupy

logger.debug("=== CUDA DEBUG INFO ===")
logger.debug(f"PyTorch version: {torch.__version__}")
logger.debug(f"PyTorch CUDA available: {torch.cuda.is_available()}")

logger.debug(f"Cupy+Numba CUDA available: {has_cuda_and_cupy()}")

try:
import cupy

logger.debug(f"Cupy version: {cupy.__version__}")
logger.debug(f"Cupy CUDA version: {cupy.cuda.get_local_runtime_version()}")
logger.debug(f"Cupy CUDA driver version: {cupy.cuda.runtime.driverGetVersion()}")
logger.debug(f"Cupy CUDA runtime version: {cupy.cuda.runtime.runtimeGetVersion()}")
except ImportError:
logger.debug("Module 'cupy' not found, darts is probably installed without CUDA support.")
logger.debug(f"LD_LIBRARY_PATH: {os.environ.get("LD_LIBRARY_PATH")}")

try:
from pynvml import (
Expand Down Expand Up @@ -56,6 +49,36 @@ def debug_info():
except ImportError:
logger.debug("Module 'pynvml' not found, darts is probably installed without CUDA support.")

try:
import cupy

logger.debug(f"Cupy version: {cupy.__version__}")
# This is the version which is installed (dynamically linked via PATH or LD_LIBRARY_PATH) in the environment
env_runtime_version = cupy.cuda.get_local_runtime_version()
# This is the version which is used by cupy (statically linked)
cupy_runtime_version = cupy.cuda.runtime.runtimeGetVersion()
if env_runtime_version != cupy_runtime_version:
logger.warning(
"Cupy CUDA runtime versions don't match!\n"
f"Got {env_runtime_version} as local (dynamically linked) runtime version.\n"
f"Got {cupy_runtime_version} as by cupy statically linked runtime version.\n"
"Cupy will use the statically linked runtime version!"
)
else:
logger.debug(f"Cupy CUDA runtime version: {cupy_runtime_version}")
logger.debug(f"Cupy CUDA driver version: {cupy.cuda.runtime.driverGetVersion()}")
except ImportError:
logger.debug("Module 'cupy' not found, darts is probably installed without CUDA support.")

try:
import numba.cuda

logger.debug(f"Numba CUDA runtime: {numba.cuda.runtime.get_version()}")
logger.debug(f"Numba CUDA is available: {numba.cuda.is_available()}")
logger.debug(f"Numba CUDA has supported devices: {numba.cuda.detect()}")
except ImportError:
logger.debug("Module 'numba.cuda' not found, darts is probably installed without CUDA support.")


def decide_device(device: Literal["cuda", "cpu", "auto"] | int | None) -> Literal["cuda", "cpu"] | int:
"""Decide the device based on the input.
Expand Down

0 comments on commit 5174ba4

Please sign in to comment.