-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #74 from awi-response/refactor/native-pipes
Refactor native
- Loading branch information
Showing
9 changed files
with
574 additions
and
625 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Legacy pipeline module.""" | ||
|
||
from darts.legacy_pipeline.planet import run_native_planet_pipeline as run_native_planet_pipeline | ||
from darts.legacy_pipeline.planet_fast import run_native_planet_pipeline_fast as run_native_planet_pipeline_fast | ||
from darts.legacy_pipeline.s2 import run_native_sentinel2_pipeline as run_native_sentinel2_pipeline | ||
from darts.legacy_pipeline.s2_fast import run_native_sentinel2_pipeline_fast as run_native_sentinel2_pipeline_fast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import logging | ||
import multiprocessing as mp | ||
from collections import namedtuple | ||
from collections.abc import Generator | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
AquisitionData = namedtuple("AquisitionData", ["optical", "arcticdem", "tcvis", "data_masks"]) | ||
|
||
|
||
@dataclass | ||
class _BasePipeline: | ||
"""Base class for all legacy pipelines. | ||
This class provides the run method which is the main entry point for all pipelines. | ||
This class is meant to be subclassed by the specific pipelines. | ||
These specific pipelines must implement the following methods: | ||
- "_path_generator" which generates the paths to the data (e.g. through Source Mixin) | ||
- "_get_data" which loads the data for a given path | ||
- "_preprocess" which preprocesses the data (e.g. through Processing Mixin) | ||
It is possible to implement these functions, by subclassing other mixins, e.g. _S2Mixin. | ||
The main class must be also a dataclass, to fully inherit all parameter of this class (and the mixins). | ||
""" | ||
|
||
output_data_dir: Path = Path("data/output") | ||
tcvis_dir: Path = Path("data/download/tcvis") | ||
model_dir: Path = Path("models") | ||
tcvis_model_name: str = "RTS_v6_tcvis_s2native.pt" | ||
notcvis_model_name: str = "RTS_v6_notcvis_s2native.pt" | ||
device: Literal["cuda", "cpu", "auto"] | int | None = None | ||
ee_project: str | None = None | ||
ee_use_highvolume: bool = True | ||
patch_size: int = 1024 | ||
overlap: int = 256 | ||
batch_size: int = 8 | ||
reflection: int = 0 | ||
binarization_threshold: float = 0.5 | ||
mask_erosion_size: int = 10 | ||
min_object_size: int = 32 | ||
use_quality_mask: bool = False | ||
write_model_outputs: bool = False | ||
|
||
def _path_generator(self) -> Generator[tuple[Path, Path]]: | ||
raise NotImplementedError | ||
|
||
def _get_data(self, fpath: Path) -> AquisitionData: | ||
raise NotImplementedError | ||
|
||
def _preprocess(self, aqdata: AquisitionData): | ||
raise NotImplementedError | ||
|
||
def run(self): | ||
import torch | ||
from darts_ensemble.ensemble_v1 import EnsembleV1 | ||
from darts_export.inference import InferenceResultWriter | ||
from darts_postprocessing import prepare_export | ||
from dask.distributed import Client, LocalCluster | ||
from odc.stac import configure_rio | ||
|
||
from darts.utils.cuda import debug_info, decide_device | ||
from darts.utils.earthengine import init_ee | ||
|
||
debug_info() | ||
self.device = decide_device(self.device) | ||
init_ee(self.ee_project, self.ee_use_highvolume) | ||
|
||
ensemble = EnsembleV1( | ||
self.model_dir / self.tcvis_model_name, | ||
self.model_dir / self.notcvis_model_name, | ||
device=torch.device(self.device), | ||
) | ||
|
||
# Init Dask stuff with a context manager | ||
with LocalCluster(n_workers=mp.cpu_count() - 1) as cluster, Client(cluster) as client: | ||
logger.info(f"Using Dask client: {client}") | ||
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client) | ||
logger.info("Configured Rasterio with Dask") | ||
|
||
# Iterate over all the data (_path_generator) | ||
for fpath, outpath in self._path_generator(): | ||
try: | ||
aqdata = self._get_data(fpath) | ||
tile = self._preprocess(aqdata) | ||
|
||
tile = ensemble.segment_tile( | ||
tile, | ||
patch_size=self.patch_size, | ||
overlap=self.overlap, | ||
batch_size=self.batch_size, | ||
reflection=self.reflection, | ||
keep_inputs=self.write_model_outputs, | ||
) | ||
tile = prepare_export( | ||
tile, | ||
self.binarization_threshold, | ||
self.mask_erosion_size, | ||
self.min_object_size, | ||
self.use_quality_mask, | ||
self.device, | ||
) | ||
|
||
outpath.mkdir(parents=True, exist_ok=True) | ||
writer = InferenceResultWriter(tile) | ||
writer.export_probabilities(outpath) | ||
writer.export_binarized(outpath) | ||
writer.export_polygonized(outpath) | ||
except KeyboardInterrupt: | ||
logger.warning("Keyboard interrupt detected.\nExiting...") | ||
break | ||
except Exception as e: | ||
logger.warning(f"Could not process folder '{fpath.resolve()}'.\nSkipping...") | ||
logger.exception(e) | ||
|
||
|
||
# ============================================================================= | ||
# Processing mixins (they provide _preprocess method) | ||
# ============================================================================= | ||
@dataclass | ||
class _VRTMixin: | ||
arcticdem_slope_vrt: Path = Path("data/input/ArcticDEM/slope.vrt") | ||
arcticdem_elevation_vrt: Path = Path("data/input/ArcticDEM/elevation.vrt") | ||
|
||
def _preprocess(self, aqdata: AquisitionData): | ||
from darts_preprocessing import preprocess_legacy | ||
|
||
return preprocess_legacy(aqdata.optical, aqdata.arcticdem, aqdata.tcvis, aqdata.data_masks) | ||
|
||
|
||
@dataclass | ||
class _FastMixin: | ||
arcticdem_dir: Path = Path("data/download/arcticdem") | ||
tpi_outer_radius: int = 100 | ||
tpi_inner_radius: int = 0 | ||
|
||
def _preprocess(self, aqdata: AquisitionData): | ||
from darts_preprocessing import preprocess_legacy_fast | ||
|
||
return preprocess_legacy_fast( | ||
aqdata.optical, | ||
aqdata.arcticdem, | ||
aqdata.tcvis, | ||
aqdata.data_masks, | ||
self.tpi_outer_radius, | ||
self.tpi_inner_radius, | ||
self.device, | ||
) | ||
|
||
|
||
# ============================================================================= | ||
# Source mixins (they provide _path_generator method) | ||
# ============================================================================= | ||
@dataclass | ||
class _PlanetMixin: | ||
orthotiles_dir: Path = Path("data/input/planet/PSOrthoTile") | ||
scenes_dir: Path = Path("data/input/planet/PSScene") | ||
|
||
def _path_generator(self): | ||
# Find all PlanetScope orthotiles | ||
for fpath in self.orthotiles_dir.glob("*/*/"): | ||
tile_id = fpath.parent.name | ||
scene_id = fpath.name | ||
outpath = self.output_data_dir / tile_id / scene_id | ||
yield fpath, outpath | ||
|
||
# Find all PlanetScope scenes | ||
for fpath in self.scenes_dir.glob("*/"): | ||
scene_id = fpath.name | ||
outpath = self.output_data_dir / scene_id | ||
yield fpath, outpath | ||
|
||
|
||
@dataclass | ||
class _S2Mixin: | ||
sentinel2_dir: Path = Path("data/input/sentinel2") | ||
|
||
def _path_generator(self): | ||
for fpath in self.sentinel2_dir.glob("*/"): | ||
scene_id = fpath.name | ||
outpath = self.output_data_dir / scene_id | ||
yield fpath, outpath |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
"""Legacy pipeline for Planet data.""" | ||
|
||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Annotated | ||
|
||
from cyclopts import Parameter | ||
|
||
from darts.legacy_pipeline._base import AquisitionData, _BasePipeline, _PlanetMixin, _VRTMixin | ||
|
||
|
||
@dataclass | ||
class LegacyNativePlanetPipeline(_PlanetMixin, _VRTMixin, _BasePipeline): | ||
"""Pipeline for Planet data. | ||
Args: | ||
orthotiles_dir (Path): The directory containing the PlanetScope orthotiles. | ||
Defaults to Path("data/input/planet/PSOrthoTile"). | ||
scenes_dir (Path): The directory containing the PlanetScope scenes. | ||
Defaults to Path("data/input/planet/PSScene"). | ||
output_data_dir (Path): The "output" directory. Defaults to Path("data/output"). | ||
arcticdem_slope_vrt (Path): The path to the ArcticDEM slope VRT file. | ||
Defaults to Path("data/input/ArcticDEM/slope.vrt"). | ||
arcticdem_elevation_vrt (Path): The path to the ArcticDEM elevation VRT file. | ||
Defaults to Path("data/input/ArcticDEM/elevation.vrt"). | ||
tcvis_dir (Path): The directory containing the TCVis data. Defaults to Path("data/download/tcvis"). | ||
model_dir (Path): The path to the models to use for segmentation. Defaults to Path("models"). | ||
tcvis_model_name (str, optional): The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt". | ||
notcvis_model_name (str, optional): The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt". | ||
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on. | ||
If "cuda" take the first device (0), if int take the specified device. | ||
If "auto" try to automatically select a free GPU (<50% memory usage). | ||
Defaults to "cuda" if available, else "cpu". | ||
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if | ||
project is defined within persistent API credentials obtained via `earthengine authenticate`. | ||
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com). | ||
patch_size (int, optional): The patch size to use for inference. Defaults to 1024. | ||
overlap (int, optional): The overlap to use for inference. Defaults to 16. | ||
batch_size (int, optional): The batch size to use for inference. Defaults to 8. | ||
reflection (int, optional): The reflection padding to use for inference. Defaults to 0. | ||
binarization_threshold (float, optional): The threshold to binarize the probabilities. Defaults to 0.5. | ||
mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping. | ||
Defaults to 10. | ||
min_object_size (int, optional): The minimum object size to keep in pixel. Defaults to 32. | ||
use_quality_mask (bool, optional): Whether to use the "quality" mask instead of the "valid" mask | ||
to mask the output. | ||
write_model_outputs (bool, optional): Also save the model outputs, not only the ensemble result. | ||
Defaults to False. | ||
Examples: | ||
### PS Orthotile | ||
Data directory structure: | ||
```sh | ||
data/input | ||
├── ArcticDEM | ||
│ ├── elevation.vrt | ||
│ ├── slope.vrt | ||
│ ├── relative_elevation | ||
│ │ └── 4372514_relative_elevation_100.tif | ||
│ └── slope | ||
│ └── 4372514_slope.tif | ||
└── planet | ||
└── PSOrthoTile | ||
└── 4372514/5790392_4372514_2022-07-16_2459 | ||
├── 5790392_4372514_2022-07-16_2459_BGRN_Analytic_metadata.xml | ||
├── 5790392_4372514_2022-07-16_2459_BGRN_DN_udm.tif | ||
├── 5790392_4372514_2022-07-16_2459_BGRN_SR.tif | ||
├── 5790392_4372514_2022-07-16_2459_metadata.json | ||
└── 5790392_4372514_2022-07-16_2459_udm2.tif | ||
``` | ||
then the config should be | ||
``` | ||
... | ||
orthotiles_dir: data/input/planet/PSOrthoTile | ||
arcticdem_slope_vrt: data/input/ArcticDEM/slope.vrt | ||
arcticdem_elevation_vrt: data/input/ArcticDEM/elevation.vrt | ||
``` | ||
### PS Scene | ||
Data directory structure: | ||
```sh | ||
data/input | ||
├── ArcticDEM | ||
│ ├── elevation.vrt | ||
│ ├── slope.vrt | ||
│ ├── relative_elevation | ||
│ │ └── 4372514_relative_elevation_100.tif | ||
│ └── slope | ||
│ └── 4372514_slope.tif | ||
└── planet | ||
└── PSScene | ||
└── 20230703_194241_43_2427 | ||
├── 20230703_194241_43_2427_3B_AnalyticMS_metadata.xml | ||
├── 20230703_194241_43_2427_3B_AnalyticMS_SR.tif | ||
├── 20230703_194241_43_2427_3B_udm2.tif | ||
├── 20230703_194241_43_2427_metadata.json | ||
└── 20230703_194241_43_2427.json | ||
``` | ||
then the config should be | ||
``` | ||
... | ||
scenes_dir: data/input/planet/PSScene | ||
arcticdem_slope_vrt: data/input/ArcticDEM/slope.vrt | ||
arcticdem_elevation_vrt: data/input/ArcticDEM/elevation.vrt | ||
``` | ||
""" | ||
|
||
def _get_data(self, fpath: Path): | ||
from darts_acquisition.arcticdem import load_arcticdem_from_vrt | ||
from darts_acquisition.planet import load_planet_masks, load_planet_scene | ||
from darts_acquisition.tcvis import load_tcvis | ||
|
||
optical = load_planet_scene(fpath) | ||
arcticdem = load_arcticdem_from_vrt(self.arcticdem_slope_vrt, self.arcticdem_elevation_vrt, optical) | ||
tcvis = load_tcvis(optical.odc.geobox, self.tcvis_dir) | ||
data_masks = load_planet_masks(fpath) | ||
aqdata = AquisitionData(optical, arcticdem, tcvis, data_masks) | ||
return aqdata | ||
|
||
|
||
def run_native_planet_pipeline(*, pipeline: Annotated[LegacyNativePlanetPipeline, Parameter("*")]): | ||
"""Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.""" | ||
pipeline.run() |
Oops, something went wrong.