From 9162ca5bc82d517faaa3eb1c2f868d94dab786aa Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 Jan 2024 01:59:40 -0500 Subject: [PATCH] refactor DeepPot and support AutoBatchSize Signed-off-by: Jinzhe Zeng --- deepmd_pt/infer/deep_eval.py | 267 +++++++++++++++-------------- deepmd_pt/train/wrapper.py | 5 +- deepmd_pt/utils/auto_batch_size.py | 26 +++ tests/test_deeppot.py | 43 +++++ 4 files changed, 215 insertions(+), 126 deletions(-) create mode 100644 deepmd_pt/utils/auto_batch_size.py create mode 100644 tests/test_deeppot.py diff --git a/deepmd_pt/infer/deep_eval.py b/deepmd_pt/infer/deep_eval.py index 0161897c..2107004d 100644 --- a/deepmd_pt/infer/deep_eval.py +++ b/deepmd_pt/infer/deep_eval.py @@ -6,14 +6,16 @@ from deepmd_pt.train.wrapper import ModelWrapper from deepmd_pt.utils.preprocess import Region3D, normalize_coord, make_env_mat from deepmd_pt.utils.dataloader import collate_batch -from typing import Optional, Union, List +from typing import Callable, Optional, Tuple, Union, List from deepmd_pt.utils import env +from deepmd_pt.utils.auto_batch_size import AutoBatchSize class DeepEval: def __init__( self, - model_file: "Path" + model_file: "Path", + auto_batch_size: Union[bool, int, AutoBatchSize] = True, ): self.model_path = model_file state_dict = torch.load(model_file, map_location=env.DEVICE) @@ -29,6 +31,17 @@ def __init__( self.dp.load_state_dict(state_dict) self.rcut = self.dp.model['Default'].descriptor.get_rcut() self.sec = np.cumsum(self.dp.model['Default'].descriptor.get_sel()) + if isinstance(auto_batch_size, bool): + if auto_batch_size: + self.auto_batch_size = AutoBatchSize() + else: + self.auto_batch_size = None + elif isinstance(auto_batch_size, int): + self.auto_batch_size = AutoBatchSize(auto_batch_size) + elif isinstance(auto_batch_size, AutoBatchSize): + self.auto_batch_size = auto_batch_size + else: + raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") def eval( self, @@ -36,140 +49,146 @@ def eval( cells: Optional[Union[np.ndarray, torch.Tensor]], atom_types: Union[np.ndarray, torch.Tensor, List[int]], atomic: bool = False, - infer_batch_size: int = 2, ): raise NotImplementedError class DeepPot(DeepEval): def __init__( - self, - model_file: "Path" + self, + model_file: "Path", + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + neighbor_list=None, ): - super(DeepPot, self).__init__(model_file) + if neighbor_list is not None: + raise NotImplementedError + super(DeepPot, self).__init__( + model_file, + auto_batch_size=auto_batch_size, + ) def eval( - self, - coords: Union[np.ndarray, torch.Tensor], - cells: Optional[Union[np.ndarray, torch.Tensor]], - atom_types: Union[np.ndarray, torch.Tensor, List[int]], - atomic: bool = False, - infer_batch_size: int = 2, + self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + mixed_type: bool = False, ): - return eval_model(self.dp, coords, cells, atom_types, atomic, infer_batch_size) - - -def eval_model( - model, - coords: Union[np.ndarray, torch.Tensor], - cells: Optional[Union[np.ndarray, torch.Tensor]], - atom_types: Union[np.ndarray, torch.Tensor, List[int]], - atomic: bool = False, - infer_batch_size: int = 2, - denoise: bool = False, -): - model = model.to(DEVICE) - energy_out = [] - atomic_energy_out = [] - force_out = [] - virial_out = [] - atomic_virial_out = [] - updated_coord_out = [] - logits_out = [] - err_msg = f"All inputs should be the same format, " \ - f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! " - return_tensor = True - if isinstance(coords, torch.Tensor): - if cells is not None: - assert isinstance(cells, torch.Tensor), err_msg - assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list) - atom_types = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) - elif isinstance(coords, np.ndarray): - if cells is not None: - assert isinstance(cells, np.ndarray), err_msg - assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list) + if fparam is not None or aparam is not None or efield is not None: + raise NotImplementedError + # convert all of the input to numpy array atom_types = np.array(atom_types, dtype=np.int32) - return_tensor = False + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types, len(atom_types.shape) > 1) + return self._eval_func(self._eval_model, numb_test, natoms)(coords, cells, atom_types, atomic) + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + numb_test : int + number of tests + natoms : int + number of atoms + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) - nframes = coords.shape[0] - if len(atom_types.shape) == 1: - natoms = len(atom_types) - if isinstance(atom_types, torch.Tensor): - atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape(nframes, -1) else: - atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) - else: - natoms = len(atom_types[0]) - - coord_input = torch.tensor(coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) - box_input = None - if cells is None: - pbc = False - else: - pbc = True - box_input = torch.tensor(cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size) - - for ii in range(num_iter): - batch_coord = coord_input[ii * infer_batch_size:(ii + 1) * infer_batch_size] - batch_atype = type_input[ii * infer_batch_size:(ii + 1) * infer_batch_size] - batch_box = None - if pbc: - batch_box = box_input[ii * infer_batch_size:(ii + 1) * infer_batch_size] - batch_output = model(batch_coord, batch_atype, box=batch_box) - if isinstance(batch_output, tuple): - batch_output = batch_output[0] - if not return_tensor: - if 'energy' in batch_output: - energy_out.append(batch_output['energy'].detach().cpu().numpy()) - if 'atom_energy' in batch_output: - atomic_energy_out.append(batch_output['atom_energy'].detach().cpu().numpy()) - if 'force' in batch_output: - force_out.append(batch_output['force'].detach().cpu().numpy()) - if 'virial' in batch_output: - virial_out.append(batch_output['virial'].detach().cpu().numpy()) - if 'atomic_virial' in batch_output: - atomic_virial_out.append(batch_output['atomic_virial'].detach().cpu().numpy()) - if 'updated_coord' in batch_output: - updated_coord_out.append(batch_output['updated_coord'].detach().cpu().numpy()) - if 'logits' in batch_output: - logits_out.append(batch_output['logits'].detach().cpu().numpy()) + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: Union[List[int], np.ndarray], + mixed_type: bool = False, + ) -> Tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + if natoms == 0: + assert coords.size == 0 + else: + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _eval_model( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + atomic: bool = False, + denoise: bool = False, + ): + model = self.dp.to(DEVICE) + energy_out = None + atomic_energy_out = None + force_out = None + virial_out = None + atomic_virial_out = None + updated_coord_out = None + logits_out = None + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + if isinstance(atom_types, torch.Tensor): + atom_types = torch.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape(nframes, -1) + else: + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) else: - if 'energy' in batch_output: - energy_out.append(batch_output['energy']) - if 'atom_energy' in batch_output: - atomic_energy_out.append(batch_output['atom_energy']) - if 'force' in batch_output: - force_out.append(batch_output['force']) - if 'virial' in batch_output: - virial_out.append(batch_output['virial']) - if 'atomic_virial' in batch_output: - atomic_virial_out.append(batch_output['atomic_virial']) - if 'updated_coord' in batch_output: - updated_coord_out.append(batch_output['updated_coord']) - if 'logits' in batch_output: - logits_out.append(batch_output['logits']) - if not return_tensor: - energy_out = np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1]) - atomic_energy_out = np.concatenate(atomic_energy_out) if atomic_energy_out else np.zeros([nframes, natoms, 1]) - force_out = np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3]) - virial_out = np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3]) - atomic_virial_out = np.concatenate(atomic_virial_out) if atomic_virial_out else np.zeros([nframes, natoms, 3, 3]) - updated_coord_out = np.concatenate(updated_coord_out) if updated_coord_out else None - logits_out = np.concatenate(logits_out) if logits_out else None - else: - energy_out = torch.cat(energy_out) if energy_out else torch.zeros([nframes, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - atomic_energy_out = torch.cat(atomic_energy_out) if atomic_energy_out else torch.zeros([nframes, natoms, 1], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - force_out = torch.cat(force_out) if force_out else torch.zeros([nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - virial_out = torch.cat(virial_out) if virial_out else torch.zeros([nframes, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - atomic_virial_out = torch.cat(atomic_virial_out) if atomic_virial_out else torch.zeros([nframes, natoms, 3, 3], dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) - updated_coord_out = torch.cat(updated_coord_out) if updated_coord_out else None - logits_out = torch.cat(logits_out) if logits_out else None - if denoise: - return updated_coord_out, logits_out - else: - if not atomic: - return energy_out, force_out, virial_out + natoms = len(atom_types[0]) + + coord_input = torch.tensor(coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) + type_input = torch.tensor(atom_types, dtype=torch.long).to(DEVICE) + if cells is not None: + box_input = torch.tensor(cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION).to(DEVICE) else: - return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out + box_input = None + + batch_output = model(coord_input, type_input, box=box_input, do_atomic_virial=atomic) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + if 'energy' in batch_output: + energy_out = batch_output['energy'].detach().cpu().numpy() + if 'atom_energy' in batch_output: + atomic_energy_out = batch_output['atom_energy'].detach().cpu().numpy() + if 'force' in batch_output: + force_out = batch_output['force'].detach().cpu().numpy() + if 'virial' in batch_output: + virial_out = batch_output['virial'].detach().cpu().numpy() + if 'atomic_virial' in batch_output: + atomic_virial_out = batch_output['atomic_virial'].detach().cpu().numpy() + if 'updated_coord' in batch_output: + updated_coord_out = batch_output['updated_coord'].detach().cpu().numpy() + if 'logits' in batch_output: + logits_out = batch_output['logits'].detach().cpu().numpy() + + if denoise: + return updated_coord_out, logits_out + else: + if not atomic: + return energy_out, force_out, virial_out + else: + return energy_out, force_out, virial_out, atomic_energy_out, atomic_virial_out diff --git a/deepmd_pt/train/wrapper.py b/deepmd_pt/train/wrapper.py index 1efd97cd..aff8ca5e 100644 --- a/deepmd_pt/train/wrapper.py +++ b/deepmd_pt/train/wrapper.py @@ -115,13 +115,14 @@ def share_params(self, shared_links, resume=False): def forward(self, coord, atype, box: Optional[torch.Tensor] = None, cur_lr: Optional[torch.Tensor] = None, label: Optional[torch.Tensor] = None, - task_key: Optional[torch.Tensor] = None, inference_only=False): + task_key: Optional[torch.Tensor] = None, inference_only=False, + do_atomic_virial=False): if not self.multi_task: task_key = "Default" else: assert task_key is not None, \ f"Multitask model must specify the inference task! Supported tasks are {list(self.model.keys())}." - model_pred = self.model[task_key](coord, atype, box=box) + model_pred = self.model[task_key](coord, atype, box=box, do_atomic_virial=do_atomic_virial) natoms = atype.shape[-1] if not self.inference_only and not inference_only: loss, more_loss = self.loss[task_key](model_pred, label, natoms=natoms, learning_rate=cur_lr) diff --git a/deepmd_pt/utils/auto_batch_size.py b/deepmd_pt/utils/auto_batch_size.py new file mode 100644 index 00000000..049bad4c --- /dev/null +++ b/deepmd_pt/utils/auto_batch_size.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase + + +class AutoBatchSize(AutoBatchSizeBase): + def is_gpu_available(self) -> bool: + """Check if GPU is available. + + Returns + ------- + bool + True if GPU is available + """ + return torch.cuda.is_available() + + def is_oom_error(self, e: Exception) -> bool: + """Check if the exception is an OOM error. + + Parameters + ---------- + e : Exception + Exception + """ + return isinstance(e, RuntimeError) and "CUDA out of memory." in e.args[0] diff --git a/tests/test_deeppot.py b/tests/test_deeppot.py new file mode 100644 index 00000000..3a83d71a --- /dev/null +++ b/tests/test_deeppot.py @@ -0,0 +1,43 @@ +from copy import deepcopy +import json +import unittest +from pathlib import Path + +import numpy as np +from deepmd_pt.entrypoints.main import get_trainer +from deepmd_pt.infer.deep_eval import DeepPot + + +class TestDeepPot(unittest.TestCase): + def setUp(self): + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json, "r") as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["training"]["training_data"]["systems"] = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["validation_data"]["systems"] = [str(Path(__file__).parent / "water/data/single")] + self.input_json = "test_dp_test.json" + with open(self.input_json, "w") as fp: + json.dump(self.config, fp, indent=4) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + + input_dict, label_dict, _ = trainer.get_data(is_train=False) + trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0) + self.model = Path(__file__).parent / "model.pt" + + def test_dp_test(self): + dp = DeepPot(str(self.model)) + cell = np.array([ + 5.122106549439247480e+00,4.016537340154059388e-01,6.951654033828678081e-01, + 4.016537340154059388e-01,6.112136112297989143e+00,8.178091365465004481e-01, + 6.951654033828678081e-01,8.178091365465004481e-01,6.159552512682983760e+00, + ]).reshape(1, 3, 3) + coord = np.array([ + 2.978060152121375648e+00,3.588469695887098077e+00,2.792459820604495491e+00,3.895592322591093115e+00,2.712091020667753760e+00,1.366836847133650501e+00,9.955616170888935690e-01,4.121324820711413039e+00,1.817239061889086571e+00,3.553661462345699906e+00,5.313046969500791583e+00,6.635182659098815883e+00,6.088601018589653080e+00,6.575011420004332585e+00,6.825240650611076099e+00 + ]).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + e, f, v, ae, av = dp.eval(coord, cell, atype, atomic=True)