-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a minimal training method with lightning
- Loading branch information
1 parent
0c78f74
commit ddcf79d
Showing
8 changed files
with
285 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,9 @@ wheels/ | |
models | ||
data | ||
logs | ||
wandb | ||
lightning_logs | ||
artifacts | ||
|
||
# Mkdocs | ||
.cache | ||
|
95 changes: 95 additions & 0 deletions
95
darts-segmentation/src/darts_segmentation/training/data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# ruff: noqa: D101 | ||
# ruff: noqa: D102 | ||
# ruff: noqa: D105 | ||
# ruff: noqa: D107 | ||
"""Training script for DARTS segmentation.""" | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
import albumentations as A # noqa: N812 | ||
import lightning as L # noqa: N812 | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader, Dataset, random_split | ||
|
||
logger = logging.getLogger(__name__.replace("darts_", "darts.")) | ||
|
||
|
||
class DartsDataset(Dataset): | ||
def __init__(self, data_dir: Path, augment: bool, norm_factors: list[float]): | ||
self.norm_factors = np.array(norm_factors, dtype=np.float32).reshape(-1, 1, 1) | ||
self.x_files = sorted((data_dir / "x").glob("*.pt")) | ||
self.y_files = sorted((data_dir / "y").glob("*.pt")) | ||
|
||
assert len(self.x_files) == len( | ||
self.y_files | ||
), f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!" | ||
|
||
self.transform = ( | ||
A.Compose( | ||
[ | ||
A.HorizontalFlip(), | ||
A.VerticalFlip(), | ||
A.RandomRotate90(), | ||
A.Blur(), | ||
A.RandomBrightnessContrast(), | ||
A.MultiplicativeNoise(per_channel=True, elementwise=True), | ||
# ToTensorV2(), | ||
] | ||
) | ||
if augment | ||
else None | ||
) | ||
|
||
def __len__(self): | ||
return len(self.x_files) | ||
|
||
def __getitem__(self, idx): | ||
xfile = self.x_files[idx] | ||
yfile = self.y_files[idx] | ||
assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!" | ||
|
||
x = torch.load(xfile).numpy() | ||
y = torch.load(yfile).int().numpy() | ||
|
||
# Normalize the bands and clip the values | ||
x = x * self.norm_factors | ||
x = np.clip(x, 0, 1) | ||
|
||
# Apply augmentations | ||
if self.transform is not None: | ||
augmented = self.transform(image=x.transpose(1, 2, 0), mask=y) | ||
x = augmented["image"].transpose(2, 0, 1) | ||
y = augmented["mask"] | ||
|
||
return x, y | ||
|
||
|
||
class DartsDataModule(L.LightningDataModule): | ||
def __init__( | ||
self, data_dir: Path, norm_factors: list[float], batch_size: int, augment: bool = True, num_workers: int = 0 | ||
): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.data_dir = data_dir | ||
self.batch_size = batch_size | ||
self.augment = augment | ||
self.norm_factors = norm_factors | ||
self.num_workers = num_workers | ||
|
||
def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None): | ||
dataset = DartsDataset(self.data_dir, self.augment, self.norm_factors) | ||
splits = [0.8, 0.1, 0.1] | ||
generator = torch.Generator().manual_seed(42) | ||
self.train, self.val, self.test = random_split(dataset, splits, generator) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers) |
75 changes: 75 additions & 0 deletions
75
darts-segmentation/src/darts_segmentation/training/module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# ruff: noqa: D100 | ||
# ruff: noqa: D101 | ||
# ruff: noqa: D102 | ||
# ruff: noqa: D105 | ||
# ruff: noqa: D107 | ||
|
||
"""Training script for DARTS segmentation.""" | ||
|
||
import lightning as L # noqa: N812 | ||
import segmentation_models_pytorch as smp | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torchmetrics import ( | ||
Accuracy, | ||
CohenKappa, | ||
F1Score, | ||
HammingDistance, | ||
JaccardIndex, | ||
MetricCollection, | ||
Precision, | ||
Recall, | ||
Specificity, | ||
) | ||
|
||
from darts_segmentation.segment import SMPSegmenterConfig | ||
|
||
|
||
class SMPSegmenter(L.LightningModule): | ||
def __init__(self, config: SMPSegmenterConfig): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.model = smp.create_model(**config["model"]) | ||
|
||
metrics = MetricCollection( | ||
{ | ||
"Accuracy": Accuracy(task="binary", validate_args=False), | ||
"Precision": Precision(task="binary", validate_args=False), | ||
"Specificity": Specificity(task="binary", validate_args=False), | ||
"Recall": Recall(task="binary", validate_args=False), | ||
"F1Score": F1Score(task="binary", validate_args=False), | ||
"JaccardIndex": JaccardIndex(task="binary", validate_args=False), | ||
"CohenKappa": CohenKappa(task="binary", validate_args=False), | ||
"HammingDistance": HammingDistance(task="binary", validate_args=False), | ||
} | ||
) | ||
self.train_metrics = metrics.clone(prefix="train_") | ||
self.val_metrics = metrics.clone(prefix="val_") | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model(x) | ||
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) | ||
self.train_metrics(y_hat, y) | ||
self.log("train_loss", loss) | ||
self.log_dict(self.train_metrics, on_step=True, on_epoch=False) | ||
return loss | ||
|
||
def on_train_epoch_end(self): | ||
self.train_metrics.reset() | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model(x) | ||
loss = nn.functional.mse_loss(y_hat, y.unsqueeze(1).float()) | ||
self.log("val_loss", loss) | ||
self.val_metrics.update(y_hat, y) | ||
return loss | ||
|
||
def on_validation_epoch_end(self): | ||
self.log_dict(self.val_metrics.compute()) | ||
self.val_metrics.reset() | ||
|
||
def configure_optimizers(self): | ||
optimizer = optim.AdamW(self.parameters(), lr=1e-3) | ||
return optimizer |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters