Skip to content

Commit

Permalink
Further refine training and visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 3, 2024
1 parent c93a07f commit f56ccc8
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

CUCIM_AVAILABLE = True
DEFAULT_DEVICE = "cuda"
logger.debug("GPU-accelerated xrspatial functions are available.")
logger.debug("GPU-accelerated cucim functions are available.")
except ImportError:
CUCIM_AVAILABLE = False
DEFAULT_DEVICE = "cpu"
logger.debug("GPU-accelerated xrspatial functions are not available.")
logger.debug("GPU-accelerated cucim functions are not available.")


def erode_mask(mask: xr.DataArray, size: int, device: Literal["cuda", "cpu"] | int) -> xr.DataArray:
Expand Down
116 changes: 65 additions & 51 deletions darts-segmentation/src/darts_segmentation/training/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import lightning as L # noqa: N812
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.optim as optim
import wandb
Expand Down Expand Up @@ -41,31 +42,39 @@ def __init__(self, config: SMPSegmenterConfig, learning_rate: float = 1e-5, gamm
self.save_hyperparameters()
self.model = smp.create_model(**config["model"], activation="sigmoid")

self.loss_fn = smp.losses.FocalLoss(mode="binary")
# Assumes that the training preparation was done with setting invalid pixels in the mask to 2
self.loss_fn = smp.losses.FocalLoss(mode="binary", ignore_index=2)

metric_kwargs = {"task": "binary", "validate_args": False, "ignore_index": 2}
metrics = MetricCollection(
{
"Accuracy": Accuracy(task="binary", validate_args=False),
"Precision": Precision(task="binary", validate_args=False),
"Specificity": Specificity(task="binary", validate_args=False),
"Recall": Recall(task="binary", validate_args=False),
"F1Score": F1Score(task="binary", validate_args=False),
"JaccardIndex": JaccardIndex(task="binary", validate_args=False),
"CohenKappa": CohenKappa(task="binary", validate_args=False),
"HammingDistance": HammingDistance(task="binary", validate_args=False),
"Accuracy": Accuracy(**metric_kwargs),
"Precision": Precision(**metric_kwargs),
"Specificity": Specificity(**metric_kwargs),
"Recall": Recall(**metric_kwargs),
"F1Score": F1Score(**metric_kwargs),
"JaccardIndex": JaccardIndex(**metric_kwargs),
"CohenKappa": CohenKappa(**metric_kwargs),
"HammingDistance": HammingDistance(**metric_kwargs),
}
)
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.val_metrics.add_metrics(
{
"AUROC": AUROC(task="binary", thresholds=20, validate_args=False),
"AveragePrecision": AveragePrecision(task="binary", thresholds=20, validate_args=False),
"AUROC": AUROC(thresholds=20, **metric_kwargs),
"AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs),
}
)
self.val_roc = ROC(task="binary", thresholds=20, validate_args=False)
self.val_prc = PrecisionRecallCurve(task="binary", thresholds=20, validate_args=False)
self.val_cmx = ConfusionMatrix(task="binary", normalize="true", validate_args=False)
self.val_roc = ROC(thresholds=20, **metric_kwargs)
self.val_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs)
self.val_cmx = ConfusionMatrix(normalize="true", **metric_kwargs)
self.plot_every_n_val_epochs = 5

@property
def is_val_plot_epoch(self):
n = self.plot_every_n_val_epochs * self.trainer.check_val_every_n_epoch
return ((self.current_epoch + 1) % n) == 0 or self.current_epoch == 0

def training_step(self, batch, batch_idx):
x, y = batch
Expand All @@ -91,49 +100,54 @@ def validation_step(self, batch, batch_idx):
self.val_cmx.update(y_hat, y)

# Create figures for the samples
for i in range(x.shape[0]):
fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"])
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({f"val/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step)
fig.clear()
if self.is_val_plot_epoch and batch_idx % 2 == 0:
for i in range(x.shape[0]):
fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"])
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({f"val-samples/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step)
fig.clear()
plt.close(fig)

return loss

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute())

self.val_cmx.compute()
self.val_roc.compute()
self.val_prc.compute()

# Plot roc, prc and confusion matrix to disk and wandb
fig_cmx, _ = self.val_cmx.plot(cmap="Blues")
fig_roc, _ = self.val_roc.plot(score=True)
fig_prc, _ = self.val_prc.plot(score=True)

# Check for a wandb or csv logger to log the images
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png")
fig_roc.savefig(fig_dir / f"roc_{self.global_step}png")
fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step)
wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step)
wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step)

fig_cmx.clear()
fig_roc.clear()
fig_prc.clear()
# Only do this every self.plot_every_n_val_epochs epochs
if self.is_val_plot_epoch:
self.val_cmx.compute()
self.val_roc.compute()
self.val_prc.compute()

# Plot roc, prc and confusion matrix to disk and wandb
fig_cmx, _ = self.val_cmx.plot(cmap="Blues")
fig_roc, _ = self.val_roc.plot(score=True)
fig_prc, _ = self.val_prc.plot(score=True)

# Check for a wandb or csv logger to log the images
for logger in self.loggers:
if isinstance(logger, CSVLogger):
fig_dir = Path(logger.log_dir) / "figures"
fig_dir.mkdir(exist_ok=True)
fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png")
fig_roc.savefig(fig_dir / f"roc_{self.global_step}png")
fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png")
if isinstance(logger, WandbLogger):
wandb_run: Run = logger.experiment
wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step)
wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step)
wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step)

fig_cmx.clear()
fig_roc.clear()
fig_prc.clear()
plt.close("all")

self.val_metrics.reset()
self.val_roc.reset()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Functions to prepare the training data for the segmentation model training."""

import logging
from collections.abc import Generator
from typing import Literal

import geopandas as gpd
import torch
import xarray as xr

# TODO: move erode_mask to darts_utils, since uasge is not limited to prepare_export
from darts_postprocessing.prepare_export import erode_mask
from geocube.api.core import make_geocube

from darts_segmentation.utils import create_patches

logger = logging.getLogger(__name__.replace("darts_", "darts."))


def create_training_patches(
tile: xr.Dataset,
Expand All @@ -17,8 +24,10 @@ def create_training_patches(
norm_factors: dict[str, float],
patch_size: int,
overlap: int,
include_allzero: bool,
include_allzero: bool, # TODO: rename to include_nopositive
include_nan_edges: bool,
device: Literal["cuda", "cpu"] | int,
mask_erosion_size: int,
) -> Generator[tuple[torch.tensor, torch.tensor]]:
"""Create training patches from a tile and labels.
Expand All @@ -31,6 +40,8 @@ def create_training_patches(
overlap (int): The size of the overlap.
include_allzero (bool): Whether to include patches where the labels are all zero.
include_nan_edges (bool): Whether to include patches where the input data has nan values at the edges.
device (Literal["cuda", "cpu"] | int): The device to use for the erosion.
mask_erosion_size (int): The size of the disk to use for erosion.
Yields:
Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors.
Expand All @@ -43,8 +54,10 @@ def create_training_patches(
# Rasterize the labels
labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003

# Filter out the nodata values
labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0)
# Filter out the nodata values (class 2 -> invalid data)
mask = erode_mask(tile["valid_data_mask"], mask_erosion_size, device)
mask = tile["valid_data_mask"]
labels_rasterized = xr.where(mask, labels_rasterized, 2)

# Normalize the bands and clip the values
for band in bands:
Expand Down Expand Up @@ -83,13 +96,22 @@ def create_training_patches(
x = tensor_patches[i]
y = tensor_labels[i]

if not include_allzero and y.sum() == 0:
if not include_allzero and not (y == 1).any():
continue

if not include_nan_edges and torch.isnan(x).any():
continue

# Skip where there are less than 10% visible pixel
if ((y != 2).sum() / y.numel()) < 0.1:
continue

# Skip patches where everything is nan
if torch.isnan(x).all():
continue

# Convert all nan values to 0
x[torch.isnan(x)] = 0

logger.debug(f"Yielding patch {i} with\n\t{x=}\n\t{y=}")
yield x, y
Loading

0 comments on commit f56ccc8

Please sign in to comment.