Skip to content

Commit

Permalink
switch workplace
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 22, 2024
1 parent 5939f3a commit 964be79
Show file tree
Hide file tree
Showing 4 changed files with 740 additions and 74 deletions.
30 changes: 23 additions & 7 deletions deepmd_pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from .se_atten import analyze_descrpt
from .se_atten import DescrptBlockSeAtten
from deepmd_pt.model.network.mlp import EmbdLayer
from IPython import embed

@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
Expand All @@ -41,14 +43,17 @@ def __init__(
post_ln=True,
ffn=False,
ffn_embed_dim=1024,
activation="tanh",
activation_function="tanh",
precision: str = "float64",
resnet_dt: bool = False,
scaling_factor=1.0,
head_num=1,
normalize=True,
temperature=None,
return_rot=False,
concat_output_tebd: bool = True,
type: Optional[str] = None,
old_impl: bool = False,
):
super(DescrptDPA1, self).__init__()
del type
Expand All @@ -63,17 +68,23 @@ def __init__(
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
post_ln=post_ln,
ffn=ffn,
ffn_embed_dim=ffn_embed_dim,
activation=activation,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
scaling_factor=scaling_factor,
head_num=head_num,
normalize=normalize,
temperature=temperature,
return_rot=return_rot,
old_impl=old_impl,
)
self.type_embedding = TypeEmbedNet(ntypes, tebd_dim)
self.type_embedding_old = None
self.type_embedding = None
self.old_impl = old_impl
if self.old_impl:
self.type_embedding_old = TypeEmbedNet(ntypes, tebd_dim)
else:
self.type_embedding = EmbdLayer(ntypes, tebd_dim, padding=True, precision=precision)
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd

Expand Down Expand Up @@ -147,7 +158,12 @@ def forward(
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
if self.old_impl:
assert self.type_embedding_old is not None
g1_ext = self.type_embedding_old(extended_atype)
else:
assert self.type_embedding is not None
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:,:nloc,:]
g1, env_mat, diff, rot_mat, sw = self.se_atten(
nlist,
Expand Down
Loading

0 comments on commit 964be79

Please sign in to comment.