Skip to content

Commit

Permalink
improve testing coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Aug 9, 2024
1 parent 9c1c4a9 commit 24a8360
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
21 changes: 0 additions & 21 deletions acegen/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings

from torchrl.envs import Compose


def adapt_state_dict(source_state_dict: dict, target_state_dict: dict):
"""Adapt the source state dict to the target state dict.
Expand Down Expand Up @@ -34,22 +32,3 @@ def adapt_state_dict(source_state_dict: dict, target_state_dict: dict):
target_state_dict[key_target] = value_source

return target_state_dict


def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module."""
primers = []

def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())

module.apply(make_primers)
if not primers:
import warnings

raise warnings.warn("No primers found in the module.")
elif len(primers) == 1:
return primers[0]
else:
return Compose(primers)
19 changes: 19 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
smiles_to_tensordict,
SMILESDataset,
)
from acegen.data.chem_utils import fraction_valid
from acegen.vocabulary.tokenizers import SMILESTokenizerChEMBL
from acegen.vocabulary.vocabulary import Vocabulary
from tensordict import TensorDict
Expand Down Expand Up @@ -102,3 +103,21 @@ def test_load_dataset(randomize_smiles):
data_batch = dataloader.__iter__().__next__()
assert isinstance(data_batch, TensorDict)
shutil.rmtree(temp_dir)


def test_fraction_valid():

multiple_smiles = [
"CCO", # Ethanol (C2H5OH)
"CCN(CC)CC", # Triethylamine (C6H15N)
"CC(=O)OC(C)C", # Diethyl carbonate (C7H14O3)
"CC(C)C", # Isobutane (C4H10)
"CC1=CC=CC=C1", # Toluene (C7H8)
]

assert fraction_valid(multiple_smiles) == 1.0

# add an invalid SMILES
multiple_smiles.append("invalid")

assert fraction_valid(multiple_smiles) == 5 / 6

0 comments on commit 24a8360

Please sign in to comment.