Skip to content

Commit

Permalink
Batched inputs are not supported for NNPOps optimised models.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Jan 6, 2025
1 parent 2ed7116 commit 8143cd2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
10 changes: 10 additions & 0 deletions emle/models/_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def __init__(
create_aev_calculator=False,
)

# Initialise the NNOps flag.
self._is_nnpops = False

if ani2x_model is not None:
# Add the base ANI2x model and ensemble.
allowed_types = [
Expand Down Expand Up @@ -229,6 +232,9 @@ def __init__(
self._ani2x = _NNPOps.OptimizedTorchANI(
self._ani2x, atomic_numbers
).to(device)

# Flag that the model has been optimised with NNPOps.
self._is_nnpops = True
except Exception as e:
raise RuntimeError(
"Failed to optimise the ANI2x model with NNPOps."
Expand Down Expand Up @@ -373,6 +379,10 @@ def forward(
xyz_qm = xyz_qm.unsqueeze(0)
xyz_mm = xyz_mm.unsqueeze(0)
charges_mm = charges_mm.unsqueeze(0)
elif self._is_nnpops:
raise RuntimeError(
"Batched inputs are not supported when using NNPOps optimised models."
)

# Get the in vacuo energy.
E_vac = self._ani2x((atomic_numbers, xyz_qm)).energies
Expand Down
15 changes: 8 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,14 @@ def test_ani2x_nnpops(alpha_mode, atomic_numbers, charges_mm, xyz_qm, xyz_mm):
energy = model(atomic_numbers, charges_mm, xyz_qm, xyz_mm)
grad_qm, grad_mm = torch.autograd.grad(energy.sum(), (xyz_qm, xyz_mm))

# Test batched inputs.
energy = model(
atomic_numbers.unsqueeze(0).repeat(2, 1),
charges_mm.unsqueeze(0).repeat(2, 1),
xyz_qm.unsqueeze(0).repeat(2, 1, 1),
xyz_mm.unsqueeze(0).repeat(2, 1, 1),
)
# Make sure that batched inputs raise an exception.
with pytest.raises(torch.jit.Error):
energy = model(
atomic_numbers.unsqueeze(0).repeat(2, 1),
charges_mm.unsqueeze(0).repeat(2, 1),
xyz_qm.unsqueeze(0).repeat(2, 1, 1),
xyz_mm.unsqueeze(0).repeat(2, 1, 1),
)


@pytest.mark.skipif(not has_mace, reason="mace-torch not installed")
Expand Down

0 comments on commit 8143cd2

Please sign in to comment.