diff --git a/.gitignore b/.gitignore index ff73966..be1b799 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,9 @@ wheels/ models data logs +wandb +lightning_logs +artifacts # Mkdocs .cache diff --git a/darts-segmentation/src/darts_segmentation/training/data.py b/darts-segmentation/src/darts_segmentation/training/data.py new file mode 100644 index 0000000..4e779eb --- /dev/null +++ b/darts-segmentation/src/darts_segmentation/training/data.py @@ -0,0 +1,95 @@ +# ruff: noqa: D101 +# ruff: noqa: D102 +# ruff: noqa: D105 +# ruff: noqa: D107 +"""Training script for DARTS segmentation.""" + +import logging +from pathlib import Path +from typing import Literal + +import albumentations as A # noqa: N812 +import lightning as L # noqa: N812 +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, random_split + +logger = logging.getLogger(__name__.replace("darts_", "darts.")) + + +class DartsDataset(Dataset): + def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]): + self.norm_factors = np.array(norm_factors, dtype=np.float32).reshape(-1, 1, 1) + self.x_files = sorted((data_dir / "x").glob("*.pt")) + self.y_files = sorted((data_dir / "y").glob("*.pt")) + + assert len(self.x_files) == len( + self.y_files + ), f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!" + + self.transform = ( + A.Compose( + [ + A.HorizontalFlip(), + A.VerticalFlip(), + A.RandomRotate90(), + A.Blur(), + A.RandomBrightnessContrast(), + A.MultiplicativeNoise(per_channel=True, elementwise=True), + # ToTensorV2(), + ] + ) + if augment + else None + ) + + def __len__(self): + return len(self.x_files) + + def __getitem__(self, idx): + xfile = self.x_files[idx] + yfile = self.y_files[idx] + assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!" + + x = torch.load(xfile).numpy() + y = torch.load(yfile).int().numpy() + + # Normalize the bands and clip the values + x = x * self.norm_factors + x = np.clip(x, 0, 1) + + # Apply augmentations + if self.transform is not None: + augmented = self.transform(image=x.transpose(1, 2, 0), mask=y) + x = augmented["image"].transpose(2, 0, 1) + y = augmented["mask"] + + return x, y + + +class DartsDataModule(L.LightningDataModule): + def __init__( + self, data_dir: Path, norm_factors: list[float], batch_size: int, augment: bool = True, num_workers: int = 0 + ): + super().__init__() + self.save_hyperparameters() + self.data_dir = data_dir + self.batch_size = batch_size + self.augment = augment + self.norm_factors = norm_factors + self.num_workers = num_workers + + def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None): + dataset = DartsDataset(self.data_dir, self.augment, self.norm_factors) + splits = [0.8, 0.1, 0.1] + generator = torch.Generator().manual_seed(42) + self.train, self.val, self.test = random_split(dataset, splits, generator) + + def train_dataloader(self): + return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) + + def val_dataloader(self): + return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self): + return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers) diff --git a/darts-segmentation/src/darts_segmentation/training/module.py b/darts-segmentation/src/darts_segmentation/training/module.py new file mode 100644 index 0000000..7eaa9fb --- /dev/null +++ b/darts-segmentation/src/darts_segmentation/training/module.py @@ -0,0 +1,75 @@ +# ruff: noqa: D100 +# ruff: noqa: D101 +# ruff: noqa: D102 +# ruff: noqa: D105 +# ruff: noqa: D107 + +"""Training script for DARTS segmentation.""" + +import lightning as L # noqa: N812 +import segmentation_models_pytorch as smp +import torch.nn as nn +import torch.optim as optim +from torchmetrics import ( + Accuracy, + CohenKappa, + F1Score, + HammingDistance, + JaccardIndex, + MetricCollection, + Precision, + Recall, + Specificity, +) + +from darts_segmentation.segment import SMPSegmenterConfig + + +class SMPSegmenter(L.LightningModule): + def __init__(self, config: SMPSegmenterConfig): + super().__init__() + self.save_hyperparameters() + self.model = smp.create_model(**config["model"]) + + metrics = MetricCollection( + { + "Accuracy": Accuracy(task="binary", validate_args=False), + "Precision": Precision(task="binary", validate_args=False), + "Specificity": Specificity(task="binary", validate_args=False), + "Recall": Recall(task="binary", validate_args=False), + "F1Score": F1Score(task="binary", validate_args=False), + "JaccardIndex": JaccardIndex(task="binary", validate_args=False), + "CohenKappa": CohenKappa(task="binary", validate_args=False), + "HammingDistance": HammingDistance(task="binary", validate_args=False), + } + ) + self.train_metrics = metrics.clone(prefix="train_") + self.val_metrics = metrics.clone(prefix="val_") + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) + self.train_metrics(y_hat, y) + self.log("train_loss", loss) + self.log_dict(self.train_metrics, on_step=True, on_epoch=False) + return loss + + def on_train_epoch_end(self): + self.train_metrics.reset() + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) + self.log("val_loss", loss) + self.val_metrics.update(y_hat, y) + return loss + + def on_validation_epoch_end(self): + self.log_dict(self.val_metrics.compute()) + self.val_metrics.reset() + + def configure_optimizers(self): + optimizer = optim.AdamW(self.parameters(), lr=1e-3) + return optimizer diff --git a/darts-segmentation/src/darts_segmentation/prepare_training.py b/darts-segmentation/src/darts_segmentation/training/prepare_training.py similarity index 100% rename from darts-segmentation/src/darts_segmentation/prepare_training.py rename to darts-segmentation/src/darts_segmentation/training/prepare_training.py diff --git a/darts/src/darts/cli.py b/darts/src/darts/cli.py index ab77e4d..7e4cc85 100644 --- a/darts/src/darts/cli.py +++ b/darts/src/darts/cli.py @@ -15,7 +15,7 @@ run_native_sentinel2_pipeline, run_native_sentinel2_pipeline_fast, ) -from darts.training import preprocess_s2_train_data +from darts.training import preprocess_s2_train_data, train_smp from darts.utils.config import ConfigParser from darts.utils.logging import add_logging_handlers, setup_logging @@ -70,6 +70,7 @@ def env_info(): app.command(group=pipeline_group)(run_native_sentinel2_pipeline_fast) app.command(group=train_group)(preprocess_s2_train_data) +app.command(group=train_group)(train_smp) # Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules diff --git a/darts/src/darts/training.py b/darts/src/darts/training.py index c2b3a73..6b63762 100644 --- a/darts/src/darts/training.py +++ b/darts/src/darts/training.py @@ -62,7 +62,7 @@ def preprocess_s2_train_data( from darts_acquisition.s2 import load_s2_masks, load_s2_scene from darts_acquisition.tcvis import load_tcvis from darts_preprocessing import preprocess_legacy_fast - from darts_segmentation.prepare_training import create_training_patches + from darts_segmentation.training.prepare_training import create_training_patches from dask.distributed import Client, LocalCluster from odc.stac import configure_rio @@ -153,3 +153,101 @@ def preprocess_s2_train_data( client.close() cluster.close() + + +def train_smp( + train_data_dir: Path, + artifact_dir: Path = Path("lightning_logs"), + augment: bool = True, + batch_size: int = 8, + num_workers: int = 0, + wandb_entity: str | None = None, + wandb_project: str | None = None, + run_name: str | None = None, +): + """Run the training of the SMP model. + + Args: + train_data_dir (Path): Path to the training data directory. + artifact_dir (Path, optional): Path to the training output directory. + Will contain checkpoints and metrics. Defaults to Path("lightning_logs"). + augment (bool, optional): Weather to apply augments or not. Defaults to True. + batch_size (int, optional): Batch Size. Defaults to 8. + num_workers (int, optional): Number of Dataloader workers. Defaults to 0. + wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None. + wandb_project (str | None, optional): Weights and Biases Project. Defaults to None. + run_name (str | None, optional): Name of this run, as a further grouping method for logs etc. Defaults to None. + + """ + import lightning as L # noqa: N812 + import lovely_tensors + import torch + from darts_segmentation.segment import SMPSegmenterConfig + from darts_segmentation.training.data import DartsDataModule + from darts_segmentation.training.module import SMPSegmenter + from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar + from lightning.pytorch.loggers import CSVLogger, WandbLogger + + lovely_tensors.monkey_patch() + + torch.set_float32_matmul_precision("medium") + + config = SMPSegmenterConfig( + input_combination=[ + "blue", + "green", + "red", + "nir", + "ndvi", + "tc_brightness", + "tc_greenness", + "tc_wetness", + "relative_elevation", + "slope", + ], + model={ + "arch": "Unet", + "encoder_name": "dpn107", + "encoder_weights": "imagenet+5k", + "in_channels": 10, + "classes": 1, + }, + norm_factors={ + "red": 1 / 3000, + "green": 1 / 3000, + "blue": 1 / 3000, + "nir": 1 / 3000, + "ndvi": 1 / 20000, + "relative_elevation": 1 / 30000, + "slope": 1 / 90, + "tc_brightness": 1 / 255, + "tc_greenness": 1 / 255, + "tc_wetness": 1 / 255, + }, + ) + + norm_factors_ordered = [config["norm_factors"][band] for band in config["input_combination"]] + + datamodule = DartsDataModule(train_data_dir, norm_factors_ordered, batch_size, augment, num_workers) + + model = SMPSegmenter(config) + + trainer_loggers = [ + CSVLogger(save_dir=artifact_dir, name=run_name), + ] + if wandb_entity and wandb_project: + wandb_logger = WandbLogger(save_dir=artifact_dir, name=run_name, project=wandb_project, entity=wandb_entity) + trainer_loggers.append(wandb_logger) + early_stopping = EarlyStopping(monitor="val_loss") + callbacks = [ + early_stopping, + RichProgressBar(), + ] + + trainer = L.Trainer( + max_epochs=100, + callbacks=callbacks, + log_every_n_steps=10, + logger=trainer_loggers, + ) + trainer.fit(model, datamodule) diff --git a/darts/src/darts/utils/logging.py b/darts/src/darts/utils/logging.py index fe1730b..038f136 100644 --- a/darts/src/darts/utils/logging.py +++ b/darts/src/darts/utils/logging.py @@ -46,11 +46,17 @@ def add_logging_handlers(command: str, console: Console, log_dir: Path): log_dir (Path): The directory to save the logs to. """ + import lightning as L # noqa: N812 + import torch + import torch.utils.data + log_dir.mkdir(parents=True, exist_ok=True) current_time = time.strftime("%Y-%m-%d_%H-%M-%S") # Configure the rich console handler - rich_handler = RichHandler(console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts]) + rich_handler = RichHandler( + console=console, rich_tracebacks=True, tracebacks_suppress=[cyclopts, L, torch, torch.utils.data] + ) rich_handler.setFormatter( logging.Formatter( "%(message)s", diff --git a/pyproject.toml b/pyproject.toml index eb9829c..eda1728 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,10 @@ dependencies = [ "bokeh>=3.5.2", "jupyter-bokeh>=4.0.5", "setuptools>=75.5.0", + "lightning>=2.4.0", + "albumentations>=1.4.21", + "wandb>=0.18.7", + "torchmetrics>=1.6.0", ] readme = "README.md" requires-python = ">= 3.11"