Skip to content

Commit

Permalink
Add workaround for MACE/MACE-OFF version self-consistency issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Dec 10, 2024
1 parent 7c37d11 commit 9dab879
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
19 changes: 15 additions & 4 deletions emle/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,30 @@ def __init__(
f"Unsupported MACE model: '{mace_model}'. Available MACE-OFF23 models are "
"'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'"
)
self._mace = _mace_off(model=size, device=device, return_raw_model=True)
source_model = _mace_off(model=size, device=device, return_raw_model=True)
else:
# Assuming that the model is a local model.
if _os.path.exists(mace_model):
self._mace = _torch.load(mace_model, map_location=device)
source_model = _torch.load(mace_model, map_location=device)
else:
raise FileNotFoundError(f"MACE model file not found: {mace_model}")
else:
# If no MACE model is provided, use the default MACE-OFF23(S) model.
self._mace = _mace_off(model="small", device=device, return_raw_model=True)
source_model = _mace_off(model="small", device=device, return_raw_model=True)

from mace.tools.scripts_utils import extract_config_mace_model

# Extract the config from the model.
config = extract_config_mace_model(source_model)

# Create the target model.
target_model = source_model.__class__(**config).to(device)

# Load the state dict.
target_model.load_state_dict(source_model.state_dict())

# Compile the model.
self._mace = _e3nn_jit.compile(self._mace).to(self._dtype)
self._mace = _e3nn_jit.compile(target_model).to(self._dtype)

# Create the z_table of the MACE model.
self._z_table = [int(z.item()) for z in self._mace.atomic_numbers]
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ dependencies:
- torchani
- xtb-python
- pip:
- mace-torch < 0.3.9
- mace-torch
2 changes: 1 addition & 1 deletion environment_rascal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ dependencies:
- xtb-python
- pip:
- git+https://github.com/lab-cosmo/librascal.git
- mace-torch < 0.3.9
- mace-torch
2 changes: 1 addition & 1 deletion environment_sire.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ dependencies:
- torchani
- xtb-python
- pip:
- mace-torch < 0.3.9
- mace-torch

0 comments on commit 9dab879

Please sign in to comment.