-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add definition for the atomic model . add function that makes a model…
… out of an atomic model.
- Loading branch information
Han Wang
committed
Jan 22, 2024
1 parent
8d0bfbd
commit 947308d
Showing
4 changed files
with
351 additions
and
205 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from abc import ABC, abstractmethod | ||
import torch | ||
import logging | ||
import os | ||
from typing import ( | ||
Optional, | ||
Dict, | ||
) | ||
from deepmd_utils.model_format import FittingOutputDef | ||
from deepmd_pt.model.task import Fitting | ||
|
||
|
||
class AtomicModel(ABC): | ||
|
||
@abstractmethod | ||
def get_fitting_net(self)->Fitting: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_fitting_output_def(self)->FittingOutputDef: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_rcut(self)->float: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_sel(self)->int: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def distinguish_types(self)->bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def forward_atomic( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, torch.Tensor]: | ||
raise NotImplementedError | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import numpy as np | ||
import torch | ||
import logging | ||
import os | ||
from typing import ( | ||
Optional, | ||
Dict, | ||
List, | ||
) | ||
from deepmd_utils.model_format import FittingOutputDef | ||
from deepmd_pt.model.descriptor import Descriptor | ||
from deepmd_pt.model.task import Fitting, DenoiseNet | ||
from deepmd_pt.model.network import TypeEmbedNet | ||
|
||
from .model import BaseModel | ||
from .atomic_model import AtomicModel | ||
|
||
|
||
class DPAtomicModel(BaseModel, AtomicModel): | ||
"""Model give atomic prediction of some physical property. | ||
Parameters | ||
---------- | ||
descriptor | ||
Descriptor | ||
fitting_net | ||
Fitting net | ||
type_map | ||
Mapping atom type to the name (str) of the type. | ||
For example `type_map[1]` gives the name of the type 1. | ||
type_embedding | ||
Type embedding net | ||
resuming | ||
Whether to resume/fine-tune from checkpoint or not. | ||
stat_file_dir | ||
The directory to the state files. | ||
stat_file_path | ||
The path to the state files. | ||
sampled | ||
Sampled frames to compute the statistics. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
descriptor: dict, | ||
fitting_net: dict, | ||
type_map: Optional[List[str]], | ||
type_embedding: Optional[dict] = None, | ||
resuming: bool = False, | ||
stat_file_dir=None, | ||
stat_file_path=None, | ||
sampled=None, | ||
**kwargs, | ||
): | ||
"""Based on components, construct a DPA-1 model for energy. | ||
Args: | ||
- model_params: The Dict-like configuration with model options. | ||
- sampled: The sampled dataset for stat. | ||
""" | ||
super().__init__() | ||
# Descriptor + Type Embedding Net (Optional) | ||
ntypes = len(type_map) | ||
self.type_map = type_map | ||
self.ntypes = ntypes | ||
descriptor['ntypes'] = ntypes | ||
self.combination = descriptor.get('combination',False) | ||
if(self.combination): | ||
self.prefactor=descriptor.get('prefactor', [0.5,0.5]) | ||
self.descriptor_type = descriptor['type'] | ||
|
||
self.type_split = True | ||
if self.descriptor_type not in ['se_e2_a']: | ||
self.type_split = False | ||
|
||
self.descriptor = Descriptor(**descriptor) | ||
self.rcut = self.descriptor.get_rcut() | ||
self.sel = self.descriptor.get_sel() | ||
self.split_nlist = False | ||
|
||
# Statistics | ||
self.compute_or_load_stat(fitting_net, ntypes, | ||
resuming=resuming, | ||
type_map=type_map, | ||
stat_file_dir=stat_file_dir, | ||
stat_file_path=stat_file_path, | ||
sampled=sampled) | ||
|
||
# Fitting | ||
if fitting_net: | ||
fitting_net['type'] = fitting_net.get('type', 'ener') | ||
if self.descriptor_type not in ['se_e2_a']: | ||
fitting_net['ntypes'] = 1 | ||
else: | ||
fitting_net['ntypes'] = self.descriptor.get_ntype() | ||
fitting_net['use_tebd'] = False | ||
fitting_net['embedding_width'] = self.descriptor.dim_out | ||
|
||
self.grad_force = 'direct' not in fitting_net['type'] | ||
if not self.grad_force: | ||
fitting_net['out_dim'] = self.descriptor.dim_emb | ||
if 'ener' in fitting_net['type']: | ||
fitting_net['return_energy'] = True | ||
self.fitting_net = Fitting(**fitting_net) | ||
else: | ||
self.fitting_net = None | ||
self.grad_force = False | ||
if not self.split_nlist: | ||
self.coord_denoise_net = DenoiseNet(self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb) | ||
elif self.combination: | ||
self.coord_denoise_net = DenoiseNet(self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb_list, self.prefactor) | ||
else: | ||
self.coord_denoise_net = DenoiseNet(self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb) | ||
|
||
|
||
def get_fitting_net(self): | ||
return ( | ||
self.fitting_net | ||
if self.fitting_net is not None | ||
else self.coord_denoise_net | ||
) | ||
|
||
def get_fitting_output_def(self)->FittingOutputDef: | ||
return ( | ||
self.fitting_net.output_def() | ||
if self.fitting_net is not None | ||
else self.coord_denoise_net.output_def() | ||
) | ||
|
||
def get_rcut(self): | ||
return self.rcut | ||
|
||
def get_sel(self): | ||
return self.sel | ||
|
||
def distinguish_types(self): | ||
return self.type_split | ||
|
||
|
||
def forward_atomic( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[torch.Tensor] = None, | ||
do_atomic_virial: bool = False, | ||
) -> Dict[str, torch.Tensor]: | ||
"""Return atomic prediction. | ||
Parameters | ||
---------- | ||
extended_coord | ||
coodinates in extended region | ||
extended_atype | ||
atomic type in extended region | ||
nlist | ||
neighbor list. nf x nloc x nsel | ||
mapping | ||
mapps the extended indices to local indices | ||
do_atomic_virial | ||
whether do atomic virial | ||
Return | ||
------ | ||
result_dict | ||
the result dict, defined by the fitting net output def. | ||
""" | ||
nframes, nloc, nnei = nlist.shape | ||
atype = extended_atype[:, :nloc] | ||
if self.grad_force: | ||
extended_coord.requires_grad_(True) | ||
descriptor, env_mat, diff, rot_mat, sw = \ | ||
self.descriptor( | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping=mapping, | ||
) | ||
assert descriptor is not None | ||
# energy, force | ||
if self.fitting_net is not None: | ||
fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat) | ||
# denoise | ||
else: | ||
nlist_list = [nlist] | ||
if not self.split_nlist: | ||
nnei_mask = nlist != -1 | ||
elif self.combination: | ||
nnei_mask = [] | ||
for item in nlist_list: | ||
nnei_mask_item = item != -1 | ||
nnei_mask.append(nnei_mask_item) | ||
else: | ||
env_mat = env_mat[-1] | ||
diff = diff[-1] | ||
nnei_mask = nlist_list[-1] != -1 | ||
fit_ret = self.coord_denoise_net(env_mat, diff, nnei_mask, descriptor, sw) | ||
return fit_ret |
Oops, something went wrong.