From ea8a330b558d3d226990dc4f0be8e8c1644f125d Mon Sep 17 00:00:00 2001 From: "Marek (on hp-zbook)" Date: Tue, 24 Sep 2024 13:21:40 +0200 Subject: [PATCH] bugfixes to make the writer working --- ahcore/callbacks/file_writer_callback.py | 2 +- ahcore/transforms/pre_transforms.py | 20 +++++++++++--------- ahcore/utils/data.py | 1 + ahcore/utils/manifest.py | 20 ++++++++++++++------ 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/ahcore/callbacks/file_writer_callback.py b/ahcore/callbacks/file_writer_callback.py index f2d2c97..4535584 100644 --- a/ahcore/callbacks/file_writer_callback.py +++ b/ahcore/callbacks/file_writer_callback.py @@ -155,7 +155,7 @@ def _get_writer_data_args( tile_size = (1, 1) tile_overlap = (0, 0) num_samples = num_samples - grid = None # just let the writer make a new grid + grid = current_dataset._grids[0][0] # give grid, bc doesn't work otherwise else: raise NotImplementedError(f"Data format {data_format} is not yet supported.") diff --git a/ahcore/transforms/pre_transforms.py b/ahcore/transforms/pre_transforms.py index 45d00ee..b752cfb 100644 --- a/ahcore/transforms/pre_transforms.py +++ b/ahcore/transforms/pre_transforms.py @@ -5,7 +5,6 @@ from __future__ import annotations -import logging from typing import Any, Callable import numpy as np @@ -26,7 +25,9 @@ class PreTransformTaskFactory: - def __init__(self, transforms: list[PreTransformCallable]): + def __init__( + self, transforms: list[PreTransformCallable], data_description: DataDescription, requires_target: bool + ) -> None: """ Pre-transforms are transforms that are applied to the samples directly originating from the dataset. These transforms are typically the same for the specific tasks (e.g., segmentation, @@ -71,7 +72,7 @@ def for_segmentation( """ transforms: list[PreTransformCallable] = [] if not requires_target: - return cls(transforms) + return cls(transforms, data_description, requires_target) if data_description.index_map is None: raise ConfigurationError("`index_map` is required for segmentation models when the target is required.") @@ -88,7 +89,7 @@ def for_segmentation( if not multiclass: transforms.append(OneHotEncodeMask(index_map=data_description.index_map)) - return cls(transforms) + return cls(transforms, data_description, requires_target) @classmethod def for_wsi_classification( @@ -99,7 +100,7 @@ def for_wsi_classification( transforms.append(SampleNFeatures(n=1000)) if not requires_target: - return cls(transforms) + return cls(transforms, data_description, requires_target) index_map = data_description.index_map if index_map is None: @@ -112,7 +113,7 @@ def for_wsi_classification( transforms.append(LabelToClassIndex(index_map=index_map)) - return cls(transforms) + return cls(transforms, data_description, requires_target) def __call__(self, data: DlupDatasetSample) -> DlupDatasetSample: for transform in self._transforms: @@ -311,9 +312,10 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]: if "labels" in sample.keys() and sample["labels"] is not None: for key, value in sample["labels"].items(): - sample["labels"][key] = torch.tensor(value, dtype=torch.float32) - if sample["labels"][key].dim() == 0: - sample["labels"][key] = sample["labels"][key].unsqueeze(0) + if isinstance(value, float) or isinstance(value, int): + sample["labels"][key] = torch.tensor(value, dtype=torch.float32) + if sample["labels"][key].dim() == 0: + sample["labels"][key] = sample["labels"][key].unsqueeze(0) # annotation_data is added by the ConvertPolygonToMask transform. if "annotation_data" not in sample: diff --git a/ahcore/utils/data.py b/ahcore/utils/data.py index c5362ae..463d144 100644 --- a/ahcore/utils/data.py +++ b/ahcore/utils/data.py @@ -69,3 +69,4 @@ class DataDescription(BaseModel): convert_mask_to_rois: bool = True use_roi: bool = True apply_color_profile: bool = False + tiling_mode: Optional[str] = None diff --git a/ahcore/utils/manifest.py b/ahcore/utils/manifest.py index 40d8bab..9627f80 100644 --- a/ahcore/utils/manifest.py +++ b/ahcore/utils/manifest.py @@ -62,7 +62,10 @@ class _AnnotationReadersDict(TypedDict): "DARWIN_JSON": WsiAnnotations.from_darwin_json, "GEOJSON": WsiAnnotations.from_geojson, "PYVIPS": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.PYVIPS), - "TIFFFILE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.TIFFFILE), + "TIFFFILE": functools.partial( + SlideImage.from_file_path, + backend=DLUPImageBackend.TIFFFILE, + ), "OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.OPENSLIDE), } @@ -200,11 +203,12 @@ def get_relevant_feature_info_from_record( return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp -def _get_rois(mask: Optional[WsiAnnotations], data_description: DataDescription, stage: str) -> Optional[Rois]: +def _get_rois(mask: Optional[_AnnotationReturnTypes], data_description: DataDescription, stage: str) -> Optional[Rois]: if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois): return None assert data_description.training_grid is not None + assert isinstance(mask, WsiAnnotations) # this is necessary for the compute_rois to work tile_size = data_description.training_grid.tile_size tile_overlap = data_description.training_grid.tile_overlap @@ -485,7 +489,7 @@ def get_image_info( raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.") mask, annotations = get_mask_and_annotations_from_record(annotations_root, image) - assert isinstance(mask, WsiAnnotations) or (mask is None) + assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold rois = _get_rois(mask, data_description, stage) @@ -495,7 +499,11 @@ def get_image_info( backend = DLUPImageBackend[str(image.reader)] mpp = getattr(grid_description, "mpp", 1.0) overwrite_mpp = float(image.mpp) - tile_mode = TilingMode.overflow + tile_mode = ( + TilingMode(data_description.tiling_mode) + if data_description.tiling_mode is not None + else TilingMode.overflow + ) output_tile_size = getattr(grid_description, "output_tile_size", None) # Update the dictionary with the actual values @@ -535,7 +543,7 @@ def datasets_from_data_description( split_category=stage, ) - for patient in patients: + for patient_idx, patient in enumerate(patients): patient_labels = get_labels_from_record(patient) for image in patient.images: @@ -583,7 +591,7 @@ def datasets_from_data_description( and all(isinstance(i, int) for i in output_tile_size) # pylint: disable=not-an-iterable or (output_tile_size is None) ) - assert isinstance(mask, WsiAnnotations) or (mask is None) + assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage) assert isinstance(mask_threshold, float) or mask_threshold is None assert isinstance(annotations, WsiAnnotations) or (annotations is None)