diff --git a/darts-export/pyproject.toml b/darts-export/pyproject.toml index 47f5c4f..b05d305 100644 --- a/darts-export/pyproject.toml +++ b/darts-export/pyproject.toml @@ -4,11 +4,12 @@ version = "0.1.0" description = "Dataset export for the DARTS dataset." dependencies = [ "numpy>=1.26.3, <2", - "xarray>=2024.9.0", + "xarray>=2024", "rasterio>=1.4.0", "rioxarray>=0.17.0", "h5netcdf>=1.3.0", "geopandas>=1.0.1", + "pytest" ] readme = "README.md" requires-python = ">= 3.11" diff --git a/darts-export/src/darts_export/__init__.py b/darts-export/src/darts_export/__init__.py index 25ed7d0..e69de29 100644 --- a/darts-export/src/darts_export/__init__.py +++ b/darts-export/src/darts_export/__init__.py @@ -1,11 +0,0 @@ -"""Dataset export for the DARTS dataset.""" - - -def hello() -> str: - """Say hello to the user. - - Returns: - str: Greating message. - - """ - return "Hello from darts-export!" diff --git a/darts-export/src/darts_export/inference.py b/darts-export/src/darts_export/inference.py new file mode 100755 index 0000000..3af0c76 --- /dev/null +++ b/darts-export/src/darts_export/inference.py @@ -0,0 +1,22 @@ +from pathlib import Path +import xarray +import rioxarray + +class InferenceResultWriter: + + def __init__(self, ds) -> None: + self.ds:xarray.Dataset = ds + + def export_probabilities(self, path:Path, filename="pred_probabilities.tif", tags=dict()): + + # write the probability layer from the raster to a GeoTiff + self.ds.probabilities.rio.to_raster(path / filename, driver="GTiff", tags=tags, compress="LZW") + + def export_binarized(self, path:Path, filename="pred_binarized.tif", tags=dict()): + + self.ds.binarized_segmentation.rio.to_raster(path / filename, driver="GTiff", tags=tags, compress="LZW") + + + def export_vectors(self, path:Path, filename_prefix="pred_segments"): + pass + diff --git a/darts-export/test/conftest.py b/darts-export/test/conftest.py new file mode 100755 index 0000000..36f4924 --- /dev/null +++ b/darts-export/test/conftest.py @@ -0,0 +1,15 @@ +import pytest +import xarray, numpy as np +import rioxarray + +@pytest.fixture +def probabilities(): + + xds = xarray.DataArray(np.randint(0, 100, 16**2), + dims=("y", "x"), + coords={"x": range(500000, 500016), "y": range(4500000, 4500016)}, + ) + xds.rio.write_crs("EPSG:32601", inplace=True) + xds.rio.set_spatial_dims(x_dim='x', y_dim='y', inplace=True) + xds.rio.write_transform(inplace=True) + return xds \ No newline at end of file diff --git a/darts-export/test/test_inference.py b/darts-export/test/test_inference.py new file mode 100755 index 0000000..4726bd2 --- /dev/null +++ b/darts-export/test/test_inference.py @@ -0,0 +1,5 @@ +#import + + +def test_writeProbabilities(probabilities): + pass \ No newline at end of file