diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 1d8ba35b..087dd156 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -48,7 +48,8 @@ def __init__( magmoms (list[list[float]], optional): [data_size, n_atoms, 1] structure_ids (list[str], optional): a list of ids to track the structures graph_converter (CrystalGraphConverter, optional): Converts the structures - to graphs. If None, it will be set to CHGNet default converter. + to graphs. If None, it will be set to CHGNet 0.3.0 converter + with AtomGraph cutoff = 6A. Raises: RuntimeError: if the length of structures and labels (energies, forces, @@ -74,7 +75,7 @@ def __init__( random.shuffle(self.keys) print(f"{len(structures)} structures imported") self.graph_converter = graph_converter or CrystalGraphConverter( - atom_graph_cutoff=5, bond_graph_cutoff=3 + atom_graph_cutoff=6, bond_graph_cutoff=3 ) self.failed_idx: list[int] = [] self.failed_graph_id: dict[str, str] = {} @@ -157,9 +158,9 @@ def __init__( labels (str, dict): the path or dictionary of labels targets ("ef" | "efs" | "efm" | "efsm"): The training targets. Default = "efsm" - graph_converter (CrystalGraphConverter, optional): - a CrystalGraphConverter to convert the structures, - if None, it will be set to CHGNet default converter + graph_converter (CrystalGraphConverter, optional): Converts the structures + to graphs. If None, it will be set to CHGNet 0.3.0 converter + with AtomGraph cutoff = 6A. energy_key (str, optional): the key of energy in the labels. Default = "energy_per_atom". force_key (str, optional): the key of force in the labels. @@ -175,7 +176,7 @@ def __init__( random.shuffle(self.cif_ids) print(f"{cif_path}: {len(self.cif_ids):,} structures imported") self.graph_converter = graph_converter or CrystalGraphConverter( - atom_graph_cutoff=5, bond_graph_cutoff=3 + atom_graph_cutoff=6, bond_graph_cutoff=3 ) self.energy_key = energy_key diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 6b0bd0e1..0850807d 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -650,15 +650,26 @@ def from_file(cls, path, **kwargs): return CHGNet.from_dict(state["model"], **kwargs) @classmethod - def load(cls, model_name="MPtrj-efsm"): + def load(cls, model_name="0.3.0"): """Load pretrained CHGNet.""" current_dir = os.path.dirname(os.path.abspath(__file__)) - if model_name == "MPtrj-efsm": + if model_name == "0.3.0": return cls.from_file( - os.path.join(current_dir, "../pretrained/e30f77s348m32.pth.tar"), + os.path.join( + current_dir, + "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", + ) + ) + elif model_name == "0.2.0": # noqa: RET505 + return cls.from_file( + os.path.join( + current_dir, + "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", + ), mlp_out_bias=True, ) - raise ValueError(f"Unknown {model_name=}") + else: + raise ValueError(f"Unknown {model_name=}") @dataclass diff --git a/chgnet/pretrained/0.2.0/README.md b/chgnet/pretrained/0.2.0/README.md new file mode 100755 index 00000000..87ba34c0 --- /dev/null +++ b/chgnet/pretrained/0.2.0/README.md @@ -0,0 +1,74 @@ +## Model 0.2.0 + +This is the pretrained weights published with CHGNet Nature Machine Intelligence paper. +All the experiments and results shown in the paper were performed with this version of weights. + +Date: 2/24/2023 + +Author: Bowen Deng + +## Model Parameters + +```python +model = CHGNet( + atom_fea_dim=64, + bond_fea_dim=64, + angle_fea_dim=64, + composition_model="MPtrj", + num_radial=9, + num_angular=9, + n_conv=4, + atom_conv_hidden_dim=64, + update_bond=True, + bond_conv_hidden_dim=64, + update_angle=True, + angle_layer_hidden_dim=0, + conv_dropout=0, + read_out="ave", + mlp_hidden_dims=[64, 64], + mlp_first=True, + is_intensive=True, + non_linearity="silu", + atom_graph_cutoff=5, + bond_graph_cutoff=3, + graph_converter_algorithm="fast", + cutoff_coeff=5, + learnable_rbf=True, + mlp_out_bias=True, +) +``` + +## Dataset Used + +MPtrj dataset with 8-1-1 train-val-test splitting + +## Trainer + +```python +trainer = Trainer( + model=model, + targets='efsm', + energy_loss_ratio=1, + force_loss_ratio=1, + stress_loss_ratio=0.1, + mag_loss_ratio=0.1, + optimizer='Adam', + weight_decay=0, + scheduler='CosLR', + criterion='Huber', + delta=0.1, + epochs=20, + starting_epoch=0, + learning_rate=1e-3, + use_device='cuda', + print_freq=1000 +) +``` + +## Mean Absolute Error (MAE) logs + +| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | +| ---------- | ----------------- | ------------- | ------------ | ------------ | +| Train | 22 | 59 | 0.246 | 0.030 | +| Validation | 20 | 75 | 0.350 | 0.033 | +| Test | 30 | 77 | 0.348 | 0.032 | diff --git a/chgnet/pretrained/e30f77s348m32.pth.tar b/chgnet/pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar similarity index 100% rename from chgnet/pretrained/e30f77s348m32.pth.tar rename to chgnet/pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar diff --git a/chgnet/pretrained/0.3.0/README.md b/chgnet/pretrained/0.3.0/README.md new file mode 100755 index 00000000..50223bbc --- /dev/null +++ b/chgnet/pretrained/0.3.0/README.md @@ -0,0 +1,80 @@ +## Model 0.3.0 + +Major changes: + +1. Increased AtomGraph cutoff to 6A +2. Resolved discontinuity issue when no BondGraph presents +3. Added some normalization layers +4. Slight improvements on energy, force, stress accuracies + +Date: 10/22/2023 + +Author: Bowen Deng + +## Model Parameters + +```python +model = CHGNet( + atom_fea_dim=64, + bond_fea_dim=64, + angle_fea_dim=64, + composition_model="MPtrj", + num_radial=31, + num_angular=31, + n_conv=4, + atom_conv_hidden_dim=64, + update_bond=True, + bond_conv_hidden_dim=64, + update_angle=True, + angle_layer_hidden_dim=0, + conv_dropout=0, + read_out="ave", + gMLP_norm='layer', + readout_norm='layer', + mlp_hidden_dims=[64, 64, 64], + mlp_first=True, + is_intensive=True, + non_linearity="silu", + atom_graph_cutoff=6, + bond_graph_cutoff=3, + graph_converter_algorithm="fast", + cutoff_coeff=8, + learnable_rbf=True, +) +``` + +## Dataset Used + +MPtrj dataset with 9-0.5-0.5 train-val-test splitting + +## Trainer + +```python +trainer = Trainer( + model=model, + targets='efsm', + energy_loss_ratio=1, + force_loss_ratio=1, + stress_loss_ratio=0.1, + mag_loss_ratio=0.1, + optimizer='Adam', + weight_decay=0, + scheduler='CosLR', + scheduler_params={'decay_fraction': 0.5e-2}, + criterion='Huber', + delta=0.1, + epochs=30, + starting_epoch=0, + learning_rate=5e-3, + use_device='cuda', + print_freq=1000 +) +``` + +## Mean Absolute Error (MAE) logs + +| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) | +| ---------- | ----------------- | ------------- | ------------ | ------------ | +| Train | 26 | 60 | 0.266 | 0.037 | +| Validation | 29 | 70 | 0.308 | 0.037 | +| Test | 29 | 68 | 0.314 | 0.037 | diff --git a/chgnet/pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar b/chgnet/pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar new file mode 100644 index 00000000..20ce3228 Binary files /dev/null and b/chgnet/pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar differ diff --git a/tests/test_md.py b/tests/test_md.py index 1235a5b8..044591fe 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -35,14 +35,15 @@ def test_eos(): eos = EquationOfState() eos.fit(atoms=structure) - assert eos.get_bulk_modulus() == approx(0.66012829210838, rel=1e-4) - assert eos.get_bulk_modulus(unit="GPa") == approx(105.76421250583728, rel=1e-4) - assert eos.get_compressibility() == approx(1.51485, rel=1e-4) - assert eos.get_compressibility(unit="GPa^-1") == approx(0.0094549940505, rel=1e-4) + + assert eos.get_bulk_modulus() == approx(0.6501, rel=1e-4) + assert eos.get_bulk_modulus(unit="GPa") == approx(104.16, rel=1e-4) + assert eos.get_compressibility() == approx(1.5381, rel=1e-4) + assert eos.get_compressibility(unit="GPa^-1") == approx(0.00960, rel=1e-4) @pytest.mark.parametrize("algorithm", ["legacy", "fast"]) -def test_md_nvt( +def test_md_nvt_berendsen( tmp_path: Path, monkeypatch: MonkeyPatch, algorithm: Literal["legacy", "fast"] ): monkeypatch.chdir(tmp_path) # run MD in temporary directory @@ -77,17 +78,17 @@ def test_md_nvt( logs = log_file.read() logs = np.fromstring(logs, sep=" ") ref = np.fromstring( - "0.0000 -58.9727 -58.9727 0.0000 0.0\n" - "0.0200 -58.9723 -58.9731 0.0009 0.8\n" - "0.0400 -58.9672 -58.9727 0.0055 5.4\n" - "0.0600 -58.9427 -58.9663 0.0235 22.8\n" - "0.0800 -58.8605 -58.9352 0.0747 72.2\n" - "0.1000 -58.7651 -58.8438 0.0786 76.0\n" - "0.1200 -58.6684 -58.7268 0.0584 56.4\n" - "0.1400 -58.5703 -58.6202 0.0499 48.2\n" - "0.1600 -58.4724 -58.5531 0.0807 78.1\n" - "0.1800 -58.3891 -58.8077 0.4186 404.8\n" - "0.2000 -58.3398 -58.9244 0.5846 565.4\n", + "0.0000 -58.8678 -58.8678 0.0000 0.0\n" + "0.0200 -58.8665 -58.8692 0.0027 2.6\n" + "0.0400 -58.8650 -58.8846 0.0196 18.9\n" + "0.0600 -58.7870 -58.8671 0.0801 77.5\n" + "0.0800 -58.7024 -58.8023 0.0999 96.7\n" + "0.1000 -58.6080 -58.6803 0.0723 69.9\n" + "0.1200 -58.5487 -58.5849 0.0362 35.0\n" + "0.1400 -58.4648 -58.5285 0.0637 61.6\n" + "0.1600 -58.3202 -58.5693 0.2491 240.9\n" + "0.1800 -58.2515 -58.7861 0.5346 517.0\n" + "0.2000 -58.2441 -58.8199 0.5758 556.8\n", sep=" ", ) assert_allclose(logs, ref, rtol=2.1e-3, atol=1e-8) @@ -113,10 +114,11 @@ def test_md_nve(tmp_path: Path, monkeypatch: MonkeyPatch): assert set(os.listdir()) == {"md_out.log", "md_out.traj"} with open("md_out.log") as log_file: logs = log_file.read() + print("nve logs", logs) assert logs == ( "Time[ps] Etot[eV] Epot[eV] Ekin[eV] T[K]\n" - "0.0000 -58.9727 -58.9727 0.0000 0.0\n" - "0.0100 -58.9727 -58.9728 0.0001 0.1\n" + "0.0000 -58.9415 -58.9415 0.0000 0.0\n" + "0.0100 -58.9415 -58.9417 0.0002 0.2\n" ) @@ -139,7 +141,7 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch assert isinstance(md.atoms, Atoms) assert isinstance(md.atoms.calc, CHGNetCalculator) assert isinstance(md.dyn, Inhomogeneous_NPTBerendsen) - assert md.bulk_modulus == approx(105.764, rel=1e-2) + assert md.bulk_modulus == approx(104.16, rel=1e-2) assert md.dyn.pressure == approx(6.324e-07, rel=1e-4) assert set(os.listdir()) == {"md_out.log", "md_out.traj"} with open("md_out.log") as log_file: @@ -147,17 +149,17 @@ def test_md_npt_inhomogeneous_berendsen(tmp_path: Path, monkeypatch: MonkeyPatch logs = log_file.read() logs = np.fromstring(logs, sep=" ") ref = np.fromstring( - "0.0000 -58.9727 -58.9727 0.0000 0.0\n" - "0.0200 -58.9723 -58.9731 0.0009 0.8\n" - "0.0400 -58.9672 -58.9727 0.0055 5.3\n" - "0.0600 -58.9427 -58.9663 0.0235 22.7\n" - "0.0800 -58.8605 -58.9352 0.0747 72.2\n" - "0.1000 -58.7652 -58.8438 0.0786 76.0\n" - "0.1200 -58.6686 -58.7269 0.0584 56.4\n" - "0.1400 -58.5707 -58.6205 0.0499 48.2\n" - "0.1600 -58.4731 -58.5533 0.0802 77.6\n" - "0.1800 -58.3897 -58.8064 0.4167 402.9\n" - "0.2000 -58.3404 -58.9253 0.5849 565.6\n", + "0.0000 -58.9415 -58.9415 0.0000 0.0\n" + "0.0200 -58.9407 -58.9423 0.0016 1.6\n" + "0.0400 -58.9310 -58.9415 0.0105 10.1\n" + "0.0600 -58.8819 -58.9315 0.0495 47.9\n" + "0.0800 -58.7860 -58.8800 0.0940 90.9\n" + "0.1000 -58.6916 -58.7694 0.0778 75.2\n" + "0.1200 -58.5945 -58.6458 0.0513 49.6\n" + "0.1400 -58.4972 -58.5543 0.0571 55.2\n" + "0.1600 -58.4008 -58.5540 0.1532 148.1\n" + "0.1800 -58.3292 -58.8330 0.5038 487.2\n" + "0.2000 -58.2842 -58.8526 0.5684 549.7\n", sep=" ", ) assert_allclose(logs, ref, rtol=2.1e-3, atol=1e-8) @@ -200,16 +202,16 @@ def test_md_nvt_nose_hoover(tmp_path: Path, monkeypatch: MonkeyPatch): logs = log_file.read() logs = np.fromstring(logs, sep=" ") ref = np.fromstring( - "0.0200 -199.2479 -199.3994 0.1515 36.6\n" - "0.0400 -199.2459 -199.3440 0.0981 23.7\n" - "0.0600 -199.2394 -199.2669 0.0275 6.6\n" - "0.0800 -199.2348 -199.4143 0.1795 43.4\n" - "0.1000 -199.2274 -199.2774 0.0500 12.1\n" - "0.1200 -199.2123 -199.3001 0.0878 21.2\n" - "0.1400 -199.2040 -199.4000 0.1961 47.4\n" - "0.1600 -199.1856 -199.2181 0.0325 7.9\n" - "0.1800 -199.1603 -199.3266 0.1662 40.2\n" - "0.2000 -199.1455 -199.3490 0.2035 49.2\n", + "0.0200 -199.3047 -199.3890 0.0844 20.4\n" + "0.0400 -199.3036 -199.3510 0.0475 11.5\n" + "0.0600 -199.2999 -199.3219 0.0221 5.3\n" + "0.0800 -199.2974 -199.4012 0.1038 25.1\n" + "0.1000 -199.2927 -199.3097 0.0170 4.1\n" + "0.1200 -199.2847 -199.3522 0.0675 16.3\n" + "0.1400 -199.2802 -199.3789 0.0988 23.9\n" + "0.1600 -199.2681 -199.2785 0.0104 2.5\n" + "0.1800 -199.2565 -199.3830 0.1265 30.6\n" + "0.2000 -199.2463 -199.3190 0.0727 17.6\n", sep=" ", ) assert_allclose(logs, ref, rtol=1e-2, atol=1e-7) @@ -241,7 +243,7 @@ def test_md_npt_nose_hoover(tmp_path: Path, monkeypatch: MonkeyPatch): assert isinstance(md.atoms, Atoms) assert isinstance(md.atoms.calc, CHGNetCalculator) assert isinstance(md.dyn, NPT) - assert md.bulk_modulus == approx(102.977, rel=1e-2) + assert md.bulk_modulus == approx(88.6389, rel=1e-2) assert_allclose( md.dyn.externalstress, [-6.324e-07, -6.324e-07, -6.324e-07, 0.0, 0.0, 0.0], @@ -254,18 +256,19 @@ def test_md_npt_nose_hoover(tmp_path: Path, monkeypatch: MonkeyPatch): logs = log_file.read() logs = np.fromstring(logs, sep=" ") ref = np.fromstring( - "0.0200 -199.2480 -199.3994 0.1514 36.6\n" - "0.0400 -199.2460 -199.3442 0.0982 23.7\n" - "0.0600 -199.2397 -199.2672 0.0275 6.7\n" - "0.0800 -199.2355 -199.4148 0.1793 43.3\n" - "0.1000 -199.2282 -199.2782 0.0500 12.1\n" - "0.1200 -199.2135 -199.3017 0.0882 21.3\n" - "0.1400 -199.2060 -199.4014 0.1954 47.2\n" - "0.1600 -199.1878 -199.2201 0.0323 7.8\n" - "0.1800 -199.1630 -199.3306 0.1675 40.5\n" - "0.2000 -199.1496 -199.3506 0.2010 48.6\n", + "0.0200 -199.3048 -199.3891 0.0843 20.4\n" + "0.0400 -199.3038 -199.3513 0.0475 11.5\n" + "0.0600 -199.3005 -199.3226 0.0221 5.4\n" + "0.0800 -199.2988 -199.4024 0.1036 25.0\n" + "0.1000 -199.2945 -199.3115 0.0170 4.1\n" + "0.1200 -199.2872 -199.3550 0.0679 16.4\n" + "0.1400 -199.2841 -199.3822 0.0981 23.7\n" + "0.1600 -199.2729 -199.2833 0.0105 2.5\n" + "0.1800 -199.2622 -199.3895 0.1273 30.8\n" + "0.2000 -199.2539 -199.3247 0.0708 17.1\n", sep=" ", ) + assert_allclose(logs, ref, rtol=1e-2, atol=1e-7) @@ -296,7 +299,9 @@ def test_md_crystal_feas_log( assert isinstance(crystal_feas, list) assert len(crystal_feas) == 101 assert len(crystal_feas[0]) == 64 - assert crystal_feas[0][0] == approx(1.4411131, rel=1e-5) - assert crystal_feas[0][1] == approx(2.652704, rel=1e-5) - assert crystal_feas[10][0] == approx(1.4390125, rel=1e-5) - assert crystal_feas[10][1] == approx(2.6525214, rel=1e-5) + print(crystal_feas[0][0], crystal_feas[0][1]) + print(crystal_feas[10][0], crystal_feas[10][1]) + assert crystal_feas[0][0] == approx(-0.002082636, abs=1e-5) + assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5) + assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5) + assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5) diff --git a/tests/test_model.py b/tests/test_model.py index 55ace5b2..127b0aed 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -63,7 +63,6 @@ def test_predict_structure() -> None: return_atom_feas=True, return_crystal_feas=True, ) - assert sorted(out) == [ "atom_fea", "crystal_fea", @@ -73,48 +72,57 @@ def test_predict_structure() -> None: "s", "site_energies", ] - assert out["e"] == pytest.approx(-7.37159, rel=1e-4, abs=1e-4) + assert out["e"] == pytest.approx(-7.36769, rel=1e-4, abs=1e-4) forces = [ - [4.4703484e-08, -4.2840838e-08, 2.4071064e-02], - [-4.4703484e-08, -1.4551915e-08, -2.4071217e-02], - [-1.7881393e-07, 1.0244548e-08, 2.5402933e-02], - [5.9604645e-08, -2.3283064e-08, -2.5402665e-02], - [-1.1920929e-07, 6.6356733e-08, -2.1660209e-02], - [2.3543835e-06, -8.0077443e-06, 9.5508099e-03], - [-2.2947788e-06, 7.9898164e-06, -9.5513463e-03], - [-5.9604645e-08, -0.0000000e00, 2.1660626e-02], + [1.34110451e-07, -2.92202458e-08, 2.38135569e-02], + [5.96046448e-08, 4.63332981e-08, -2.38130391e-02], + [8.94069672e-08, -2.06753612e-07, 9.25870836e-02], + [-1.49011612e-07, -1.06170774e-07, -9.25877392e-02], + [5.96046448e-08, 2.00234354e-08, -2.43449211e-03], + [-1.19209290e-06, -4.74974513e-08, -1.30698681e-02], + [1.40070915e-06, 1.64378434e-07, 1.30702555e-02], + [-5.96046448e-08, 1.66241080e-07, 2.43446976e-03], ] assert out["f"] == pytest.approx(np.array(forces), rel=1e-4, abs=1e-4) stress = [ - [3.3677614e-01, -1.9665707e-07, -5.6416429e-06], - [4.9939729e-07, 2.4675032e-01, 1.8549043e-05], - [-4.0414070e-06, 1.9096897e-05, 4.0323928e-02], + [-3.0366361e-01, -3.7709856e-07, 2.2964025e-06], + [-1.2128221e-06, 2.2305478e-01, -3.2104114e-07], + [1.3322200e-06, -8.3219516e-07, -1.0736181e-01], ] assert out["s"] == pytest.approx(np.array(stress), rel=1e-4, abs=1e-4) - magmom = [0.00521, 0.00521, 3.85728, 3.85729, 0.02538, 0.03706, 0.03706, 0.02538] + magmom = [ + 3.0495524e-03, + 3.0494630e-03, + 3.8694179e00, + 3.8694181e00, + 4.4136152e-02, + 3.8622141e-02, + 3.8622111e-02, + 4.4136211e-02, + ] assert out["m"] == pytest.approx(magmom, rel=1e-4, abs=1e-4) site_energies = [ - -3.8090043, - -3.8090036, - -10.2737875, - -10.2737875, - -7.659066, - -7.744509, - -7.744509, - -7.659066, + -3.6264274, + -3.6264274, + -9.634681, + -9.634682, + -8.024935, + -8.184724, + -8.184724, + -8.024935, ] assert out["site_energies"] == pytest.approx(site_energies, rel=1e-4, abs=1e-4) assert out["site_energies"].shape == (8,) assert np.sum(out["site_energies"]) / len(structure) == pytest.approx( out["e"], rel=1e-4, abs=1e-6 ) - assert out["crystal_fea"].mean() == pytest.approx(0.27905, rel=1e-4, abs=1e-4) + assert out["crystal_fea"].mean() == pytest.approx(0.26999, rel=1e-4, abs=1e-4) assert out["crystal_fea"].shape == (64,) - assert out["atom_fea"].mean() == pytest.approx(0.01606, rel=1e-4, abs=1e-4) + assert out["atom_fea"].mean() == pytest.approx(-0.09668, rel=1e-4, abs=1e-4) assert out["atom_fea"].shape == (8, 64) @@ -202,26 +210,15 @@ def test_predict_batched_structures() -> None: ) -model_arg_keys = frozenset( - "atom_fea_dim bond_fea_dim angle_fea_dim composition_model num_radial num_angular n_conv " - "atom_conv_hidden_dim update_bond bond_conv_hidden_dim update_angle angle_layer_hidden_dim" - " conv_dropout read_out mlp_hidden_dims mlp_dropout mlp_first is_intensive non_linearity " - "atom_graph_cutoff bond_graph_cutoff graph_converter_algorithm cutoff_coeff learnable_rbf " - "skip_connection conv_norm gMLP_norm readout_norm".split() -) - - def test_as_to_from_dict() -> None: dct = model.as_dict() assert {*dct} == {"model_args", "state_dict"} - assert {*dct["model_args"]} >= model_arg_keys model_2 = CHGNet.from_dict(dct) assert model_2.as_dict()["model_args"] == dct["model_args"] to_dict = model.todict() assert {*to_dict} == {"model_name", "model_args"} - assert {*to_dict["model_args"]} >= model_arg_keys model_3 = CHGNet(**to_dict["model_args"]) assert model_3.todict() == to_dict diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index adfbe35d..a2a53b9b 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -19,7 +19,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): chgnet = CHGNet.load() converter = CrystalGraphConverter( - atom_graph_cutoff=5, bond_graph_cutoff=3, algorithm=algorithm + atom_graph_cutoff=6, bond_graph_cutoff=3, algorithm=algorithm ) assert converter.algorithm == algorithm @@ -37,7 +37,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): # make sure final structure is more relaxed than initial one assert traj.energies[0] > traj.energies[-1] - assert traj.energies[-1] == approx(-58.972927) + assert traj.energies[-1] == approx(-58.94209, rel=1e-4) no_cuda = mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")