Skip to content

Commit

Permalink
add optuna crossval sweep and fix region of preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 23, 2024
1 parent 9a2b7c8 commit 0840ffc
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 32 deletions.
60 changes: 60 additions & 0 deletions darts-acquisition/src/darts_acquisition/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Download of admin level files for the regions."""

import io
import logging
import time
import zipfile
from pathlib import Path

import requests

logger = logging.getLogger(__name__)


def _download_zip(url: str, admin_dir: Path):
tick_fstart = time.perf_counter()
response = requests.get(url)

# Get the downloaded data as a byte string
data = response.content

tick_download = time.perf_counter()
logger.debug(f"Downloaded {len(data)} bytes in {tick_download - tick_fstart:.2f} seconds")

# Create a bytesIO object
with io.BytesIO(data) as buffer:
# Create a zipfile.ZipFile object and extract the files to a directory
admin_dir.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(buffer, "r") as zip_ref:
# Extract the files to the specified directory
zip_ref.extractall(admin_dir)

tick_extract = time.perf_counter()
logger.debug(f"Extraction completed in {tick_extract - tick_download:.2f} seconds")


def download_admin_files(admin_dir: Path):
"""Download the admin files for the regions.
Files will be stored under [admin_dir]/adm1.shp and [admin_dir]/adm2.shp.
Args:
admin_dir (Path): The path to the admin files.
"""
tick_fstart = time.perf_counter()

# Download the admin files
admin_1_url = "https://github.com/wmgeolab/geoBoundaries/raw/main/releaseData/CGAZ/geoBoundariesCGAZ_ADM1.zip"
admin_2_url = "https://github.com/wmgeolab/geoBoundaries/raw/main/releaseData/CGAZ/geoBoundariesCGAZ_ADM2.zip"

admin_dir.mkdir(exist_ok=True, parents=True)

logger.debug(f"Downloading {admin_1_url} to {admin_dir.resolve()}")
_download_zip(admin_1_url, admin_dir)

logger.debug(f"Downloading {admin_2_url} to {admin_dir.resolve()}")
_download_zip(admin_2_url, admin_dir)

tick_fend = time.perf_counter()
logger.info(f"Downloaded admin files in {tick_fend - tick_fstart:.2f} seconds")
Empty file.
16 changes: 9 additions & 7 deletions darts-segmentation/src/darts_segmentation/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import albumentations as A # noqa: N812
import lightning as L # noqa: N812
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Dataset, Subset

logger = logging.getLogger(__name__.replace("darts_", "darts."))

Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
self,
data_dir: Path,
batch_size: int,
current_fold: int = 0,
augment: bool = True,
num_workers: int = 0,
in_memory: bool = False,
Expand All @@ -128,6 +130,7 @@ def __init__(
self.save_hyperparameters()
self.data_dir = data_dir
self.batch_size = batch_size
self.current_fold = current_fold
self.augment = augment
self.num_workers = num_workers
self.in_memory = in_memory
Expand All @@ -138,15 +141,14 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = No
if not self.in_memory
else DartsDatasetInMemory(self.data_dir, self.augment)
)
splits = [0.8, 0.1, 0.1]
generator = torch.Generator().manual_seed(42)
self.train, self.val, self.test = random_split(dataset, splits, generator)

kf = KFold(n_splits=5)
train_idx, val_idx = list(kf.split(dataset))[self.current_fold]
self.train = Subset(dataset, train_idx)
self.val = Subset(dataset, val_idx)

def train_dataloader(self):
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)

def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)
14 changes: 11 additions & 3 deletions darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
run_native_sentinel2_pipeline,
run_native_sentinel2_pipeline_fast,
)
from darts.legacy_training import convert_lightning_checkpoint, preprocess_s2_train_data, sweep_smp, train_smp
from darts.legacy_training import (
convert_lightning_checkpoint,
optuna_sweep_smp,
preprocess_s2_train_data,
train_smp,
wandb_sweep_smp,
)
from darts.utils.config import ConfigParser
from darts.utils.logging import LoggingManager

Expand Down Expand Up @@ -72,7 +78,8 @@ def env_info():
app.command(group=train_group)(preprocess_s2_train_data)
app.command(group=train_group)(train_smp)
app.command(group=train_group)(convert_lightning_checkpoint)
app.command(group=train_group)(sweep_smp)
app.command(group=train_group)(wandb_sweep_smp)
app.command(group=train_group)(optuna_sweep_smp)


# Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules
Expand All @@ -96,9 +103,10 @@ def launcher( # noqa: D103
*tokens: Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)],
log_dir: Path = Path("logs"),
config_file: Path = Path("config.toml"),
tracebacks_show_locals: bool = False,
):
command, bound, _ = app.parse_args(tokens)
LoggingManager.add_logging_handlers(command.__name__, console, log_dir)
LoggingManager.add_logging_handlers(command.__name__, console, log_dir, tracebacks_show_locals)
logger.debug(f"Running on Python version {sys.version} from {__name__} ({root_file})")
return command(*bound.args, **bound.kwargs)

Expand Down
3 changes: 2 additions & 1 deletion darts/src/darts/legacy_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Legacy training module for DARTS."""

from darts.legacy_training.preprocess import preprocess_s2_train_data as preprocess_s2_train_data
from darts.legacy_training.train import sweep_smp as sweep_smp
from darts.legacy_training.train import optuna_sweep_smp as optuna_sweep_smp
from darts.legacy_training.train import train_smp as train_smp
from darts.legacy_training.train import wandb_sweep_smp as wandb_sweep_smp
from darts.legacy_training.util import convert_lightning_checkpoint as convert_lightning_checkpoint
49 changes: 39 additions & 10 deletions darts/src/darts/legacy_training/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from pathlib import Path
from typing import Literal

import toml

logger = logging.getLogger(__name__)


def split_dataset_paths(

Check failure on line 13 in darts/src/darts/legacy_training/preprocess.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D417)

darts/src/darts/legacy_training/preprocess.py:13:5: D417 Missing argument description in the docstring for `split_dataset_paths`: `admin_dir`
s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None
s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None, admin_dir: Path
):
"""Split the dataset into a cross-val, a val-test and a test dataset.
Expand All @@ -33,6 +31,7 @@ def split_dataset_paths(
"""
# Import here to avoid long loading times when running other commands
import geopandas as gpd
from darts_acquisition.admin import download_admin_files
from darts_acquisition.s2 import parse_s2_tile_id
from sklearn.model_selection import train_test_split

Expand All @@ -45,11 +44,32 @@ def split_dataset_paths(
test_paths: list[Path] = []
training_paths: list[Path] = []
if test_regions:
# Download admin files if they do not exist
admin1_fpath = admin_dir / "geoBoundariesCGAZ_ADM1.shp"
admin2_fpath = admin_dir / "geoBoundariesCGAZ_ADM2.shp"

if not admin1_fpath.exists() or not admin2_fpath.exists():
download_admin_files(admin_dir)

# Load the admin files
admin1 = gpd.read_file(admin1_fpath)
admin2 = gpd.read_file(admin2_fpath)

# Get the regions from the admin files
test_region_geometries_adm1 = admin1[admin1["shapeName"].isin(test_regions)]
test_region_geometries_adm2 = admin2[admin2["shapeName"].isin(test_regions)]

logger.debug(f"Found {len(test_region_geometries_adm1)} admin1-regions in {admin1_fpath}")
logger.debug(f"Found {len(test_region_geometries_adm2)} admin2-regions in {admin2_fpath}")

for fpath in s2_paths:
_, s2_tile_id, _ = parse_s2_tile_id(fpath)
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")
# If any of the regions is in the test region, add to the test set
if labels["region"].isin(test_regions).any():
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp").to_crs("EPSG:4326")
# Check if any label is intersecting with the test regions
adm1_intersects = labels.overlay(test_region_geometries_adm1, how="intersection")
adm2_intersects = labels.overlay(test_region_geometries_adm2, how="intersection")

if (len(adm1_intersects.index) > 0) or (len(adm2_intersects.index) > 0):
test_paths.append(fpath)
else:
training_paths.append(fpath)
Expand Down Expand Up @@ -92,6 +112,7 @@ def preprocess_s2_train_data(
train_data_dir: Path,
arcticdem_dir: Path,
tcvis_dir: Path,
admin_dir: Path,
preprocess_cache: Path | None = None,
device: Literal["cuda", "cpu", "auto"] | int | None = None,
dask_worker: int = min(16, mp.cpu_count() - 1),
Expand All @@ -116,6 +137,7 @@ def preprocess_s2_train_data(
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
Will be created and downloaded if it does not exist.
tcvis_dir (Path): The directory containing the TCVis data.
admin_dir (Path): The directory containing the admin files.
preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
If "cuda" take the first device (0), if int take the specified device.
Expand Down Expand Up @@ -144,6 +166,7 @@ def preprocess_s2_train_data(
# Import here to avoid long loading times when running other commands
import geopandas as gpd
import pandas as pd
import toml
import torch
import xarray as xr
from darts_acquisition.arcticdem import load_arcticdem_tile
Expand Down Expand Up @@ -189,14 +212,19 @@ def preprocess_s2_train_data(

# Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
n_patches = 0
n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0}
joint_lables = []
s2_paths = sorted(sentinel2_dir.glob("*/"))
logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")
path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions)
path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions, admin_dir)
for i, (fpath, output_dir, mode) in enumerate(path_gen):
try:
_, s2_tile_id, tile_id = parse_s2_tile_id(fpath)

logger.debug(
f"Processing sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}' ({tile_id=}) to split '{mode}'"
)

# Check for a cached preprocessed file
if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
cache_file = preprocess_cache / f"{tile_id}.nc"
Expand Down Expand Up @@ -251,14 +279,15 @@ def preprocess_s2_train_data(
outdir_y = output_dir / "y"
outdir_x.mkdir(exist_ok=True, parents=True)
outdir_y.mkdir(exist_ok=True, parents=True)
n_patches = 0
patch_id = None
for patch_id, (x, y) in enumerate(gen):
torch.save(x, outdir_x / f"{tile_id}_pid{patch_id}.pt")
torch.save(y, outdir_y / f"{tile_id}_pid{patch_id}.pt")
n_patches += 1
n_patches_by_mode[mode] += 1
if n_patches > 0 and len(labels) > 0:
labels["mode"] = mode
joint_lables.append(labels)
joint_lables.append(labels.to_crs("EPSG:3413"))

logger.info(
f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'"
Expand Down Expand Up @@ -300,4 +329,4 @@ def preprocess_s2_train_data(
with open(train_data_dir / "config.toml", "w") as f:
toml.dump(config, f)

logger.info(f"Saved {n_patches} patches to {train_data_dir}")
logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}")
Loading

0 comments on commit 0840ffc

Please sign in to comment.