Skip to content

Commit

Permalink
Add train/val datasplitting to preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 21, 2024
1 parent 6c55265 commit b1c0c1d
Show file tree
Hide file tree
Showing 10 changed files with 298 additions and 271 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if has_cuda_and_cupy():
import cupy as cp
import cupy_xarray # noqa: F401


def calculate_topographic_position_index(arcticdem_ds: xr.Dataset, outer_radius: int, inner_radius: int) -> xr.Dataset:
Expand Down Expand Up @@ -52,7 +53,7 @@ def calculate_topographic_position_index(arcticdem_ds: xr.Dataset, outer_radius:
f"Calculating Topographic Position Index with circle kernel of {outer_radius_px} ({outer_radius_m}) cells."
)

if arcticdem_ds.cupy.is_cupy:
if has_cuda_and_cupy() and arcticdem_ds.cupy.is_cupy:
kernel = cp.asarray(kernel)

tpi = arcticdem_ds.dem - convolution.convolution_2d(arcticdem_ds.dem, kernel) / kernel.sum()
Expand Down
2 changes: 1 addition & 1 deletion darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
run_native_sentinel2_pipeline,
run_native_sentinel2_pipeline_fast,
)
from darts.training import convert_lightning_checkpoint, preprocess_s2_train_data, sweep_smp, train_smp
from darts.legacy_training import convert_lightning_checkpoint, preprocess_s2_train_data, sweep_smp, train_smp
from darts.utils.config import ConfigParser
from darts.utils.logging import LoggingManager

Expand Down
1 change: 1 addition & 0 deletions darts/src/darts/legacy_pipeline/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LegacyNativePlanetPipeline(_PlanetMixin, _VRTMixin, _BasePipeline):
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
Expand Down
1 change: 1 addition & 0 deletions darts/src/darts/legacy_pipeline/planet_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LegacyNativePlanetPipelineFast(_FastMixin, _PlanetMixin, _BasePipeline):
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
Expand Down
1 change: 1 addition & 0 deletions darts/src/darts/legacy_pipeline/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LegacyNativeSentinel2Pipeline(_S2Mixin, _VRTMixin, _BasePipeline):
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
Expand Down
1 change: 1 addition & 0 deletions darts/src/darts/legacy_pipeline/s2_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LegacyNativeSentinel2PipelineFast(_FastMixin, _S2Mixin, _BasePipeline):
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
Expand Down
6 changes: 6 additions & 0 deletions darts/src/darts/legacy_training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Legacy training module for DARTS."""

from darts.legacy_training.preprocess import preprocess_s2_train_data as preprocess_s2_train_data
from darts.legacy_training.train import sweep_smp as sweep_smp
from darts.legacy_training.train import train_smp as train_smp
from darts.legacy_training.util import convert_lightning_checkpoint as convert_lightning_checkpoint
228 changes: 228 additions & 0 deletions darts/src/darts/legacy_training/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Preprocessing functions for legacy training."""

import logging
import multiprocessing as mp
from itertools import chain, repeat
from math import ceil, sqrt
from pathlib import Path
from typing import Literal

import toml

logger = logging.getLogger(__name__)


def preprocess_s2_train_data(
*,
bands: list[str],
sentinel2_dir: Path,
train_data_dir: Path,
arcticdem_dir: Path,
tcvis_dir: Path,
preprocess_cache: Path | None = None,
device: Literal["cuda", "cpu", "auto"] | int | None = None,
dask_worker: int = min(16, mp.cpu_count() - 1),
ee_project: str | None = None,
ee_use_highvolume: bool = True,
tpi_outer_radius: int = 100,
tpi_inner_radius: int = 0,
patch_size: int = 1024,
overlap: int = 16,
exclude_nopositive: bool = False,
exclude_nan: bool = True,
mask_erosion_size: int = 10,
test_val_split: float = 0.05,
):
"""Preprocess Sentinel 2 data for training.
Args:
bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
train_data_dir (Path): The "output" directory where the tensors are written to.
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
Will be created and downloaded if it does not exist.
tcvis_dir (Path): The directory containing the TCVis data.
preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
in m. Defaults to 100m.
tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
in m. Defaults to 0.
patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
overlap (int, optional): The overlap to use for inference. Defaults to 16.
exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
Defaults to False.
exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
Defaults to True.
mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
Defaults to 10.
test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
"""
# Import here to avoid long loading times when running other commands
import geopandas as gpd
import torch
import xarray as xr
from darts_acquisition.arcticdem import load_arcticdem_tile
from darts_acquisition.s2 import load_s2_masks, load_s2_scene
from darts_acquisition.tcvis import load_tcvis
from darts_preprocessing import preprocess_legacy_fast
from darts_segmentation.training.prepare_training import create_training_patches
from dask.distributed import Client, LocalCluster
from lovely_tensors import monkey_patch
from odc.stac import configure_rio
from sklearn.model_selection import train_test_split

from darts.utils.cuda import debug_info, decide_device
from darts.utils.earthengine import init_ee

monkey_patch()
debug_info()
device = decide_device(device)
init_ee(ee_project, ee_use_highvolume)

with LocalCluster(n_workers=dask_worker) as cluster, Client(cluster) as client:
logger.info(f"Using Dask client: {client} on cluster {cluster}")
logger.info(f"Dashboard available at: {client.dashboard_link}")
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
logger.info("Configured Rasterio with Dask")

# We hardcode these because they depend on the preprocessing used
norm_factors = {
"red": 1 / 3000,
"green": 1 / 3000,
"blue": 1 / 3000,
"nir": 1 / 3000,
"ndvi": 1 / 20000,
"relative_elevation": 1 / 30000,
"slope": 1 / 90,
"tc_brightness": 1 / 255,
"tc_greenness": 1 / 255,
"tc_wetness": 1 / 255,
}
# Filter out bands that are not in the specified bands
norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

train_data_dir.mkdir(exist_ok=True, parents=True)
output_dir_train = train_data_dir / "train"
output_dir_val = train_data_dir / "val"

# Find all Sentinel 2 scenes
n_patches = 0
s2_paths = sorted(sentinel2_dir.glob("*/"))
logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")
train_paths: list[Path]
val_paths: list[Path]
train_paths, val_paths = train_test_split(s2_paths, test_size=test_val_split, random_state=42)
logger.info(f"Split the data into {len(train_paths)} training and {len(val_paths)} validation samples.")

fpathgen = chain(train_paths, val_paths)
modegen = chain(repeat(output_dir_train, len(train_paths)), repeat(output_dir_val, len(val_paths)))
for i, (fpath, output_dir) in enumerate(zip(fpathgen, modegen)):
try:
optical = load_s2_scene(fpath)
logger.info(f"Found optical tile with size {optical.sizes}")
tile_id = optical.attrs["tile_id"]

# Check for a cached preprocessed file
if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
cache_file = preprocess_cache / f"{tile_id}.nc"
logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
tile = xr.open_dataset(preprocess_cache / f"{tile_id}.nc", engine="h5netcdf").set_coords(
"spatial_ref"
)
else:
arctidem_res = 10
arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
arcticdem = load_arcticdem_tile(
optical.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
)
tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
data_masks = load_s2_masks(fpath, optical.odc.geobox)

tile: xr.Dataset = preprocess_legacy_fast(
optical,
arcticdem,
tcvis,
data_masks,
tpi_outer_radius,
tpi_inner_radius,
device,
)
# Only cache if we have a cache directory
if preprocess_cache:
preprocess_cache.mkdir(exist_ok=True, parents=True)
cache_file = preprocess_cache / f"{tile_id}.nc"
logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
tile.to_netcdf(cache_file, engine="h5netcdf")

labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp")

# Save the patches
gen = create_training_patches(
tile,
labels,
bands,
norm_factors,
patch_size,
overlap,
exclude_nopositive,
exclude_nan,
device,
mask_erosion_size,
)

outdir_x = output_dir / "x"
outdir_y = output_dir / "y"
outdir_x.mkdir(exist_ok=True, parents=True)
outdir_y.mkdir(exist_ok=True, parents=True)
n_patches = 0
for patch_id, (x, y) in enumerate(gen):
torch.save(x, outdir_x / f"{tile_id}_pid{patch_id}.pt")
torch.save(y, outdir_y / f"{tile_id}_pid{patch_id}.pt")
n_patches += 1

logger.info(
f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'"
f"({tile_id=}) with {patch_id} patches."
)
except KeyboardInterrupt:
logger.info("Interrupted by user.")
break

except Exception as e:
logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
logger.exception(e)

# Save a config file as toml
config = {
"darts": {
"sentinel2_dir": sentinel2_dir,
"train_data_dir": train_data_dir,
"arcticdem_dir": arcticdem_dir,
"tcvis_dir": tcvis_dir,
"bands": bands,
"norm_factors": norm_factors,
"device": device,
"ee_project": ee_project,
"ee_use_highvolume": ee_use_highvolume,
"tpi_outer_radius": tpi_outer_radius,
"tpi_inner_radius": tpi_inner_radius,
"patch_size": patch_size,
"overlap": overlap,
"exclude_nopositive": exclude_nopositive,
"exclude_nan": exclude_nan,
"n_patches": n_patches,
}
}
with open(train_data_dir / "config.toml", "w") as f:
toml.dump(config, f)

logger.info(f"Saved {n_patches} patches to {train_data_dir}")
Loading

0 comments on commit b1c0c1d

Please sign in to comment.