diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c52825f1..569afb1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.7 + rev: v0.4.1 hooks: - id: ruff args: [--fix] @@ -46,7 +46,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.0.0 + rev: v9.1.1 hooks: - id: eslint types: [file] diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index 9b23ca58..25383348 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -20,7 +20,7 @@ def __init__(self, index: int, info: dict | None = None) -> None: self.info = info self.neighbors: dict[int, list[DirectedEdge | UndirectedEdge]] = {} - def add_neighbor(self, index, edge): + def add_neighbor(self, index, edge) -> None: """Draw an directed edge between self and the node specified by index. Args: @@ -44,7 +44,7 @@ def __init__( self.index = index self.info = info - def __repr__(self): + def __repr__(self) -> str: """String representation of this edge.""" nodes, index, info = self.nodes, self.index, self.info return f"{type(self).__name__}({nodes=}, {index=}, {info=})" @@ -336,7 +336,7 @@ def as_dict(self): "undirected_edges_list": self.undirected_edges_list, } - def to(self, filename="graph.json"): + def to(self, filename="graph.json") -> None: """Save graph dictionary to file.""" write_json(self.as_dict(), filename) diff --git a/chgnet/model/composition_model.py b/chgnet/model/composition_model.py index d9d6f544..135018d1 100644 --- a/chgnet/model/composition_model.py +++ b/chgnet/model/composition_model.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from collections.abc import Sequence + from pathlib import Path from chgnet.graph.crystalgraph import CrystalGraph @@ -199,7 +200,7 @@ def get_site_energies(self, graphs: list[CrystalGraph]): for graph in graphs ] - def initialize_from(self, dataset: str): + def initialize_from(self, dataset: str) -> None: """Initialize pre-fitted weights from a dataset.""" if dataset in ["MPtrj", "MPtrj_e"]: self.initialize_from_MPtrj() @@ -208,7 +209,7 @@ def initialize_from(self, dataset: str): else: raise NotImplementedError(f"{dataset=} not supported yet") - def initialize_from_MPtrj(self): + def initialize_from_MPtrj(self) -> None: """Initialize pre-fitted weights from MPtrj dataset.""" state_dict = collections.OrderedDict() state_dict["weight"] = torch.tensor( @@ -313,7 +314,7 @@ def initialize_from_MPtrj(self): self.is_intensive = True self.fitted = True - def initialize_from_MPF(self): + def initialize_from_MPF(self) -> None: """Initialize pre-fitted weights from MPF dataset.""" state_dict = collections.OrderedDict() state_dict["weight"] = torch.tensor( @@ -418,7 +419,7 @@ def initialize_from_MPF(self): self.is_intensive = False self.fitted = True - def initialize_from_numpy(self, file_name): + def initialize_from_numpy(self, file_name: str | Path) -> None: """Initialize pre-fitted weights from numpy file.""" atom_ref_np = np.load(file_name) state_dict = collections.OrderedDict() diff --git a/examples/crystaltoolkit_relax_viewer.ipynb b/examples/crystaltoolkit_relax_viewer.ipynb index e8e762a6..f25440fd 100644 --- a/examples/crystaltoolkit_relax_viewer.ipynb +++ b/examples/crystaltoolkit_relax_viewer.ipynb @@ -386,7 +386,7 @@ " return structure, fig\n", "\n", "\n", - "app.run(height=800, use_reloader=False)\n" + "app.run(height=800, use_reloader=False)" ] } ], diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d1e52578..26a2e210 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -92,33 +92,18 @@ def test_structure_data_inconsistent_length(): def test_dataset_no_shuffling(): - structures, energies, forces, stresses, magmoms, structure_ids = ( - [], - [], - [], - [], - [], - [], - ) - for index in range(100): - struct = NaCl.copy() - struct.perturb(0.1) - structures.append(struct) - energies.append(np.random.random(1)) - forces.append(np.random.random([2, 3])) - stresses.append(np.random.random([3, 3])) - magmoms.append(np.random.random([2, 1])) - structure_ids.append(index) + n_samples = 100 + structure_ids = list(range(n_samples)) + structure_data = StructureData( - structures=structures, - energies=energies, - forces=forces, - stresses=stresses, - magmoms=magmoms, + structures=[NaCl.copy().perturb(0.1) for _ in range(n_samples)], + energies=np.random.random(n_samples), + forces=np.random.random([n_samples, 2, 3]), + stresses=np.random.random([n_samples, 3, 3]), + magmoms=np.random.random([n_samples, 2, 1]), structure_ids=structure_ids, shuffle=False, ) - - assert structure_data[0][0].mp_id == 0 - assert structure_data[1][0].mp_id == 1 - assert structure_data[2][0].mp_id == 2 + sample_ids = [data[0].mp_id for data in structure_data] + # shuffle=False means structure_ids should be in order + assert sample_ids == structure_ids