Skip to content

Commit

Permalink
Switch to using autograd.grad directly.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Jun 26, 2024
1 parent 7d3e5e5 commit 316c70f
Showing 1 changed file with 56 additions and 49 deletions.
105 changes: 56 additions & 49 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@

import torch as _torch

try:
from torch.func import grad_and_value as _grad_and_value
except:
from func_torch import grad_and_value as _grad_and_value


_ANGSTROM_TO_BOHR = 1.0 / _ase.units.Bohr
_NANOMETER_TO_BOHR = 10.0 / _ase.units.Bohr
_BOHR_TO_ANGSTROM = _ase.units.Bohr
Expand Down Expand Up @@ -1341,7 +1335,6 @@ def __init__(
self._params["n_ref"],
1e-3,
)
self._get_E_with_grad = _grad_and_value(self._get_E, argnums=(1, 2, 3, 4))

# Initialise the maximum number of MM atom that have been seen.
self._max_mm_atoms = 0
Expand Down Expand Up @@ -1577,26 +1570,37 @@ def run(self, path=None):

# Convert inputs to Torch tensors.
xyz_qm_bohr = _torch.tensor(
xyz_qm_bohr, dtype=_torch.float32, device=self._device
xyz_qm_bohr, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm_bohr = _torch.tensor(
xyz_mm_bohr, dtype=_torch.float32, device=self._device
xyz_mm_bohr, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
s = _torch.tensor(s, dtype=_torch.float32, device=self._device)
chi = _torch.tensor(chi, dtype=_torch.float32, device=self._device)

# Compute gradients and energy.
grads, E = self._get_E_with_grad(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
s = _torch.tensor(
s, dtype=_torch.float32, device=self._device, requires_grad=True
)
chi = _torch.tensor(
chi, dtype=_torch.float32, device=self._device, requires_grad=True
)

# Compute energy and gradients.
E = self._get_E(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
if self._method == "mm":
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E, (xyz_qm_bohr, xyz_mm_bohr)
)
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
else:
grads = _torch.autograd.grad(E, (xyz_qm_bohr, xyz_mm_bohr, s, chi))
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
)

# Compute the total energy and gradients.
E_tot = E + E_vac
grad_qm = dE_dxyz_qm_bohr + grad_vac
Expand All @@ -1620,16 +1624,12 @@ def run(self, path=None):
method = self._method
self._method = "mm"

# Recompute the gradients and energy.
grads, E = self._get_E_with_grad(
charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi
)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
# Recompute the energy and gradients.
E = self._get_E(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E, (xyz_qm_bohr, xyz_mm_bohr)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()

# Restore the method.
Expand Down Expand Up @@ -1983,26 +1983,37 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):

# Convert inputs to Torch tensors.
xyz_qm_bohr = _torch.tensor(
xyz_qm_bohr, dtype=_torch.float32, device=self._device
xyz_qm_bohr, dtype=_torch.float32, device=self._device, requires_grad=True
)
xyz_mm_bohr = _torch.tensor(
xyz_mm_bohr, dtype=_torch.float32, device=self._device
xyz_mm_bohr, dtype=_torch.float32, device=self._device, requires_grad=True
)
charges_mm = _torch.tensor(
charges_mm, dtype=_torch.float32, device=self._device
)
s = _torch.tensor(s, dtype=_torch.float32, device=self._device)
chi = _torch.tensor(chi, dtype=_torch.float32, device=self._device)

# Compute gradients and energy.
grads, E = self._get_E_with_grad(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
s = _torch.tensor(
s, dtype=_torch.float32, device=self._device, requires_grad=True
)
chi = _torch.tensor(
chi, dtype=_torch.float32, device=self._device, requires_grad=True
)

# Compute energy and gradients.
E = self._get_E(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
if self._method == "mm":
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E, (xyz_qm_bohr, xyz_mm_bohr)
)
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()
else:
grads = _torch.autograd.grad(E, (xyz_qm_bohr, xyz_mm_bohr, s, chi))
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
)

# Compute the total energy and gradients.
E_tot = E + E_vac
grad_qm = dE_dxyz_qm_bohr + grad_vac
Expand Down Expand Up @@ -2030,16 +2041,12 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
method = self._method
self._method = "mm"

# Recompute the gradients and energy.
grads, E = self._get_E_with_grad(
charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi
)
dE_dxyz_qm_bohr_part, dE_dxyz_mm_bohr, dE_ds, dE_dchi = grads
dE_dxyz_qm_bohr = (
dE_dxyz_qm_bohr_part.cpu().numpy()
+ dE_ds.cpu().numpy() @ ds_dxyz_qm_bohr.swapaxes(0, 1)
+ dE_dchi.cpu().numpy() @ dchi_dxyz_qm_bohr.swapaxes(0, 1)
# Recompute the energy and gradients.
E = self._get_E(charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi)
dE_dxyz_qm_bohr, dE_dxyz_mm_bohr = _torch.autograd.grad(
E, (xyz_qm_bohr, xyz_mm_bohr)
)
dE_dxyz_qm_bohr = dE_dxyz_qm_bohr.cpu().numpy()
dE_dxyz_mm_bohr = dE_dxyz_mm_bohr.cpu().numpy()

# Restore the method.
Expand Down

0 comments on commit 316c70f

Please sign in to comment.