diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a50ba8d..63d189c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.1 + rev: v0.2.2 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.0.0-alpha.2 + rev: v9.0.0-beta.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 97b9743a..35a8819b 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -216,6 +216,7 @@ def relax( loginterval: int | None = 1, crystal_feas_save_path: str | None = None, verbose: bool = True, + assign_magmoms: bool = True, **kwargs, ) -> dict[str, Structure | TrajectoryObserver]: """Relax the Structure/Atoms until maximum force is smaller than fmax. @@ -242,6 +243,8 @@ def relax( Default = None verbose (bool): Whether to print the output of the ASE optimizer. Default = True + assign_magmoms (bool): Whether to assign magnetic moments to the final + structure. Default = True **kwargs: Additional parameters for the optimizer. Returns: @@ -260,8 +263,8 @@ def relax( ase_filter = "ExpCellFilter" print( "Failed to import ase.filters. Default filter to ExpCellFilter. " - "For better relaxation accuracy with the new FrechetCellFilter," - "Run pip install git+https://gitlab.com/ase/ase" + "For better relaxation accuracy with the new FrechetCellFilter, " + "run pip install git+https://gitlab.com/ase/ase" ) valid_filter_names = [ name @@ -291,7 +294,7 @@ def relax( if relax_cell: atoms = ase_filter(atoms) - optimizer = self.optimizer_class(atoms, **kwargs) + optimizer: Optimizer = self.optimizer_class(atoms, **kwargs) optimizer.attach(obs, interval=loginterval) if crystal_feas_save_path: @@ -309,11 +312,13 @@ def relax( if isinstance(atoms, Filter): atoms = atoms.atoms struct = AseAtomsAdaptor.get_structure(atoms) - for key in struct.site_properties: - struct.remove_site_property(property_name=key) - struct.add_site_property( - "magmom", [float(magmom) for magmom in atoms.get_magnetic_moments()] - ) + + if assign_magmoms: + for key in struct.site_properties: + struct.remove_site_property(property_name=key) + struct.add_site_property( + "magmom", [float(magmom) for magmom in atoms.get_magnetic_moments()] + ) return {"final_structure": struct, "trajectory": obs} @@ -336,7 +341,7 @@ def __init__(self, atoms: Atoms) -> None: self.atom_positions: list[np.ndarray] = [] self.cells: list[np.ndarray] = [] - def __call__(self): + def __call__(self) -> None: """The logic for saving the properties of an Atoms during the relaxation.""" self.energies.append(self.compute_energy()) self.forces.append(self.atoms.get_forces()) @@ -792,7 +797,7 @@ def fit( steps: int | None = 500, verbose: bool | None = False, **kwargs, - ): + ) -> None: """Relax the Structure/Atoms and fit the Birch-Murnaghan equation of state. Args: @@ -839,7 +844,7 @@ def fit( self.bm.fit() self.fitted = True - def get_bulk_modulus(self, unit: str = "eV/A^3"): + def get_bulk_modulus(self, unit: str = "eV/A^3") -> float: """Get the bulk modulus of from the fitted Birch-Murnaghan equation of state. Args: @@ -859,7 +864,7 @@ def get_bulk_modulus(self, unit: str = "eV/A^3"): return self.bm.b0_GPa raise NotImplementedError("unit has to be eV/A^3 or GPa") - def get_compressibility(self, unit: str = "A^3/eV"): + def get_compressibility(self, unit: str = "A^3/eV") -> float: """Get the bulk modulus of from the fitted Birch-Murnaghan equation of state. Args: @@ -878,6 +883,6 @@ def get_compressibility(self, unit: str = "A^3/eV"): return 1 / self.bm.b0 if unit == "GPa^-1": return 1 / self.bm.b0_GPa - if unit in ["Pa^-1", "m^2/N"]: + if unit in ("Pa^-1", "m^2/N"): return 1 / (self.bm.b0_GPa * 1e9) raise NotImplementedError("unit has to be one of A^3/eV, GPa^-1 Pa^-1 or m^2/N") diff --git a/site/package.json b/site/package.json index 48cb61a4..0975d0a1 100644 --- a/site/package.json +++ b/site/package.json @@ -15,29 +15,29 @@ "changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false" }, "devDependencies": { - "@sveltejs/adapter-static": "^2.0.3", - "@sveltejs/kit": "^1.27.2", - "@sveltejs/vite-plugin-svelte": "^2.4.6", - "@typescript-eslint/eslint-plugin": "^6.9.1", - "@typescript-eslint/parser": "^6.9.1", - "eslint": "^8.52.0", - "eslint-plugin-svelte": "^2.34.0", - "hastscript": "^8.0.0", + "@sveltejs/adapter-static": "^3.0.1", + "@sveltejs/kit": "^2.5.0", + "@sveltejs/vite-plugin-svelte": "^3.0.2", + "@typescript-eslint/eslint-plugin": "^7.0.1", + "@typescript-eslint/parser": "^7.0.1", + "eslint": "^8.56.0", + "eslint-plugin-svelte": "^2.35.1", + "hastscript": "^9.0.0", "mdsvex": "^0.11.0", - "prettier": "^3.0.3", - "prettier-plugin-svelte": "^3.0.3", - "rehype-autolink-headings": "^7.0.0", + "prettier": "^3.2.5", + "prettier-plugin-svelte": "^3.2.1", + "rehype-autolink-headings": "^7.1.0", "rehype-slug": "^6.0.0", - "svelte": "^4.2.2", - "svelte-check": "^3.5.2", + "svelte": "^4.2.11", + "svelte-check": "^3.6.4", "svelte-multiselect": "^10.2.0", - "svelte-preprocess": "^5.0.4", - "svelte-toc": "^0.5.6", - "svelte-zoo": "^0.4.9", - "svelte2tsx": "^0.6.23", + "svelte-preprocess": "^5.1.3", + "svelte-toc": "^0.5.7", + "svelte-zoo": "^0.4.10", + "svelte2tsx": "^0.7.1", "tslib": "^2.6.2", - "typescript": "^5.2.2", - "vite": "^4.5.0" + "typescript": "^5.3.3", + "vite": "^5.1.3" }, "prettier": { "semi": false, diff --git a/site/src/routes/+layout.svelte b/site/src/routes/+layout.svelte index 91cc8b2c..e0ee939e 100644 --- a/site/src/routes/+layout.svelte +++ b/site/src/routes/+layout.svelte @@ -12,14 +12,14 @@ $page.url.pathname === `${base}/api` ? `h1, h2, h3, h4` : `h2` }):not(.toc-exclude)` - const file_routes = Object.keys(import.meta.glob(`./**/+page.{svx,svelte,md}`)) + const file_routes = Object.keys(import.meta.glob(`./*/+page.{svx,svelte,md}`)) .filter((key) => !key.includes(`/[`)) .map((filename) => { const parts = filename.split(`/`) return `/` + parts.slice(1, -1).join(`/`) }) - const actions = file_routes.map((name) => { + const actions = [`/`, ...file_routes].map((name) => { return { label: name, action: () => goto(`${base}${name.toLowerCase()}`) } }) diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte index 56f39a5b..7e2c6975 100644 --- a/site/src/routes/+page.svelte +++ b/site/src/routes/+page.svelte @@ -20,4 +20,8 @@ :global(a:has(img[alt='Docs'])) { display: none; } + /* hide proprietary models */ + :global(table tr:has(.proprietary)) { + display: none; + } diff --git a/site/vite.config.ts b/site/vite.config.ts index e4122deb..c0cd4a6d 100644 --- a/site/vite.config.ts +++ b/site/vite.config.ts @@ -4,7 +4,7 @@ import type { UserConfig } from 'vite' // fetch latest Matbench Discovery metrics table at build time and save to src/ dir await fetch( - `https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table.svelte`, + `https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`, ) .then((res) => res.text()) .then((text) => { diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index e25e84d7..4f85f6f2 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -2,25 +2,27 @@ import os import re -from typing import TYPE_CHECKING, Literal +from typing import Literal import pytest import torch from ase.filters import ExpCellFilter, Filter, FrechetCellFilter +from pymatgen.core import Structure from pytest import approx, mark, param from chgnet.graph import CrystalGraphConverter from chgnet.model import CHGNet, StructOptimizer -if TYPE_CHECKING: - from pymatgen.core import Structure - @pytest.mark.parametrize( - "algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)] + "algorithm, ase_filter, assign_magmoms", + [("legacy", FrechetCellFilter, True), ("fast", ExpCellFilter, False)], ) def test_relaxation( - algorithm: Literal["legacy", "fast"], ase_filter: Filter, li_mn_o2: Structure + algorithm: Literal["legacy", "fast"], + ase_filter: Filter, + assign_magmoms: bool, + li_mn_o2: Structure, ) -> None: chgnet = CHGNet.load() converter = CrystalGraphConverter( @@ -30,10 +32,21 @@ def test_relaxation( chgnet.graph_converter = converter relaxer = StructOptimizer(model=chgnet) - result = relaxer.relax(li_mn_o2, verbose=True, ase_filter=ase_filter) + result = relaxer.relax( + li_mn_o2, verbose=True, ase_filter=ase_filter, assign_magmoms=assign_magmoms + ) assert list(result) == ["final_structure", "trajectory"] + final_struct, traj = result["final_structure"], result["trajectory"] + assert isinstance(final_struct, Structure) + if assign_magmoms: + assert isinstance(final_struct.site_properties["magmom"], list) + assert len(final_struct.site_properties["magmom"]) == len(final_struct) + assert all( + isinstance(mm, float) for mm in final_struct.site_properties["magmom"] + ) + else: + assert "magmom" not in final_struct.site_properties - traj = result["trajectory"] # make sure trajectory has expected attributes assert {*traj.__dict__} == { *"atoms energies forces stresses magmoms atom_positions cells".split()