Skip to content

Commit

Permalink
Add support for restart simulations.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 9, 2023
1 parent 2cc854e commit 1e446d5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
30 changes: 30 additions & 0 deletions bin/emle-server
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ from glob import glob

from emle.socket import Socket


# Copied from distutils.util.strtobool, which is deprecated.
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
True values are case insensitive 'y', 'yes', 't', 'true', 'on', and '1'.
false values are case insensitive 'n', 'no', 'f', 'false', 'off', and '0'.
Raises ValueError if 'val' is anything else.
"""
val = val.replace(" ", "").lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError("invalid truth value %r" % (val,))


# Check whether any EMLE environment variables are set.
config = os.getenv("EMLE_CONFIG")
host = os.getenv("EMLE_HOST")
Expand Down Expand Up @@ -69,6 +87,10 @@ except:
interpolate_steps = None
qm_indices = os.getenv("EMLE_QM_INDICES")
sqm_theory = os.getenv("EMLE_SQM_THEORY")
try:
restart = strtobool(os.getenv("EMLE_RESTART"))
except:
restart = False
try:
retries = int(os.getenv("EMLE_RETRIES"))
except:
Expand Down Expand Up @@ -103,6 +125,7 @@ env = {
"parm7": parm7,
"qm_indices": qm_indices,
"sqm_theory": sqm_theory,
"restart": restart,
"log": log,
}

Expand Down Expand Up @@ -228,6 +251,11 @@ parser.add_argument(
help="the semi-empirical theory to use for the QM region when using the SQM backend",
required=False,
)
parser.add_argument(
"--restart",
action=argparse.BooleanOptionalAction,
required=False,
)
parser.add_argument(
"--log",
type=int,
Expand Down Expand Up @@ -448,6 +476,8 @@ if args["deepmd_model"] is not None:
args["deepmd_model"] = models

if args["external_backend"] is None:
if backend is None:
args["backend"] = "torchani"
print(f"Starting ML-MM server using {args['backend']} backend...")
else:
print(f"Starting ML-MM server using external backend...")
Expand Down
31 changes: 22 additions & 9 deletions emle/emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def __init__(
sqm_theory="DFTB3",
lambda_interpolate=None,
interpolate_steps=None,
restart=False,
device=None,
log=1,
):
Expand Down Expand Up @@ -440,6 +441,10 @@ def __init__(
The QM theory to use when using the SQM backend. See the AmberTools
manual for the supported theory levels for your version of AmberTools.
restart : bool
Whether this is a restart simulation with sander. If True, then energies
are logged immediately.
device : str
The name of the device to be used by PyTorch. Options are "cpu"
or "cuda".
Expand Down Expand Up @@ -704,6 +709,13 @@ def __init__(
# Flag that delta-learning corrections will be applied.
self._is_delta = True

if restart is not None:
if not isinstance(restart, bool):
raise TypeError("'restart' must be of type 'bool'")
else:
restart = False
self._restart = restart

# Validate the interpolation lambda parameter.
if lambda_interpolate is not None:
if self._backend == "rascal":
Expand Down Expand Up @@ -899,10 +911,11 @@ def __init__(
# Initialise the number of steps. (Calls to the calculator.)
self._step = 0

# Flag whether that this is the first step since lambda has been set.
# This is used to avoid writing duplicate energy records since sander
# will call orca on startup, i.e. not just after each integration step.
self._is_first_step = True
# Flag whether to skip logging the first call to the server. This is
# used to avoid writing duplicate energy records since sander will call
# orca on startup when not performing a restart simulation, i.e. not
# just after each integration step.
self._is_first_step = not self._restart

# Store the settings as a dictionary.
self._settings = {
Expand All @@ -917,6 +930,7 @@ def __init__(
"sqm_theory": sqm_theory,
"lambda_interpolate": lambda_interpolate,
"interpolate_steps": interpolate_steps,
"restart": restart,
"device": device,
"plugin_path": plugin_path,
"log": log,
Expand Down Expand Up @@ -1152,8 +1166,9 @@ def run(self, path=None):
if len(self._lambda_interpolate) == 1:
lam = self._lambda_interpolate[0]
else:
offset = int(not self._restart)
lam = self._lambda_interpolate[0] + (
(self._step / (self._interpolate_steps - 1))
(self._step / (self._interpolate_steps - offset))
) * (self._lambda_interpolate[1] - self._lambda_interpolate[0])
if lam < 0.0:
lam = 0.0
Expand Down Expand Up @@ -1199,9 +1214,7 @@ def run(self, path=None):
f"#{'Step':>9}{'λ':>22}{'E(λ) (Eh)':>22}{'E(λ=0) (Eh)':>22}{'E(λ=1) (Eh)':>22}\n"
)
else:
f.write(
f"#{'Step':>9}{'E_vac (Eh)':>22}{'E_tot (Eh)':>22}\n"
)
f.write(f"#{'Step':>9}{'E_vac (Eh)':>22}{'E_tot (Eh)':>22}\n")
# Write the record.
if self._is_interpolate:
f.write(
Expand Down Expand Up @@ -1270,7 +1283,7 @@ def set_lambda_interpolate(self, lambda_interpolate):
self._lambda_interpolate = [lambda_interpolate]

# Reset the first step flag.
self._is_first_step = True
self._is_first_step = not self._restart

def _get_E(self, charges_mm, xyz_qm_bohr, xyz_mm_bohr, s, chi):
"""
Expand Down

0 comments on commit 1e446d5

Please sign in to comment.