From 6d62f9d860e2c9ea92eb00b2268c8b8b88ed7d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Mon, 16 Dec 2024 23:01:04 +0100 Subject: [PATCH] Make datacubes mp-safe --- .../darts_acquisition/arcticdem/datacube.py | 25 +++++++++---------- .../src/darts_acquisition/tcvis.py | 7 +++++- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/darts-acquisition/src/darts_acquisition/arcticdem/datacube.py b/darts-acquisition/src/darts_acquisition/arcticdem/datacube.py index 962685a..c261cb5 100644 --- a/darts-acquisition/src/darts_acquisition/arcticdem/datacube.py +++ b/darts-acquisition/src/darts_acquisition/arcticdem/datacube.py @@ -2,6 +2,7 @@ import io import logging +import multiprocessing as mp import time import zipfile from pathlib import Path @@ -66,6 +67,11 @@ # "datamask": {"dtype": "bool", "_FillValue": False}, } +# Lock for downloading the data +# This will block other processes from downloading or processing the data, until the current download is finished. +# This may result in halting a tile-process for a while, even if it's data was already downloaded. +download_lock = mp.Lock() + def download_arcticdem_extent(dem_data_dir: Path): """Download the ArcticDEM mosaic extent info from the provided URL and extracts it to the specified directory. @@ -173,10 +179,6 @@ def procedural_download_datacube(storage: zarr.storage.Store, tiles: gpd.GeoData arcticdem_datacube = xr.open_zarr(storage, mask_and_scale=False) loaded_tiles: list[str] = arcticdem_datacube.attrs.get("loaded_tiles", []).copy() - # TODO: Add another attribute called "loading_tiles" to keep track of the tiles that are currently being loaded - # This is necessary for parallel loading of the data - # Maybe it would be better to turn this into a class which is meant to be used as singleton and can store - # the loaded tiles as state # Collect all tiles which should be downloaded new_tiles = tiles[~tiles.dem_id.isin(loaded_tiles)] @@ -270,14 +272,9 @@ def load_arcticdem_tile( Final dataset is NOT matched to the reference CRS and resolution. Warning: - 1. This function is not thread-safe. Thread-safety might be added in the future. - 2. Geobox must be in a meter based CRS. + Geobox must be in a meter based CRS. """ - # TODO: Thread-safety concers: - # - How can we ensure that the same arcticdem tile is not downloaded twice at the same time? - # - How can we ensure that the extent is not downloaded twice at the same time? - tick_fstart = time.perf_counter() datacube_fpath = data_dir / f"datacube_{resolution}m_v4.1.zarr" @@ -305,8 +302,9 @@ def load_arcticdem_tile( # because of the network overhead, hence we use the extent file # Download the extent, download if the file does not exist extent_fpath = data_dir / f"ArcticDEM_Mosaic_Index_v4_1_{resolution}m.parquet" - if not extent_fpath.exists(): - download_arcticdem_extent(data_dir) + with download_lock: + if not extent_fpath.exists(): + download_arcticdem_extent(data_dir) extent = gpd.read_parquet(extent_fpath) # Add a buffer around the geobox to get the adjacent tiles @@ -314,7 +312,8 @@ def load_arcticdem_tile( adjacent_tiles = extent[extent.intersects(reference_geobox.extent.geom)] # Download the adjacent tiles (if necessary) - procedural_download_datacube(storage, adjacent_tiles) + with download_lock: + procedural_download_datacube(storage, adjacent_tiles) # Load the datacube and set the spatial_ref since it is set as a coordinate within the zarr format chunks = None if persist else "auto" diff --git a/darts-acquisition/src/darts_acquisition/tcvis.py b/darts-acquisition/src/darts_acquisition/tcvis.py index 9ba687d..e9d1584 100644 --- a/darts-acquisition/src/darts_acquisition/tcvis.py +++ b/darts-acquisition/src/darts_acquisition/tcvis.py @@ -1,6 +1,7 @@ """Landsat Trends related Data Loading. Should be used temporary and maybe moved to the acquisition package.""" import logging +import multiprocessing as mp import time import warnings from pathlib import Path @@ -43,6 +44,9 @@ "tc_wetness": {"dtype": "uint8"}, } +# Lock for downloading the data +download_lock = mp.Lock() + def procedural_download_datacube(storage: zarr.storage.Store, geobox: GeoBox): """Download the TCVIS data procedurally and add it to the datacube. @@ -195,7 +199,8 @@ def load_tcvis( # Download the adjacent tiles (if necessary) reference_geobox = geobox.to_crs("epsg:4326", resolution=DATA_EXTENT.resolution.x).pad(buffer) - procedural_download_datacube(storage, reference_geobox) + with download_lock: + procedural_download_datacube(storage, reference_geobox) # Load the datacube and set the spatial_ref since it is set as a coordinate within the zarr format chunks = None if persist else "auto"