Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assign_magmoms=True keyword to StructOptimizer.relax() #124

Merged
merged 4 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
rev: v0.2.2
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.0.0-alpha.2
rev: v9.0.0-beta.0
hooks:
- id: eslint
types: [file]
Expand Down
31 changes: 18 additions & 13 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def relax(
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
verbose: bool = True,
assign_magmoms: bool = True,
Copy link
Collaborator Author

@janosh janosh Feb 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenD-UCB let me know if assign_magmoms makes sense or you'd like a different name or API here

**kwargs,
) -> dict[str, Structure | TrajectoryObserver]:
"""Relax the Structure/Atoms until maximum force is smaller than fmax.
Expand All @@ -242,6 +243,8 @@ def relax(
Default = None
verbose (bool): Whether to print the output of the ASE optimizer.
Default = True
assign_magmoms (bool): Whether to assign magnetic moments to the final
structure. Default = True
**kwargs: Additional parameters for the optimizer.

Returns:
Expand All @@ -260,8 +263,8 @@ def relax(
ase_filter = "ExpCellFilter"
print(
"Failed to import ase.filters. Default filter to ExpCellFilter. "
"For better relaxation accuracy with the new FrechetCellFilter,"
"Run pip install git+https://gitlab.com/ase/ase"
"For better relaxation accuracy with the new FrechetCellFilter, "
"run pip install git+https://gitlab.com/ase/ase"
)
valid_filter_names = [
name
Expand Down Expand Up @@ -291,7 +294,7 @@ def relax(

if relax_cell:
atoms = ase_filter(atoms)
optimizer = self.optimizer_class(atoms, **kwargs)
optimizer: Optimizer = self.optimizer_class(atoms, **kwargs)
optimizer.attach(obs, interval=loginterval)

if crystal_feas_save_path:
Expand All @@ -309,11 +312,13 @@ def relax(
if isinstance(atoms, Filter):
atoms = atoms.atoms
struct = AseAtomsAdaptor.get_structure(atoms)
for key in struct.site_properties:
struct.remove_site_property(property_name=key)
struct.add_site_property(
"magmom", [float(magmom) for magmom in atoms.get_magnetic_moments()]
)

if assign_magmoms:
for key in struct.site_properties:
struct.remove_site_property(property_name=key)
struct.add_site_property(
"magmom", [float(magmom) for magmom in atoms.get_magnetic_moments()]
)
return {"final_structure": struct, "trajectory": obs}


Expand All @@ -336,7 +341,7 @@ def __init__(self, atoms: Atoms) -> None:
self.atom_positions: list[np.ndarray] = []
self.cells: list[np.ndarray] = []

def __call__(self):
def __call__(self) -> None:
"""The logic for saving the properties of an Atoms during the relaxation."""
self.energies.append(self.compute_energy())
self.forces.append(self.atoms.get_forces())
Expand Down Expand Up @@ -792,7 +797,7 @@ def fit(
steps: int | None = 500,
verbose: bool | None = False,
**kwargs,
):
) -> None:
"""Relax the Structure/Atoms and fit the Birch-Murnaghan equation of state.

Args:
Expand Down Expand Up @@ -839,7 +844,7 @@ def fit(
self.bm.fit()
self.fitted = True

def get_bulk_modulus(self, unit: str = "eV/A^3"):
def get_bulk_modulus(self, unit: str = "eV/A^3") -> float:
"""Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.

Args:
Expand All @@ -859,7 +864,7 @@ def get_bulk_modulus(self, unit: str = "eV/A^3"):
return self.bm.b0_GPa
raise NotImplementedError("unit has to be eV/A^3 or GPa")

def get_compressibility(self, unit: str = "A^3/eV"):
def get_compressibility(self, unit: str = "A^3/eV") -> float:
"""Get the bulk modulus of from the fitted Birch-Murnaghan equation of state.

Args:
Expand All @@ -878,6 +883,6 @@ def get_compressibility(self, unit: str = "A^3/eV"):
return 1 / self.bm.b0
if unit == "GPa^-1":
return 1 / self.bm.b0_GPa
if unit in ["Pa^-1", "m^2/N"]:
if unit in ("Pa^-1", "m^2/N"):
return 1 / (self.bm.b0_GPa * 1e9)
raise NotImplementedError("unit has to be one of A^3/eV, GPa^-1 Pa^-1 or m^2/N")
38 changes: 19 additions & 19 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,29 @@
"changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
},
"devDependencies": {
"@sveltejs/adapter-static": "^2.0.3",
"@sveltejs/kit": "^1.27.2",
"@sveltejs/vite-plugin-svelte": "^2.4.6",
"@typescript-eslint/eslint-plugin": "^6.9.1",
"@typescript-eslint/parser": "^6.9.1",
"eslint": "^8.52.0",
"eslint-plugin-svelte": "^2.34.0",
"hastscript": "^8.0.0",
"@sveltejs/adapter-static": "^3.0.1",
"@sveltejs/kit": "^2.5.0",
"@sveltejs/vite-plugin-svelte": "^3.0.2",
"@typescript-eslint/eslint-plugin": "^7.0.1",
"@typescript-eslint/parser": "^7.0.1",
"eslint": "^8.56.0",
"eslint-plugin-svelte": "^2.35.1",
"hastscript": "^9.0.0",
"mdsvex": "^0.11.0",
"prettier": "^3.0.3",
"prettier-plugin-svelte": "^3.0.3",
"rehype-autolink-headings": "^7.0.0",
"prettier": "^3.2.5",
"prettier-plugin-svelte": "^3.2.1",
"rehype-autolink-headings": "^7.1.0",
"rehype-slug": "^6.0.0",
"svelte": "^4.2.2",
"svelte-check": "^3.5.2",
"svelte": "^4.2.11",
"svelte-check": "^3.6.4",
"svelte-multiselect": "^10.2.0",
"svelte-preprocess": "^5.0.4",
"svelte-toc": "^0.5.6",
"svelte-zoo": "^0.4.9",
"svelte2tsx": "^0.6.23",
"svelte-preprocess": "^5.1.3",
"svelte-toc": "^0.5.7",
"svelte-zoo": "^0.4.10",
"svelte2tsx": "^0.7.1",
"tslib": "^2.6.2",
"typescript": "^5.2.2",
"vite": "^4.5.0"
"typescript": "^5.3.3",
"vite": "^5.1.3"
},
"prettier": {
"semi": false,
Expand Down
4 changes: 2 additions & 2 deletions site/src/routes/+layout.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
$page.url.pathname === `${base}/api` ? `h1, h2, h3, h4` : `h2`
}):not(.toc-exclude)`

const file_routes = Object.keys(import.meta.glob(`./**/+page.{svx,svelte,md}`))
const file_routes = Object.keys(import.meta.glob(`./*/+page.{svx,svelte,md}`))
.filter((key) => !key.includes(`/[`))
.map((filename) => {
const parts = filename.split(`/`)
return `/` + parts.slice(1, -1).join(`/`)
})

const actions = file_routes.map((name) => {
const actions = [`/`, ...file_routes].map((name) => {
return { label: name, action: () => goto(`${base}${name.toLowerCase()}`) }
})
</script>
Expand Down
4 changes: 4 additions & 0 deletions site/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@
:global(a:has(img[alt='Docs'])) {
display: none;
}
/* hide proprietary models */
:global(table tr:has(.proprietary)) {
display: none;
}
</style>
2 changes: 1 addition & 1 deletion site/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { UserConfig } from 'vite'

// fetch latest Matbench Discovery metrics table at build time and save to src/ dir
await fetch(
`https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table.svelte`,
`https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`,
)
.then((res) => res.text())
.then((text) => {
Expand Down
29 changes: 21 additions & 8 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@

import os
import re
from typing import TYPE_CHECKING, Literal
from typing import Literal

import pytest
import torch
from ase.filters import ExpCellFilter, Filter, FrechetCellFilter
from pymatgen.core import Structure
from pytest import approx, mark, param

from chgnet.graph import CrystalGraphConverter
from chgnet.model import CHGNet, StructOptimizer

if TYPE_CHECKING:
from pymatgen.core import Structure


@pytest.mark.parametrize(
"algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)]
"algorithm, ase_filter, assign_magmoms",
[("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)],
)
def test_relaxation(
algorithm: Literal["legacy", "fast"], ase_filter: Filter, li_mn_o2: Structure
algorithm: Literal["legacy", "fast"],
ase_filter: Filter,
assign_magmoms: bool,
li_mn_o2: Structure,
) -> None:
chgnet = CHGNet.load()
converter = CrystalGraphConverter(
Expand All @@ -30,10 +32,21 @@ def test_relaxation(

chgnet.graph_converter = converter
relaxer = StructOptimizer(model=chgnet)
result = relaxer.relax(li_mn_o2, verbose=True, ase_filter=ase_filter)
result = relaxer.relax(
li_mn_o2, verbose=True, ase_filter=ase_filter, assign_magmoms=assign_magmoms
)
assert list(result) == ["final_structure", "trajectory"]
final_struct, traj = result["final_structure"], result["trajectory"]
assert isinstance(final_struct, Structure)
if assign_magmoms:
assert isinstance(final_struct.site_properties["magmom"], list)
assert len(final_struct.site_properties["magmom"]) == len(final_struct)
assert all(
isinstance(mm, float) for mm in final_struct.site_properties["magmom"]
)
else:
assert "magmom" not in final_struct.site_properties

traj = result["trajectory"]
# make sure trajectory has expected attributes
assert {*traj.__dict__} == {
*"atoms energies forces stresses magmoms atom_positions cells".split()
Expand Down
Loading