From 84e8d55132b2242fad06f9b7f7706b35c1a7da7e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 16 Nov 2024 21:14:27 +0000 Subject: [PATCH] CHGNetCalculator add kwarg task: PredTask = "efsm" (#215) --- .pre-commit-config.yaml | 10 +++++----- chgnet/model/dynamics.py | 36 +++++++++++++++++++++++------------- chgnet/model/model.py | 9 ++++++--- chgnet/trainer/trainer.py | 2 +- site/.gitignore | 1 - site/package.json | 36 ++++++++++++++++++------------------ site/src/routes/+page.svelte | 5 +---- site/vite.config.ts | 10 ---------- tests/test_md.py | 28 ++++++++++++++++++++++++++-- tests/test_relaxation.py | 4 +++- tests/test_trainer.py | 6 +++--- 11 files changed, 86 insertions(+), 61 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f0a13d2..bc3acb2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ -default_stages: [commit] +default_stages: [pre-commit] default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.4 hooks: - id: ruff args: [--fix] @@ -28,11 +28,11 @@ repos: rev: v2.3.0 hooks: - id: codespell - stages: [commit, commit-msg] + stages: [pre-commit, commit-msg] args: [--check-filenames] - repo: https://github.com/kynan/nbstripout - rev: 0.7.1 + rev: 0.8.0 hooks: - id: nbstripout args: [--drop-empty-cells, --keep-output] @@ -48,7 +48,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.12.0 + rev: v9.15.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index b5b01f97..8b03bf0b 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -33,6 +33,8 @@ from ase.optimize.optimize import Optimizer from typing_extensions import Self + from chgnet import PredTask + # We would like to thank M3GNet develop team for this module # source: https://github.com/materialsvirtuallab/m3gnet @@ -59,7 +61,7 @@ def __init__( *, use_device: str | None = None, check_cuda_mem: bool = False, - stress_weight: float | None = 1 / 160.21766208, + stress_weight: float = units.GPa, # GPa to eV/A^3 on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", return_site_energies: bool = False, **kwargs, @@ -124,6 +126,7 @@ def calculate( atoms: Atoms | None = None, properties: list | None = None, system_changes: list | None = None, + task: PredTask = "efsm", ) -> None: """Calculate various properties of the atoms using CHGNet. @@ -133,6 +136,8 @@ def calculate( Default is all properties. system_changes (list | None): The changes made to the system. Default is all changes. + task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm". + Default = "efsm" """ properties = properties or all_properties system_changes = system_changes or all_changes @@ -147,23 +152,28 @@ def calculate( graph = self.model.graph_converter(structure) model_prediction = self.model.predict_graph( graph.to(self.device), - task="efsm", + task=task, return_crystal_feas=True, return_site_energies=self.return_site_energies, ) # Convert Result - factor = 1 if not self.model.is_intensive else structure.composition.num_atoms - self.results.update( - energy=model_prediction["e"] * factor, - forces=model_prediction["f"], - free_energy=model_prediction["e"] * factor, - magmoms=model_prediction["m"], - stress=model_prediction["s"] * self.stress_weight, - crystal_fea=model_prediction["crystal_fea"], + extensive_factor = len(structure) if self.model.is_intensive else 1 + key_map = dict( + e=("energy", extensive_factor), + f=("forces", 1), + m=("magmoms", 1), + s=("stress", self.stress_weight), ) + self.results |= { + long_key: model_prediction[key] * factor + for key, (long_key, factor) in key_map.items() + if key in model_prediction + } + self.results["free_energy"] = self.results["energy"] + self.results["crystal_fea"] = model_prediction["crystal_fea"] if self.return_site_energies: - self.results.update(energies=model_prediction["site_energies"]) + self.results["energies"] = model_prediction["site_energies"] class StructOptimizer: @@ -174,7 +184,7 @@ def __init__( model: CHGNet | CHGNetCalculator | None = None, optimizer_class: Optimizer | str | None = "FIRE", use_device: str | None = None, - stress_weight: float = 1 / 160.21766208, + stress_weight: float = units.GPa, on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", ) -> None: """Provide a trained CHGNet model and an optimizer to relax crystal structures. @@ -773,7 +783,7 @@ def __init__( model: CHGNet | CHGNetCalculator | None = None, optimizer_class: Optimizer | str | None = "FIRE", use_device: str | None = None, - stress_weight: float = 1 / 160.21766208, + stress_weight: float = units.GPa, on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", ) -> None: """Initialize a structure optimizer object for calculation of bulk modulus. diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d42c61c9..c1bd58f8 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -4,12 +4,13 @@ import os from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, get_args import torch from pymatgen.core import Structure from torch import Tensor, nn +from chgnet import PredTask from chgnet.graph import CrystalGraph, CrystalGraphConverter from chgnet.graph.crystalgraph import TORCH_DTYPE from chgnet.model.composition_model import AtomRef @@ -27,7 +28,6 @@ if TYPE_CHECKING: from typing_extensions import Self - from chgnet import PredTask module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -603,7 +603,7 @@ def predict_graph( Args: graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict. - task (str): can be 'e' 'ef', 'em', 'efs', 'efsm' + task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm' Default = "efsm" return_site_energies (bool): whether to return per-site energies. Default = False @@ -626,6 +626,9 @@ def predict_graph( raise TypeError( f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs" ) + valid_tasks = get_args(PredTask) + if task not in valid_tasks: + raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.") model_device = next(self.parameters()).device diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index e3637212..b742118a 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -858,7 +858,7 @@ def forward( for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True): # exclude structures without magmom labels if self.allow_missing_labels: - if mag_target is not None and not np.isnan(mag_target).any(): + if mag_target is not None and not torch.isnan(mag_target).any(): mag_preds.append(mag_pred) mag_targets.append(mag_target) m_mae_size += mag_target.shape[0] diff --git a/site/.gitignore b/site/.gitignore index 59078f29..bded1f72 100644 --- a/site/.gitignore +++ b/site/.gitignore @@ -5,4 +5,3 @@ node_modules .svelte-kit build src/routes/api/*.md -src/MetricsTable.svelte diff --git a/site/package.json b/site/package.json index 3474e4be..2f8156fc 100644 --- a/site/package.json +++ b/site/package.json @@ -15,28 +15,28 @@ "changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false" }, "devDependencies": { - "@sveltejs/adapter-static": "^3.0.2", - "@sveltejs/kit": "^2.5.17", - "@sveltejs/vite-plugin-svelte": "^3.1.1", - "eslint": "^9.5.0", - "eslint-plugin-svelte": "^2.41.0", + "@sveltejs/adapter-static": "^3.0.6", + "@sveltejs/kit": "^2.8.1", + "@sveltejs/vite-plugin-svelte": "^4.0.1", + "eslint": "^9.15.0", + "eslint-plugin-svelte": "^2.46.0", "hastscript": "^9.0.0", - "mdsvex": "^0.11.2", - "prettier": "^3.3.2", - "prettier-plugin-svelte": "^3.2.5", + "mdsvex": "^0.12.3", + "prettier": "^3.3.3", + "prettier-plugin-svelte": "^3.2.8", "rehype-autolink-headings": "^7.1.0", "rehype-slug": "^6.0.0", - "svelte": "^4.2.18", - "svelte-check": "^3.8.4", - "svelte-multiselect": "^10.3.0", - "svelte-preprocess": "^6.0.1", + "svelte": "^5.2.1", + "svelte-check": "^4.0.8", + "svelte-multiselect": "11.0.0-rc.1", + "svelte-preprocess": "^6.0.3", "svelte-toc": "^0.5.9", - "svelte-zoo": "^0.4.10", - "svelte2tsx": "^0.7.13", - "tslib": "^2.6.3", - "typescript": "^5.5.2", - "typescript-eslint": "^7.14.1", - "vite": "^5.3.1" + "svelte-zoo": "^0.4.13", + "svelte2tsx": "^0.7.25", + "tslib": "^2.8.1", + "typescript": "^5.6.3", + "typescript-eslint": "^8.14.0", + "vite": "^5.4.11" }, "prettier": { "semi": false, diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte index 7e2c6975..201fe721 100644 --- a/site/src/routes/+page.svelte +++ b/site/src/routes/+page.svelte @@ -1,12 +1,9 @@
- - - +