diff --git a/darts-segmentation/src/darts_segmentation/segment.py b/darts-segmentation/src/darts_segmentation/segment.py index ba24054..865d750 100644 --- a/darts-segmentation/src/darts_segmentation/segment.py +++ b/darts-segmentation/src/darts_segmentation/segment.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy as np import segmentation_models_pytorch as smp import torch import torch.nn as nn @@ -58,7 +59,10 @@ def segment_tile(self, tile: xr.Dataset) -> xr.Dataset: Input tile augmented by a predicted `probabilities` layer of type uint8 and a `binarized` layer of type bool. """ + tensor_tile = self.tile2tensor(tile) + + predictions = tile["ndvi"].copy(data=tensor_tile[:1].numpy()) # TODO: Missing implementation - tile["probabilities"] = tile["ndvi"] # Highly sophisticated DL-based predictor - tile["binarized"] = tile["ndvi"] > 0 # Highly sophisticated DL-based predictor + tile["probabilities"] = (predictions / 255).astype(np.uint8) # Highly sophisticated DL-based predictor + tile["binarized"] = tile["ndvi"] > 0.5 # Highly sophisticated DL-based predictor return tile