diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4307d2f..5ede3515 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,15 +4,11 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.1.3 hooks: - id: ruff args: [--fix] - - - repo: https://github.com/psf/black - rev: 23.10.0 - hooks: - - id: black-jupyter + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 @@ -31,6 +27,7 @@ repos: hooks: - id: codespell stages: [commit, commit-msg] + args: [--check-filenames] - repo: https://github.com/kynan/nbstripout rev: 0.6.1 diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 087dd156..ee686fc0 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -125,7 +125,8 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]: return crystal_graph, targets - # Omit structures with isolated atoms. Return another randomly selected structure + # Omit structures with isolated atoms. Return another randomly selected + # structure except Exception: struct = self.structures[graph_id] self.failed_graph_id[graph_id] = struct.composition.formula @@ -491,7 +492,8 @@ def __init__( Args: data (str | dict): file path or dir name that contain all the JSONs - graph_converter (CrystalGraphConverter): Converts pymatgen.core.Structure to graph + graph_converter (CrystalGraphConverter): Converts pymatgen.core.Structure + to CrystalGraph object. targets ("ef" | "efs" | "efm" | "efsm"): The training targets. Default = "efsm" energy_key (str, optional): the key of energy in the labels. @@ -575,7 +577,8 @@ def __getitem__(self, idx): targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype)) return crystal_graph, targets - # Omit structures with isolated atoms. Return another randomly selected structure + # Omit structures with isolated atoms. Return another randomly selected + # structure except Exception: structure = Structure.from_dict(self.data[mp_id][graph_id]["structure"]) self.failed_graph_id[graph_id] = structure.composition.formula diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index 67b85cdd..cd993e1e 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -30,8 +30,8 @@ def __init__( ) -> None: """Initialize the crystal graph. - Attention! This data class is not intended to be created manually. CrystalGraph should - be returned by a CrystalGraphConverter + Attention! This data class is not intended to be created manually. CrystalGraph + should be returned by a CrystalGraphConverter Args: atomic_number (Tensor): the atomic numbers of atoms in the structure @@ -92,7 +92,8 @@ def __init__( self.composition = composition if len(directed2undirected) != 2 * len(undirected2directed): raise ValueError( - f"{graph_id} number of directed indices != 2 * number of undirected indices!" + f"{graph_id} number of directed indices ({len(directed2undirected)}) !=" + f" 2 * number of undirected indices ({2 * len(undirected2directed)})!" ) def to(self, device: str = "cpu") -> CrystalGraph: diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index 90c93508..72e6668b 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -89,8 +89,8 @@ def __eq__(self, other: object) -> bool: other (DirectedEdge): another DirectedEdge to compare to Returns: - bool: True if other is the same directed edge, or if other is the directed edge - with reverse direction of self, else False. + bool: True if other is the same directed edge, or if other is the directed + edge with reverse direction of self, else False. """ self_img = (self.info or {}).get("image") other_img = (other.info or {}).get("image") diff --git a/chgnet/model/basis.py b/chgnet/model/basis.py index 1451c236..2468d2d5 100644 --- a/chgnet/model/basis.py +++ b/chgnet/model/basis.py @@ -172,8 +172,8 @@ def __init__(self, cutoff: float = 5, cutoff_coeff: float = 5) -> None: Default = 5 cutoff_coeff (float): the strength of soft-Cutoff 0 will disable the cutoff, returning 1 at every r - for positive numbers > 0, the smaller cutoff_coeff is, the faster this function - decays. Default = 5. + for positive numbers > 0, the smaller cutoff_coeff is, the faster this + function decays. Default = 5. """ super().__init__() self.cutoff = cutoff diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index 01e1bb77..7f2c2825 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -46,7 +46,8 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor: """Predict the energy given composition encoding. Args: - composition_feas: batched atom feature matrix [batch_size, total_num_elements]. + composition_feas: batched atom feature matrix of shape + [batch_size, total_num_elements]. Returns: prediction associated with each composition [batchsize]. @@ -111,7 +112,8 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor: """Predict the energy given composition encoding. Args: - composition_feas: batched atom feature matrix [batch_size, total_num_elements]. + composition_feas: batched atom feature matrix of shape + [batch_size, total_num_elements]. Returns: prediction associated with each composition [batchsize]. diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 4d12f765..2348a050 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -424,8 +424,8 @@ def __init__( Default = None loginterval (int): write to log file every interval steps Default = 1 - crystal_feas_logfile (str): open this file for recording crystal features during MD - Default = None + crystal_feas_logfile (str): open this file for recording crystal features + during MD. Default = None append_trajectory (bool): Whether to append to prev trajectory. If false, previous trajectory gets overwritten Default = False @@ -541,8 +541,8 @@ def __init__( bulk_modulus_au = eos.get_bulk_modulus(unit="eV/A^3") compressibility_au = eos.get_compressibility(unit="A^3/eV") print( - f"Done bulk modulus calculation: " - f"k = {round(bulk_modulus, 3)}GPa, {round(bulk_modulus_au, 3)}eV/A^3" + f"Completed bulk modulus calculation: " + f"k = {bulk_modulus:.3}GPa, {bulk_modulus_au:.3}eV/A^3" ) except Exception: bulk_modulus_au = 2 / 160.2176 @@ -667,8 +667,8 @@ def upper_triangular_cell(self, verbose: bool | None = False): while ASE's canonical description is lower-triangular cell. Args: - verbose (bool): Whether to notify user about upper-triangular cell transformation. - Default = False + verbose (bool): Whether to notify user about upper-triangular cell + transformation. Default = False """ if not NPT._isuppertriangular(self.atoms.get_cell()): a, b, c, alpha, beta, gamma = self.atoms.cell.cellpar() diff --git a/chgnet/model/encoders.py b/chgnet/model/encoders.py index e8035718..d2eb4059 100644 --- a/chgnet/model/encoders.py +++ b/chgnet/model/encoders.py @@ -14,7 +14,8 @@ def __init__(self, atom_feature_dim: int, max_num_elements: int = 94) -> None: Args: atom_feature_dim (int): dimension of atomic embedding. - max_num_elements (int): maximum number of elements in the dataset. Default = 94 + max_num_elements (int): maximum number of elements in the dataset. + Default = 94 """ super().__init__() self.embedding = nn.Embedding(max_num_elements, atom_feature_dim) @@ -32,7 +33,9 @@ def forward(self, atomic_numbers: Tensor) -> Tensor: class BondEncoder(nn.Module): - """Encode a chemical bond given the position of two atoms using Gaussian Distance.""" + """Encode a chemical bond given the positions of two atoms using Gaussian + distance. + """ def __init__( self, diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 521c1b7d..f593502c 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -534,8 +534,8 @@ def predict_structure( """Predict from pymatgen.core.Structure. Args: - structure (Structure | Sequence[Structure]): structure or a list of structures - to predict. + structure (Structure | Sequence[Structure]): structure or a list of + structures to predict. task (str): can be 'e' 'ef', 'em', 'efs', 'efsm' Default = "efsm" return_site_energies (bool): whether to return per-site energies. @@ -552,7 +552,8 @@ def predict_structure( e (Tensor) : energy of structures float in eV/atom f (Tensor) : force on atoms [num_atoms, 3] in eV/A s (Tensor) : stress of structure [3, 3] in GPa - m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B + m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr + magneton mu_B """ if self.graph_converter is None: raise ValueError("graph_converter cannot be None!") @@ -598,7 +599,8 @@ def predict_graph( e (Tensor) : energy of structures float in eV/atom f (Tensor) : force on atoms [num_atoms, 3] in eV/A s (Tensor) : stress of structure [3, 3] in GPa - m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B + m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr + magneton mu_B """ if not isinstance(graph, (CrystalGraph, Sequence)): raise ValueError( diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 466baa0c..b1d9bc5a 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -63,8 +63,8 @@ def __init__( Default = 0.1 mag_loss_ratio (float): magmom loss ratio in loss function Default = 0.1 - optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW", "RAdam" - Default = 'Adam' + optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW", + "RAdam". Default = 'Adam' scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR", "CosRestartLR". Default = 'CosLR' criterion (str): loss function criterion. Can be "MSE", "Huber", "MAE" @@ -216,7 +216,8 @@ def train( Default = None save_dir (str): the dir name to save the trained weights Default = None - save_test_result (bool): whether to save the test set prediction in a json file + save_test_result (bool): Whether to save the test set prediction in a JSON + file. Default = False train_composition_model (bool): whether to train the composition model (AtomRef), this is suggested when the fine-tuning dataset has large elemental energy shift from the pretrained CHGNet, which typically comes diff --git a/pyproject.toml b/pyproject.toml index 8113797b..15aeb786 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "cython>=0.29.26", "numpy>=1.21.6", "nvidia-ml-py3>=7.352.0", - "pymatgen", + "pymatgen>=2023.10.11", "torch>=1.11.0", ] classifiers = [ @@ -49,7 +49,6 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] } [tool.ruff] target-version = "py39" -line-length = 95 include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"] select = [ "B", # flake8-bugbear diff --git a/tests/test_crystal_graph.py b/tests/test_crystal_graph.py index edddef76..d24d139a 100644 --- a/tests/test_crystal_graph.py +++ b/tests/test_crystal_graph.py @@ -345,9 +345,8 @@ def test_crystal_graph_stability_fast(): def test_crystal_graph_repr(): graph = converter_legacy(structure) assert ( - repr(graph) - == "CrystalGraph(composition='Li2 Mn2 O4', atom_graph_cutoff=5, bond_graph_cutoff=3, " - "n_atoms=8, atom_graph_len=384, bond_graph_len=744)" + repr(graph) == "CrystalGraph(composition='Li2 Mn2 O4', atom_graph_cutoff=5, " + "bond_graph_cutoff=3, n_atoms=8, atom_graph_len=384, bond_graph_len=744)" ) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 503ac7e4..736a6be8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -82,5 +82,6 @@ def test_structure_data_inconsistent_length(): assert ( str(exc.value) - == f"Inconsistent number of structures and labels: {len(structures)=}, {len(forces)=}" + == f"Inconsistent number of structures and labels: {len(structures)=}, " + f"{len(forces)=}" )