Skip to content

Commit

Permalink
Add visualizations to training
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 1, 2024
1 parent ddcf79d commit 2ad73fe
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 59 deletions.
17 changes: 4 additions & 13 deletions darts-segmentation/src/darts_segmentation/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@

import albumentations as A # noqa: N812
import lightning as L # noqa: N812
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split

logger = logging.getLogger(__name__.replace("darts_", "darts."))


class DartsDataset(Dataset):
def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]):
self.norm_factors = np.array(norm_factors, dtype=np.float32).reshape(-1, 1, 1)
def __init__(self, data_dir: Path, augment: bool):
self.x_files = sorted((data_dir / "x").glob("*.pt"))
self.y_files = sorted((data_dir / "y").glob("*.pt"))

Expand All @@ -33,7 +31,7 @@ def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]):
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomRotate90(),
A.Blur(),
# A.Blur(),
A.RandomBrightnessContrast(),
A.MultiplicativeNoise(per_channel=True, elementwise=True),
# ToTensorV2(),
Expand All @@ -54,10 +52,6 @@ def __getitem__(self, idx):
x = torch.load(xfile).numpy()
y = torch.load(yfile).int().numpy()

# Normalize the bands and clip the values
x = x * self.norm_factors
x = np.clip(x, 0, 1)

# Apply augmentations
if self.transform is not None:
augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
Expand All @@ -68,19 +62,16 @@ def __getitem__(self, idx):


class DartsDataModule(L.LightningDataModule):
def __init__(
self, data_dir: Path, norm_factors: list[float], batch_size: int, augment: bool = True, num_workers: int = 0
):
def __init__(self, data_dir: Path, batch_size: int, augment: bool = True, num_workers: int = 0):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.batch_size = batch_size
self.augment = augment
self.norm_factors = norm_factors
self.num_workers = num_workers

def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
dataset = DartsDataset(self.data_dir, self.augment, self.norm_factors)
dataset = DartsDataset(self.data_dir, self.augment)
splits = [0.8, 0.1, 0.1]
generator = torch.Generator().manual_seed(42)
self.train, self.val, self.test = random_split(dataset, splits, generator)
Expand Down
97 changes: 84 additions & 13 deletions darts-segmentation/src/darts_segmentation/training/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,42 @@

"""Training script for DARTS segmentation."""

from pathlib import Path

import lightning as L # noqa: N812
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim as optim
import wandb
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from torchmetrics import (
AUROC,
ROC,
Accuracy,
AveragePrecision,
CohenKappa,
ConfusionMatrix,
F1Score,
HammingDistance,
JaccardIndex,
MetricCollection,
Precision,
PrecisionRecallCurve,
Recall,
Specificity,
)
from wandb.sdk.wandb_run import Run

from darts_segmentation.segment import SMPSegmenterConfig
from darts_segmentation.training.viz import plot_sample


class SMPSegmenter(L.LightningModule):
def __init__(self, config: SMPSegmenterConfig):
def __init__(self, config: SMPSegmenterConfig, learning_rate: float = 1e-5, gamma: float = 0.9):
super().__init__()
self.save_hyperparameters()
self.model = smp.create_model(**config["model"])
self.model = smp.create_model(**config["model"], activation="sigmoid")

self.loss_fn = smp.losses.FocalLoss(mode="binary")

metrics = MetricCollection(
{
Expand All @@ -43,15 +55,24 @@ def __init__(self, config: SMPSegmenterConfig):
"HammingDistance": HammingDistance(task="binary", validate_args=False),
}
)
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.val_metrics.add_metrics(
{
"AUROC": AUROC(task="binary", thresholds=20, validate_args=False),
"AveragePrecision": AveragePrecision(task="binary", thresholds=20, validate_args=False),
}
)
self.val_roc = ROC(task="binary", thresholds=20, validate_args=False)
self.val_prc = PrecisionRecallCurve(task="binary", thresholds=20, validate_args=False)
self.val_cmx = ConfusionMatrix(task="binary", normalize="true", validate_args=False)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float())
y_hat = self.model(x).squeeze(1)
loss = self.loss_fn(y_hat, y.long())
self.train_metrics(y_hat, y)
self.log("train_loss", loss)
self.log("train/loss", loss)
self.log_dict(self.train_metrics, on_step=True, on_epoch=False)
return loss

Expand All @@ -60,16 +81,66 @@ def on_train_epoch_end(self):

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float())
self.log("val_loss", loss)
y_hat = self.model(x).squeeze(1)
loss = self.loss_fn(y_hat, y.long())
self.log("val/loss", loss)

self.val_metrics.update(y_hat, y)
self.val_roc.update(y_hat, y)
self.val_prc.update(y_hat, y)
self.val_cmx.update(y_hat, y)

# Create figures for the samples
for i in range(x.shape[0]):
fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"])
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({f"val/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step)
fig.clear()

return loss

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute())

self.val_cmx.compute()
self.val_roc.compute()
self.val_prc.compute()

# Plot roc, prc and confusion matrix to disk and wandb
fig_cmx, _ = self.val_cmx.plot(cmap="Blues")
fig_roc, _ = self.val_roc.plot(score=True)
fig_prc, _ = self.val_prc.plot(score=True)

# Check for a wandb or csv logger to log the images
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png")
fig_roc.savefig(fig_dir / f"roc_{self.global_step}png")
fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step)
wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step)
wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step)

fig_cmx.clear()
fig_roc.clear()
fig_prc.clear()

self.val_metrics.reset()
self.val_roc.reset()
self.val_prc.reset()
self.val_cmx.reset()

def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=1e-3)
return optimizer
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma)
return [optimizer], [scheduler]
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def create_training_patches(
tile: xr.Dataset,
labels: gpd.GeoDataFrame,
bands: list[str],
norm_factors: dict[str, float],
patch_size: int,
overlap: int,
include_allzero: bool,
Expand All @@ -25,6 +26,7 @@ def create_training_patches(
tile (xr.Dataset): The input tile, containing preprocessed, harmonized data.
labels (gpd.GeoDataFrame): The labels to be used for training.
bands (list[str]): The bands to be used for training. Must be present in the tile.
norm_factors (dict[str, float]): The normalization factors for the bands.
patch_size (int): The size of the patches.
overlap (int): The size of the overlap.
include_allzero (bool): Whether to include patches where the labels are all zero.
Expand All @@ -34,13 +36,24 @@ def create_training_patches(
Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors.
The input has the format (C, H, W), the labels (H, W).
Raises:
ValueError: If a band is not found in the preprocessed data.
"""
# Rasterize the labels
labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003

# Filter out the nodata values
labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0)

# Normalize the bands and clip the values
for band in bands:
if band not in tile:
raise ValueError(f"Band '{band}' not found in the preprocessed data.")
with xr.set_options(keep_attrs=True):
tile[band] = tile[band] * norm_factors[band]
tile[band] = tile[band].clip(0, 1)

# Replace invalid values with nan (used for nan check later on)
tile = xr.where(tile["valid_data_mask"], tile, float("nan"))

Expand Down
89 changes: 89 additions & 0 deletions darts-segmentation/src/darts_segmentation/training/viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Visualization utilities for the training module."""

import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import torch


def plot_sample(x, y, y_pred, input_combinations: list[str]):
"""Plot a single sample with the input, the ground truth and the prediction.
Args:
x (torch.Tensor): The input tensor [C, H, W] (float).
y (torch.Tensor): The ground truth tensor [H, W] (int).
y_pred (torch.Tensor): The prediction tensor [H, W] (float).
input_combinations (list[str]): The combinations of the input bands.
Returns:
tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.
"""
x = x.cpu()
y = y.cpu()
y_pred = y_pred.detach().cpu()

classification_labels = (y_pred > 0.5).int() + y * 2
classification_labels = classification_labels.where(classification_labels != 0, torch.nan)

# Calculate accuracy and iou
true_positive = (classification_labels == 3).sum()
false_positive = (classification_labels == 1).sum()
false_negative = (classification_labels == 2).sum()
acc = true_positive / (true_positive + false_positive + false_negative)

cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"])
fig, axs = plt.subplot_mosaic([["a", "a", "b", "c"], ["a", "a", "d", "e"]], layout="constrained", figsize=(16, 8))

# RGB Plot
red_band = input_combinations.index("red")
green_band = input_combinations.index("green")
blue_band = input_combinations.index("blue")
rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1)
ax_rgb = axs["a"]
ax_rgb.imshow(rgb ** (1 / 1.4))
ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3)
# Add a legend
patches = [
mpatches.Patch(color="#6cd875", label="True Positive"),
mpatches.Patch(color="#3e0f2f", label="False Negative"),
mpatches.Patch(color="#cd43b2", label="False Positive"),
]
ax_rgb.legend(handles=patches, loc="upper left")
# disable axis
ax_rgb.axis("off")
ax_rgb.set_title(f"Accuracy: {acc:.1%}")

# NIR Plot
nir_band = input_combinations.index("nir")
nir = x[nir_band]
ax_nir = axs["b"]
ax_nir.imshow(nir, vmin=0, vmax=1)
ax_nir.axis("off")
ax_nir.set_title("NIR")

# TCVIS Plot
tcb_band = input_combinations.index("tc_brightness")
tcg_band = input_combinations.index("tc_greenness")
tcw_band = input_combinations.index("tc_wetness")
tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1)
ax_tcv = axs["c"]
ax_tcv.imshow(tcvis)
ax_tcv.axis("off")
ax_tcv.set_title("TCVIS")

# NDVI Plot
ndvi_band = input_combinations.index("ndvi")
ndvi = x[ndvi_band]
ax_ndvi = axs["d"]
ax_ndvi.imshow(ndvi, vmin=0, vmax=1)
ax_ndvi.axis("off")
ax_ndvi.set_title("NDVI")

# Prediction Plot
ax_mask = axs["e"]
ax_mask.imshow(y_pred, vmin=0, vmax=1)
ax_mask.axis("off")
ax_mask.set_title("Prediction")

return fig, axs
Loading

0 comments on commit 2ad73fe

Please sign in to comment.