Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispersion #192

Merged
merged 4 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.5
rev: v0.6.2
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.8.0
rev: v9.9.0
hooks:
- id: eslint
types: [file]
Expand Down
274 changes: 274 additions & 0 deletions examples/dispersion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Adding dispersion to the CHGNet pre-trained model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook describes the process of adding a dispersion correction to the CHGNet pre-trained model. CHGNet is trained on PBE (GGA) DFT calculations; as such, it does not include a correction for van der Waals or dispersive forces. This kind of correction may be particularly useful for those studying porous materials, such as MOFs or zeolites, but who do not wish to fine-tune the pre-trained model on data that include a dispersion correction.\n",
"\n",
"This notebook uses both the [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master) and [DFT-D4](https://dftd4.readthedocs.io/en/latest/reference/ase.html) repositories to add dispersion to CHGNet. The torch-dftd repository currently has DFT-D2 and DFT-D3 implementations and does not have the most recent DFT-D4 version, but is GPU-accelerated where DFT-D4 is not. The Grimme group has released a version of [DFT-D4 implemented in PyTorch](https://github.com/dftd4/tad-dftd4); however, this version does not have an ASE-compatible calculator available.\n",
"\n",
"You will need to install CHGNet, [ASE](https://wiki.fysik.dtu.dk/ase/install.html), [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master?tab=readme-ov-file#install), and [DFT-D4](https://dftd4.readthedocs.io/en/latest/recipe/installation.html) to run this notebook (links are to their installation instructions)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ase.build import fcc111\n",
"from ase.calculators.mixing import SumCalculator\n",
"from dftd4.ase import DFTD4\n",
"from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n",
"\n",
"from chgnet.model.dynamics import CHGNetCalculator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CHGNet v0.3.0 initialized with 412,525 parameters\n",
"CHGNet will run on cpu\n",
"CHGNet will run on cpu\n"
]
}
],
"source": [
"# pre-trained chgnet model\n",
"chgnet_calc = CHGNetCalculator()\n",
"\n",
"d3_calc = TorchDFTD3Calculator() # uses PBE parameters by default\n",
"\n",
"d4_calc = DFTD4(method=\"PBE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example shows how to initialize an Atoms object (of a Cu(111) surface) and compute its energy with and without the dispersion correction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Disp calculator: sumcalculator\n",
"Cu4\n",
"E without dispersion: -12.876540184020996\n",
"E with DFT-D3 dispersion: -13.272150007989707\n",
"E with DFT-D4 dispersion: -13.573149493981\n"
]
}
],
"source": [
"# Create a 2x2x1 fcc(111) Cu slab\n",
"atoms = fcc111(\"Cu\", (2, 2, 1), vacuum=10.0)\n",
"atoms.set_pbc([True, True, True])\n",
"\n",
"atoms_disp = atoms.copy()\n",
"atoms_d4 = atoms.copy()\n",
"\n",
"atoms.calc = chgnet_calc\n",
"\n",
"chgnet_d3 = SumCalculator([chgnet_calc, d3_calc])\n",
"chgnet_d4 = SumCalculator([chgnet_calc, d4_calc])\n",
"atoms_disp.calc = chgnet_d3\n",
"atoms_d4.calc = chgnet_d4\n",
"\n",
"e_chg = atoms.get_potential_energy()\n",
"e_disp = atoms_disp.get_potential_energy()\n",
"e_d4 = atoms_d4.get_potential_energy()\n",
"\n",
"print(f\"Disp calculator: {chgnet_d3.name}\")\n",
"print(atoms.get_chemical_formula())\n",
"print(f\"E without dispersion: {e_chg}\")\n",
"print(f\"E with DFT-D3 dispersion: {e_disp}\")\n",
"print(f\"E with DFT-D4 dispersion: {e_d4}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Optimization example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below is a simple example of an optimization of a Cu cell with a displaced atom and perturbed unit cell."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ase.build import bulk\n",
"from ase.filters import FrechetCellFilter\n",
"from ase.optimize import BFGS\n",
"\n",
"atoms = bulk(\"Cu\", cubic=True)\n",
"\n",
"atoms[0].x += 0.1\n",
"atoms.cell[0] += 0.1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atoms.cell"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Step Time Energy fmax\n",
"BFGS: 0 17:55:31 -18.448409 0.718576\n",
"BFGS: 1 17:55:33 -18.479250 0.649700\n",
"BFGS: 2 17:55:34 -18.602240 1352.772997\n",
"BFGS: 3 17:55:35 -18.495526 0.666888\n",
"BFGS: 4 17:55:36 -18.505485 0.674898\n",
"BFGS: 5 17:55:37 -18.524050 0.707074\n",
"BFGS: 6 17:55:39 -18.524685 0.708438\n",
"BFGS: 7 17:55:40 -18.527744 0.722830\n",
"BFGS: 8 17:55:41 -18.529436 0.740565\n",
"BFGS: 9 17:55:42 -18.530648 0.755404\n",
"BFGS: 10 17:55:43 -18.530939 0.757174\n",
"BFGS: 11 17:55:44 -18.531477 0.756699\n",
"BFGS: 12 17:55:46 -18.532681 0.750869\n",
"BFGS: 13 17:55:47 -18.534685 0.741311\n",
"BFGS: 14 17:55:48 -18.537005 0.728779\n",
"BFGS: 15 17:55:49 -18.539452 0.706232\n",
"BFGS: 16 17:55:50 -18.540848 0.684654\n",
"BFGS: 17 17:55:52 -18.541594 0.680409\n",
"BFGS: 18 17:55:53 -18.544656 0.659496\n",
"BFGS: 19 17:55:54 -18.548455 0.625415\n",
"BFGS: 20 17:55:55 -18.554630 0.578256\n",
"BFGS: 21 17:55:56 -18.562180 0.520110\n",
"BFGS: 22 17:55:58 -18.568905 0.460355\n"
]
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"atoms.calc = chgnet_d4\n",
"cell_filter = FrechetCellFilter(atoms)\n",
"opt = BFGS(cell_filter, trajectory=\"Cu.traj\")\n",
"\n",
"opt.run(fmax=0.5, steps=100)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output of this optimization can be viewed by running\n",
"\n",
"```bash\n",
"ase gui Cu.traj\n",
"```\n",
"\n",
"in the command line in this folder."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "htvs",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ test = ["pytest-cov>=4", "pytest>=8"]
examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"]
docs = ["lazydocs>=0.4"]
logging = ["wandb>=0.17"]
dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"]

[project.urls]
Source = "https://github.com/CederGroupHub/chgnet"
Expand All @@ -52,7 +53,6 @@ build-backend = "setuptools.build_meta"

[tool.ruff]
target-version = "py39"
extend-include = ["*.ipynb"]

[tool.ruff.lint]
select = ["ALL"]
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from chgnet import ROOT


@pytest.fixture()
@pytest.fixture
def li_mn_o2() -> Structure:
return Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
2 changes: 1 addition & 1 deletion tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NaCl = Structure(lattice, species, coords)


@pytest.fixture()
@pytest.fixture
def _set_make_graph() -> Generator[None, None, None]:
# fixture to force make_graph to be None and then restore it after test
from chgnet.graph import converter
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
NaCl = Structure(lattice, species, coords)


@pytest.fixture()
@pytest.fixture
def structure_data() -> StructureData:
"""Create a graph with 3 nodes and 3 directed edges."""
random.seed(42)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chgnet.graph.graph import DirectedEdge, Graph, Node, UndirectedEdge


@pytest.fixture()
@pytest.fixture
def graph() -> Graph:
"""Create a graph with 3 nodes and 3 directed edges."""
nodes = [Node(index=idx) for idx in range(3)]
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_as_dict(graph: Graph) -> None:
assert len(graph_dict["undirected_edges_list"]) == 3


@pytest.fixture()
@pytest.fixture
def bigraph() -> Graph:
"""Create a bi-directional graph with 3 nodes and 4 bi-directed edges."""
nodes = [Node(index=idx) for idx in range(3)]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None:
assert torch.all(comparison == expect)


@pytest.fixture()
@pytest.fixture
def mock_wandb():
with patch("chgnet.trainer.trainer.wandb") as mock:
yield mock
Expand Down
Loading