From 417014edc2e3590353059042ac40eeeda5d2d18b Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Wed, 31 Jul 2024 12:38:39 +0100 Subject: [PATCH] Add alpha_model option to modules and decouple model files from repo. --- emle/_utils.py | 103 ++++++++++++++++++++++++++++++++++++++ emle/calculator.py | 26 +++++++--- emle/models.py | 122 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 216 insertions(+), 35 deletions(-) create mode 100644 emle/_utils.py diff --git a/emle/_utils.py b/emle/_utils.py new file mode 100644 index 0000000..42d3b7e --- /dev/null +++ b/emle/_utils.py @@ -0,0 +1,103 @@ +####################################################################### +# EMLE-Engine: https://github.com/chemle/emle-engine +# +# Copyright: 2023-2024 +# +# Authors: Lester Hedges +# Kirill Zinovjev +# +# EMLE-Engine is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# EMLE-Engine is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with EMLE-Engine. If not, see . +##################################################################### + +"""EMLE utilities.""" + +__author__ = "Lester Hedges" +__email__ = "lester.hedges@gmail.com" + + +def _fetch_resources(): + """Fetch resources required for EMLE.""" + + import os as _os + import pygit2 as _pygit2 + + # Create the name for the expected resources directory. + resource_dir = _os.path.dirname(_os.path.abspath(__file__)) + "/resources" + + # Check if the resources directory exists. + if not _os.path.exists(resource_dir): + # If it doesn't, clone the resources repository. + print("Downloading EMLE resources...") + _pygit2.clone_repository( + "https://github.com/chemle/emle-models.git", resource_dir + ) + else: + # If it does, open the repository and pull the latest changes. + repo = _pygit2.Repository(resource_dir) + _pull(repo) + + +def _pull(repo, remote_name="origin", branch="main"): + """ + Pull the latest changes from the remote repository. + + Taken from: + https://github.com/MichaelBoselowitz/pygit2-examples/blob/master/examples.py + """ + + import pygit2 as _pygit2 + + for remote in repo.remotes: + if remote.name == remote_name: + remote.fetch() + remote_master_id = repo.lookup_reference( + "refs/remotes/origin/%s" % (branch) + ).target + merge_result, _ = repo.merge_analysis(remote_master_id) + # Up to date, do nothing + if merge_result & _pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: + return + # We can just fastforward + elif merge_result & _pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: + print("Updating EMLE resources...") + repo.checkout_tree(repo.get(remote_master_id)) + try: + master_ref = repo.lookup_reference("refs/heads/%s" % (branch)) + master_ref.set_target(remote_master_id) + except KeyError: + repo.create_branch(branch, repo.get(remote_master_id)) + repo.head.set_target(remote_master_id) + elif merge_result & _pygit2.GIT_MERGE_ANALYSIS_NORMAL: + print("Updating EMLE resources...") + repo.merge(remote_master_id) + + if repo.index.conflicts is not None: + for conflict in repo.index.conflicts: + print("Conflicts found in:", conflict[0].path) + raise AssertionError("Conflicts!") + + user = repo.default_signature + tree = repo.index.write_tree() + commit = repo.create_commit( + "HEAD", + user, + user, + "Merge!", + tree, + [repo.head.target, remote_master_id], + ) + # We need to do this or git CLI will think we are still merging. + repo.state_cleanup() + else: + raise AssertionError("Unknown merge analysis result") diff --git a/emle/calculator.py b/emle/calculator.py index 7cbc1c2..9d3bedd 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -221,13 +221,15 @@ class EMLECalculator: # Class attributes. - # Get the directory of this module file. - _module_dir = _os.path.dirname(_os.path.abspath(__file__)) + # Store the expected path to the resources directory. + _resource_dir = _os.path.join( + _os.path.dirname(_os.path.abspath(__file__)), "resources" + ) # Create the name of the default model file for each alpha mode. _default_models = { - "species": _os.path.join(_module_dir, "/resources/emle_qm7_aev.mat"), - "reference": _os.path.join(_module_dir, "/resources/emle_qm7_aev_alphagpr.mat"), + "species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"), + "reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"), } # Store the list of supported species. @@ -444,6 +446,11 @@ def __init__( the calculator. """ + from ._utils import _fetch_resources + + # Fetch or update the resources. + _fetch_resources() + # Validate input. # First handle the logger. @@ -489,10 +496,13 @@ def __init__( # Validate the alpha mode first so that we can choose an appropriate # default model. - if not isinstance(alpha_mode, str): - msg = "'alpha_mode' must be of type 'str'" - _logger.error(msg) - raise TypeError(msg) + if alpha_mode is not None: + if not isinstance(alpha_mode, str): + msg = "'alpha_mode' must be of type 'str'" + _logger.error(msg) + raise TypeError(msg) + else: + alpha_mode = "species" # Convert to lower case and strip whitespace. alpha_mode = alpha_mode.lower().replace(" ", "") diff --git a/emle/models.py b/emle/models.py index 3b2c258..435c7cc 100644 --- a/emle/models.py +++ b/emle/models.py @@ -62,13 +62,23 @@ class EMLE(_torch.nn.Module): embedding. """ - def __init__(self, device=None, dtype=None, create_aev_calculator=True): + def __init__( + self, alpha_mode="species", device=None, dtype=None, create_aev_calculator=True + ): """ Constructor Parameters ---------- + alpha_mode: str + How atomic polarizabilities are calculated. + "species": + one volume scaling factor is used for each species + "reference": + scaling factors are obtained with GPR using the values learned + for each reference environment + device: torch.device The device on which to run the model. @@ -82,11 +92,27 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): # Call the base class constructor. super().__init__() - # Get the directory of this module file. - module_dir = _os.path.dirname(_os.path.abspath(__file__)) + from ._utils import _fetch_resources + + # Fetch or update the resources. + _fetch_resources() - # Create the name of the default model file. - model = _os.path.join(module_dir, "emle_qm7_aev_masked.mat") + # Store the expected path to the resources directory. + resource_dir = _os.path.join( + _os.path.dirname(_os.path.abspath(__file__)), "resources" + ) + + if not isinstance(alpha_mode, str): + raise TypeError("'alpha_mode' must be of type 'str'") + if alpha_mode not in ["species", "reference"]: + raise ValueError("'alpha_mode' must be 'species' or 'reference'") + self._alpha_mode = alpha_mode + + # Choose the model based on the alpha_mode. + if alpha_mode == "species": + model = _os.path.join(resource_dir, "emle_qm7_aev.mat") + else: + model = _os.path.join(resource_dir, "emle_qm7_aev_alphagpr.mat") if device is not None: if not isinstance(device, _torch.device): @@ -132,13 +158,16 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): q_core = _torch.tensor(params["q_core"], dtype=dtype, device=device) a_QEq = _torch.tensor(params["a_QEq"], dtype=dtype, device=device) a_Thole = _torch.tensor(params["a_Thole"], dtype=dtype, device=device) - k_Z = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + if self._alpha_mode == "species": + k = _torch.tensor(params["k_Z"], dtype=dtype, device=device) + else: + k = _torch.tensor(params["sqrtk_ref"], dtype=dtype, device=device) q_total = _torch.tensor( params.get("total_charge", 0), dtype=dtype, device=device ) # Extract the reference features. - ref_features = _torch.tensor(params["ref_soap"], dtype=dtype, device=device) + ref_features = _torch.tensor(params["ref_aev"], dtype=dtype, device=device) # Extract the reference values for the MBIS valence shell widths. ref_values_s = _torch.tensor(params["s_ref"], dtype=dtype, device=device) @@ -152,7 +181,7 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): ref_shifted = ref_values_s - ref_mean_s[:, None] c_s = (Kinv @ ref_shifted[:, :, None]).squeeze() - # Exctract the reference values for the electronegativities. + # Extract the reference values for the electronegativities. ref_values_chi = _torch.tensor(params["chi_ref"], dtype=dtype, device=device) # Store additional attributes for the electronegativity GPR model. @@ -160,6 +189,15 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): ref_shifted = ref_values_chi - ref_mean_chi[:, None] c_chi = (Kinv @ ref_shifted[:, :, None]).squeeze() + # Extract the reference values for the polarizabilities. + if self._alpha_mode == "reference": + ref_mean_k = _torch.sum(k, dim=1) / n_ref + ref_shifted = k - ref_mean_k[:, None] + c_k = (Kinv @ ref_shifted[:, :, None]).squeeze() + else: + ref_mean_k = _torch.empty(0, dtype=dtype, device=device) + c_k = _torch.empty(0, dtype=dtype, device=device) + # Store the current device. self._device = device @@ -169,7 +207,7 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): self.register_buffer("_q_core", q_core) self.register_buffer("_a_QEq", a_QEq) self.register_buffer("_a_Thole", a_Thole) - self.register_buffer("_k_Z", k_Z) + self.register_buffer("_k", k) self.register_buffer("_q_total", q_total) self.register_buffer("_ref_features", ref_features) self.register_buffer("_n_ref", n_ref) @@ -179,6 +217,8 @@ def __init__(self, device=None, dtype=None, create_aev_calculator=True): self.register_buffer("_ref_mean_chi", ref_mean_chi) self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) + self.register_buffer("_ref_mean_k", ref_mean_k) + self.register_buffer("_c_k", c_k) def to(self, *args, **kwargs): """ @@ -191,7 +231,7 @@ def to(self, *args, **kwargs): self._q_core = self._q_core.to(*args, **kwargs) self._a_QEq = self._a_QEq.to(*args, **kwargs) self._a_Thole = self._a_Thole.to(*args, **kwargs) - self._k_Z = self._k_Z.to(*args, **kwargs) + self._k = self._k.to(*args, **kwargs) self._q_total = self._q_total.to(*args, **kwargs) self._ref_features = self._ref_features.to(*args, **kwargs) self._n_ref = self._n_ref.to(*args, **kwargs) @@ -201,6 +241,8 @@ def to(self, *args, **kwargs): self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) + self._ref_mean_k = self._ref_mean_k.to(*args, **kwargs) + self._c_k = self._c_k.to(*args, **kwargs) # Check for a device type in args and update the device attribute. for arg in args: @@ -221,7 +263,7 @@ def cuda(self, **kwargs): self._q_core = self._q_core.cuda(**kwargs) self._a_QEq = self._a_QEq.cuda(**kwargs) self._a_Thole = self._a_Thole.cuda(**kwargs) - self._k_Z = self._k_Z.cuda(**kwargs) + self._k = self._k.cuda(**kwargs) self._q_total = self._q_total.cuda(**kwargs) self._ref_features = self._ref_features.cuda(**kwargs) self._n_ref = self._n_ref.cuda(**kwargs) @@ -231,6 +273,8 @@ def cuda(self, **kwargs): self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) + self._ref_mean_k = self._ref_mean_k.cuda(**kwargs) + self._c_k = self._c_k.cuda(**kwargs) # Update the device attribute. self._device = self._species_map.device @@ -248,7 +292,7 @@ def cpu(self, **kwargs): self._q_core = self._q_core.cpu(**kwargs) self._a_QEq = self._a_QEq.cpu(**kwargs) self._a_Thole = self._a_Thole.cpu(**kwargs) - self._k_Z = self._k_Z.cpu(**kwargs) + self._k = self._k.cpu(**kwargs) self._q_total = self._q_total.cpu(**kwargs) self._ref_features = self._ref_features.cpu(**kwargs) self._n_ref = self._n_ref.cpu(**kwargs) @@ -258,6 +302,8 @@ def cpu(self, **kwargs): self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) + self._ref_mean_k = self._ref_mean_k.cpu(**kwargs) + self._c_k = self._c_k.cpu(**kwargs) # Update the device attribute. self._device = self._species_map.device @@ -273,7 +319,7 @@ def double(self): self._q_core = self._q_core.double() self._a_QEq = self._a_QEq.double() self._a_Thole = self._a_Thole.double() - self._k_Z = self._k_Z.double() + self._k = self._k.double() self._q_total = self._q_total.double() self._ref_features = self._ref_features.double() self._ref_values_s = self._ref_values_s.double() @@ -282,6 +328,8 @@ def double(self): self._ref_mean_chi = self._ref_mean_chi.double() self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() + self._ref_mean_k = self._ref_mean_k.double() + self._c_k = self._c_k.double() return self def float(self): @@ -293,7 +341,7 @@ def float(self): self._q_core = self._q_core.float() self._a_QEq = self._a_QEq.float() self._a_Thole = self._a_Thole.float() - self._k_Z = self._k_Z.float() + self._k = self._k.float() self._q_total = self._q_total.float() self._ref_features = self._ref_features.float() self._ref_values_s = self._ref_values_s.float() @@ -302,6 +350,8 @@ def float(self): self._ref_mean_chi = self._ref_mean_chi.float() self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() + self._ref_mean_k = self._ref_mean_k.float() + self._c_k = self._c_k.float() return self def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): @@ -360,12 +410,15 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Compute the static energy. q_core = self._q_core[species_id] - k_Z = self._k_Z[species_id] + if self._alpha_mode == "species": + k = self._k[species_id] + else: + k = self._gpr(aev, self._ref_mean_k, self._c_k, species_id) r_data = self._get_r_data(xyz_qm_bohr) mesh_data = self._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s) q = self._get_q(r_data, s, chi) q_val = q - q_core - mu_ind = self._get_mu_ind(r_data, mesh_data, charges_mm, s, q_val, k_Z) + mu_ind = self._get_mu_ind(r_data, mesh_data, charges_mm, s, q_val, k) vpot_q_core = self._get_vpot_q(q_core, mesh_data[0]) vpot_q_val = self._get_vpot_q(q_val, mesh_data[1]) vpot_static = vpot_q_core + vpot_q_val @@ -527,7 +580,7 @@ def _get_mu_ind( q, s, q_val, - k_Z, + k, ): """ Internal method, calculates induced atomic dipoles @@ -549,7 +602,7 @@ def _get_mu_ind( q_val: torch.Tensor (N_QM_ATOMS,) MBIS valence charges. - k_Z: torch.Tensor (N_Z) + k: torch.Tensor (N_Z) Scaling factors for polarizabilities. Returns @@ -558,7 +611,7 @@ def _get_mu_ind( result: torch.Tensor (N_ATOMS, 3) Array of induced dipoles """ - A = self._get_A_thole(r_data, s, q_val, k_Z) + A = self._get_A_thole(r_data, s, q_val, k) r = 1.0 / mesh_data[0] f1 = self._get_f1_slater(r, s[:, None] * 2.0) @@ -568,9 +621,7 @@ def _get_mu_ind( E_ind = mu_ind @ fields * 0.5 return mu_ind.reshape((-1, 3)) - def _get_A_thole( - self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k_Z - ): + def _get_A_thole(self, r_data: Tuple[Tensor, Tensor, Tensor, Tensor], s, q_val, k): """ Internal method, generates A matrix for induced dipoles prediction (Eq. 20 in 10.1021/acs.jctc.2c00914) @@ -586,7 +637,7 @@ def _get_A_thole( q_val: torch.Tensor (N_ATOMS,) MBIS charges. - k_Z: torch.Tensor (N_Z) + k: torch.Tensor (N_Z) Scaling factors for polarizabilities. Returns @@ -596,7 +647,7 @@ def _get_A_thole( The A matrix for induced dipoles prediction. """ v = -60 * q_val * s**3 - alpha = v * k_Z + alpha = v * k alphap = alpha * self._a_Thole alphap_mat = alphap[:, None] * alphap[None, :] @@ -864,6 +915,7 @@ def _lambda5(au3): class ANI2xEMLE(EMLE): def __init__( self, + alpha_mode="species", model_index=None, ani2x_model=None, atomic_numbers=None, @@ -876,6 +928,14 @@ def __init__( Parameters ---------- + alpha_mode: str + How atomic polarizabilities are calculated. + "species": + one volume scaling factor is used for each species + "reference": + scaling factors are obtained with GPR using the values learned + for each reference environment + model_index: int The index of the model to use. If None, then the full 8 model ensemble will be used. @@ -927,7 +987,12 @@ def __init__( self._atomic_numbers = None # Call the base class constructor. - super().__init__(device=device, dtype=dtype, create_aev_calculator=False) + super().__init__( + alpha_mode=alpha_mode, + device=device, + dtype=dtype, + create_aev_calculator=False, + ) if ani2x_model is not None: # Add the base ANI2x model and ensemble. @@ -1124,12 +1189,15 @@ def forward(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm): # Compute the static energy. q_core = self._q_core[species_id] - k_Z = self._k_Z[species_id] + if self._alpha_mode == "species": + k = self._k[species_id] + else: + k = self._gpr(aev, self._ref_mean_k, self._c_k, species_id) r_data = self._get_r_data(xyz_qm_bohr) mesh_data = self._get_mesh_data(xyz_qm_bohr, xyz_mm_bohr, s) q = self._get_q(r_data, s, chi) q_val = q - q_core - mu_ind = self._get_mu_ind(r_data, mesh_data, charges_mm, s, q_val, k_Z) + mu_ind = self._get_mu_ind(r_data, mesh_data, charges_mm, s, q_val, k) vpot_q_core = self._get_vpot_q(q_core, mesh_data[0]) vpot_q_val = self._get_vpot_q(q_val, mesh_data[1]) vpot_static = vpot_q_core + vpot_q_val