From a8b5829901ff2c178bbf051e25e18c1b037ef2f8 Mon Sep 17 00:00:00 2001 From: Jonas Teuwen <2347927+jonasteuwen@users.noreply.github.com> Date: Tue, 30 Jul 2024 13:15:45 +0200 Subject: [PATCH] Add offset to writer metadata (#102) --- ahcore/backends.py | 2 +- ahcore/callbacks/milvus_vector_db_callback.py | 6 +- ahcore/data/dataset.py | 6 +- ahcore/readers.py | 9 +- ahcore/utils/callbacks.py | 7 +- ahcore/utils/database_models.py | 2 +- ahcore/utils/io.py | 17 ++ ahcore/writers.py | 20 ++- setup.cfg | 4 +- tests/test_readers/test_readers.py | 6 +- tests/test_writers/test_h5_writer.py | 150 ++++++++++++++---- 11 files changed, 181 insertions(+), 48 deletions(-) diff --git a/ahcore/backends.py b/ahcore/backends.py index c969726..94e66d9 100644 --- a/ahcore/backends.py +++ b/ahcore/backends.py @@ -31,7 +31,7 @@ def vendor(self) -> str: @property def properties(self) -> dict[str, Any]: - return self._reader._metadata + return self._reader.metadata @property def magnification(self): diff --git a/ahcore/callbacks/milvus_vector_db_callback.py b/ahcore/callbacks/milvus_vector_db_callback.py index 3223a42..84848ef 100644 --- a/ahcore/callbacks/milvus_vector_db_callback.py +++ b/ahcore/callbacks/milvus_vector_db_callback.py @@ -63,9 +63,9 @@ def __init__( self._collection_name = collection_name self._collection_type = CollectionType[collection_type.upper()] self._collection: Collection | None = None - self._prepare_data: Callable[ - [Dict[str, Any], Dict[str, Any]], PrepareDataTypePathCLS | PrepareDataTypeDebug - ] | None = None + self._prepare_data: ( + Callable[[Dict[str, Any], Dict[str, Any]], PrepareDataTypePathCLS | PrepareDataTypeDebug] | None + ) = None self.embedding_dim = embedding_dim diff --git a/ahcore/data/dataset.py b/ahcore/data/dataset.py index ea2d7a9..f4c7d13 100644 --- a/ahcore/data/dataset.py +++ b/ahcore/data/dataset.py @@ -88,10 +88,12 @@ def __len__(self) -> int: return self.cumulative_sizes[-1] @overload - def __getitem__(self, index: int) -> DlupDatasetSample: ... + def __getitem__(self, index: int) -> DlupDatasetSample: + ... @overload - def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ... + def __getitem__(self, index: slice) -> list[DlupDatasetSample]: + ... def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]: """Returns the sample at the given index.""" diff --git a/ahcore/readers.py b/ahcore/readers.py index 1de2260..9c3b583 100644 --- a/ahcore/readers.py +++ b/ahcore/readers.py @@ -149,6 +149,13 @@ def _empty_tile(self) -> GenericNumberArray: self.__empty_tile = np.zeros((self._num_channels, *self._tile_size), dtype=self._dtype) return self.__empty_tile + @property + def metadata(self) -> dict[str, Any]: + if not self._metadata: + self._open_file() + assert self._metadata + return self._metadata + def _decompress_data(self, tile: GenericNumberArray) -> GenericNumberArray: if self._is_binary: with PIL.Image.open(io.BytesIO(tile)) as img: @@ -156,7 +163,7 @@ def _decompress_data(self, tile: GenericNumberArray) -> GenericNumberArray: else: return tile - def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> GenericNumberArray: + def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image: """ Reads a region in the stored h5 file. This function stitches the regions as saved in the cache file. Doing this it takes into account: diff --git a/ahcore/utils/callbacks.py b/ahcore/utils/callbacks.py index d332858..125b4f1 100644 --- a/ahcore/utils/callbacks.py +++ b/ahcore/utils/callbacks.py @@ -1,7 +1,7 @@ """Ahcore's callbacks""" from __future__ import annotations -import pyvips + import hashlib import logging from pathlib import Path @@ -9,6 +9,7 @@ import numpy as np import numpy.typing as npt +import pyvips from dlup import SlideImage from dlup.annotations import WsiAnnotations from dlup.data.transforms import convert_annotations, rename_labels @@ -138,7 +139,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: sample = {} coordinates = self._regions[idx] - sample["prediction"] = self._get_h5_region(coordinates) + sample["prediction"] = self._get_backend_file_region(coordinates) if self._annotations is not None: target, roi = self._get_annotation_data(coordinates) @@ -150,7 +151,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return sample - def _get_h5_region(self, coordinates: tuple[int, int]) -> npt.NDArray[np.uint8 | np.uint16 | np.float32 | np.bool_]: + def _get_backend_file_region(self, coordinates: tuple[int, int]) -> pyvips.Image: x, y = coordinates width, height = self._region_size diff --git a/ahcore/utils/database_models.py b/ahcore/utils/database_models.py index 21538b7..943e449 100644 --- a/ahcore/utils/database_models.py +++ b/ahcore/utils/database_models.py @@ -202,7 +202,7 @@ class Split(Base): # pylint: disable=E1102 created = Column(DateTime(timezone=True), default=func.now()) last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now()) - category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False) + category: Column[str] = Column(Enum(CategoryEnum), nullable=False) patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False) split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False) diff --git a/ahcore/utils/io.py b/ahcore/utils/io.py index 850120c..bc7b2dd 100644 --- a/ahcore/utils/io.py +++ b/ahcore/utils/io.py @@ -11,6 +11,7 @@ import logging import os import pathlib +import subprocess import warnings from enum import Enum from pathlib import Path @@ -273,3 +274,19 @@ def validate_checkpoint_paths(config: DictConfig) -> DictConfig: if checkpoint_path and not os.path.isabs(checkpoint_path): config.trainer.resume_from_checkpoint = pathlib.Path(hydra.utils.get_original_cwd()) / checkpoint_path return config + + +def get_git_hash(): + try: + # Check if we're in a git repository + subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, capture_output=True, text=True) + + # Get the git hash + result = subprocess.run(["git", "rev-parse", "HEAD"], check=True, capture_output=True, text=True) + return result.stdout.strip() + except subprocess.CalledProcessError: + # This will be raised if we're not in a git repo + return None + except FileNotFoundError: + # This will be raised if git is not installed or not in PATH + return None diff --git a/ahcore/writers.py b/ahcore/writers.py index 0581ca7..1a59997 100644 --- a/ahcore/writers.py +++ b/ahcore/writers.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any, Generator, NamedTuple, Optional +import dlup import h5py import numcodecs import numpy as np @@ -23,7 +24,8 @@ import zarr from dlup.tiling import Grid, GridOrder, TilingMode -from ahcore.utils.io import get_logger +import ahcore +from ahcore.utils.io import get_git_hash, get_logger from ahcore.utils.types import GenericNumberArray, InferencePrecision logger = get_logger(__name__) @@ -65,6 +67,8 @@ class WriterMetadata(NamedTuple): format: str | None num_channels: int dtype: str + grid_offset: tuple[int, int] | None + versions: dict[str, str] class Writer(abc.ABC): @@ -97,7 +101,7 @@ def __init__( self._progress = progress self._grid_coordinates: Optional[npt.NDArray[np.int_]] = None - self._grid_offset: npt.NDArray[np.int_] | None = None + self._grid_offset: tuple[int, int] | None = None self._current_index: int = 0 self._tiles_seen = 0 @@ -173,7 +177,14 @@ def get_writer_metadata(self, first_batch: GenericNumberArray) -> WriterMetadata _dtype = str(first_batch.dtype) - return WriterMetadata(_mode, _format, _num_channels, _dtype) + return WriterMetadata( + mode=_mode, + format=_format, + num_channels=_num_channels, + dtype=_dtype, + grid_offset=self._grid_offset, + versions={"dlup": dlup.__version__, "ahcore": ahcore.__version__, "ahcore_git_hash": get_git_hash()}, + ) def set_grid(self) -> None: # TODO: We only support a single Grid @@ -211,6 +222,7 @@ def construct_metadata(self, writer_metadata: WriterMetadata) -> dict[str, Any]: "format": writer_metadata.format, "dtype": writer_metadata.dtype, "is_binary": self._is_compressed_image, + "grid_offset": writer_metadata.grid_offset, "precision": self._precision.value if self._precision else str(InferencePrecision.FP32), "multiplier": ( self._precision.get_multiplier() if self._precision else InferencePrecision.FP32.get_multiplier() @@ -302,13 +314,13 @@ def consume(self, batch_generator: Generator[tuple[GenericNumberArray, GenericNu def init_writer(self, first_coordinates: GenericNumberArray, first_batch: GenericNumberArray, file: Any) -> Any: """Initializes the image_dataset based on the first tile.""" + self._grid_offset = (int(first_coordinates[0][0]), int(first_coordinates[0][1])) writer_metadata = self.get_writer_metadata(first_batch) self._current_index = 0 # The grid can be smaller than the actual image when slide bounds are given. # As the grid should cover the image, the offset is given by the first tile. - self._grid_offset = np.array(first_coordinates[0]) self._coordinates_dataset = self.create_dataset( file, name="coordinates", shape=(self._num_samples, 2), dtype=np.int_, compression="gzip" diff --git a/setup.cfg b/setup.cfg index 18ce0d2..22dedcd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,14 +3,14 @@ current_version = 0.1.1 commit = True tag = False parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\-(?P[a-z]+)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}-{release}{build} {major}.{minor}.{patch} [bumpversion:part:release] optional_value = prod first_value = dev -values = +values = dev prod diff --git a/tests/test_readers/test_readers.py b/tests/test_readers/test_readers.py index 675c385..26278d2 100644 --- a/tests/test_readers/test_readers.py +++ b/tests/test_readers/test_readers.py @@ -1,16 +1,16 @@ import errno import os from pathlib import Path -from typing import Generator, Any +from typing import Any, Generator from unittest.mock import patch import h5py import numpy as np +import numpy.typing as npt import pytest from ahcore.readers import FileImageReader, StitchingMode from ahcore.utils.types import GenericNumberArray -import numpy.typing as npt def colorize(image_array: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]: @@ -204,7 +204,7 @@ def test_stitching_modes(temp_h5_file: Path) -> None: with patch("pyvips.Image.new_from_array") as mock_pyvips: # Reading a region that covers all 4 tiles - region = reader.read_region((0, 0), 0, (350, 350)) + _ = reader.read_region((0, 0), 0, (350, 350)) mock_pyvips.assert_called_once() stitched_image = mock_pyvips.call_args[0][0] diff --git a/tests/test_writers/test_h5_writer.py b/tests/test_writers/test_h5_writer.py index 5b28bf0..34537e8 100644 --- a/tests/test_writers/test_h5_writer.py +++ b/tests/test_writers/test_h5_writer.py @@ -1,11 +1,12 @@ +import json from pathlib import Path -from typing import Generator +from typing import Any, Generator import h5py import numpy as np import pytest -from ahcore.utils.types import GenericNumberArray +from ahcore.utils.types import GenericNumberArray, InferencePrecision from ahcore.writers import H5FileImageWriter @@ -13,30 +14,34 @@ def temp_h5_file(tmp_path: Path) -> Generator[Path, None, None]: h5_file_path = tmp_path / "test_data.h5" yield h5_file_path + if h5_file_path.exists(): + h5_file_path.unlink() -def test_h5_file_image_writer(temp_h5_file: Path) -> None: - """ - This test the H5FileImageWriter class for the following case: +@pytest.fixture +def dummy_batch_data() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + dummy_coordinates = np.array([[0, 0]]) + dummy_batch = np.random.rand(1, 3, 200, 200).astype(np.float32) + yield dummy_coordinates, dummy_batch + + +@pytest.fixture +def dummy_batch_generator( + dummy_batch_data: tuple[GenericNumberArray, GenericNumberArray] +) -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + def generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + yield dummy_batch_data - Assuming that we have a tile of size (200, 200) and we want to write it to an H5 file. - This test writes the tile to the H5 file and then reads it back to perform assertions. + return generator - Parameters - ---------- - temp_h5_file - Returns - ------- - None - """ +def test_h5_file_image_writer_creation(temp_h5_file: Path) -> None: size = (200, 200) mpp = 0.5 tile_size = (200, 200) tile_overlap = (0, 0) num_samples = 1 - # Create an instance of H5FileImageWriter writer = H5FileImageWriter( filename=temp_h5_file, size=size, @@ -46,28 +51,117 @@ def test_h5_file_image_writer(temp_h5_file: Path) -> None: num_samples=num_samples, ) - # Dummy batch data for testing - dummy_coordinates = np.random.randint(0, 255, (1, 2)) - dummy_batch = np.random.rand(1, 3, 200, 200) + assert writer._filename == temp_h5_file + assert writer._size == size + assert writer._mpp == mpp + assert writer._tile_size == tile_size + assert writer._tile_overlap == tile_overlap + assert writer._num_samples == num_samples - # Use a generator to yield dummy batches - def batch_generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: - for i in range(num_samples): - yield dummy_coordinates, dummy_batch - # Write data to the H5 file - writer.consume(batch_generator()) +def test_h5_file_image_writer_consume(temp_h5_file: Path, dummy_batch_generator: Any) -> None: + size = (200, 200) + mpp = 0.5 + tile_size = (200, 200) + tile_overlap = (0, 0) + num_samples = 1 + + writer = H5FileImageWriter( + filename=temp_h5_file, + size=size, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + num_samples=num_samples, + ) + + writer.consume(dummy_batch_generator()) - # Perform assertions with h5py.File(temp_h5_file, "r") as h5file: assert "data" in h5file assert "coordinates" in h5file assert np.array(h5file["data"]).shape == (num_samples, 3, 200, 200) assert np.array(h5file["coordinates"]).shape == (num_samples, 2) + + dummy_coordinates, dummy_batch = next(dummy_batch_generator()) assert np.allclose(h5file["data"], dummy_batch) assert np.allclose(h5file["coordinates"], dummy_coordinates) -# Run the tests -if __name__ == "__main__": - pytest.main() +def test_h5_file_image_writer_metadata(temp_h5_file: Path, dummy_batch_generator: Any) -> None: + size = (200, 200) + mpp = 0.5 + tile_size = (200, 200) + tile_overlap = (0, 0) + num_samples = 1 + num_tiles = 1 + grid_order = "C" + tiling_mode = "overflow" + format = "RAW" + dtype = "float32" + is_binary = False + grid_offset = (0, 0) + precision = InferencePrecision.FP32 + multiplier = 1.0 + + writer = H5FileImageWriter( + filename=temp_h5_file, + size=size, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + num_samples=num_samples, + precision=precision, + ) + + writer.consume(dummy_batch_generator()) + + with h5py.File(temp_h5_file, "r") as h5file: + metadata = json.loads(h5file.attrs["metadata"]) + assert metadata["size"] == list(size) + assert metadata["mpp"] == mpp + assert metadata["num_samples"] == num_samples + assert metadata["tile_size"] == list(tile_size) + assert metadata["tile_overlap"] == list(tile_overlap) + assert metadata["num_tiles"] == num_tiles + assert metadata["grid_order"] == grid_order + assert metadata["tiling_mode"] == tiling_mode + assert metadata["format"] == format + assert metadata["dtype"] == dtype + assert metadata["is_binary"] == is_binary + assert metadata["grid_offset"] == list(grid_offset) + assert metadata["precision"] == precision + assert metadata["multiplier"] == multiplier + + +def test_h5_file_image_writer_multiple_tiles(temp_h5_file: Path) -> None: + size = (400, 400) + mpp = 0.5 + tile_size = (200, 200) + tile_overlap = (0, 0) + num_samples = 2 + + writer = H5FileImageWriter( + filename=temp_h5_file, + size=size, + mpp=mpp, + tile_size=tile_size, + tile_overlap=tile_overlap, + num_samples=num_samples, + ) + + def multiple_tile_generator() -> Generator[tuple[GenericNumberArray, GenericNumberArray], None, None]: + for i in range(num_samples): + coordinates = np.array([[i * 200, 0]]) + batch = np.random.rand(1, 3, 200, 200).astype(np.float32) + yield coordinates, batch + + writer.consume(multiple_tile_generator()) + + with h5py.File(temp_h5_file, "r") as h5file: + assert "data" in h5file + assert "coordinates" in h5file + assert h5file["data"].shape == (num_samples, 3, 200, 200) + assert h5file["coordinates"].shape == (num_samples, 2) + for i in range(num_samples): + assert np.allclose(h5file["coordinates"][i], [i * 200, 0])