Skip to content

Commit

Permalink
Merge pull request #74 from awi-response/refactor/native-pipes
Browse files Browse the repository at this point in the history
Refactor native
  • Loading branch information
relativityhd authored Dec 13, 2024
2 parents ff615a1 + ea19c05 commit 81f069b
Show file tree
Hide file tree
Showing 9 changed files with 574 additions and 625 deletions.
3 changes: 0 additions & 3 deletions darts/src/darts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,4 @@

from importlib.metadata import version

from darts.native import run_native_planet_pipeline as run_native_planet_pipeline
from darts.native import run_native_sentinel2_pipeline as run_native_sentinel2_pipeline

__version__ = version("darts-nextgen")
6 changes: 3 additions & 3 deletions darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.console import Console

from darts import __version__
from darts.native import (
from darts.legacy_pipeline import (
run_native_planet_pipeline,
run_native_planet_pipeline_fast,
run_native_sentinel2_pipeline,
Expand All @@ -28,8 +28,8 @@
version=__version__,
console=console,
config=config_parser,
help_format="rich",
version_format="rich",
help_format="plaintext",
version_format="plaintext",
)

pipeline_group = cyclopts.Group.create_ordered("Pipeline Commands")
Expand Down
6 changes: 6 additions & 0 deletions darts/src/darts/legacy_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Legacy pipeline module."""

from darts.legacy_pipeline.planet import run_native_planet_pipeline as run_native_planet_pipeline
from darts.legacy_pipeline.planet_fast import run_native_planet_pipeline_fast as run_native_planet_pipeline_fast
from darts.legacy_pipeline.s2 import run_native_sentinel2_pipeline as run_native_sentinel2_pipeline
from darts.legacy_pipeline.s2_fast import run_native_sentinel2_pipeline_fast as run_native_sentinel2_pipeline_fast
187 changes: 187 additions & 0 deletions darts/src/darts/legacy_pipeline/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import logging
import multiprocessing as mp
from collections import namedtuple
from collections.abc import Generator
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

logger = logging.getLogger(__name__)

AquisitionData = namedtuple("AquisitionData", ["optical", "arcticdem", "tcvis", "data_masks"])


@dataclass
class _BasePipeline:
"""Base class for all legacy pipelines.
This class provides the run method which is the main entry point for all pipelines.
This class is meant to be subclassed by the specific pipelines.
These specific pipelines must implement the following methods:
- "_path_generator" which generates the paths to the data (e.g. through Source Mixin)
- "_get_data" which loads the data for a given path
- "_preprocess" which preprocesses the data (e.g. through Processing Mixin)
It is possible to implement these functions, by subclassing other mixins, e.g. _S2Mixin.
The main class must be also a dataclass, to fully inherit all parameter of this class (and the mixins).
"""

output_data_dir: Path = Path("data/output")
tcvis_dir: Path = Path("data/download/tcvis")
model_dir: Path = Path("models")
tcvis_model_name: str = "RTS_v6_tcvis_s2native.pt"
notcvis_model_name: str = "RTS_v6_notcvis_s2native.pt"
device: Literal["cuda", "cpu", "auto"] | int | None = None
ee_project: str | None = None
ee_use_highvolume: bool = True
patch_size: int = 1024
overlap: int = 256
batch_size: int = 8
reflection: int = 0
binarization_threshold: float = 0.5
mask_erosion_size: int = 10
min_object_size: int = 32
use_quality_mask: bool = False
write_model_outputs: bool = False

def _path_generator(self) -> Generator[tuple[Path, Path]]:
raise NotImplementedError

def _get_data(self, fpath: Path) -> AquisitionData:
raise NotImplementedError

def _preprocess(self, aqdata: AquisitionData):
raise NotImplementedError

def run(self):
import torch
from darts_ensemble.ensemble_v1 import EnsembleV1
from darts_export.inference import InferenceResultWriter
from darts_postprocessing import prepare_export
from dask.distributed import Client, LocalCluster
from odc.stac import configure_rio

from darts.utils.cuda import debug_info, decide_device
from darts.utils.earthengine import init_ee

debug_info()
self.device = decide_device(self.device)
init_ee(self.ee_project, self.ee_use_highvolume)

ensemble = EnsembleV1(
self.model_dir / self.tcvis_model_name,
self.model_dir / self.notcvis_model_name,
device=torch.device(self.device),
)

# Init Dask stuff with a context manager
with LocalCluster(n_workers=mp.cpu_count() - 1) as cluster, Client(cluster) as client:
logger.info(f"Using Dask client: {client}")
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
logger.info("Configured Rasterio with Dask")

# Iterate over all the data (_path_generator)
for fpath, outpath in self._path_generator():
try:
aqdata = self._get_data(fpath)
tile = self._preprocess(aqdata)

tile = ensemble.segment_tile(
tile,
patch_size=self.patch_size,
overlap=self.overlap,
batch_size=self.batch_size,
reflection=self.reflection,
keep_inputs=self.write_model_outputs,
)
tile = prepare_export(
tile,
self.binarization_threshold,
self.mask_erosion_size,
self.min_object_size,
self.use_quality_mask,
self.device,
)

outpath.mkdir(parents=True, exist_ok=True)
writer = InferenceResultWriter(tile)
writer.export_probabilities(outpath)
writer.export_binarized(outpath)
writer.export_polygonized(outpath)
except KeyboardInterrupt:
logger.warning("Keyboard interrupt detected.\nExiting...")
break
except Exception as e:
logger.warning(f"Could not process folder '{fpath.resolve()}'.\nSkipping...")
logger.exception(e)


# =============================================================================
# Processing mixins (they provide _preprocess method)
# =============================================================================
@dataclass
class _VRTMixin:
arcticdem_slope_vrt: Path = Path("data/input/ArcticDEM/slope.vrt")
arcticdem_elevation_vrt: Path = Path("data/input/ArcticDEM/elevation.vrt")

def _preprocess(self, aqdata: AquisitionData):
from darts_preprocessing import preprocess_legacy

return preprocess_legacy(aqdata.optical, aqdata.arcticdem, aqdata.tcvis, aqdata.data_masks)


@dataclass
class _FastMixin:
arcticdem_dir: Path = Path("data/download/arcticdem")
tpi_outer_radius: int = 100
tpi_inner_radius: int = 0

def _preprocess(self, aqdata: AquisitionData):
from darts_preprocessing import preprocess_legacy_fast

return preprocess_legacy_fast(
aqdata.optical,
aqdata.arcticdem,
aqdata.tcvis,
aqdata.data_masks,
self.tpi_outer_radius,
self.tpi_inner_radius,
self.device,
)


# =============================================================================
# Source mixins (they provide _path_generator method)
# =============================================================================
@dataclass
class _PlanetMixin:
orthotiles_dir: Path = Path("data/input/planet/PSOrthoTile")
scenes_dir: Path = Path("data/input/planet/PSScene")

def _path_generator(self):
# Find all PlanetScope orthotiles
for fpath in self.orthotiles_dir.glob("*/*/"):
tile_id = fpath.parent.name
scene_id = fpath.name
outpath = self.output_data_dir / tile_id / scene_id
yield fpath, outpath

# Find all PlanetScope scenes
for fpath in self.scenes_dir.glob("*/"):
scene_id = fpath.name
outpath = self.output_data_dir / scene_id
yield fpath, outpath


@dataclass
class _S2Mixin:
sentinel2_dir: Path = Path("data/input/sentinel2")

def _path_generator(self):
for fpath in self.sentinel2_dir.glob("*/"):
scene_id = fpath.name
outpath = self.output_data_dir / scene_id
yield fpath, outpath
132 changes: 132 additions & 0 deletions darts/src/darts/legacy_pipeline/planet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Legacy pipeline for Planet data."""

from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

from cyclopts import Parameter

from darts.legacy_pipeline._base import AquisitionData, _BasePipeline, _PlanetMixin, _VRTMixin


@dataclass
class LegacyNativePlanetPipeline(_PlanetMixin, _VRTMixin, _BasePipeline):
"""Pipeline for Planet data.
Args:
orthotiles_dir (Path): The directory containing the PlanetScope orthotiles.
Defaults to Path("data/input/planet/PSOrthoTile").
scenes_dir (Path): The directory containing the PlanetScope scenes.
Defaults to Path("data/input/planet/PSScene").
output_data_dir (Path): The "output" directory. Defaults to Path("data/output").
arcticdem_slope_vrt (Path): The path to the ArcticDEM slope VRT file.
Defaults to Path("data/input/ArcticDEM/slope.vrt").
arcticdem_elevation_vrt (Path): The path to the ArcticDEM elevation VRT file.
Defaults to Path("data/input/ArcticDEM/elevation.vrt").
tcvis_dir (Path): The directory containing the TCVis data. Defaults to Path("data/download/tcvis").
model_dir (Path): The path to the models to use for segmentation. Defaults to Path("models").
tcvis_model_name (str, optional): The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt".
notcvis_model_name (str, optional): The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt".
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
overlap (int, optional): The overlap to use for inference. Defaults to 16.
batch_size (int, optional): The batch size to use for inference. Defaults to 8.
reflection (int, optional): The reflection padding to use for inference. Defaults to 0.
binarization_threshold (float, optional): The threshold to binarize the probabilities. Defaults to 0.5.
mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
Defaults to 10.
min_object_size (int, optional): The minimum object size to keep in pixel. Defaults to 32.
use_quality_mask (bool, optional): Whether to use the "quality" mask instead of the "valid" mask
to mask the output.
write_model_outputs (bool, optional): Also save the model outputs, not only the ensemble result.
Defaults to False.
Examples:
### PS Orthotile
Data directory structure:
```sh
data/input
├── ArcticDEM
│ ├── elevation.vrt
│ ├── slope.vrt
│ ├── relative_elevation
│ │ └── 4372514_relative_elevation_100.tif
│ └── slope
│ └── 4372514_slope.tif
└── planet
└── PSOrthoTile
└── 4372514/5790392_4372514_2022-07-16_2459
├── 5790392_4372514_2022-07-16_2459_BGRN_Analytic_metadata.xml
├── 5790392_4372514_2022-07-16_2459_BGRN_DN_udm.tif
├── 5790392_4372514_2022-07-16_2459_BGRN_SR.tif
├── 5790392_4372514_2022-07-16_2459_metadata.json
└── 5790392_4372514_2022-07-16_2459_udm2.tif
```
then the config should be
```
...
orthotiles_dir: data/input/planet/PSOrthoTile
arcticdem_slope_vrt: data/input/ArcticDEM/slope.vrt
arcticdem_elevation_vrt: data/input/ArcticDEM/elevation.vrt
```
### PS Scene
Data directory structure:
```sh
data/input
├── ArcticDEM
│ ├── elevation.vrt
│ ├── slope.vrt
│ ├── relative_elevation
│ │ └── 4372514_relative_elevation_100.tif
│ └── slope
│ └── 4372514_slope.tif
└── planet
└── PSScene
└── 20230703_194241_43_2427
├── 20230703_194241_43_2427_3B_AnalyticMS_metadata.xml
├── 20230703_194241_43_2427_3B_AnalyticMS_SR.tif
├── 20230703_194241_43_2427_3B_udm2.tif
├── 20230703_194241_43_2427_metadata.json
└── 20230703_194241_43_2427.json
```
then the config should be
```
...
scenes_dir: data/input/planet/PSScene
arcticdem_slope_vrt: data/input/ArcticDEM/slope.vrt
arcticdem_elevation_vrt: data/input/ArcticDEM/elevation.vrt
```
"""

def _get_data(self, fpath: Path):
from darts_acquisition.arcticdem import load_arcticdem_from_vrt
from darts_acquisition.planet import load_planet_masks, load_planet_scene
from darts_acquisition.tcvis import load_tcvis

optical = load_planet_scene(fpath)
arcticdem = load_arcticdem_from_vrt(self.arcticdem_slope_vrt, self.arcticdem_elevation_vrt, optical)
tcvis = load_tcvis(optical.odc.geobox, self.tcvis_dir)
data_masks = load_planet_masks(fpath)
aqdata = AquisitionData(optical, arcticdem, tcvis, data_masks)
return aqdata


def run_native_planet_pipeline(*, pipeline: Annotated[LegacyNativePlanetPipeline, Parameter("*")]):
"""Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them."""
pipeline.run()
Loading

0 comments on commit 81f069b

Please sign in to comment.