Skip to content

Commit

Permalink
Add a minimal training method with lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Nov 30, 2024
1 parent 0c78f74 commit ddcf79d
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ wheels/
models
data
logs
wandb
lightning_logs
artifacts

# Mkdocs
.cache
Expand Down
95 changes: 95 additions & 0 deletions darts-segmentation/src/darts_segmentation/training/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# ruff: noqa: D101
# ruff: noqa: D102
# ruff: noqa: D105
# ruff: noqa: D107
"""Training script for DARTS segmentation."""

import logging
from pathlib import Path
from typing import Literal

import albumentations as A # noqa: N812
import lightning as L # noqa: N812
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split

logger = logging.getLogger(__name__.replace("darts_", "darts."))


class DartsDataset(Dataset):
def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]):
self.norm_factors = np.array(norm_factors, dtype=np.float32).reshape(-1, 1, 1)
self.x_files = sorted((data_dir / "x").glob("*.pt"))
self.y_files = sorted((data_dir / "y").glob("*.pt"))

assert len(self.x_files) == len(
self.y_files
), f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!"

self.transform = (
A.Compose(
[
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomRotate90(),
A.Blur(),
A.RandomBrightnessContrast(),
A.MultiplicativeNoise(per_channel=True, elementwise=True),
# ToTensorV2(),
]
)
if augment
else None
)

def __len__(self):
return len(self.x_files)

def __getitem__(self, idx):
xfile = self.x_files[idx]
yfile = self.y_files[idx]
assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"

x = torch.load(xfile).numpy()
y = torch.load(yfile).int().numpy()

# Normalize the bands and clip the values
x = x * self.norm_factors
x = np.clip(x, 0, 1)

# Apply augmentations
if self.transform is not None:
augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
x = augmented["image"].transpose(2, 0, 1)
y = augmented["mask"]

return x, y


class DartsDataModule(L.LightningDataModule):
def __init__(
self, data_dir: Path, norm_factors: list[float], batch_size: int, augment: bool = True, num_workers: int = 0
):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.batch_size = batch_size
self.augment = augment
self.norm_factors = norm_factors
self.num_workers = num_workers

def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
dataset = DartsDataset(self.data_dir, self.augment, self.norm_factors)
splits = [0.8, 0.1, 0.1]
generator = torch.Generator().manual_seed(42)
self.train, self.val, self.test = random_split(dataset, splits, generator)

def train_dataloader(self):
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)
75 changes: 75 additions & 0 deletions darts-segmentation/src/darts_segmentation/training/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# ruff: noqa: D100
# ruff: noqa: D101
# ruff: noqa: D102
# ruff: noqa: D105
# ruff: noqa: D107

"""Training script for DARTS segmentation."""

import lightning as L # noqa: N812
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.optim as optim
from torchmetrics import (
Accuracy,
CohenKappa,
F1Score,
HammingDistance,
JaccardIndex,
MetricCollection,
Precision,
Recall,
Specificity,
)

from darts_segmentation.segment import SMPSegmenterConfig


class SMPSegmenter(L.LightningModule):
def __init__(self, config: SMPSegmenterConfig):
super().__init__()
self.save_hyperparameters()
self.model = smp.create_model(**config["model"])

metrics = MetricCollection(
{
"Accuracy": Accuracy(task="binary", validate_args=False),
"Precision": Precision(task="binary", validate_args=False),
"Specificity": Specificity(task="binary", validate_args=False),
"Recall": Recall(task="binary", validate_args=False),
"F1Score": F1Score(task="binary", validate_args=False),
"JaccardIndex": JaccardIndex(task="binary", validate_args=False),
"CohenKappa": CohenKappa(task="binary", validate_args=False),
"HammingDistance": HammingDistance(task="binary", validate_args=False),
}
)
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float())
self.train_metrics(y_hat, y)
self.log("train_loss", loss)
self.log_dict(self.train_metrics, on_step=True, on_epoch=False)
return loss

def on_train_epoch_end(self):
self.train_metrics.reset()

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float())
self.log("val_loss", loss)
self.val_metrics.update(y_hat, y)
return loss

def on_validation_epoch_end(self):
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()

def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=1e-3)
return optimizer
3 changes: 2 additions & 1 deletion darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
run_native_sentinel2_pipeline,
run_native_sentinel2_pipeline_fast,
)
from darts.training import preprocess_s2_train_data
from darts.training import preprocess_s2_train_data, train_smp
from darts.utils.config import ConfigParser
from darts.utils.logging import add_logging_handlers, setup_logging

Expand Down Expand Up @@ -70,6 +70,7 @@ def env_info():
app.command(group=pipeline_group)(run_native_sentinel2_pipeline_fast)

app.command(group=train_group)(preprocess_s2_train_data)
app.command(group=train_group)(train_smp)


# Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules
Expand Down
100 changes: 99 additions & 1 deletion darts/src/darts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def preprocess_s2_train_data(
from darts_acquisition.s2 import load_s2_masks, load_s2_scene
from darts_acquisition.tcvis import load_tcvis
from darts_preprocessing import preprocess_legacy_fast
from darts_segmentation.prepare_training import create_training_patches
from darts_segmentation.training.prepare_training import create_training_patches
from dask.distributed import Client, LocalCluster
from odc.stac import configure_rio

Expand Down Expand Up @@ -153,3 +153,101 @@ def preprocess_s2_train_data(

client.close()
cluster.close()


def train_smp(
train_data_dir: Path,
artifact_dir: Path = Path("lightning_logs"),
augment: bool = True,
batch_size: int = 8,
num_workers: int = 0,
wandb_entity: str | None = None,
wandb_project: str | None = None,
run_name: str | None = None,
):
"""Run the training of the SMP model.
Args:
train_data_dir (Path): Path to the training data directory.
artifact_dir (Path, optional): Path to the training output directory.
Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
augment (bool, optional): Weather to apply augments or not. Defaults to True.
batch_size (int, optional): Batch Size. Defaults to 8.
num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
run_name (str | None, optional): Name of this run, as a further grouping method for logs etc. Defaults to None.
"""
import lightning as L # noqa: N812
import lovely_tensors
import torch
from darts_segmentation.segment import SMPSegmenterConfig
from darts_segmentation.training.data import DartsDataModule
from darts_segmentation.training.module import SMPSegmenter
from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar
from lightning.pytorch.loggers import CSVLogger, WandbLogger

lovely_tensors.monkey_patch()

torch.set_float32_matmul_precision("medium")

config = SMPSegmenterConfig(
input_combination=[
"blue",
"green",
"red",
"nir",
"ndvi",
"tc_brightness",
"tc_greenness",
"tc_wetness",
"relative_elevation",
"slope",
],
model={
"arch": "Unet",
"encoder_name": "dpn107",
"encoder_weights": "imagenet+5k",
"in_channels": 10,
"classes": 1,
},
norm_factors={
"red": 1 / 3000,
"green": 1 / 3000,
"blue": 1 / 3000,
"nir": 1 / 3000,
"ndvi": 1 / 20000,
"relative_elevation": 1 / 30000,
"slope": 1 / 90,
"tc_brightness": 1 / 255,
"tc_greenness": 1 / 255,
"tc_wetness": 1 / 255,
},
)

norm_factors_ordered = [config["norm_factors"][band] for band in config["input_combination"]]

datamodule = DartsDataModule(train_data_dir, norm_factors_ordered, batch_size, augment, num_workers)

model = SMPSegmenter(config)

trainer_loggers = [
CSVLogger(save_dir=artifact_dir, name=run_name),
]
if wandb_entity and wandb_project:
wandb_logger = WandbLogger(save_dir=artifact_dir, name=run_name, project=wandb_project, entity=wandb_entity)
trainer_loggers.append(wandb_logger)
early_stopping = EarlyStopping(monitor="val_loss")
callbacks = [
early_stopping,
RichProgressBar(),
]

trainer = L.Trainer(
max_epochs=100,
callbacks=callbacks,
log_every_n_steps=10,
logger=trainer_loggers,
)
trainer.fit(model, datamodule)
8 changes: 7 additions & 1 deletion darts/src/darts/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,17 @@ def add_logging_handlers(command: str, console: Console, log_dir: Path):
log_dir (Path): The directory to save the logs to.
"""
import lightning as L # noqa: N812
import torch
import torch.utils.data

log_dir.mkdir(parents=True, exist_ok=True)
current_time = time.strftime("%Y-%m-%d_%H-%M-%S")

# Configure the rich console handler
rich_handler = RichHandler(console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts])
rich_handler = RichHandler(
console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts, L, torch, torch.utils.data]
)
rich_handler.setFormatter(
logging.Formatter(
"%(message)s",
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ dependencies = [
"bokeh>=3.5.2",
"jupyter-bokeh>=4.0.5",
"setuptools>=75.5.0",
"lightning>=2.4.0",
"albumentations>=1.4.21",
"wandb>=0.18.7",
"torchmetrics>=1.6.0",
]
readme = "README.md"
requires-python = ">= 3.11"
Expand Down

0 comments on commit ddcf79d

Please sign in to comment.