diff --git a/acegen/models/utils.py b/acegen/models/utils.py index 0225b23..1e70383 100644 --- a/acegen/models/utils.py +++ b/acegen/models/utils.py @@ -1,7 +1,5 @@ import warnings -from torchrl.envs import Compose - def adapt_state_dict(source_state_dict: dict, target_state_dict: dict): """Adapt the source state dict to the target state dict. @@ -34,22 +32,3 @@ def adapt_state_dict(source_state_dict: dict, target_state_dict: dict): target_state_dict[key_target] = value_source return target_state_dict - - -def get_primers_from_module(module): - """Get all tensordict primers from all submodules of a module.""" - primers = [] - - def make_primers(submodule): - if hasattr(submodule, "make_tensordict_primer"): - primers.append(submodule.make_tensordict_primer()) - - module.apply(make_primers) - if not primers: - import warnings - - raise warnings.warn("No primers found in the module.") - elif len(primers) == 1: - return primers[0] - else: - return Compose(primers) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 0497b27..8fc3203 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -11,6 +11,7 @@ smiles_to_tensordict, SMILESDataset, ) +from acegen.data.chem_utils import fraction_valid from acegen.vocabulary.tokenizers import SMILESTokenizerChEMBL from acegen.vocabulary.vocabulary import Vocabulary from tensordict import TensorDict @@ -102,3 +103,21 @@ def test_load_dataset(randomize_smiles): data_batch = dataloader.__iter__().__next__() assert isinstance(data_batch, TensorDict) shutil.rmtree(temp_dir) + + +def test_fraction_valid(): + + multiple_smiles = [ + "CCO", # Ethanol (C2H5OH) + "CCN(CC)CC", # Triethylamine (C6H15N) + "CC(=O)OC(C)C", # Diethyl carbonate (C7H14O3) + "CC(C)C", # Isobutane (C4H10) + "CC1=CC=CC=C1", # Toluene (C7H8) + ] + + assert fraction_valid(multiple_smiles) == 1.0 + + # add an invalid SMILES + multiple_smiles.append("invalid") + + assert fraction_valid(multiple_smiles) == 5 / 6