Skip to content

Commit

Permalink
added option to disable dataset shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Apr 22, 2024
1 parent 3543115 commit 72f9108
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
magmoms: list[Sequence[Sequence[float]]] | None = None,
structure_ids: list[str] | None = None,
graph_converter: CrystalGraphConverter | None = None,
shuffle: bool = True,
) -> None:
"""Initialize the dataset.
Expand All @@ -45,11 +46,16 @@ def __init__(
energies (list[float]): [data_size, 1]
forces (list[list[float]]): [data_size, n_atoms, 3]
stresses (list[list[float]], optional): [data_size, 3, 3]
Default = None
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
Default = None
structure_ids (list[str], optional): a list of ids to track the structures
Default = None
graph_converter (CrystalGraphConverter, optional): Converts the structures
to graphs. If None, it will be set to CHGNet 0.3.0 converter
with AtomGraph cutoff = 6A.
shuffle (bool): whether to shuffle the sequence of dataset
Default = True
Raises:
RuntimeError: if the length of structures and labels (energies, forces,
Expand All @@ -72,7 +78,8 @@ def __init__(
self.magmoms = magmoms
self.structure_ids = structure_ids
self.keys = np.arange(len(structures))
random.shuffle(self.keys)
if shuffle:
random.shuffle(self.keys)
print(f"{len(structures)} structures imported")
self.graph_converter = graph_converter or CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3
Expand Down Expand Up @@ -151,6 +158,7 @@ def __init__(
force_key: str = "force",
stress_key: str = "stress",
magmom_key: str = "magmom",
shuffle: bool = True,
) -> None:
"""Initialize the dataset from a directory containing CIFs.
Expand All @@ -163,18 +171,21 @@ def __init__(
to graphs. If None, it will be set to CHGNet 0.3.0 converter
with AtomGraph cutoff = 6A.
energy_key (str, optional): the key of energy in the labels.
Default = "energy_per_atom".
Default = "energy_per_atom"
force_key (str, optional): the key of force in the labels.
Default = "force".
Default = "force"
stress_key (str, optional): the key of stress in the labels.
Default = "stress".
Default = "stress"
magmom_key (str, optional): the key of magmom in the labels.
Default = "magmom".
Default = "magmom"
shuffle (bool): whether to shuffle the sequence of dataset
Default = True
"""
self.data_dir = cif_path
self.data = utils.read_json(os.path.join(cif_path, labels))
self.cif_ids = list(self.data)
random.shuffle(self.cif_ids)
if shuffle:
random.shuffle(self.cif_ids)
print(f"{cif_path}: {len(self.cif_ids):,} structures imported")
self.graph_converter = graph_converter or CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3
Expand Down Expand Up @@ -262,6 +273,7 @@ def __init__(
force_key: str = "force",
stress_key: str = "stress",
magmom_key: str = "magmom",
shuffle: bool = True,
) -> None:
"""Initialize the dataset from a directory containing saved crystal graphs.
Expand All @@ -274,13 +286,15 @@ def __init__(
exclude (str, list | None): the path or list of excluded graphs.
Default = None
energy_key (str, optional): the key of energy in the labels.
Default = "energy_per_atom".
Default = "energy_per_atom"
force_key (str, optional): the key of force in the labels.
Default = "force".
Default = "force"
stress_key (str, optional): the key of stress in the labels.
Default = "stress".
Default = "stress"
magmom_key (str, optional): the key of magmom in the labels.
Default = "magmom".
Default = "magmom"
shuffle (bool): whether to shuffle the sequence of dataset
Default = True
"""
self.graph_path = graph_path
if isinstance(labels, str):
Expand All @@ -300,7 +314,8 @@ def __init__(
self.keys = [
(mp_id, graph_id) for mp_id, dic in self.labels.items() for graph_id in dic
]
random.shuffle(self.keys)
if shuffle:
random.shuffle(self.keys)
print(f"{len(self.labels)} mp_ids, {len(self)} frames imported")
if self.excluded_graph is not None:
print(f"{len(self.excluded_graph)} graphs are pre-excluded")
Expand Down Expand Up @@ -486,6 +501,7 @@ def __init__(
force_key: str = "force",
stress_key: str = "stress",
magmom_key: str = "magmom",
shuffle: bool = True,
) -> None:
"""Initialize the dataset by reading JSON files.
Expand All @@ -496,13 +512,15 @@ def __init__(
targets ("ef" | "efs" | "efm" | "efsm"): The training targets.
Default = "efsm"
energy_key (str, optional): the key of energy in the labels.
Default = "energy_per_atom".
Default = "energy_per_atom"
force_key (str, optional): the key of force in the labels.
Default = "force".
Default = "force"
stress_key (str, optional): the key of stress in the labels.
Default = "stress".
Default = "stress"
magmom_key (str, optional): the key of magmom in the labels.
Default = "magmom".
Default = "magmom"
shuffle (bool): whether to shuffle the sequence of dataset
Default = True
"""
if isinstance(data, str):
self.data = {}
Expand All @@ -522,7 +540,8 @@ def __init__(
self.keys = [
(mp_id, graph_id) for mp_id, dct in self.data.items() for graph_id in dct
]
random.shuffle(self.keys)
if shuffle:
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
self.graph_converter = graph_converter
self.energy_key = energy_key
Expand Down

0 comments on commit 72f9108

Please sign in to comment.