Skip to content

Commit

Permalink
adding function to avoid sorting cuda memory
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Mar 26, 2024
1 parent fe0ad9a commit dd0dd07
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 18 deletions.
16 changes: 6 additions & 10 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import contextlib
import inspect
import io
import os
import pickle
import sys
from typing import TYPE_CHECKING, Literal

import numpy as np
import torch
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.md.npt import NPT
Expand All @@ -28,7 +26,7 @@
from pymatgen.io.ase import AseAtomsAdaptor

from chgnet.model.model import CHGNet
from chgnet.utils import cuda_devices_sorted_by_free_mem
from chgnet.utils import determine_device

if TYPE_CHECKING:
from ase.io import Trajectory
Expand Down Expand Up @@ -58,6 +56,7 @@ def __init__(
self,
model: CHGNet | None = None,
use_device: str | None = None,
check_cuda_mem: bool = True,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
Expand All @@ -72,6 +71,8 @@ def __init__(
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
Expand All @@ -82,13 +83,8 @@ def __init__(
super().__init__(**kwargs)

# Determine the device to use
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if self.device == "cuda":
self.device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"
device = determine_device(use_device=use_device, check_cuda_mem=check_cuda_mem)
self.device = device

# Move the model to the specified device
self.model = (model or CHGNet.load(verbose=False)).to(self.device)
Expand Down
13 changes: 5 additions & 8 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
GraphAttentionReadOut,
GraphPooling,
)
from chgnet.utils import cuda_devices_sorted_by_free_mem
from chgnet.utils import determine_device

if TYPE_CHECKING:
from chgnet import PredTask
Expand Down Expand Up @@ -673,6 +673,7 @@ def load(
cls,
model_name="0.3.0",
use_device: str | None = None,
check_cuda_mem: bool = True,
verbose: bool = True,
) -> CHGNet:
"""Load pretrained CHGNet model.
Expand All @@ -684,6 +685,8 @@ def load(
either "cpu", "cuda", or "mps". If not specified, the default device is
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
verbose (bool): whether to print model device information
Default = True
Raises:
Expand All @@ -707,13 +710,7 @@ def load(
)

# Determine the device to use
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda":
device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"
device = determine_device(use_device=use_device, check_cuda_mem=check_cuda_mem)

# Move the model to the specified device
model = model.to(device)
Expand Down
1 change: 1 addition & 0 deletions chgnet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chgnet.utils.common_utils import (
AverageMeter,
cuda_devices_sorted_by_free_mem,
determine_device,
mae,
mkdir,
read_json,
Expand Down
24 changes: 24 additions & 0 deletions chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
from torch import Tensor


def determine_device(
use_device: str | None = None,
check_cuda_mem: bool = True,
):
"""Determine the device to use for torch model.
Args:
use_device (str): User specify device name
check_cuda_mem (bool): Whether to return cuda with most available memory
Returns:
device (str): device name to be passed to model.to(device)
"""
use_device = use_device or os.getenv("CHGNET_DEVICE")
if use_device in ("mps", None) and torch.backends.mps.is_available():
device = "mps"
else:
device = use_device or ("cuda" if torch.cuda.is_available() else "cpu")
if device == "cuda" and check_cuda_mem:
device = f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"

return device


def cuda_devices_sorted_by_free_mem() -> list[int]:
"""List available CUDA devices sorted by increasing available memory.
Expand Down

0 comments on commit dd0dd07

Please sign in to comment.