Skip to content

Commit

Permalink
Add alpha_model option to modules and decouple model files from repo.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Jul 31, 2024
1 parent aa337f0 commit 417014e
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 35 deletions.
103 changes: 103 additions & 0 deletions emle/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#######################################################################
# EMLE-Engine: https://github.com/chemle/emle-engine
#
# Copyright: 2023-2024
#
# Authors: Lester Hedges <[email protected]>
# Kirill Zinovjev <[email protected]>
#
# EMLE-Engine is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# EMLE-Engine is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with EMLE-Engine. If not, see <http://www.gnu.org/licenses/>.
#####################################################################

"""EMLE utilities."""

__author__ = "Lester Hedges"
__email__ = "[email protected]"


def _fetch_resources():
"""Fetch resources required for EMLE."""

import os as _os
import pygit2 as _pygit2

# Create the name for the expected resources directory.
resource_dir = _os.path.dirname(_os.path.abspath(__file__)) + "/resources"

# Check if the resources directory exists.
if not _os.path.exists(resource_dir):
# If it doesn't, clone the resources repository.
print("Downloading EMLE resources...")
_pygit2.clone_repository(
"https://github.com/chemle/emle-models.git", resource_dir
)
else:
# If it does, open the repository and pull the latest changes.
repo = _pygit2.Repository(resource_dir)
_pull(repo)


def _pull(repo, remote_name="origin", branch="main"):
"""
Pull the latest changes from the remote repository.
Taken from:
https://github.com/MichaelBoselowitz/pygit2-examples/blob/master/examples.py
"""

import pygit2 as _pygit2

for remote in repo.remotes:
if remote.name == remote_name:
remote.fetch()
remote_master_id = repo.lookup_reference(
"refs/remotes/origin/%s" % (branch)
).target
merge_result, _ = repo.merge_analysis(remote_master_id)
# Up to date, do nothing
if merge_result & _pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return
# We can just fastforward
elif merge_result & _pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
print("Updating EMLE resources...")
repo.checkout_tree(repo.get(remote_master_id))
try:
master_ref = repo.lookup_reference("refs/heads/%s" % (branch))
master_ref.set_target(remote_master_id)
except KeyError:
repo.create_branch(branch, repo.get(remote_master_id))
repo.head.set_target(remote_master_id)
elif merge_result & _pygit2.GIT_MERGE_ANALYSIS_NORMAL:
print("Updating EMLE resources...")
repo.merge(remote_master_id)

if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print("Conflicts found in:", conflict[0].path)
raise AssertionError("Conflicts!")

user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit(
"HEAD",
user,
user,
"Merge!",
tree,
[repo.head.target, remote_master_id],
)
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
raise AssertionError("Unknown merge analysis result")
26 changes: 18 additions & 8 deletions emle/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,15 @@ class EMLECalculator:

# Class attributes.

# Get the directory of this module file.
_module_dir = _os.path.dirname(_os.path.abspath(__file__))
# Store the expected path to the resources directory.
_resource_dir = _os.path.join(
_os.path.dirname(_os.path.abspath(__file__)), "resources"
)

# Create the name of the default model file for each alpha mode.
_default_models = {
"species": _os.path.join(_module_dir, "/resources/emle_qm7_aev.mat"),
"reference": _os.path.join(_module_dir, "/resources/emle_qm7_aev_alphagpr.mat"),
"species": _os.path.join(_resource_dir, "emle_qm7_aev.mat"),
"reference": _os.path.join(_resource_dir, "emle_qm7_aev_alphagpr.mat"),
}

# Store the list of supported species.
Expand Down Expand Up @@ -444,6 +446,11 @@ def __init__(
the calculator.
"""

from ._utils import _fetch_resources

# Fetch or update the resources.
_fetch_resources()

# Validate input.

# First handle the logger.
Expand Down Expand Up @@ -489,10 +496,13 @@ def __init__(

# Validate the alpha mode first so that we can choose an appropriate
# default model.
if not isinstance(alpha_mode, str):
msg = "'alpha_mode' must be of type 'str'"
_logger.error(msg)
raise TypeError(msg)
if alpha_mode is not None:
if not isinstance(alpha_mode, str):
msg = "'alpha_mode' must be of type 'str'"
_logger.error(msg)
raise TypeError(msg)
else:
alpha_mode = "species"

# Convert to lower case and strip whitespace.
alpha_mode = alpha_mode.lower().replace(" ", "")
Expand Down
Loading

0 comments on commit 417014e

Please sign in to comment.