Skip to content

Commit

Permalink
export: add polygonized export +test
Browse files Browse the repository at this point in the history
  • Loading branch information
iona5 committed Oct 14, 2024
1 parent 5cd7b6d commit 52381ea
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
5 changes: 5 additions & 0 deletions darts-export/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ authors = [

[project.optional-dependencies]
test = ["pytest"]
gdal39 = ["gdal==3.9.2"]
gdal38 = ["gdal==3.8.5"]
gdal384 = ["gdal==3.8.4"]
gdal37 = ["gdal==3.7.3"]
gdal36 = ["gdal==3.6.4"]

[build-system]
requires = ["hatchling"]
Expand Down
73 changes: 71 additions & 2 deletions darts-export/src/darts_export/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Darts export module for inference results."""

import logging
from pathlib import Path

import xarray
from osgeo import gdal, gdal_array, ogr

gdal.UseExceptions()
L = logging.getLogger("darts.export")


class InferenceResultWriter:
Expand Down Expand Up @@ -45,5 +50,69 @@ def export_binarized(self, path: Path, filename="pred_binarized.tif", tags={}):
self.ds.binarized_segmentation.rio.to_raster(file_path, driver="GTiff", tags=tags, compress="LZW")
return file_path

# def export_vectors(self, path: Path, filename_prefix="pred_segments"):
# pass
def export_polygonized(self, path: Path, filename_prefix="pred_segments"):
"""Export the binarized probabilities as a vector dataset in GeoPackage format.
If the gdal installation supports GeoParquet files, an additional file will be written in this format.
Args:
path (Path): The path where to export the files
filename_prefix (str, optional): the file prefix of the exported files. Defaults to "pred_segments".
"""
# extract the layer with the binarized probabilities and create an in-memory
# gdal dataset from it:
layer = self.ds.binarized_segmentation
dta = gdal_array.OpenArray(layer.values)

# copy over the geodata
# the transform object of rasterio has to be converted into a tuple
affine_transform = layer.rio.transform()
geotransform = (
affine_transform.c,
affine_transform.a,
affine_transform.b,
affine_transform.f,
affine_transform.d,
affine_transform.e,
)
dta.SetGeoTransform(geotransform)
dta.SetProjection(layer.rio.crs.to_wkt())

# create the vector output datasets
gpkg_drv = ogr.GetDriverByName("GPKG")
gpkg_ds = gpkg_drv.CreateDataSource(path / f"{filename_prefix}.gpkg")
output_layer = gpkg_ds.CreateLayer("filename_prefix", geom_type=ogr.wkbPolygon, srs=dta.GetSpatialRef())

# add the field where to store the polygonization threshold
field = ogr.FieldDefn("DN", ogr.OFTInteger)
output_layer.CreateField(field)

# do the polygonization
gdal.Polygonize(
dta.GetRasterBand(1),
None, # no masking, polygonize everything
output_layer, # where to write the vector data to
0, # write the polygonization threshold in the first attribute ("DN")
)

# remove all features where DN is not 1
output_layer.SetAttributeFilter("DN != 1")
for feature in output_layer:
feature_id = feature.GetFID()
output_layer.DeleteFeature(feature_id)

output_layer.SetAttributeFilter(None)

parquet_drv = ogr.GetDriverByName("Parquet")
if parquet_drv is not None:
parquet_ds = parquet_drv.CreateDataSource(path / f"{filename_prefix}.parquet")
parquet_ds.CopyLayer(output_layer, filename_prefix)
else:
L.warning(
"export of polygonized inference data in geoparquet failed,"
" because the GDAL installation does not support it."
)

gpkg_ds = None
output_layer = None
22 changes: 22 additions & 0 deletions darts-export/test/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from pathlib import Path # noqa: I001

import geopandas as gpd
import rasterio
from darts_export import inference
from osgeo import ogr
from xarray import Dataset


Expand All @@ -28,3 +30,23 @@ def test_writeBinarization(probabilities: Dataset, tmp_path: Path):
assert rio_ds.width == 16
assert rio_ds.height == 16
assert rio_ds.dtypes[0] == "uint8"


def test_writeVectors(probabilities: Dataset, tmp_path: Path):
gdal_has_parquet = ogr.GetDriverByName("Parquet") is not None

ds = inference.InferenceResultWriter(probabilities)
ds.export_polygonized(tmp_path)

assert (tmp_path / "pred_segments.gpkg").is_file()
if gdal_has_parquet:
assert (tmp_path / "pred_segments.parquet").is_file()

gdf = gpd.read_file(tmp_path / "pred_segments.gpkg")
assert gdf.shape == (1, 2) # one feature, two attributes (fid, DN)
assert gdf.iloc[0].geometry.wkt == (
"POLYGON (("
"500007.5 4500003.5, 500007.5 4500004.5, 500006.5 4500004.5, 500006.5 4500008.5, 500007.5 4500008.5, "
"500007.5 4500009.5, 500009.5 4500009.5, 500009.5 4500008.5, 500010.5 4500008.5, 500010.5 4500004.5, "
"500009.5 4500004.5, 500009.5 4500003.5, 500007.5 4500003.5))"
)

0 comments on commit 52381ea

Please sign in to comment.