From 2ad73fe2d64d05734e6ef6adf9c1be535412d9c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sun, 1 Dec 2024 21:19:54 +0100 Subject: [PATCH] Add visualizations to training --- .../src/darts_segmentation/training/data.py | 17 +--- .../src/darts_segmentation/training/module.py | 97 ++++++++++++++++--- .../training/prepare_training.py | 13 +++ .../src/darts_segmentation/training/viz.py | 89 +++++++++++++++++ darts/src/darts/training.py | 95 +++++++++++------- 5 files changed, 252 insertions(+), 59 deletions(-) create mode 100644 darts-segmentation/src/darts_segmentation/training/viz.py diff --git a/darts-segmentation/src/darts_segmentation/training/data.py b/darts-segmentation/src/darts_segmentation/training/data.py index 4e779eb..983bea5 100644 --- a/darts-segmentation/src/darts_segmentation/training/data.py +++ b/darts-segmentation/src/darts_segmentation/training/data.py @@ -10,7 +10,6 @@ import albumentations as A # noqa: N812 import lightning as L # noqa: N812 -import numpy as np import torch from torch.utils.data import DataLoader, Dataset, random_split @@ -18,8 +17,7 @@ class DartsDataset(Dataset): - def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]): - self.norm_factors = np.array(norm_factors, dtype=np.float32).reshape(-1, 1, 1) + def __init__(self, data_dir: Path, augment: bool): self.x_files = sorted((data_dir / "x").glob("*.pt")) self.y_files = sorted((data_dir / "y").glob("*.pt")) @@ -33,7 +31,7 @@ def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]): A.HorizontalFlip(), A.VerticalFlip(), A.RandomRotate90(), - A.Blur(), + # A.Blur(), A.RandomBrightnessContrast(), A.MultiplicativeNoise(per_channel=True, elementwise=True), # ToTensorV2(), @@ -54,10 +52,6 @@ def __getitem__(self, idx): x = torch.load(xfile).numpy() y = torch.load(yfile).int().numpy() - # Normalize the bands and clip the values - x = x * self.norm_factors - x = np.clip(x, 0, 1) - # Apply augmentations if self.transform is not None: augmented = self.transform(image=x.transpose(1, 2, 0), mask=y) @@ -68,19 +62,16 @@ def __getitem__(self, idx): class DartsDataModule(L.LightningDataModule): - def __init__( - self, data_dir: Path, norm_factors: list[float], batch_size: int, augment: bool = True, num_workers: int = 0 - ): + def __init__(self, data_dir: Path, batch_size: int, augment: bool = True, num_workers: int = 0): super().__init__() self.save_hyperparameters() self.data_dir = data_dir self.batch_size = batch_size self.augment = augment - self.norm_factors = norm_factors self.num_workers = num_workers def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None): - dataset = DartsDataset(self.data_dir, self.augment, self.norm_factors) + dataset = DartsDataset(self.data_dir, self.augment) splits = [0.8, 0.1, 0.1] generator = torch.Generator().manual_seed(42) self.train, self.val, self.test = random_split(dataset, splits, generator) diff --git a/darts-segmentation/src/darts_segmentation/training/module.py b/darts-segmentation/src/darts_segmentation/training/module.py index 7eaa9fb..781c4bf 100644 --- a/darts-segmentation/src/darts_segmentation/training/module.py +++ b/darts-segmentation/src/darts_segmentation/training/module.py @@ -6,30 +6,42 @@ """Training script for DARTS segmentation.""" +from pathlib import Path + import lightning as L # noqa: N812 import segmentation_models_pytorch as smp -import torch.nn as nn import torch.optim as optim +import wandb +from lightning.pytorch.loggers import CSVLogger, WandbLogger from torchmetrics import ( + AUROC, + ROC, Accuracy, + AveragePrecision, CohenKappa, + ConfusionMatrix, F1Score, HammingDistance, JaccardIndex, MetricCollection, Precision, + PrecisionRecallCurve, Recall, Specificity, ) +from wandb.sdk.wandb_run import Run from darts_segmentation.segment import SMPSegmenterConfig +from darts_segmentation.training.viz import plot_sample class SMPSegmenter(L.LightningModule): - def __init__(self, config: SMPSegmenterConfig): + def __init__(self, config: SMPSegmenterConfig, learning_rate: float = 1e-5, gamma: float = 0.9): super().__init__() self.save_hyperparameters() - self.model = smp.create_model(**config["model"]) + self.model = smp.create_model(**config["model"], activation="sigmoid") + + self.loss_fn = smp.losses.FocalLoss(mode="binary") metrics = MetricCollection( { @@ -43,15 +55,24 @@ def __init__(self, config: SMPSegmenterConfig): "HammingDistance": HammingDistance(task="binary", validate_args=False), } ) - self.train_metrics = metrics.clone(prefix="train_") - self.val_metrics = metrics.clone(prefix="val_") + self.train_metrics = metrics.clone(prefix="train/") + self.val_metrics = metrics.clone(prefix="val/") + self.val_metrics.add_metrics( + { + "AUROC": AUROC(task="binary", thresholds=20, validate_args=False), + "AveragePrecision": AveragePrecision(task="binary", thresholds=20, validate_args=False), + } + ) + self.val_roc = ROC(task="binary", thresholds=20, validate_args=False) + self.val_prc = PrecisionRecallCurve(task="binary", thresholds=20, validate_args=False) + self.val_cmx = ConfusionMatrix(task="binary", normalize="true", validate_args=False) def training_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) - loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) + y_hat = self.model(x).squeeze(1) + loss = self.loss_fn(y_hat, y.long()) self.train_metrics(y_hat, y) - self.log("train_loss", loss) + self.log("train/loss", loss) self.log_dict(self.train_metrics, on_step=True, on_epoch=False) return loss @@ -60,16 +81,66 @@ def on_train_epoch_end(self): def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) - loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) - self.log("val_loss", loss) + y_hat = self.model(x).squeeze(1) + loss = self.loss_fn(y_hat, y.long()) + self.log("val/loss", loss) + self.val_metrics.update(y_hat, y) + self.val_roc.update(y_hat, y) + self.val_prc.update(y_hat, y) + self.val_cmx.update(y_hat, y) + + # Create figures for the samples + for i in range(x.shape[0]): + fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"]) + for logger in self.loggers: + if isinstance(logger, CSVLogger): + fig_dir = Path(logger.log_dir) / "figures" + fig_dir.mkdir(exist_ok=True) + fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png") + if isinstance(logger, WandbLogger): + wandb_run: Run = logger.experiment + wandb_run.log({f"val/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step) + fig.clear() + return loss def on_validation_epoch_end(self): self.log_dict(self.val_metrics.compute()) + + self.val_cmx.compute() + self.val_roc.compute() + self.val_prc.compute() + + # Plot roc, prc and confusion matrix to disk and wandb + fig_cmx, _ = self.val_cmx.plot(cmap="Blues") + fig_roc, _ = self.val_roc.plot(score=True) + fig_prc, _ = self.val_prc.plot(score=True) + + # Check for a wandb or csv logger to log the images + for logger in self.loggers: + if isinstance(logger, CSVLogger): + fig_dir = Path(logger.log_dir) / "figures" + fig_dir.mkdir(exist_ok=True) + fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png") + fig_roc.savefig(fig_dir / f"roc_{self.global_step}png") + fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png") + if isinstance(logger, WandbLogger): + wandb_run: Run = logger.experiment + wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step) + wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step) + wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step) + + fig_cmx.clear() + fig_roc.clear() + fig_prc.clear() + self.val_metrics.reset() + self.val_roc.reset() + self.val_prc.reset() + self.val_cmx.reset() def configure_optimizers(self): - optimizer = optim.AdamW(self.parameters(), lr=1e-3) - return optimizer + optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma) + return [optimizer], [scheduler] diff --git a/darts-segmentation/src/darts_segmentation/training/prepare_training.py b/darts-segmentation/src/darts_segmentation/training/prepare_training.py index fd1e925..25ae160 100644 --- a/darts-segmentation/src/darts_segmentation/training/prepare_training.py +++ b/darts-segmentation/src/darts_segmentation/training/prepare_training.py @@ -14,6 +14,7 @@ def create_training_patches( tile: xr.Dataset, labels: gpd.GeoDataFrame, bands: list[str], + norm_factors: dict[str, float], patch_size: int, overlap: int, include_allzero: bool, @@ -25,6 +26,7 @@ def create_training_patches( tile (xr.Dataset): The input tile, containing preprocessed, harmonized data. labels (gpd.GeoDataFrame): The labels to be used for training. bands (list[str]): The bands to be used for training. Must be present in the tile. + norm_factors (dict[str, float]): The normalization factors for the bands. patch_size (int): The size of the patches. overlap (int): The size of the overlap. include_allzero (bool): Whether to include patches where the labels are all zero. @@ -34,6 +36,9 @@ def create_training_patches( Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. The input has the format (C, H, W), the labels (H, W). + Raises: + ValueError: If a band is not found in the preprocessed data. + """ # Rasterize the labels labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003 @@ -41,6 +46,14 @@ def create_training_patches( # Filter out the nodata values labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0) + # Normalize the bands and clip the values + for band in bands: + if band not in tile: + raise ValueError(f"Band '{band}' not found in the preprocessed data.") + with xr.set_options(keep_attrs=True): + tile[band] = tile[band] * norm_factors[band] + tile[band] = tile[band].clip(0, 1) + # Replace invalid values with nan (used for nan check later on) tile = xr.where(tile["valid_data_mask"], tile, float("nan")) diff --git a/darts-segmentation/src/darts_segmentation/training/viz.py b/darts-segmentation/src/darts_segmentation/training/viz.py new file mode 100644 index 0000000..2266863 --- /dev/null +++ b/darts-segmentation/src/darts_segmentation/training/viz.py @@ -0,0 +1,89 @@ +"""Visualization utilities for the training module.""" + +import matplotlib.colors as mcolors +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import torch + + +def plot_sample(x, y, y_pred, input_combinations: list[str]): + """Plot a single sample with the input, the ground truth and the prediction. + + Args: + x (torch.Tensor): The input tensor [C, H, W] (float). + y (torch.Tensor): The ground truth tensor [H, W] (int). + y_pred (torch.Tensor): The prediction tensor [H, W] (float). + input_combinations (list[str]): The combinations of the input bands. + + Returns: + tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot. + + """ + x = x.cpu() + y = y.cpu() + y_pred = y_pred.detach().cpu() + + classification_labels = (y_pred > 0.5).int() + y * 2 + classification_labels = classification_labels.where(classification_labels != 0, torch.nan) + + # Calculate accuracy and iou + true_positive = (classification_labels == 3).sum() + false_positive = (classification_labels == 1).sum() + false_negative = (classification_labels == 2).sum() + acc = true_positive / (true_positive + false_positive + false_negative) + + cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"]) + fig, axs = plt.subplot_mosaic([["a", "a", "b", "c"], ["a", "a", "d", "e"]], layout="constrained", figsize=(16, 8)) + + # RGB Plot + red_band = input_combinations.index("red") + green_band = input_combinations.index("green") + blue_band = input_combinations.index("blue") + rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1) + ax_rgb = axs["a"] + ax_rgb.imshow(rgb ** (1 / 1.4)) + ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3) + # Add a legend + patches = [ + mpatches.Patch(color="#6cd875", label="True Positive"), + mpatches.Patch(color="#3e0f2f", label="False Negative"), + mpatches.Patch(color="#cd43b2", label="False Positive"), + ] + ax_rgb.legend(handles=patches, loc="upper left") + # disable axis + ax_rgb.axis("off") + ax_rgb.set_title(f"Accuracy: {acc:.1%}") + + # NIR Plot + nir_band = input_combinations.index("nir") + nir = x[nir_band] + ax_nir = axs["b"] + ax_nir.imshow(nir, vmin=0, vmax=1) + ax_nir.axis("off") + ax_nir.set_title("NIR") + + # TCVIS Plot + tcb_band = input_combinations.index("tc_brightness") + tcg_band = input_combinations.index("tc_greenness") + tcw_band = input_combinations.index("tc_wetness") + tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1) + ax_tcv = axs["c"] + ax_tcv.imshow(tcvis) + ax_tcv.axis("off") + ax_tcv.set_title("TCVIS") + + # NDVI Plot + ndvi_band = input_combinations.index("ndvi") + ndvi = x[ndvi_band] + ax_ndvi = axs["d"] + ax_ndvi.imshow(ndvi, vmin=0, vmax=1) + ax_ndvi.axis("off") + ax_ndvi.set_title("NDVI") + + # Prediction Plot + ax_mask = axs["e"] + ax_mask.imshow(y_pred, vmin=0, vmax=1) + ax_mask.axis("off") + ax_mask.set_title("Prediction") + + return fig, axs diff --git a/darts/src/darts/training.py b/darts/src/darts/training.py index 6b63762..47846e8 100644 --- a/darts/src/darts/training.py +++ b/darts/src/darts/training.py @@ -11,6 +11,21 @@ logger = logging.getLogger(__name__) +# We hardcode these because they depend on the preprocessing used +NORM_FACTORS = { + "red": 1 / 3000, + "green": 1 / 3000, + "blue": 1 / 3000, + "nir": 1 / 3000, + "ndvi": 1 / 20000, + "relative_elevation": 1 / 30000, + "slope": 1 / 90, + "tc_brightness": 1 / 255, + "tc_greenness": 1 / 255, + "tc_wetness": 1 / 255, +} + + def preprocess_s2_train_data( *, sentinel2_dir: Path, @@ -86,6 +101,22 @@ def preprocess_s2_train_data( outpath_x.mkdir(exist_ok=True, parents=True) outpath_y.mkdir(exist_ok=True, parents=True) + # We hardcode these because they depend on the preprocessing used + norm_factors = { + "red": 1 / 3000, + "green": 1 / 3000, + "blue": 1 / 3000, + "nir": 1 / 3000, + "ndvi": 1 / 20000, + "relative_elevation": 1 / 30000, + "slope": 1 / 90, + "tc_brightness": 1 / 255, + "tc_greenness": 1 / 255, + "tc_wetness": 1 / 255, + } + # Filter out bands that are not in the specified bands + norm_factors = {k: v for k, v in norm_factors.items() if k in bands} + # Find all Sentinel 2 scenes n_patches = 0 for fpath in sentinel2_dir.glob("*/"): @@ -111,7 +142,16 @@ def preprocess_s2_train_data( # Save the patches tile_id = optical.attrs["tile_id"] - gen = create_training_patches(tile, labels, bands, patch_size, overlap, include_allzero, include_nan_edges) + gen = create_training_patches( + tile, + labels, + bands, + norm_factors, + patch_size, + overlap, + include_allzero, + include_nan_edges, + ) for patch_id, (x, y) in enumerate(gen): torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt") torch.save(y, outpath_y / f"{tile_id}_pid{patch_id}.pt") @@ -134,6 +174,7 @@ def preprocess_s2_train_data( "arcticdem_dir": arcticdem_dir, "tcvis_dir": tcvis_dir, "bands": bands, + "norm_factors": norm_factors, "device": device, "ee_project": ee_project, "ee_use_highvolume": ee_use_highvolume, @@ -156,8 +197,12 @@ def preprocess_s2_train_data( def train_smp( + *, train_data_dir: Path, artifact_dir: Path = Path("lightning_logs"), + model_arch: str = "Unet", + model_encoder: str = "dpn107", + model_encoder_weights: str | None = None, augment: bool = True, batch_size: int = 8, num_workers: int = 0, @@ -167,10 +212,15 @@ def train_smp( ): """Run the training of the SMP model. + Please see https://smp.readthedocs.io/en/latest/index.html for model configurations. + Args: train_data_dir (Path): Path to the training data directory. artifact_dir (Path, optional): Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs"). + model_arch (str, optional): Model architecture to use. Defaults to "Unet". + model_encoder (str, optional): Encoder to use. Defaults to "dpn107". + model_encoder_weights (str | None, optional): Path to the encoder weights. Defaults to None. augment (bool, optional): Weather to apply augments or not. Defaults to True. batch_size (int, optional): Batch Size. Defaults to 8. num_workers (int, optional): Number of Dataloader workers. Defaults to 0. @@ -192,43 +242,21 @@ def train_smp( torch.set_float32_matmul_precision("medium") + preprocess_config = toml.load(train_data_dir / "config.toml")["darts"] + config = SMPSegmenterConfig( - input_combination=[ - "blue", - "green", - "red", - "nir", - "ndvi", - "tc_brightness", - "tc_greenness", - "tc_wetness", - "relative_elevation", - "slope", - ], + input_combination=preprocess_config["bands"], model={ - "arch": "Unet", - "encoder_name": "dpn107", - "encoder_weights": "imagenet+5k", - "in_channels": 10, + "arch": model_arch, + "encoder_name": model_encoder, + "encoder_weights": model_encoder_weights, + "in_channels": len(preprocess_config["bands"]), "classes": 1, }, - norm_factors={ - "red": 1 / 3000, - "green": 1 / 3000, - "blue": 1 / 3000, - "nir": 1 / 3000, - "ndvi": 1 / 20000, - "relative_elevation": 1 / 30000, - "slope": 1 / 90, - "tc_brightness": 1 / 255, - "tc_greenness": 1 / 255, - "tc_wetness": 1 / 255, - }, + norm_factors=preprocess_config["norm_factors"], ) - norm_factors_ordered = [config["norm_factors"][band] for band in config["input_combination"]] - - datamodule = DartsDataModule(train_data_dir, norm_factors_ordered, batch_size, augment, num_workers) + datamodule = DartsDataModule(train_data_dir, batch_size, augment, num_workers) model = SMPSegmenter(config) @@ -238,7 +266,7 @@ def train_smp( if wandb_entity and wandb_project: wandb_logger = WandbLogger(save_dir=artifact_dir, name=run_name, project=wandb_project, entity=wandb_entity) trainer_loggers.append(wandb_logger) - early_stopping = EarlyStopping(monitor="val_loss") + early_stopping = EarlyStopping(monitor="val/JaccardIndex", mode="max", patience=5) callbacks = [ early_stopping, RichProgressBar(), @@ -249,5 +277,6 @@ def train_smp( callbacks=callbacks, log_every_n_steps=10, logger=trainer_loggers, + check_val_every_n_epoch=3, ) trainer.fit(model, datamodule)