Skip to content

Commit

Permalink
better compatibility with ase versions
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Feb 7, 2024
1 parent d5f5f9c commit 5c42e56
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,6 @@
from ase.io import Trajectory
from ase.optimize.optimize import Optimizer

try:
from ase.filters import Filter, FrechetCellFilter
except ImportError:
print(
"We recommend using ase's unreleased FrechetCellFilter over ExpCellFilter for "
"CHGNet structural relaxation. ExpCellFilter has a bug in its calculation "
"of cell gradients which was fixed in FrechetCellFilter. Otherwise the two "
"are identical. ExpCellFilter was kept only for backwards compatibility and "
"should no longer be used. Run pip install git+https://gitlab.com/ase/ase to "
"install from main branch."
)

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand Down Expand Up @@ -223,7 +211,7 @@ def relax(
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
ase_filter: str | Filter = FrechetCellFilter,
ase_filter: str | None = "FrechetCellFilter",
save_path: str | None = None,
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
Expand Down Expand Up @@ -260,23 +248,37 @@ def relax(
dict[str, Structure | TrajectoryObserver]:
A dictionary with 'final_structure' and 'trajectory'.
"""
try:
import ase.filters as filter_classes
from ase.filters import Filter

except ImportWarning:
import ase.constraints as filter_classes
from ase.constraints import Filter

if ase_filter == "FrechetCellFilter":
ase_filter = "ExpCellFilter"
print(
"Failed to import ase.filters. Default filter to ExpCellFilter. "
"For better relaxation accuracy with the new FrechetCellFilter,"
"Run pip install git+https://gitlab.com/ase/ase"
)
valid_filter_names = [
name
for name, cls in inspect.getmembers(filter_classes, inspect.isclass)
if issubclass(cls, Filter)
]

if isinstance(ase_filter, str):
try:
import ase.filters

ase_filter = getattr(ase.filters, ase_filter)
except AttributeError as exc:
valid_filter_names = [
name
for name, cls in inspect.getmembers(ase.filters, inspect.isclass)
if issubclass(cls, Filter)
]
if ase_filter in valid_filter_names:
ase_filter = getattr(filter_classes, ase_filter)
else:
raise ValueError(
f"Invalid {ase_filter=}, must be one of {valid_filter_names}. "
) from exc
)

if isinstance(atoms, Structure):
atoms = AseAtomsAdaptor().get_atoms(atoms)
# atoms = atoms.to_ase_atoms()

atoms.calc = self.calculator # assign model used to predict forces

Expand Down

0 comments on commit 5c42e56

Please sign in to comment.