Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
Add offset to writer metadata (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasteuwen authored Jul 30, 2024
1 parent fc84395 commit a8b5829
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 48 deletions.
2 changes: 1 addition & 1 deletion ahcore/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def vendor(self) -> str:

@property
def properties(self) -> dict[str, Any]:
return self._reader._metadata
return self._reader.metadata

@property
def magnification(self):
Expand Down
6 changes: 3 additions & 3 deletions ahcore/callbacks/milvus_vector_db_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(
self._collection_name = collection_name
self._collection_type = CollectionType[collection_type.upper()]
self._collection: Collection | None = None
self._prepare_data: Callable[
[Dict[str, Any], Dict[str, Any]], PrepareDataTypePathCLS | PrepareDataTypeDebug
] | None = None
self._prepare_data: (
Callable[[Dict[str, Any], Dict[str, Any]], PrepareDataTypePathCLS | PrepareDataTypeDebug] | None
) = None

self.embedding_dim = embedding_dim

Expand Down
6 changes: 4 additions & 2 deletions ahcore/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ def __len__(self) -> int:
return self.cumulative_sizes[-1]

@overload
def __getitem__(self, index: int) -> DlupDatasetSample: ...
def __getitem__(self, index: int) -> DlupDatasetSample:
...

@overload
def __getitem__(self, index: slice) -> list[DlupDatasetSample]: ...
def __getitem__(self, index: slice) -> list[DlupDatasetSample]:
...

def __getitem__(self, index: Union[int, slice]) -> DlupDatasetSample | list[DlupDatasetSample]:
"""Returns the sample at the given index."""
Expand Down
9 changes: 8 additions & 1 deletion ahcore/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,21 @@ def _empty_tile(self) -> GenericNumberArray:
self.__empty_tile = np.zeros((self._num_channels, *self._tile_size), dtype=self._dtype)
return self.__empty_tile

@property
def metadata(self) -> dict[str, Any]:
if not self._metadata:
self._open_file()
assert self._metadata
return self._metadata

def _decompress_data(self, tile: GenericNumberArray) -> GenericNumberArray:
if self._is_binary:
with PIL.Image.open(io.BytesIO(tile)) as img:
return np.array(img).transpose(2, 0, 1)
else:
return tile

def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> GenericNumberArray:
def read_region(self, location: tuple[int, int], level: int, size: tuple[int, int]) -> pyvips.Image:
"""
Reads a region in the stored h5 file. This function stitches the regions as saved in the cache file. Doing this
it takes into account:
Expand Down
7 changes: 4 additions & 3 deletions ahcore/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Ahcore's callbacks"""

from __future__ import annotations
import pyvips

import hashlib
import logging
from pathlib import Path
from typing import Any, Iterator, Optional

import numpy as np
import numpy.typing as npt
import pyvips
from dlup import SlideImage
from dlup.annotations import WsiAnnotations
from dlup.data.transforms import convert_annotations, rename_labels
Expand Down Expand Up @@ -138,7 +139,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
sample = {}
coordinates = self._regions[idx]

sample["prediction"] = self._get_h5_region(coordinates)
sample["prediction"] = self._get_backend_file_region(coordinates)

if self._annotations is not None:
target, roi = self._get_annotation_data(coordinates)
Expand All @@ -150,7 +151,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]:

return sample

def _get_h5_region(self, coordinates: tuple[int, int]) -> npt.NDArray[np.uint8 | np.uint16 | np.float32 | np.bool_]:
def _get_backend_file_region(self, coordinates: tuple[int, int]) -> pyvips.Image:
x, y = coordinates
width, height = self._region_size

Expand Down
2 changes: 1 addition & 1 deletion ahcore/utils/database_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class Split(Base):
# pylint: disable=E1102
created = Column(DateTime(timezone=True), default=func.now())
last_updated = Column(DateTime(timezone=True), default=func.now(), onupdate=func.now())
category: Column[CategoryEnum] = Column(Enum(CategoryEnum), nullable=False)
category: Column[str] = Column(Enum(CategoryEnum), nullable=False)
patient_id = Column(Integer, ForeignKey("patient.id"), nullable=False)
split_definition_id = Column(Integer, ForeignKey("split_definitions.id"), nullable=False)

Expand Down
17 changes: 17 additions & 0 deletions ahcore/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import os
import pathlib
import subprocess
import warnings
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -273,3 +274,19 @@ def validate_checkpoint_paths(config: DictConfig) -> DictConfig:
if checkpoint_path and not os.path.isabs(checkpoint_path):
config.trainer.resume_from_checkpoint = pathlib.Path(hydra.utils.get_original_cwd()) / checkpoint_path
return config


def get_git_hash():
try:
# Check if we're in a git repository
subprocess.run(["git", "rev-parse", "--is-inside-work-tree"], check=True, capture_output=True, text=True)

# Get the git hash
result = subprocess.run(["git", "rev-parse", "HEAD"], check=True, capture_output=True, text=True)
return result.stdout.strip()
except subprocess.CalledProcessError:
# This will be raised if we're not in a git repo
return None
except FileNotFoundError:
# This will be raised if git is not installed or not in PATH
return None
20 changes: 16 additions & 4 deletions ahcore/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any, Generator, NamedTuple, Optional

import dlup
import h5py
import numcodecs
import numpy as np
Expand All @@ -23,7 +24,8 @@
import zarr
from dlup.tiling import Grid, GridOrder, TilingMode

from ahcore.utils.io import get_logger
import ahcore
from ahcore.utils.io import get_git_hash, get_logger
from ahcore.utils.types import GenericNumberArray, InferencePrecision

logger = get_logger(__name__)
Expand Down Expand Up @@ -65,6 +67,8 @@ class WriterMetadata(NamedTuple):
format: str | None
num_channels: int
dtype: str
grid_offset: tuple[int, int] | None
versions: dict[str, str]


class Writer(abc.ABC):
Expand Down Expand Up @@ -97,7 +101,7 @@ def __init__(
self._progress = progress

self._grid_coordinates: Optional[npt.NDArray[np.int_]] = None
self._grid_offset: npt.NDArray[np.int_] | None = None
self._grid_offset: tuple[int, int] | None = None

self._current_index: int = 0
self._tiles_seen = 0
Expand Down Expand Up @@ -173,7 +177,14 @@ def get_writer_metadata(self, first_batch: GenericNumberArray) -> WriterMetadata

_dtype = str(first_batch.dtype)

return WriterMetadata(_mode, _format, _num_channels, _dtype)
return WriterMetadata(
mode=_mode,
format=_format,
num_channels=_num_channels,
dtype=_dtype,
grid_offset=self._grid_offset,
versions={"dlup": dlup.__version__, "ahcore": ahcore.__version__, "ahcore_git_hash": get_git_hash()},
)

def set_grid(self) -> None:
# TODO: We only support a single Grid
Expand Down Expand Up @@ -211,6 +222,7 @@ def construct_metadata(self, writer_metadata: WriterMetadata) -> dict[str, Any]:
"format": writer_metadata.format,
"dtype": writer_metadata.dtype,
"is_binary": self._is_compressed_image,
"grid_offset": writer_metadata.grid_offset,
"precision": self._precision.value if self._precision else str(InferencePrecision.FP32),
"multiplier": (
self._precision.get_multiplier() if self._precision else InferencePrecision.FP32.get_multiplier()
Expand Down Expand Up @@ -302,13 +314,13 @@ def consume(self, batch_generator: Generator[tuple[GenericNumberArray, GenericNu

def init_writer(self, first_coordinates: GenericNumberArray, first_batch: GenericNumberArray, file: Any) -> Any:
"""Initializes the image_dataset based on the first tile."""
self._grid_offset = (int(first_coordinates[0][0]), int(first_coordinates[0][1]))

writer_metadata = self.get_writer_metadata(first_batch)

self._current_index = 0
# The grid can be smaller than the actual image when slide bounds are given.
# As the grid should cover the image, the offset is given by the first tile.
self._grid_offset = np.array(first_coordinates[0])

self._coordinates_dataset = self.create_dataset(
file, name="coordinates", shape=(self._num_samples, 2), dtype=np.int_, compression="gzip"
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ current_version = 0.1.1
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
serialize =
serialize =
{major}.{minor}.{patch}-{release}{build}
{major}.{minor}.{patch}

[bumpversion:part:release]
optional_value = prod
first_value = dev
values =
values =
dev
prod

Expand Down
6 changes: 3 additions & 3 deletions tests/test_readers/test_readers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import errno
import os
from pathlib import Path
from typing import Generator, Any
from typing import Any, Generator
from unittest.mock import patch

import h5py
import numpy as np
import numpy.typing as npt
import pytest

from ahcore.readers import FileImageReader, StitchingMode
from ahcore.utils.types import GenericNumberArray
import numpy.typing as npt


def colorize(image_array: npt.NDArray[np.uint8]) -> npt.NDArray[np.uint8]:
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_stitching_modes(temp_h5_file: Path) -> None:

with patch("pyvips.Image.new_from_array") as mock_pyvips:
# Reading a region that covers all 4 tiles
region = reader.read_region((0, 0), 0, (350, 350))
_ = reader.read_region((0, 0), 0, (350, 350))
mock_pyvips.assert_called_once()

stitched_image = mock_pyvips.call_args[0][0]
Expand Down
Loading

0 comments on commit a8b5829

Please sign in to comment.