From 0b90f0cd6a6cef03229ece4797dc3a7b4b6ca51b Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Mon, 1 Apr 2024 15:49:54 -0700 Subject: [PATCH] fixed cuda mem checking for trainer --- chgnet/trainer/trainer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index e606fbdb..a9188a6d 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -19,7 +19,7 @@ ) from chgnet.model.model import CHGNet -from chgnet.utils import AverageMeter, cuda_devices_sorted_by_free_mem, mae, write_json +from chgnet.utils import AverageMeter, determine_device, mae, write_json if TYPE_CHECKING: from torch.utils.data import DataLoader @@ -48,6 +48,7 @@ def __init__( torch_seed: int | None = None, data_seed: int | None = None, use_device: str | None = None, + check_cuda_mem: bool = True, **kwargs, ) -> None: """Initialize all hyper-parameters for trainer. @@ -80,9 +81,12 @@ def __init__( Default = None data_seed (int): random seed for random Default = None - use_device (str, optional): device name to train the CHGNet. - Can be "cuda", "cpu" + use_device (str, optional): The device to be used for predictions, + either "cpu", "cuda", or "mps". If not specified, the default device is + automatically selected based on the available options. Default = None + check_cuda_mem (bool): Whether to use cuda with most available memory + Default = True **kwargs (dict): additional hyper-params for optimizer, scheduler, etc. """ # Store trainer args for reproducibility @@ -180,16 +184,9 @@ def __init__( self.starting_epoch = starting_epoch # Determine the device to use - if use_device is not None: - self.device = use_device - elif torch.cuda.is_available(): - self.device = "cuda" - else: - self.device = "cpu" - if self.device == "cuda": - # Determine cuda device with most available memory - device_with_most_available_memory = cuda_devices_sorted_by_free_mem()[-1] - self.device = f"cuda:{device_with_most_available_memory}" + self.device = determine_device( + use_device=use_device, check_cuda_mem=check_cuda_mem + ) self.print_freq = print_freq self.training_history: dict[