diff --git a/tests/datamodules/test_utils.py b/tests/datamodules/test_utils.py index fcb1f82f584..3f8dba3c5dd 100644 --- a/tests/datamodules/test_utils.py +++ b/tests/datamodules/test_utils.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from torchgeo.datamodules.utils import group_shuffle_split +from torchgeo.datamodules.utils import group_shuffle_split, split_prefixed_kwargs def test_group_shuffle_split() -> None: @@ -44,3 +44,20 @@ def test_group_shuffle_split() -> None: assert len(set(train_indices1) & set(test_indices1)) == 0 assert len(set(groups[train_indices1])) == 2 + + +def test_split_prefixed_kwargs() -> None: + kwargs = { + 'testprefix1_param1': 10, + 'testprefix1_param2': 20, + 'testprefix2_param3': 30, + 'other_param': 40, + } + + testprefix1_kwargs, testprefix2_kwargs, other_kwargs = split_prefixed_kwargs( + 'testprefix1_', 'testprefix2_', **kwargs + ) + + assert testprefix1_kwargs == {'param1': 10, 'param2': 20} + assert testprefix2_kwargs == {'param3': 30} + assert other_kwargs == {'other_param': 40} diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e8e3aedd194..1c3afe4ae4d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -21,7 +21,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential -from .utils import MisconfigurationException +from .utils import MisconfigurationException, split_prefixed_kwargs class BaseDataModule(LightningDataModule): @@ -46,14 +46,17 @@ def __init__( dataset_class: Class used to instantiate a new dataset. batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. - **kwargs: Additional keyword arguments passed to ``dataset_class`` + **kwargs: Additional keyword arguments passed to the ``DataLoader`` + if prefixed with 'dataloader_', else passed to ``dataset_class``. """ super().__init__() self.dataset_class = dataset_class self.batch_size = batch_size self.num_workers = num_workers - self.kwargs = kwargs + self.dataloader_kwargs, self.kwargs = split_prefixed_kwargs( + 'dataloader_', **kwargs + ) # Datasets self.dataset: Dataset[dict[str, Tensor]] | None = None @@ -287,6 +290,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -431,6 +435,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: num_workers=self.num_workers, collate_fn=self.collate_fn, persistent_workers=self.num_workers > 0, + **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 4c3aab63b61..a79e7b3c887 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -169,3 +169,29 @@ def group_shuffle_split( test_idxs.append(i) return train_idxs, test_idxs + + +def split_prefixed_kwargs(*prefixes: str, **kwargs: Any) -> tuple[dict[str, Any], ...]: + """Split kwargs into prefixed and other kwargs. + + Args: + *prefixes: Prefixes to filter kwargs by. + **kwargs: Keyword arguments to filter. + + Returns: + Tuple of prefixed kwargs and other kwargs. + """ + prefixed_kwargs: list[dict[str, Any]] = [{} for _ in prefixes] + other_kwargs: dict[str, Any] = {} + + for key, value in kwargs.items(): + matched = False + for i, prefix in enumerate(prefixes): + if key.startswith(prefix): + prefixed_kwargs[i][key[len(prefix) :]] = value + matched = True + break + if not matched: + other_kwargs[key] = value + + return *prefixed_kwargs, other_kwargs