Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use return type typing_extensions.Self for class methods #179

Merged
merged 3 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.5.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.5.0
rev: v9.6.0
hooks:
- id: eslint
types: [file]
Expand Down
4 changes: 3 additions & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

from chgnet import TrainTask

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -97,7 +99,7 @@ def from_vasp(
save_path: str | None = None,
graph_converter: CrystalGraphConverter | None = None,
shuffle: bool = True,
) -> StructureData:
) -> Self:
"""Parse VASP output files into structures and labels and feed into the dataset.

Args:
Expand Down
5 changes: 3 additions & 2 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if TYPE_CHECKING:
from pymatgen.core import Structure
from typing_extensions import Self

try:
from chgnet.graph.cygraph import make_graph
Expand Down Expand Up @@ -285,6 +286,6 @@ def as_dict(self) -> dict[str, str | float]:
}

@classmethod
def from_dict(cls, dct: dict) -> CrystalGraphConverter:
def from_dict(cls, dct: dict) -> Self:
"""Create converter from dictionary."""
return CrystalGraphConverter(**dct)
return cls(**dct)
11 changes: 7 additions & 4 deletions chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import os
from typing import Any
from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor

if TYPE_CHECKING:
from typing_extensions import Self

datatype = torch.float32


Expand Down Expand Up @@ -152,7 +155,7 @@ def save(self, fname: str | None = None, save_dir: str = ".") -> str:
return save_name

@classmethod
def from_file(cls, file_name: str) -> CrystalGraph:
def from_file(cls, file_name: str) -> Self:
"""Load a crystal graph from a file.

Args:
Expand All @@ -164,9 +167,9 @@ def from_file(cls, file_name: str) -> CrystalGraph:
return torch.load(file_name)

@classmethod
def from_dict(cls, dic: dict[str, Any]) -> CrystalGraph:
def from_dict(cls, dic: dict[str, Any]) -> Self:
"""Load a CrystalGraph from a dictionary."""
return CrystalGraph(**dic)
return cls(**dic)

def __repr__(self) -> str:
"""String representation of the graph."""
Expand Down
7 changes: 3 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
if TYPE_CHECKING:
from ase.io import Trajectory
from ase.optimize.optimize import Optimizer
from typing_extensions import Self

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet
Expand Down Expand Up @@ -94,11 +95,9 @@ def __init__(
print(f"CHGNet will run on {self.device}")

@classmethod
def from_file(
cls, path: str, use_device: str | None = None, **kwargs
) -> CHGNetCalculator:
def from_file(cls, path: str, use_device: str | None = None, **kwargs) -> Self:
"""Load a user's CHGNet model and initialize the Calculator."""
return CHGNetCalculator(
return cls(
model=CHGNet.from_file(path),
use_device=use_device,
**kwargs,
Expand Down
14 changes: 8 additions & 6 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from chgnet.utils import determine_device

if TYPE_CHECKING:
from typing_extensions import Self

from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -661,17 +663,17 @@ def todict(self) -> dict:
return {"model_name": type(self).__name__, "model_args": self.model_args}

@classmethod
def from_dict(cls, dct: dict, **kwargs) -> CHGNet:
def from_dict(cls, dct: dict, **kwargs) -> Self:
"""Build a CHGNet from a saved dictionary."""
chgnet = CHGNet(**dct["model_args"], **kwargs)
chgnet = cls(**dct["model_args"], **kwargs)
chgnet.load_state_dict(dct["state_dict"])
return chgnet

@classmethod
def from_file(cls, path: str, **kwargs) -> CHGNet:
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
return CHGNet.from_dict(state["model"], **kwargs)
return cls.from_dict(state["model"], **kwargs)

@classmethod
def load(
Expand All @@ -681,7 +683,7 @@ def load(
use_device: str | None = None,
check_cuda_mem: bool = False,
verbose: bool = True,
) -> CHGNet:
) -> Self:
"""Load pretrained CHGNet model.

Args:
Expand Down Expand Up @@ -777,7 +779,7 @@ def from_graphs(
angle_basis_expansion: nn.Module,
*,
compute_stress: bool = False,
) -> BatchedGraph:
) -> Self:
"""Featurize and assemble a list of graphs.

Args:
Expand Down
7 changes: 5 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from torch.utils.data import DataLoader
from typing_extensions import Self

from chgnet import TrainTask

Expand Down Expand Up @@ -645,14 +646,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None:
)

@classmethod
def load(cls, path: str) -> Trainer:
def load(cls, path: str) -> Self:
"""Load trainer state_dict."""
state = torch.load(path, map_location=torch.device("cpu"))
model = CHGNet.from_dict(state["model"])
print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}")
# drop model from trainer_args if present
state["trainer_args"].pop("model", None)
trainer = Trainer(model=model, **state["trainer_args"])
trainer = cls(model=model, **state["trainer_args"])
trainer.model.to(trainer.device)
trainer.optimizer.load_state_dict(state["optimizer"])
trainer.scheduler.load_state_dict(state["scheduler"])
Expand Down Expand Up @@ -791,6 +792,8 @@ def forward(
out["s_MAE_size"] = stress_target.shape[0]

# Mag
print(f"{list(prediction)=}")
print(f"{list(targets)=}")
if "m" in self.target_str:
mag_preds, mag_targets = [], []
m_mae_size = 0
Expand Down
9 changes: 5 additions & 4 deletions examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def make_graphs(
"""Make graphs from a StructureJsonData dataset.

Args:
data (StructureJsonData): a StructureJsonData
graph_dir (str): a directory to save the graphs
train_ratio (float): train ratio
val_ratio (float): val ratio
data (StructureJsonData | StructureData): Input structures to convert to graphs.
graph_dir (str): a directory to save the graphs and labels.
train_ratio (float): train ratio. Default = 0.8
val_ratio (float): val ratio. Default = 0.1. The test ratio is
1 - train_ratio - val_ratio
"""
os.makedirs(graph_dir, exist_ok=True)
random.shuffle(data.keys)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"torch>=1.11.0",
"typing-extensions>=4.12",
]
classifiers = [
"Intended Audience :: Science/Research",
Expand Down
Loading