diff --git a/darts-postprocessing/src/darts_postprocessing/prepare_export.py b/darts-postprocessing/src/darts_postprocessing/prepare_export.py index 273323c..41607fd 100644 --- a/darts-postprocessing/src/darts_postprocessing/prepare_export.py +++ b/darts-postprocessing/src/darts_postprocessing/prepare_export.py @@ -18,11 +18,11 @@ CUCIM_AVAILABLE = True DEFAULT_DEVICE = "cuda" - logger.debug("GPU-accelerated xrspatial functions are available.") + logger.debug("GPU-accelerated cucim functions are available.") except ImportError: CUCIM_AVAILABLE = False DEFAULT_DEVICE = "cpu" - logger.debug("GPU-accelerated xrspatial functions are not available.") + logger.debug("GPU-accelerated cucim functions are not available.") def erode_mask(mask: xr.DataArray, size: int, device: Literal["cuda", "cpu"] | int) -> xr.DataArray: diff --git a/darts-segmentation/src/darts_segmentation/training/module.py b/darts-segmentation/src/darts_segmentation/training/module.py index 781c4bf..0022713 100644 --- a/darts-segmentation/src/darts_segmentation/training/module.py +++ b/darts-segmentation/src/darts_segmentation/training/module.py @@ -9,6 +9,7 @@ from pathlib import Path import lightning as L # noqa: N812 +import matplotlib.pyplot as plt import segmentation_models_pytorch as smp import torch.optim as optim import wandb @@ -41,31 +42,39 @@ def __init__(self, config: SMPSegmenterConfig, learning_rate: float = 1e-5, gamm self.save_hyperparameters() self.model = smp.create_model(**config["model"], activation="sigmoid") - self.loss_fn = smp.losses.FocalLoss(mode="binary") + # Assumes that the training preparation was done with setting invalid pixels in the mask to 2 + self.loss_fn = smp.losses.FocalLoss(mode="binary", ignore_index=2) + metric_kwargs = {"task": "binary", "validate_args": False, "ignore_index": 2} metrics = MetricCollection( { - "Accuracy": Accuracy(task="binary", validate_args=False), - "Precision": Precision(task="binary", validate_args=False), - "Specificity": Specificity(task="binary", validate_args=False), - "Recall": Recall(task="binary", validate_args=False), - "F1Score": F1Score(task="binary", validate_args=False), - "JaccardIndex": JaccardIndex(task="binary", validate_args=False), - "CohenKappa": CohenKappa(task="binary", validate_args=False), - "HammingDistance": HammingDistance(task="binary", validate_args=False), + "Accuracy": Accuracy(**metric_kwargs), + "Precision": Precision(**metric_kwargs), + "Specificity": Specificity(**metric_kwargs), + "Recall": Recall(**metric_kwargs), + "F1Score": F1Score(**metric_kwargs), + "JaccardIndex": JaccardIndex(**metric_kwargs), + "CohenKappa": CohenKappa(**metric_kwargs), + "HammingDistance": HammingDistance(**metric_kwargs), } ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") self.val_metrics.add_metrics( { - "AUROC": AUROC(task="binary", thresholds=20, validate_args=False), - "AveragePrecision": AveragePrecision(task="binary", thresholds=20, validate_args=False), + "AUROC": AUROC(thresholds=20, **metric_kwargs), + "AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs), } ) - self.val_roc = ROC(task="binary", thresholds=20, validate_args=False) - self.val_prc = PrecisionRecallCurve(task="binary", thresholds=20, validate_args=False) - self.val_cmx = ConfusionMatrix(task="binary", normalize="true", validate_args=False) + self.val_roc = ROC(thresholds=20, **metric_kwargs) + self.val_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs) + self.val_cmx = ConfusionMatrix(normalize="true", **metric_kwargs) + self.plot_every_n_val_epochs = 5 + + @property + def is_val_plot_epoch(self): + n = self.plot_every_n_val_epochs * self.trainer.check_val_every_n_epoch + return ((self.current_epoch + 1) % n) == 0 or self.current_epoch == 0 def training_step(self, batch, batch_idx): x, y = batch @@ -91,49 +100,54 @@ def validation_step(self, batch, batch_idx): self.val_cmx.update(y_hat, y) # Create figures for the samples - for i in range(x.shape[0]): - fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"]) - for logger in self.loggers: - if isinstance(logger, CSVLogger): - fig_dir = Path(logger.log_dir) / "figures" - fig_dir.mkdir(exist_ok=True) - fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png") - if isinstance(logger, WandbLogger): - wandb_run: Run = logger.experiment - wandb_run.log({f"val/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step) - fig.clear() + if self.is_val_plot_epoch and batch_idx % 2 == 0: + for i in range(x.shape[0]): + fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"]) + for logger in self.loggers: + if isinstance(logger, CSVLogger): + fig_dir = Path(logger.log_dir) / "figures" + fig_dir.mkdir(exist_ok=True) + fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png") + if isinstance(logger, WandbLogger): + wandb_run: Run = logger.experiment + wandb_run.log({f"val-samples/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step) + fig.clear() + plt.close(fig) return loss def on_validation_epoch_end(self): self.log_dict(self.val_metrics.compute()) - self.val_cmx.compute() - self.val_roc.compute() - self.val_prc.compute() - - # Plot roc, prc and confusion matrix to disk and wandb - fig_cmx, _ = self.val_cmx.plot(cmap="Blues") - fig_roc, _ = self.val_roc.plot(score=True) - fig_prc, _ = self.val_prc.plot(score=True) - - # Check for a wandb or csv logger to log the images - for logger in self.loggers: - if isinstance(logger, CSVLogger): - fig_dir = Path(logger.log_dir) / "figures" - fig_dir.mkdir(exist_ok=True) - fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png") - fig_roc.savefig(fig_dir / f"roc_{self.global_step}png") - fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png") - if isinstance(logger, WandbLogger): - wandb_run: Run = logger.experiment - wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step) - wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step) - wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step) - - fig_cmx.clear() - fig_roc.clear() - fig_prc.clear() + # Only do this every self.plot_every_n_val_epochs epochs + if self.is_val_plot_epoch: + self.val_cmx.compute() + self.val_roc.compute() + self.val_prc.compute() + + # Plot roc, prc and confusion matrix to disk and wandb + fig_cmx, _ = self.val_cmx.plot(cmap="Blues") + fig_roc, _ = self.val_roc.plot(score=True) + fig_prc, _ = self.val_prc.plot(score=True) + + # Check for a wandb or csv logger to log the images + for logger in self.loggers: + if isinstance(logger, CSVLogger): + fig_dir = Path(logger.log_dir) / "figures" + fig_dir.mkdir(exist_ok=True) + fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png") + fig_roc.savefig(fig_dir / f"roc_{self.global_step}png") + fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png") + if isinstance(logger, WandbLogger): + wandb_run: Run = logger.experiment + wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step) + wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step) + wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step) + + fig_cmx.clear() + fig_roc.clear() + fig_prc.clear() + plt.close("all") self.val_metrics.reset() self.val_roc.reset() diff --git a/darts-segmentation/src/darts_segmentation/training/prepare_training.py b/darts-segmentation/src/darts_segmentation/training/prepare_training.py index 25ae160..792c2cb 100644 --- a/darts-segmentation/src/darts_segmentation/training/prepare_training.py +++ b/darts-segmentation/src/darts_segmentation/training/prepare_training.py @@ -1,14 +1,21 @@ """Functions to prepare the training data for the segmentation model training.""" +import logging from collections.abc import Generator +from typing import Literal import geopandas as gpd import torch import xarray as xr + +# TODO: move erode_mask to darts_utils, since uasge is not limited to prepare_export +from darts_postprocessing.prepare_export import erode_mask from geocube.api.core import make_geocube from darts_segmentation.utils import create_patches +logger = logging.getLogger(__name__.replace("darts_", "darts.")) + def create_training_patches( tile: xr.Dataset, @@ -17,8 +24,10 @@ def create_training_patches( norm_factors: dict[str, float], patch_size: int, overlap: int, - include_allzero: bool, + include_allzero: bool, # TODO: rename to include_nopositive include_nan_edges: bool, + device: Literal["cuda", "cpu"] | int, + mask_erosion_size: int, ) -> Generator[tuple[torch.tensor, torch.tensor]]: """Create training patches from a tile and labels. @@ -31,6 +40,8 @@ def create_training_patches( overlap (int): The size of the overlap. include_allzero (bool): Whether to include patches where the labels are all zero. include_nan_edges (bool): Whether to include patches where the input data has nan values at the edges. + device (Literal["cuda", "cpu"] | int): The device to use for the erosion. + mask_erosion_size (int): The size of the disk to use for erosion. Yields: Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. @@ -43,8 +54,10 @@ def create_training_patches( # Rasterize the labels labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003 - # Filter out the nodata values - labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0) + # Filter out the nodata values (class 2 -> invalid data) + mask = erode_mask(tile["valid_data_mask"], mask_erosion_size, device) + mask = tile["valid_data_mask"] + labels_rasterized = xr.where(mask, labels_rasterized, 2) # Normalize the bands and clip the values for band in bands: @@ -83,13 +96,22 @@ def create_training_patches( x = tensor_patches[i] y = tensor_labels[i] - if not include_allzero and y.sum() == 0: + if not include_allzero and not (y == 1).any(): continue if not include_nan_edges and torch.isnan(x).any(): continue + # Skip where there are less than 10% visible pixel + if ((y != 2).sum() / y.numel()) < 0.1: + continue + + # Skip patches where everything is nan + if torch.isnan(x).all(): + continue + # Convert all nan values to 0 x[torch.isnan(x)] = 0 + logger.debug(f"Yielding patch {i} with\n\t{x=}\n\t{y=}") yield x, y diff --git a/darts-segmentation/src/darts_segmentation/training/viz.py b/darts-segmentation/src/darts_segmentation/training/viz.py index 2266863..3738960 100644 --- a/darts-segmentation/src/darts_segmentation/training/viz.py +++ b/darts-segmentation/src/darts_segmentation/training/viz.py @@ -1,14 +1,29 @@ """Visualization utilities for the training module.""" +import itertools + import matplotlib.colors as mcolors import matplotlib.patches as mpatches import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns import torch +# TODO: +# New Plot: Threshold vs. F1-Score and IoU + def plot_sample(x, y, y_pred, input_combinations: list[str]): """Plot a single sample with the input, the ground truth and the prediction. + This function does a few expections on the input: + - The input is expected to be normalized to 0-1. + - The prediction is expected to be converted from logits to prediction. + - The target is expected to be a int or long tensor with values of: + 0 (negative class) + 1 (positive class) and + 2 (invalid pixels). + Args: x (torch.Tensor): The input tensor [C, H, W] (float). y (torch.Tensor): The ground truth tensor [H, W] (int). @@ -23,24 +38,40 @@ def plot_sample(x, y, y_pred, input_combinations: list[str]): y = y.cpu() y_pred = y_pred.detach().cpu() + # Make y class 2 invalids (replace 2 with nan) + x = x.where(y != 2, torch.nan) + y_pred = y_pred.where(y != 2, torch.nan) + y = y.where(y != 2, torch.nan) + classification_labels = (y_pred > 0.5).int() + y * 2 classification_labels = classification_labels.where(classification_labels != 0, torch.nan) - # Calculate accuracy and iou + # Calculate f1 and iou true_positive = (classification_labels == 3).sum() false_positive = (classification_labels == 1).sum() false_negative = (classification_labels == 2).sum() - acc = true_positive / (true_positive + false_positive + false_negative) + true_negative = (classification_labels == 0).sum() + acc = (true_positive + true_negative) / (true_positive + true_negative + false_positive + false_negative) + f1 = 2 * true_positive / (2 * true_positive + false_positive + false_negative) + iou = true_positive / (true_positive + false_positive + false_negative) cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"]) - fig, axs = plt.subplot_mosaic([["a", "a", "b", "c"], ["a", "a", "d", "e"]], layout="constrained", figsize=(16, 8)) + fig, axs = plt.subplot_mosaic( + # [["rgb", "rgb", "ndvi", "tcvis", "stats"], ["rgb", "rgb", "pred", "slope", "elev"]], + [["rgb", "rgb", "pred", "tcvis"], ["rgb", "rgb", "ndvi", "slope"], ["none", "stats", "stats", "stats"]], + # layout="constrained", + figsize=(11, 8), + ) + + # Disable none plot + axs["none"].axis("off") # RGB Plot red_band = input_combinations.index("red") green_band = input_combinations.index("green") blue_band = input_combinations.index("blue") rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1) - ax_rgb = axs["a"] + ax_rgb = axs["rgb"] ax_rgb.imshow(rgb ** (1 / 1.4)) ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3) # Add a legend @@ -52,38 +83,78 @@ def plot_sample(x, y, y_pred, input_combinations: list[str]): ax_rgb.legend(handles=patches, loc="upper left") # disable axis ax_rgb.axis("off") - ax_rgb.set_title(f"Accuracy: {acc:.1%}") - - # NIR Plot - nir_band = input_combinations.index("nir") - nir = x[nir_band] - ax_nir = axs["b"] - ax_nir.imshow(nir, vmin=0, vmax=1) - ax_nir.axis("off") - ax_nir.set_title("NIR") - - # TCVIS Plot - tcb_band = input_combinations.index("tc_brightness") - tcg_band = input_combinations.index("tc_greenness") - tcw_band = input_combinations.index("tc_wetness") - tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1) - ax_tcv = axs["c"] - ax_tcv.imshow(tcvis) - ax_tcv.axis("off") - ax_tcv.set_title("TCVIS") + ax_rgb.set_title(f"Acc: {acc:.1%} F1: {f1:.1%} IoU: {iou:.1%}") # NDVI Plot ndvi_band = input_combinations.index("ndvi") ndvi = x[ndvi_band] - ax_ndvi = axs["d"] - ax_ndvi.imshow(ndvi, vmin=0, vmax=1) + ax_ndvi = axs["ndvi"] + ax_ndvi.imshow(ndvi, vmin=0, vmax=1, cmap="RdYlGn") ax_ndvi.axis("off") ax_ndvi.set_title("NDVI") + # TCVIS Plot + ax_tcv = axs["tcvis"] + ax_tcv.axis("off") + try: + tcb_band = input_combinations.index("tc_brightness") + tcg_band = input_combinations.index("tc_greenness") + tcw_band = input_combinations.index("tc_wetness") + tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1) + ax_tcv.imshow(tcvis) + ax_tcv.set_title("TCVIS") + except ValueError: + ax_tcv.set_title("No TCVIS values are provided!") + + # Statistics Plot + ax_stat = axs["stats"] + if (y == 1).sum() > 0: + n_bands = x.shape[0] + n_pixel = x.shape[1] * x.shape[2] + x_flat = x.flatten().cpu() + y_flat = y.flatten().repeat(n_bands).cpu() + bands = list(itertools.chain.from_iterable([input_combinations[i]] * n_pixel for i in range(n_bands))) + plot_data = pd.DataFrame({"x": x_flat, "y": y_flat, "band": bands}) + if len(plot_data) > 50000: + plot_data = plot_data.sample(50000) + plot_data = plot_data.sort_values("band") + sns.violinplot( + x="x", + y="band", + hue="y", + data=plot_data, + split=True, + inner="quart", + fill=False, + palette={1: "g", 0: ".35"}, + density_norm="width", + ax=ax_stat, + ) + ax_stat.set_title("Band Statistics") + else: + ax_stat.set_title("No positive labels in this sample!") + ax_stat.axis("off") + # Prediction Plot - ax_mask = axs["e"] + ax_mask = axs["pred"] ax_mask.imshow(y_pred, vmin=0, vmax=1) ax_mask.axis("off") - ax_mask.set_title("Prediction") + ax_mask.set_title("Model Output") + + # Slope Plot + slope_band = input_combinations.index("slope") + slope = x[slope_band] + ax_slope = axs["slope"] + ax_slope.imshow(slope, cmap="cividis") + ax_slope.axis("off") + ax_slope.set_title("Slope") + + # Relative Elevation Plot + # rel_elev_band = input_combinations.index("relative_elevation") + # rel_elev = x[rel_elev_band] + # ax_rel_elev = axs["elev"] + # ax_rel_elev.imshow(rel_elev, cmap="cividis") + # ax_rel_elev.axis("off") + # ax_rel_elev.set_title("Relative Elevation") return fig, axs diff --git a/darts/src/darts/training.py b/darts/src/darts/training.py index 2c9d291..ab61bcb 100644 --- a/darts/src/darts/training.py +++ b/darts/src/darts/training.py @@ -43,6 +43,7 @@ def preprocess_s2_train_data( overlap: int = 16, include_allzero: bool = False, include_nan_edges: bool = True, + mask_erosion_size: int = 10, ): """Preprocess Sentinel 2 data for training. @@ -69,6 +70,7 @@ def preprocess_s2_train_data( include_allzero (bool, optional): Whether to include patches where the labels are all zero. Defaults to False. include_nan_edges (bool, optional): Whether to include patches where the input data has nan values at the edges. Defaults to True. + mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping. """ # Import here to avoid long loading times when running other commands @@ -80,11 +82,13 @@ def preprocess_s2_train_data( from darts_preprocessing import preprocess_legacy_fast from darts_segmentation.training.prepare_training import create_training_patches from dask.distributed import Client, LocalCluster + from lovely_tensors import monkey_patch from odc.stac import configure_rio from darts.utils.cuda import debug_info, decide_device from darts.utils.earthengine import init_ee + monkey_patch() debug_info() device = decide_device(device) init_ee(ee_project, ee_use_highvolume) @@ -152,6 +156,8 @@ def preprocess_s2_train_data( overlap, include_allzero, include_nan_edges, + device, + mask_erosion_size, ) for patch_id, (x, y) in enumerate(gen): torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt") diff --git a/pyproject.toml b/pyproject.toml index a3650af..832165c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "albumentations>=1.4.21", "wandb>=0.18.7", "torchmetrics>=1.6.0", + "seaborn>=0.13.2", ] readme = "README.md" requires-python = ">= 3.11"