From 0ba82c58f25b72141d1ecfa66640ecda2f3427c7 Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 23 Aug 2024 18:08:23 -0400 Subject: [PATCH 1/4] ENH: add dispersion tutorial --- examples/dispersion.ipynb | 276 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 examples/dispersion.ipynb diff --git a/examples/dispersion.ipynb b/examples/dispersion.ipynb new file mode 100644 index 00000000..8c51479a --- /dev/null +++ b/examples/dispersion.ipynb @@ -0,0 +1,276 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Adding dispersion to the CHGNet pre-trained model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook describes the process of adding a dispersion correction to the CHGNet pre-trained model. CHGNet is trained on PBE (GGA) DFT calculations; as such, it does not include a correction for van der Waals or dispersive forces. This kind of correction may be particularly useful for those studying porous materials, such as MOFs or zeolites, but who do not wish to fine-tune the pre-trained model on data that include a dispersion correction.\n", + "\n", + "This notebook uses both the [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master) and [DFT-D4](https://dftd4.readthedocs.io/en/latest/reference/ase.html) repositories to add dispersion to CHGNet. The torch-dftd repository currently has DFT-D2 and DFT-D3 implementations and does not have the most recent DFT-D4 version, but is GPU-accelerated where DFT-D4 is not. The Grimme group has released a version of [DFT-D4 implemented in PyTorch](https://github.com/dftd4/tad-dftd4); however, this version does not have an ASE-compatible calculator available.\n", + "\n", + "You will need to install CHGNet, [ASE](https://wiki.fysik.dtu.dk/ase/install.html), [torch-dftd](https://github.com/pfnet-research/torch-dftd/tree/master?tab=readme-ov-file#install), and [DFT-D4](https://dftd4.readthedocs.io/en/latest/recipe/installation.html) to run this notebook (links are to their installation instructions)." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from ase.build import fcc111\n", + "from ase.calculators.mixing import SumCalculator\n", + "from dftd4.ase import DFTD4\n", + "from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n", + "\n", + "from chgnet.model.dynamics import CHGNetCalculator\n", + "from chgnet.model.model import CHGNet" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CHGNet v0.3.0 initialized with 412,525 parameters\n", + "CHGNet will run on cpu\n", + "CHGNet will run on cpu\n" + ] + } + ], + "source": [ + "# pre-trained chgnet model\n", + "chgpotential = CHGNet.load()\n", + "chgcalc = CHGNetCalculator(model=chgpotential)\n", + "\n", + "torchd3calc = TorchDFTD3Calculator() # uses PBE parameters by default\n", + "\n", + "d4calc = DFTD4(method=\"PBE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A simple example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example shows how to initialize an Atoms object (of a Cu(111) surface) and compute its energy with and without the dispersion correction." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Disp calculator: sumcalculator\n", + "Cu4\n", + "E without dispersion: -12.876540184020996\n", + "E with DFT-D3 dispersion: -13.272150007989707\n", + "E with DFT-D4 dispersion: -13.573149493981\n" + ] + } + ], + "source": [ + "# Create a 2x2x1 fcc(111) Cu slab\n", + "atoms = fcc111(\"Cu\", (2, 2, 1), vacuum=10.0)\n", + "atoms.set_pbc([True, True, True])\n", + "\n", + "atoms_disp = atoms.copy()\n", + "atoms_d4 = atoms.copy()\n", + "\n", + "atoms.calc = chgcalc\n", + "\n", + "chgd3 = SumCalculator([chgcalc, torchd3calc])\n", + "chgd4 = SumCalculator([chgcalc, d4calc])\n", + "atoms_disp.calc = chgd3\n", + "atoms_d4.calc = chgd4\n", + "\n", + "e_chg = atoms.get_potential_energy()\n", + "e_disp = atoms_disp.get_potential_energy()\n", + "e_d4 = atoms_d4.get_potential_energy()\n", + "\n", + "print(f\"Disp calculator: {chgd3.name}\")\n", + "print(atoms.get_chemical_formula())\n", + "print(f\"E without dispersion: {e_chg}\")\n", + "print(f\"E with DFT-D3 dispersion: {e_disp}\")\n", + "print(f\"E with DFT-D4 dispersion: {e_d4}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimization example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Below is a simple example of an optimization of a Cu cell with a displaced atom and perturbed unit cell." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "from ase.build import bulk\n", + "from ase.filters import FrechetCellFilter\n", + "from ase.optimize import BFGS\n", + "\n", + "atoms = bulk(\"Cu\", cubic=True)\n", + "\n", + "atoms[0].x += 0.1\n", + "atoms.cell[0] += 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "atoms.cell" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Step Time Energy fmax\n", + "BFGS: 0 17:55:31 -18.448409 0.718576\n", + "BFGS: 1 17:55:33 -18.479250 0.649700\n", + "BFGS: 2 17:55:34 -18.602240 1352.772997\n", + "BFGS: 3 17:55:35 -18.495526 0.666888\n", + "BFGS: 4 17:55:36 -18.505485 0.674898\n", + "BFGS: 5 17:55:37 -18.524050 0.707074\n", + "BFGS: 6 17:55:39 -18.524685 0.708438\n", + "BFGS: 7 17:55:40 -18.527744 0.722830\n", + "BFGS: 8 17:55:41 -18.529436 0.740565\n", + "BFGS: 9 17:55:42 -18.530648 0.755404\n", + "BFGS: 10 17:55:43 -18.530939 0.757174\n", + "BFGS: 11 17:55:44 -18.531477 0.756699\n", + "BFGS: 12 17:55:46 -18.532681 0.750869\n", + "BFGS: 13 17:55:47 -18.534685 0.741311\n", + "BFGS: 14 17:55:48 -18.537005 0.728779\n", + "BFGS: 15 17:55:49 -18.539452 0.706232\n", + "BFGS: 16 17:55:50 -18.540848 0.684654\n", + "BFGS: 17 17:55:52 -18.541594 0.680409\n", + "BFGS: 18 17:55:53 -18.544656 0.659496\n", + "BFGS: 19 17:55:54 -18.548455 0.625415\n", + "BFGS: 20 17:55:55 -18.554630 0.578256\n", + "BFGS: 21 17:55:56 -18.562180 0.520110\n", + "BFGS: 22 17:55:58 -18.568905 0.460355\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "atoms.calc = chgd4\n", + "ecf = FrechetCellFilter(atoms)\n", + "opt = BFGS(ecf, trajectory=\"Cu.traj\")\n", + "\n", + "opt.run(fmax=0.5, steps=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output of this optimization can be viewed by running\n", + "\n", + "```bash\n", + "ase gui Cu.traj\n", + "```\n", + "\n", + "in the command line in this folder." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "htvs", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 2bb6b6cba7419229ca6443d1ab9ff85aa22364e0 Mon Sep 17 00:00:00 2001 From: alex Date: Fri, 23 Aug 2024 18:19:11 -0400 Subject: [PATCH 2/4] STY: format dispersion notebook --- examples/dispersion.ipynb | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/dispersion.ipynb b/examples/dispersion.ipynb index 8c51479a..ccd01fbd 100644 --- a/examples/dispersion.ipynb +++ b/examples/dispersion.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -150,7 +150,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -166,7 +166,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -175,7 +175,7 @@ "Cell([[3.71, 0.1, 0.1], [0.0, 3.61, 0.0], [0.0, 0.0, 3.61]])" ] }, - "execution_count": 42, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -186,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -225,7 +225,7 @@ "True" ] }, - "execution_count": 43, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -268,7 +268,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.5" } }, "nbformat": 4, From a4425ee1c62bceba7e7c35b4ed03a78d9d1c4575 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 24 Aug 2024 06:56:04 +0200 Subject: [PATCH 3/4] add extra deps set dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"] also bump ruff --- .pre-commit-config.yaml | 4 ++-- pyproject.toml | 2 +- tests/conftest.py | 2 +- tests/test_converter.py | 2 +- tests/test_dataset.py | 2 +- tests/test_graph.py | 4 ++-- tests/test_trainer.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e53470a0..fb6fd17d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.5 + rev: v0.6.2 hooks: - id: ruff args: [--fix] @@ -48,7 +48,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.8.0 + rev: v9.9.0 hooks: - id: eslint types: [file] diff --git a/pyproject.toml b/pyproject.toml index 52bcac62..cec86fd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ test = ["pytest-cov>=4", "pytest>=8"] examples = ["crystal-toolkit>=2023.11.3", "pandas>=2.2"] docs = ["lazydocs>=0.4"] logging = ["wandb>=0.17"] +dispersion = ["dftd4>=3.6", "torch-dftd>=0.4"] [project.urls] Source = "https://github.com/CederGroupHub/chgnet" @@ -52,7 +53,6 @@ build-backend = "setuptools.build_meta" [tool.ruff] target-version = "py39" -extend-include = ["*.ipynb"] [tool.ruff.lint] select = ["ALL"] diff --git a/tests/conftest.py b/tests/conftest.py index 524d8d5a..ae9ccecb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,6 @@ from chgnet import ROOT -@pytest.fixture() +@pytest.fixture def li_mn_o2() -> Structure: return Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif") diff --git a/tests/test_converter.py b/tests/test_converter.py index 10a86381..46600ac1 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -17,7 +17,7 @@ NaCl = Structure(lattice, species, coords) -@pytest.fixture() +@pytest.fixture def _set_make_graph() -> Generator[None, None, None]: # fixture to force make_graph to be None and then restore it after test from chgnet.graph import converter diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b6e9d275..475aba15 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -16,7 +16,7 @@ NaCl = Structure(lattice, species, coords) -@pytest.fixture() +@pytest.fixture def structure_data() -> StructureData: """Create a graph with 3 nodes and 3 directed edges.""" random.seed(42) diff --git a/tests/test_graph.py b/tests/test_graph.py index da4b3c25..c52f8836 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -6,7 +6,7 @@ from chgnet.graph.graph import DirectedEdge, Graph, Node, UndirectedEdge -@pytest.fixture() +@pytest.fixture def graph() -> Graph: """Create a graph with 3 nodes and 3 directed edges.""" nodes = [Node(index=idx) for idx in range(3)] @@ -50,7 +50,7 @@ def test_as_dict(graph: Graph) -> None: assert len(graph_dict["undirected_edges_list"]) == 3 -@pytest.fixture() +@pytest.fixture def bigraph() -> Graph: """Create a bi-directional graph with 3 nodes and 4 bi-directed edges.""" nodes = [Node(index=idx) for idx in range(3)] diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 6fdb7310..71fe1f8b 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -114,7 +114,7 @@ def test_trainer_composition_model(tmp_path: Path) -> None: assert torch.all(comparison == expect) -@pytest.fixture() +@pytest.fixture def mock_wandb(): with patch("chgnet.trainer.trainer.wandb") as mock: yield mock From d8f6d8688d6bbac13a543fb900d889dc0a770b09 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 24 Aug 2024 06:56:32 +0200 Subject: [PATCH 4/4] tweak dispersion.ipynb var names --- examples/dispersion.ipynb | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/examples/dispersion.ipynb b/examples/dispersion.ipynb index ccd01fbd..e7a76ad6 100644 --- a/examples/dispersion.ipynb +++ b/examples/dispersion.ipynb @@ -48,8 +48,7 @@ "from dftd4.ase import DFTD4\n", "from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator\n", "\n", - "from chgnet.model.dynamics import CHGNetCalculator\n", - "from chgnet.model.model import CHGNet" + "from chgnet.model.dynamics import CHGNetCalculator" ] }, { @@ -69,12 +68,11 @@ ], "source": [ "# pre-trained chgnet model\n", - "chgpotential = CHGNet.load()\n", - "chgcalc = CHGNetCalculator(model=chgpotential)\n", + "chgnet_calc = CHGNetCalculator()\n", "\n", - "torchd3calc = TorchDFTD3Calculator() # uses PBE parameters by default\n", + "d3_calc = TorchDFTD3Calculator() # uses PBE parameters by default\n", "\n", - "d4calc = DFTD4(method=\"PBE\")" + "d4_calc = DFTD4(method=\"PBE\")" ] }, { @@ -116,18 +114,18 @@ "atoms_disp = atoms.copy()\n", "atoms_d4 = atoms.copy()\n", "\n", - "atoms.calc = chgcalc\n", + "atoms.calc = chgnet_calc\n", "\n", - "chgd3 = SumCalculator([chgcalc, torchd3calc])\n", - "chgd4 = SumCalculator([chgcalc, d4calc])\n", - "atoms_disp.calc = chgd3\n", - "atoms_d4.calc = chgd4\n", + "chgnet_d3 = SumCalculator([chgnet_calc, d3_calc])\n", + "chgnet_d4 = SumCalculator([chgnet_calc, d4_calc])\n", + "atoms_disp.calc = chgnet_d3\n", + "atoms_d4.calc = chgnet_d4\n", "\n", "e_chg = atoms.get_potential_energy()\n", "e_disp = atoms_disp.get_potential_energy()\n", "e_d4 = atoms_d4.get_potential_energy()\n", "\n", - "print(f\"Disp calculator: {chgd3.name}\")\n", + "print(f\"Disp calculator: {chgnet_d3.name}\")\n", "print(atoms.get_chemical_formula())\n", "print(f\"E without dispersion: {e_chg}\")\n", "print(f\"E with DFT-D3 dispersion: {e_disp}\")\n", @@ -231,9 +229,9 @@ } ], "source": [ - "atoms.calc = chgd4\n", - "ecf = FrechetCellFilter(atoms)\n", - "opt = BFGS(ecf, trajectory=\"Cu.traj\")\n", + "atoms.calc = chgnet_d4\n", + "cell_filter = FrechetCellFilter(atoms)\n", + "opt = BFGS(cell_filter, trajectory=\"Cu.traj\")\n", "\n", "opt.run(fmax=0.5, steps=100)" ]