diff --git a/darts/src/darts/utils/cuda.py b/darts/src/darts/utils/cuda.py index 731e32e..a8065f4 100644 --- a/darts/src/darts/utils/cuda.py +++ b/darts/src/darts/utils/cuda.py @@ -73,12 +73,21 @@ def debug_info(): try: import numba.cuda - logger.debug(f"Numba CUDA runtime: {numba.cuda.runtime.get_version()}") - logger.debug(f"Numba CUDA is available: {numba.cuda.is_available()}") - logger.debug(f"Numba CUDA has supported devices: {numba.cuda.detect()}") + cuda_available = numba.cuda.is_available() + logger.debug(f"Numba CUDA is available: {cuda_available}") + if cuda_available: + logger.debug(f"Numba CUDA runtime: {numba.cuda.runtime.get_version()}") + # logger.debug(f"Numba CUDA has supported devices: {numba.cuda.detect()}") except ImportError: logger.debug("Module 'numba.cuda' not found, darts is probably installed without CUDA support.") + try: + import cucim + + logger.debug(f"Cucim version: {cucim.__version__}") + except ImportError: + logger.debug("Module 'cucim' not found, darts is probably installed without CUDA support.") + def decide_device(device: Literal["cuda", "cpu", "auto"] | int | None) -> Literal["cuda", "cpu"] | int: """Decide the device based on the input. diff --git a/pyproject.toml b/pyproject.toml index 0963f15..ba4681c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ # Utils "lovely-tensors>=0.1.17", "lovely-numpy>=0.2.13", + "setuptools>=75.5.0", # IO "h5netcdf>=1.3.0", "rasterio>=1.4.0", @@ -28,11 +29,16 @@ dependencies = [ "zarr[jupyter]>=2.18.3", # Training and Inference "segmentation-models-pytorch>=0.3.4", + "lightning>=2.4.0", + "albumentations>=1.4.21", + "wandb>=0.18.7", + "torchmetrics>=1.6.0", # Processing "scikit-image>=0.20", "scipy>=1.14.1", "xarray-spatial>=0.4.0", "dask>=2024.11.0", + "distributed>=2024.12.0", "geocube>=0.7.0", # Visualization "cartopy>=0.24.1", @@ -42,13 +48,7 @@ dependencies = [ "folium>=0.18.0", "bokeh>=3.5.2", "jupyter-bokeh>=4.0.5", - "setuptools>=75.5.0", - "lightning>=2.4.0", - "albumentations>=1.4.21", - "wandb>=0.18.7", - "torchmetrics>=1.6.0", "seaborn>=0.13.2", - "distributed>=2024.12.0", ] readme = "README.md" requires-python = ">= 3.11"