Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
bugfixes to make the writer working
Browse files Browse the repository at this point in the history
  • Loading branch information
moerlemans committed Sep 24, 2024
1 parent ae580cb commit ea8a330
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion ahcore/callbacks/file_writer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _get_writer_data_args(
tile_size = (1, 1)
tile_overlap = (0, 0)
num_samples = num_samples
grid = None # just let the writer make a new grid
grid = current_dataset._grids[0][0] # give grid, bc doesn't work otherwise

else:
raise NotImplementedError(f"Data format {data_format} is not yet supported.")
Expand Down
20 changes: 11 additions & 9 deletions ahcore/transforms/pre_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import logging
from typing import Any, Callable

import numpy as np
Expand All @@ -26,7 +25,9 @@


class PreTransformTaskFactory:
def __init__(self, transforms: list[PreTransformCallable]):
def __init__(
self, transforms: list[PreTransformCallable], data_description: DataDescription, requires_target: bool
) -> None:
"""
Pre-transforms are transforms that are applied to the samples directly originating from the dataset.
These transforms are typically the same for the specific tasks (e.g., segmentation,
Expand Down Expand Up @@ -71,7 +72,7 @@ def for_segmentation(
"""
transforms: list[PreTransformCallable] = []
if not requires_target:
return cls(transforms)
return cls(transforms, data_description, requires_target)

if data_description.index_map is None:
raise ConfigurationError("`index_map` is required for segmentation models when the target is required.")
Expand All @@ -88,7 +89,7 @@ def for_segmentation(
if not multiclass:
transforms.append(OneHotEncodeMask(index_map=data_description.index_map))

return cls(transforms)
return cls(transforms, data_description, requires_target)

@classmethod
def for_wsi_classification(
Expand All @@ -99,7 +100,7 @@ def for_wsi_classification(
transforms.append(SampleNFeatures(n=1000))

if not requires_target:
return cls(transforms)
return cls(transforms, data_description, requires_target)

index_map = data_description.index_map
if index_map is None:
Expand All @@ -112,7 +113,7 @@ def for_wsi_classification(

transforms.append(LabelToClassIndex(index_map=index_map))

return cls(transforms)
return cls(transforms, data_description, requires_target)

def __call__(self, data: DlupDatasetSample) -> DlupDatasetSample:
for transform in self._transforms:
Expand Down Expand Up @@ -311,9 +312,10 @@ def __call__(self, sample: DlupDatasetSample) -> dict[str, DlupDatasetSample]:

if "labels" in sample.keys() and sample["labels"] is not None:
for key, value in sample["labels"].items():
sample["labels"][key] = torch.tensor(value, dtype=torch.float32)
if sample["labels"][key].dim() == 0:
sample["labels"][key] = sample["labels"][key].unsqueeze(0)
if isinstance(value, float) or isinstance(value, int):
sample["labels"][key] = torch.tensor(value, dtype=torch.float32)
if sample["labels"][key].dim() == 0:
sample["labels"][key] = sample["labels"][key].unsqueeze(0)

# annotation_data is added by the ConvertPolygonToMask transform.
if "annotation_data" not in sample:
Expand Down
1 change: 1 addition & 0 deletions ahcore/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ class DataDescription(BaseModel):
convert_mask_to_rois: bool = True
use_roi: bool = True
apply_color_profile: bool = False
tiling_mode: Optional[str] = None
20 changes: 14 additions & 6 deletions ahcore/utils/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ class _AnnotationReadersDict(TypedDict):
"DARWIN_JSON": WsiAnnotations.from_darwin_json,
"GEOJSON": WsiAnnotations.from_geojson,
"PYVIPS": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.PYVIPS),
"TIFFFILE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.TIFFFILE),
"TIFFFILE": functools.partial(
SlideImage.from_file_path,
backend=DLUPImageBackend.TIFFFILE,
),
"OPENSLIDE": functools.partial(SlideImage.from_file_path, backend=DLUPImageBackend.OPENSLIDE),
}

Expand Down Expand Up @@ -200,11 +203,12 @@ def get_relevant_feature_info_from_record(
return image_path, mpp, tile_size, tile_overlap, backend, overwrite_mpp


def _get_rois(mask: Optional[WsiAnnotations], data_description: DataDescription, stage: str) -> Optional[Rois]:
def _get_rois(mask: Optional[_AnnotationReturnTypes], data_description: DataDescription, stage: str) -> Optional[Rois]:
if (mask is None) or (stage != "fit") or (not data_description.convert_mask_to_rois):
return None

assert data_description.training_grid is not None
assert isinstance(mask, WsiAnnotations) # this is necessary for the compute_rois to work

tile_size = data_description.training_grid.tile_size
tile_overlap = data_description.training_grid.tile_overlap
Expand Down Expand Up @@ -485,7 +489,7 @@ def get_image_info(
raise ValueError(f"Grid (for stage {stage}) is not defined in the data description.")

mask, annotations = get_mask_and_annotations_from_record(annotations_root, image)
assert isinstance(mask, WsiAnnotations) or (mask is None)
assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage)
mask_threshold = 0.0 if stage != "fit" else data_description.mask_threshold
rois = _get_rois(mask, data_description, stage)

Expand All @@ -495,7 +499,11 @@ def get_image_info(
backend = DLUPImageBackend[str(image.reader)]
mpp = getattr(grid_description, "mpp", 1.0)
overwrite_mpp = float(image.mpp)
tile_mode = TilingMode.overflow
tile_mode = (
TilingMode(data_description.tiling_mode)
if data_description.tiling_mode is not None
else TilingMode.overflow
)
output_tile_size = getattr(grid_description, "output_tile_size", None)

# Update the dictionary with the actual values
Expand Down Expand Up @@ -535,7 +543,7 @@ def datasets_from_data_description(
split_category=stage,
)

for patient in patients:
for patient_idx, patient in enumerate(patients):
patient_labels = get_labels_from_record(patient)

for image in patient.images:
Expand Down Expand Up @@ -583,7 +591,7 @@ def datasets_from_data_description(
and all(isinstance(i, int) for i in output_tile_size) # pylint: disable=not-an-iterable
or (output_tile_size is None)
)
assert isinstance(mask, WsiAnnotations) or (mask is None)
assert isinstance(mask, WsiAnnotations) or (mask is None) or isinstance(mask, SlideImage)
assert isinstance(mask_threshold, float) or mask_threshold is None
assert isinstance(annotations, WsiAnnotations) or (annotations is None)

Expand Down

0 comments on commit ea8a330

Please sign in to comment.