Skip to content

Commit

Permalink
Add preprocess caching to training preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 15, 2024
1 parent f0158d0 commit 96e4493
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions darts/src/darts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@

def preprocess_s2_train_data(
*,
bands: list[str],
sentinel2_dir: Path,
train_data_dir: Path,
arcticdem_dir: Path,
tcvis_dir: Path,
bands: list[str],
preprocess_cache: Path | None = None,
device: Literal["cuda", "cpu", "auto"] | int | None = None,
ee_project: str | None = None,
ee_use_highvolume: bool = True,
Expand All @@ -48,12 +49,13 @@ def preprocess_s2_train_data(
"""Preprocess Sentinel 2 data for training.
Args:
bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
train_data_dir (Path): The "output" directory where the tensors are written to.
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
Will be created and downloaded if it does not exist.
tcvis_dir (Path): The directory containing the TCVis data.
bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Expand All @@ -76,6 +78,7 @@ def preprocess_s2_train_data(
# Import here to avoid long loading times when running other commands
import geopandas as gpd
import torch
import xarray as xr
from darts_acquisition.arcticdem import load_arcticdem_tile
from darts_acquisition.s2 import load_s2_masks, load_s2_scene
from darts_acquisition.tcvis import load_tcvis
Expand Down Expand Up @@ -129,26 +132,39 @@ def preprocess_s2_train_data(
for i, fpath in enumerate(s2_paths):
try:
optical = load_s2_scene(fpath)
arcticdem = load_arcticdem_tile(
optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2))
)
tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
data_masks = load_s2_masks(fpath, optical.odc.geobox)

tile = preprocess_legacy_fast(
optical,
arcticdem,
tcvis,
data_masks,
tpi_outer_radius,
tpi_inner_radius,
device,
)
tile_id = optical.attrs["tile_id"]

# Check for a cached preprocessed file
if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
cache_file = preprocess_cache / f"{tile_id}.nc"
logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
tile = xr.open_dataset(preprocess_cache / f"{tile_id}.nc", engine="h5netcdf").set_coords("spatial_ref")
else:
arcticdem = load_arcticdem_tile(
optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2))
)
tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
data_masks = load_s2_masks(fpath, optical.odc.geobox)

tile: xr.Dataset = preprocess_legacy_fast(
optical,
arcticdem,
tcvis,
data_masks,
tpi_outer_radius,
tpi_inner_radius,
device,
)
# Only cache if we have a cache directory
if preprocess_cache:
preprocess_cache.mkdir(exist_ok=True, parents=True)
cache_file = preprocess_cache / f"{tile_id}.nc"
logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
tile.to_netcdf(cache_file, engine="h5netcdf")

labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp")

# Save the patches
tile_id = optical.attrs["tile_id"]
gen = create_training_patches(
tile,
labels,
Expand Down

0 comments on commit 96e4493

Please sign in to comment.