Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Add database_uri as a potential input parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen committed Apr 5, 2024
1 parent 20abf8f commit f64b654
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,10 @@ def __len__(self) -> int:
return self.cumulative_sizes[-1]

@overload
def __getitem__(self, index: int) -> DlupDatasetSample:
...
def __getitem__(self, index: int) -> DlupDatasetSample: ...

@overload
def __getitem__(self, index: slice) -> list[DlupDatasetSample]:
...
def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ...

def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]:
"""Returns the sample at the given index."""
Expand All @@ -112,6 +110,7 @@ def __init__(
self,
data_description: DataDescription,
pre_transform: Callable[[bool], Callable[[DlupDatasetSample], DlupDatasetSample]],
database_uri: str | None = None,
batch_size: int = 32, # noqa,pylint: disable=unused-argument
validate_batch_size: int | None = None, # noqa,pylint: disable=unused-argument
num_workers: int = 16,
Expand All @@ -129,6 +128,8 @@ def __init__(
A pre-transform is a callable which is directly applied to the output of the dataset before collation in
the dataloader. The transforms typically convert the image in the output to a tensor, convert the
`WsiAnnotations` to a mask or similar.
database_uri : str, optional
The URI to the database. If not provided, the URI from the data description will be used.
batch_size : int
The batch size of the data loader.
validate_batch_size : int, optional
Expand Down Expand Up @@ -160,7 +161,9 @@ def __init__(
# Data settings
self.data_description: DataDescription = data_description

self._data_manager = DataManager(database_uri=data_description.manifest_database_uri)
self._data_manager = DataManager(
database_uri=data_description.manifest_database_uri if not database_uri else database_uri
)

self._batch_size = self.hparams.batch_size # type: ignore
self._validate_batch_size = self.hparams.validate_batch_size # type: ignore
Expand Down

0 comments on commit f64b654

Please sign in to comment.