Skip to content

Commit

Permalink
Update the analyzer to use the Torch modules.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 11, 2024
1 parent 5841ccb commit db76f82
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 155 deletions.
28 changes: 21 additions & 7 deletions bin/emle-analyze
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,20 @@ parser.add_argument(
parser.add_argument(
"--backend",
type=str,
choices=["deepmd", "ani2x"],
help="Gas phase ML backend ('deepmd' or 'ani2x')",
choices=["torchani", "mace", "deepmd"],
help="Gas phase ML backend",
)
parser.add_argument(
"--deepmd-model",
type=str,
metavar="name.pb",
help="Deepmd model file (for backend='deepmd')",
)
parser.add_argument(
"--mace-model",
type=str,
help="MACE model file (for backend='mace')",
)
parser.add_argument(
"--qm-xyz", type=str, metavar="name.xyz", required=True, help="QM xyz file"
)
Expand All @@ -64,19 +69,28 @@ args = parser.parse_args()

import scipy.io

from emle.models._emle import EMLE
from emle.models import EMLE
from emle._analyzer import EMLEAnalyzer
from emle._orca_parser import ORCAParser
from emle._analyzer import EMLEAnalyzer, ANI2xBackend, DeepMDBackend


if args.backend == "deepmd" and not args.deepmd_model:
parser.error("--deepmd-model is required when backend='deepmd'")
if args.backend == "mace" and not args.mace_model:
parser.error("--mace-model is required when backend='mace'")

backend = None
if args.backend == "ani2x":
backend = ANI2xBackend()
from emle._models import ANI2xEMLE

backend = ANI2xEMLE()
elif args.backend == "mace":
from emle._models import MACEEMLE

backend = MACEEMLE(mace_model=args.mace_model)
elif args.backend == "deepmd":
backend = DeepMDBackend(args.deepmd_model)
from emle._backends import DeepMD

backend = DeepMD(args.deepmd_model)

emle_base = EMLE(model=args.emle_model)._emle_base

Expand Down
159 changes: 11 additions & 148 deletions emle/_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,158 +24,16 @@
Analyser for EMLE simulation output.
"""

__all__ = ["ANI2xBackend", "DeepMDBackend", "EMLEAnalyzer"]
__all__ = ["EMLEAnalyzer"]


from abc import ABC as _ABC
from abc import abstractmethod as _abstractmethod

import ase as _ase
import ase.io as _ase_io
import os as _os
import numpy as _np
import torch as _torch

from ._units import _HARTREE_TO_KCAL_MOL, _ANGSTROM_TO_BOHR
from ._utils import pad_to_max as _pad_to_max

_EV_TO_KCALMOL = _ase.units.mol / _ase.units.kcal
_HARTREE_TO_KCALMOL = _ase.units.Hartree * _EV_TO_KCALMOL
_ANGSTROM_TO_BOHR = 1.0 / _ase.units.Bohr


class BaseBackend(_ABC):

def __init__(self, torch_device=None):

if torch_device is not None:
if not isinstance(torch_device, _torch.device):
raise ValueError("Invalid device type. Must be a torch.device.")
self._device = torch_device

def __call__(self, atomic_numbers, xyz, forces=False):
"""
atomic_numbers: np.ndarray (N_BATCH, N_QM_ATOMS,)
The atomic numbers of the atoms.
xyz: np.ndarray (N_BATCH, N_QM_ATOMS,)
The positions of the atoms.
forces: bool
Whether the forces should be calculated
Returns energy in kcal/mol (and, optionally, forces in kcal/mol/A)
as np.ndarrays
"""

if not isinstance(atomic_numbers, _np.ndarray):
raise ValueError("Invalid atomic_numbers type. Must be a numpy array.")
if atomic_numbers.ndim != 2:
raise ValueError(
"Invalid atomic_numbers shape. Must a two-dimensional array."
)

if not isinstance(xyz, _np.ndarray):
raise ValueError("Invalid xyz type. Must be a numpy array.")
if xyz.ndim != 3:
raise ValueError("Invalid xyz shape. Must a three-dimensional array.")

if self._device:
atomic_numbers = _torch.tensor(atomic_numbers, device=self._device)
xyz = _torch.tensor(xyz, device=self._device)

result = self.eval(atomic_numbers, xyz, forces)

if not self._device:
return result

if gradient:
e, f = result
if self._device:
e = e.detach().cpu().numpy()
f = f.detach().cpu().numpy()
return e, f

e = result
if self._device:
e = e.detach().cpu().numpy()
return e

@_abstractmethod
def eval(self, atomic_numbers, xyz, forces=False):
"""
atomic_numbers: (N_BATCH, N_QM_ATOMS,)
The atomic numbers of the atoms.
xyz: (N_BATCH, N_QM_ATOMS,)
The positions of the atoms.
forces: bool
Whether the gradient should be calculated
Returns energy in kcal/mol (and, optionally, forces in kcal/mol/A)
as either np.ndarrays or torch.Tensor
"""
pass


class ANI2xBackend(BaseBackend):

def __init__(self, device=None, ani2x_model_index=None):
import torchani as _torchani

if device is None:
cuda_available = _torch.cuda.is_available()
device = _torch.device("cuda" if cuda_available else "cpu")
else:
if not isinstance(device, _torch.device):
raise ValueError("Invalid device type. Must be a torch.device.")

if ani2x_model_index is not None:
if not isinstance(ani2x_model_index, int):
raise ValueError("Invalid model index type. Must be an integer.")

super().__init__(device)

self._ani2x = _torchani.models.ANI2x(
periodic_table_index=True, model_index=ani2x_model_index
).to(device)

def eval(self, atomic_numbers, xyz, forces=False):
energy = self._ani2x((atomic_numbers, xyz.float())).energies
energy = energy * _HARTREE_TO_KCALMOL
if not forces:
return energy
forces = -_torch.autograd.grad(energy.sum(), xyz)[0]
forces = forces * _HARTREE_TO_KCALMOL
return energy, forces


class DeepMDBackend(BaseBackend):

def __init__(self, model=None):

super().__init__()

if not _os.path.isfile(model):
raise ValueError(f"Unable to locate DeePMD model file: '{model}'")

try:
from deepmd.infer import DeepPot as _DeepPot

self._dp = _DeepPot(model)
self._z_map = {
element: index for index, element in enumerate(self._dp.get_type_map())
}
except Exception as e:
raise RuntimeError(f"Unable to create the DeePMD potentials: {e}")

def eval(self, atomic_numbers, xyz, forces=False):
# Assuming all the frames are of the same system
atom_types = [self._z_map[_ase.Atom(z).symbol] for z in atomic_numbers[0]]
e, f, _ = self._dp.eval(xyz, cells=None, atom_types=atom_types)
e, f = e.flatten() * _EV_TO_KCALMOL, f * _EV_TO_KCALMOL
return (e, f) if forces else e


class EMLEAnalyzer:
"""
Expand Down Expand Up @@ -236,8 +94,13 @@ def __init__(
except Exception as e:
raise RuntimeError(f"Unable to parse PC xyz file: {e}")

# Store the in vacuo energies if a backend is provided.
if backend:
self.e_backend = backend(atomic_numbers, qm_xyz)
if isinstance(backend, _torch.nn.Module):
backend = backend.to(device).to(dtype)
atomic_numbers = _torch.tensor(atomic_numbers, device=device)
qm_xyz = _torch.tensor(qm_xyz, dtype=dtype, device=device)
self.e_backend = backend(atomic_numbers, qm_xyz) * _HARTREE_TO_KCAL_MOL

self.atomic_numbers = _torch.tensor(
atomic_numbers, dtype=_torch.int, device=device
Expand All @@ -261,13 +124,13 @@ def __init__(
emle_base.get_static_energy(
self.q_core, self.q_val, self.pc_charges, mesh_data
)
* _HARTREE_TO_KCALMOL
* _HARTREE_TO_KCAL_MOL
)
self.e_induced = (
emle_base.get_induced_energy(
self.A_thole, self.pc_charges, self.s, mesh_data
)
* _HARTREE_TO_KCALMOL
* _HARTREE_TO_KCAL_MOL
)

if parser:
Expand All @@ -278,7 +141,7 @@ def __init__(
self.pc_charges,
mesh_data,
)
* _HARTREE_TO_KCALMOL
* _HARTREE_TO_KCAL_MOL
)

for attr in (
Expand Down

0 comments on commit db76f82

Please sign in to comment.