-
Notifications
You must be signed in to change notification settings - Fork 385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataModules: add configurable args to dataloader #2333
base: main
Are you sure you want to change the base?
DataModules: add configurable args to dataloader #2333
Conversation
How much of a speed difference do you notice? At the moment, the datamodules are all designed to pass additional kwargs to the dataset class. So we should technically add these new args to every datamodule, not just to the base class. We could design things such that all kwargs are either passed to the data loader or the dataset class, but I'm worried users might get confused by this. Of course, they don't need to understand the feature if they don't want to use it. I guess my concern is, why stop at only these variables? |
Varying all of those three parameters I was able to at least half the time for I agree, it would be better to solve for all arguments, and that there is a trade off between full configurability and simplicity. I just added those variables since those were the ones I wanted to vary. Could it be an idea to do something like what is done in # ...
DATALOADER_KWARG_PREFIX = "dataloader_"
# ...
dataloader_kwargs = {
k.replace(DATALOADER_KWARG_PREFIX, ""): v
for k, v in kwargs.items()
if k.startswith(DATALOADER_KWARG_PREFIX)
}
# ...
DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
**dataloader_kwargs
) |
1af0682
to
1f56161
Compare
1f56161
to
8257b8e
Compare
I changed it to use kwargs now self.dataloader_kwargs, self.kwargs = split_prefixed_kwargs(
'dataloader_', **kwargs
) |
@microsoft-github-policy-service agree company="Kongsberg Satellite Services" |
I think it woul be better for kwargs to be a JSON-like dict, something like {"dataset": ..., "dataloader": ...}. We could even turn this into a typed dict or a dataclass so that it can be clearly documented and tested. I think the transformers library likes to use these config structures a lot |
8fc54ef
to
b5291a6
Compare
9a78733
to
88cedbc
Compare
Yes, we do something similar in our combined GeoDatasets (NAIP + Chesapeake, Sentinel-2 + CDL/EuroCrops/NCCM/SAS).
The downside of this is that it would be backwards-incompatible and affect all datamodules. I'm still thinking about this idea. I wish there was a way to predict a good default value. I'm curious if any other maintainers have experimented with these variables and how important they are to tune. In the meantime, anyone can do this themselves by overriding the respective methods like so: from torchgeo.datamodules import TropicalCycloneDataModule
class MyTropicalCycloneDataModule(TropicalCycloneDataModule):
def _dataloader_factory(self, split):
dataset = self._valid_attribute(f'{split}_dataset', 'dataset')
batch_size = self._valid_attribute(f'{split}_batch_size', 'batch_size')
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=split == 'train',
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
...
) |
Closes #2332.