diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 98273b86..801cdbe8 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -3,13 +3,11 @@ import contextlib import inspect import io -import os import pickle import sys from typing import TYPE_CHECKING, Literal import numpy as np -import torch from ase import Atoms, units from ase.calculators.calculator import Calculator, all_changes, all_properties from ase.md.npt import NPT @@ -28,7 +26,7 @@ from pymatgen.io.ase import AseAtomsAdaptor from chgnet.model.model import CHGNet -from chgnet.utils import cuda_devices_sorted_by_free_mem +from chgnet.utils import determine_device if TYPE_CHECKING: from ase.io import Trajectory @@ -58,6 +56,7 @@ def __init__( self, model: CHGNet | None = None, use_device: str | None = None, + check_cuda_mem: bool = True, stress_weight: float | None = 1 / 160.21766208, on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", **kwargs, @@ -72,6 +71,8 @@ def __init__( either "cpu", "cuda", or "mps". If not specified, the default device is automatically selected based on the available options. Default = None + check_cuda_mem (bool): Whether to use cuda with most available memory + Default = True stress_weight (float): the conversion factor to convert GPa to eV/A^3. Default = 1/160.21 on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures @@ -82,13 +83,8 @@ def __init__( super().__init__(**kwargs) # Determine the device to use - use_device = use_device or os.getenv("CHGNET_DEVICE") - if use_device in ("mps", None) and torch.backends.mps.is_available(): - self.device = "mps" - else: - self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu") - if self.device == "cuda": - self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" + device = determine_device(use_device=use_device, check_cuda_mem=check_cuda_mem) + self.device = device # Move the model to the specified device self.model = (model or CHGNet.load(verbose=False)).to(self.device) diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 0b76612b..c3fde070 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -22,7 +22,7 @@ GraphAttentionReadOut, GraphPooling, ) -from chgnet.utils import cuda_devices_sorted_by_free_mem +from chgnet.utils import determine_device if TYPE_CHECKING: from chgnet import PredTask @@ -673,6 +673,7 @@ def load( cls, model_name="0.3.0", use_device: str | None = None, + check_cuda_mem: bool = True, verbose: bool = True, ) -> CHGNet: """Load pretrained CHGNet model. @@ -684,6 +685,8 @@ def load( either "cpu", "cuda", or "mps". If not specified, the default device is automatically selected based on the available options. Default = None + check_cuda_mem (bool): Whether to use cuda with most available memory + Default = True verbose (bool): whether to print model device information Default = True Raises: @@ -707,13 +710,7 @@ def load( ) # Determine the device to use - use_device = use_device or os.getenv("CHGNET_DEVICE") - if use_device in ("mps", None) and torch.backends.mps.is_available(): - device = "mps" - else: - device = use_device or ("cuda" if torch.cuda.is_available() else "cpu") - if device == "cuda": - device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" + device = determine_device(use_device=use_device, check_cuda_mem=check_cuda_mem) # Move the model to the specified device model = model.to(device) diff --git a/chgnet/utils/__init__.py b/chgnet/utils/__init__.py index f1de9229..d208d29b 100644 --- a/chgnet/utils/__init__.py +++ b/chgnet/utils/__init__.py @@ -3,6 +3,7 @@ from chgnet.utils.common_utils import ( AverageMeter, cuda_devices_sorted_by_free_mem, + determine_device, mae, mkdir, read_json, diff --git a/chgnet/utils/common_utils.py b/chgnet/utils/common_utils.py index dd05ac29..b30655ee 100644 --- a/chgnet/utils/common_utils.py +++ b/chgnet/utils/common_utils.py @@ -8,6 +8,30 @@ from torch import Tensor +def determine_device( + use_device: str | None = None, + check_cuda_mem: bool = True, +): + """Determine the device to use for torch model. + + Args: + use_device (str): User specify device name + check_cuda_mem (bool): Whether to return cuda with most available memory + + Returns: + device (str): device name to be passed to model.to(device) + """ + use_device = use_device or os.getenv("CHGNET_DEVICE") + if use_device in ("mps", None) and torch.backends.mps.is_available(): + device = "mps" + else: + device = use_device or ("cuda" if torch.cuda.is_available() else "cpu") + if device == "cuda" and check_cuda_mem: + device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" + + return device + + def cuda_devices_sorted_by_free_mem() -> list[int]: """List available CUDA devices sorted by increasing available memory.