Skip to content

Commit

Permalink
fixed cuda mem checking for trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Apr 1, 2024
1 parent 4b4a489 commit 0b90f0c
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

from chgnet.model.model import CHGNet
from chgnet.utils import AverageMeter, cuda_devices_sorted_by_free_mem, mae, write_json
from chgnet.utils import AverageMeter, determine_device, mae, write_json

if TYPE_CHECKING:
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -48,6 +48,7 @@ def __init__(
torch_seed: int | None = None,
data_seed: int | None = None,
use_device: str | None = None,
check_cuda_mem: bool = True,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -80,9 +81,12 @@ def __init__(
Default = None
data_seed (int): random seed for random
Default = None
use_device (str, optional): device name to train the CHGNet.
Can be "cuda", "cpu"
use_device (str, optional): The device to be used for predictions,
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
Expand Down Expand Up @@ -180,16 +184,9 @@ def __init__(
self.starting_epoch = starting_epoch

# Determine the device to use
if use_device is not None:
self.device = use_device
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
if self.device == "cuda":
# Determine cuda device with most available memory
device_with_most_available_memory = cuda_devices_sorted_by_free_mem()[-1]
self.device = f"cuda:{device_with_most_available_memory}"
self.device = determine_device(
use_device=use_device, check_cuda_mem=check_cuda_mem
)

self.print_freq = print_freq
self.training_history: dict[
Expand Down

0 comments on commit 0b90f0c

Please sign in to comment.