Skip to content

Commit

Permalink
refactor DeepPot and support AutoBatchSize
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 18, 2024
1 parent bb1c02a commit 9162ca5
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 126 deletions.
267 changes: 143 additions & 124 deletions deepmd_pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from deepmd_pt.train.wrapper import ModelWrapper
from deepmd_pt.utils.preprocess import Region3D, normalize_coord, make_env_mat
from deepmd_pt.utils.dataloader import collate_batch
from typing import Optional, Union, List
from typing import Callable, Optional, Tuple, Union, List
from deepmd_pt.utils import env
from deepmd_pt.utils.auto_batch_size import AutoBatchSize


class DeepEval:
def __init__(
self,
model_file: "Path"
model_file: "Path",
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
):
self.model_path = model_file
state_dict = torch.load(model_file, map_location=env.DEVICE)
Expand All @@ -29,147 +31,164 @@ def __init__(
self.dp.load_state_dict(state_dict)
self.rcut = self.dp.model['Default'].descriptor.get_rcut()
self.sec = np.cumsum(self.dp.model['Default'].descriptor.get_sel())
if isinstance(auto_batch_size, bool):
if auto_batch_size:
self.auto_batch_size = AutoBatchSize()
else:
self.auto_batch_size = None
elif isinstance(auto_batch_size, int):
self.auto_batch_size = AutoBatchSize(auto_batch_size)
elif isinstance(auto_batch_size, AutoBatchSize):
self.auto_batch_size = auto_batch_size
else:
raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize")

def eval(
self,
coords: Union[np.ndarray, torch.Tensor],
cells: Optional[Union[np.ndarray, torch.Tensor]],
atom_types: Union[np.ndarray, torch.Tensor, List[int]],
atomic: bool = False,
infer_batch_size: int = 2,
):
raise NotImplementedError


class DeepPot(DeepEval):
def __init__(
self,
model_file: "Path"
self,
model_file: "Path",
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list=None,
):
super(DeepPot, self).__init__(model_file)
if neighbor_list is not None:
raise NotImplementedError
super(DeepPot, self).__init__(
model_file,
auto_batch_size=auto_batch_size,
)

def eval(
self,
coords: Union[np.ndarray, torch.Tensor],
cells: Optional[Union[np.ndarray, torch.Tensor]],
atom_types: Union[np.ndarray, torch.Tensor, List[int]],
atomic: bool = False,
infer_batch_size: int = 2,
self,
coords: np.ndarray,
cells: np.ndarray,
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
):
return eval_model(self.dp, coords, cells, atom_types, atomic, infer_batch_size)


def eval_model(
model,
coords: Union[np.ndarray, torch.Tensor],
cells: Optional[Union[np.ndarray, torch.Tensor]],
atom_types: Union[np.ndarray, torch.Tensor, List[int]],
atomic: bool = False,
infer_batch_size: int = 2,
denoise: bool = False,
):
model = model.to(DEVICE)
energy_out = []
atomic_energy_out = []
force_out = []
virial_out = []
atomic_virial_out = []
updated_coord_out = []
logits_out = []
err_msg = f"All inputs should be the same format, " \
f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! "
return_tensor = True
if isinstance(coords, torch.Tensor):
if cells is not None:
assert isinstance(cells, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list)
if fparam is not None or aparam is not None or efield is not None:
raise NotImplementedError
# convert all of the input to numpy array
atom_types = np.array(atom_types, dtype=np.int32)
return_tensor = False
coords = np.array(coords)
if cells is not None:
cells = np.array(cells)
natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types, len(atom_types.shape) > 1)
return self._eval_func(self._eval_model, numb_test, natoms)(coords, cells, atom_types, atomic)

def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable:
"""Wrapper method with auto batch size.
Parameters
----------
inner_func : Callable
the method to be wrapped
numb_test : int
number of tests
natoms : int
number of atoms
Returns
-------
Callable
the wrapper
"""
if self.auto_batch_size is not None:

def eval_func(*args, **kwargs):
return self.auto_batch_size.execute_all(
inner_func, numb_test, natoms, *args, **kwargs
)

nframes = coords.shape[0]
if len(atom_types.shape) == 1:
natoms = len(atom_types)
if isinstance(atom_types, torch.Tensor):
atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape(nframes, -1)
else:
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
else:
natoms = len(atom_types[0])

coord_input = torch.tensor(coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
batch_coord = coord_input[ii * infer_batch_size:(ii + 1) * infer_batch_size]
batch_atype = type_input[ii * infer_batch_size:(ii + 1) * infer_batch_size]
batch_box = None
if pbc:
batch_box = box_input[ii * infer_batch_size:(ii + 1) * infer_batch_size]
batch_output = model(batch_coord, batch_atype, box=batch_box)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
if not return_tensor:
if 'energy' in batch_output:
energy_out.append(batch_output['energy'].detach().cpu().numpy())
if 'atom_energy' in batch_output:
atomic_energy_out.append(batch_output['atom_energy'].detach().cpu().numpy())
if 'force' in batch_output:
force_out.append(batch_output['force'].detach().cpu().numpy())
if 'virial' in batch_output:
virial_out.append(batch_output['virial'].detach().cpu().numpy())
if 'atomic_virial' in batch_output:
atomic_virial_out.append(batch_output['atomic_virial'].detach().cpu().numpy())
if 'updated_coord' in batch_output:
updated_coord_out.append(batch_output['updated_coord'].detach().cpu().numpy())
if 'logits' in batch_output:
logits_out.append(batch_output['logits'].detach().cpu().numpy())
eval_func = inner_func
return eval_func

def _get_natoms_and_nframes(
self,
coords: np.ndarray,
atom_types: Union[List[int], np.ndarray],
mixed_type: bool = False,
) -> Tuple[int, int]:
if mixed_type:
natoms = len(atom_types[0])
else:
natoms = len(atom_types)
if natoms == 0:
assert coords.size == 0
else:
coords = np.reshape(np.array(coords), [-1, natoms * 3])
nframes = coords.shape[0]
return natoms, nframes

def _eval_model(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
denoise: bool = False,
):
model = self.dp.to(DEVICE)
energy_out = None
atomic_energy_out = None
force_out = None
virial_out = None
atomic_virial_out = None
updated_coord_out = None
logits_out = None

nframes = coords.shape[0]
if len(atom_types.shape) == 1:
natoms = len(atom_types)
if isinstance(atom_types, torch.Tensor):
atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape(nframes, -1)
else:
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
else:
if 'energy' in batch_output:
energy_out.append(batch_output['energy'])
if 'atom_energy' in batch_output:
atomic_energy_out.append(batch_output['atom_energy'])
if 'force' in batch_output:
force_out.append(batch_output['force'])
if 'virial' in batch_output:
virial_out.append(batch_output['virial'])
if 'atomic_virial' in batch_output:
atomic_virial_out.append(batch_output['atomic_virial'])
if 'updated_coord' in batch_output:
updated_coord_out.append(batch_output['updated_coord'])
if 'logits' in batch_output:
logits_out.append(batch_output['logits'])
if not return_tensor:
energy_out = np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1])
atomic_energy_out = np.concatenate(atomic_energy_out) if atomic_energy_out else np.zeros([nframes, natoms, 1])
force_out = np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3])
virial_out = np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3])
atomic_virial_out = np.concatenate(atomic_virial_out) if atomic_virial_out else np.zeros([nframes, natoms, 3, 3])
updated_coord_out = np.concatenate(updated_coord_out) if updated_coord_out else None
logits_out = np.concatenate(logits_out) if logits_out else None
else:
energy_out = torch.cat(energy_out) if energy_out else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
atomic_energy_out = torch.cat(atomic_energy_out) if atomic_energy_out else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
force_out = torch.cat(force_out) if force_out else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
virial_out = torch.cat(virial_out) if virial_out else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
atomic_virial_out = torch.cat(atomic_virial_out) if atomic_virial_out else torch.zeros([nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None
logits_out = torch.cat(logits_out) if logits_out else None
if denoise:
return updated_coord_out, logits_out
else:
if not atomic:
return energy_out, force_out, virial_out
natoms = len(atom_types[0])

coord_input = torch.tensor(coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE)
if cells is not None:
box_input = torch.tensor(cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE)
else:
return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out
box_input = None

batch_output = model(coord_input, type_input, box=box_input, do_atomic_virial=atomic)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
if 'energy' in batch_output:
energy_out = batch_output['energy'].detach().cpu().numpy()
if 'atom_energy' in batch_output:
atomic_energy_out = batch_output['atom_energy'].detach().cpu().numpy()
if 'force' in batch_output:
force_out = batch_output['force'].detach().cpu().numpy()
if 'virial' in batch_output:
virial_out = batch_output['virial'].detach().cpu().numpy()
if 'atomic_virial' in batch_output:
atomic_virial_out = batch_output['atomic_virial'].detach().cpu().numpy()
if 'updated_coord' in batch_output:
updated_coord_out = batch_output['updated_coord'].detach().cpu().numpy()
if 'logits' in batch_output:
logits_out = batch_output['logits'].detach().cpu().numpy()

if denoise:
return updated_coord_out, logits_out
else:
if not atomic:
return energy_out, force_out, virial_out
else:
return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out
5 changes: 3 additions & 2 deletions deepmd_pt/train/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ def share_params(self, shared_links, resume=False):

def forward(self, coord, atype, box: Optional[torch.Tensor] = None,
cur_lr: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None,
task_key: Optional[torch.Tensor] = None, inference_only=False):
task_key: Optional[torch.Tensor] = None, inference_only=False,
do_atomic_virial=False):
if not self.multi_task:
task_key = "Default"
else:
assert task_key is not None, \
f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}."
model_pred = self.model[task_key](coord, atype, box=box)
model_pred = self.model[task_key](coord, atype, box=box, do_atomic_virial=do_atomic_virial)
natoms = atype.shape[-1]
if not self.inference_only and not inference_only:
loss, more_loss = self.loss[task_key](model_pred, label, natoms=natoms, learning_rate=cur_lr)
Expand Down
26 changes: 26 additions & 0 deletions deepmd_pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import torch

from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


class AutoBatchSize(AutoBatchSizeBase):
def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
bool
True if GPU is available
"""
return torch.cuda.is_available()

def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
e : Exception
Exception
"""
return isinstance(e, RuntimeError) and "CUDA out of memory." in e.args[0]
43 changes: 43 additions & 0 deletions tests/test_deeppot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from copy import deepcopy
import json
import unittest
from pathlib import Path

import numpy as np
from deepmd_pt.entrypoints.main import get_trainer
from deepmd_pt.infer.deep_eval import DeepPot


class TestDeepPot(unittest.TestCase):
def setUp(self):
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json, "r") as f:
self.config = json.load(f)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.config["training"]["training_data"]["systems"] = [str(Path(__file__).parent / "water/data/single")]
self.config["training"]["validation_data"]["systems"] = [str(Path(__file__).parent / "water/data/single")]
self.input_json = "test_dp_test.json"
with open(self.input_json, "w") as fp:
json.dump(self.config, fp, indent=4)

trainer = get_trainer(deepcopy(self.config))
trainer.run()

input_dict, label_dict, _ = trainer.get_data(is_train=False)
trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0)
self.model = Path(__file__).parent / "model.pt"

def test_dp_test(self):
dp = DeepPot(str(self.model))
cell = np.array([
5.122106549439247480e+00,4.016537340154059388e-01,6.951654033828678081e-01,
4.016537340154059388e-01,6.112136112297989143e+00,8.178091365465004481e-01,
6.951654033828678081e-01,8.178091365465004481e-01,6.159552512682983760e+00,
]).reshape(1, 3, 3)
coord = np.array([
2.978060152121375648e+00,3.588469695887098077e+00,2.792459820604495491e+00,3.895592322591093115e+00,2.712091020667753760e+00,1.366836847133650501e+00,9.955616170888935690e-01,4.121324820711413039e+00,1.817239061889086571e+00,3.553661462345699906e+00,5.313046969500791583e+00,6.635182659098815883e+00,6.088601018589653080e+00,6.575011420004332585e+00,6.825240650611076099e+00
]).reshape(1, -1, 3)
atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1)

e, f, v, ae, av = dp.eval(coord, cell, atype, atomic=True)

0 comments on commit 9162ca5

Please sign in to comment.