Skip to content

Commit

Permalink
GlobalBuildingMap: add new dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Dec 16, 2024
1 parent 60ba6cb commit 460556c
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ Chesapeake Land Cover
.. autoclass:: ChesapeakeWV
.. autoclass:: ChesapeakeCVPR

GlobalBuildingMap
^^^^^^^^^^^^^^^^^

.. autoclass:: GlobalBuildingMap

Global Mangrove Distribution
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`Aster Global DEM`_,DEM,Aster,"public domain","3,601x3,601",30
`Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,-
`Chesapeake Land Cover`_,"Imagery, Masks",NAIP,"CC0-1.0",-,1
`GlobalBuildingMap`_,Masks,PlanetScope,CC-BY-4.0,180K,3
`Global Mangrove Distribution`_,Masks,"Remote Sensing, In Situ Measurements","public domain",-,3
`Cropland Data Layer`_,Masks,Landsat,"public domain",-,30
`EDDMapS`_,Points,Citizen Scientists,-,-,-
Expand Down
Binary file added tests/data/gbm/GBM_v1_e000_n10_e005_n05.tif
Binary file not shown.
28 changes: 28 additions & 0 deletions tests/data/gbm/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36

np.random.seed(0)

profile = {
'driver': 'GTiff',
'dtype': 'uint8',
'width': SIZE,
'height': SIZE,
'count': 1,
'crs': CRS.from_epsg(3857),
'transform': Affine(3.0, 0.0, -1333.7802539161332, 0.0, -3.0, 1120234.423089223),
}

Z = np.random.choice([0, 255], size=(SIZE, SIZE))

with rasterio.open('GBM_v1_e000_n10_e005_n05.tif', 'w', **profile) as src:
src.write(Z, 1)
59 changes: 59 additions & 0 deletions tests/datasets/test_gbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
GlobalBuildingMap,
IntersectionDataset,
UnionDataset,
)


class TestGlobalBuildingMap:
@pytest.fixture
def dataset(self) -> GlobalBuildingMap:
paths = os.path.join('tests', 'data', 'gbm')
return GlobalBuildingMap(paths)

def test_getitem(self, dataset: GlobalBuildingMap) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
assert isinstance(x['mask'], torch.Tensor)

def test_len(self, dataset: GlobalBuildingMap) -> None:
assert len(dataset) == 1

def test_and(self, dataset: GlobalBuildingMap) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: GlobalBuildingMap) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_plot(self, dataset: GlobalBuildingMap) -> None:
sample = dataset[dataset.bounds]
sample['prediction'] = sample['mask']
dataset.plot(sample, suptitle='Test')
plt.close()

def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GlobalBuildingMap(tmp_path)

def test_invalid_query(self, dataset: GlobalBuildingMap) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .forestdamage import ForestDamage
from .ftw import FieldsOfTheWorld
from .gbif import GBIF
from .gbm import GlobalBuildingMap
from .geo import (
GeoDataset,
IntersectionDataset,
Expand Down Expand Up @@ -216,6 +217,7 @@
'GeoDataset',
'GeoNRW',
'GlobBiomass',
'GlobalBuildingMap',
'HySpecNet11k',
'IDTReeS',
'INaturalist',
Expand Down
69 changes: 69 additions & 0 deletions torchgeo/datasets/gbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""GlobalBuildingMap."""

from typing import Any

import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from .geo import RasterDataset


class GlobalBuildingMap(RasterDataset):
"""GlobalBuildingMap dataset.
The GlobalBuildingMap (GBM) dataset provides the highest resolution and highest
accuracy building footprint map on a global scale ever created. GBM was generated
by training and applying modern deep neural networks on nearly 800,000 satellite
images. The dataset is stored in 5 by 5 degree tiles in geotiff format.
The GlobalBuildingMap is generated by applying an ensemble of deep neural networks
on nearly 800,000 satellite images of about 3m resolution. The deep neural networks
were trained with manually inspected training samples generated from OpenStreetMap.
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/2404.13911
.. versionadded:: 0.7
"""

filename_glob = 'GBM_v1_*'
is_image = False

def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: str | None = None,
) -> Figure:
"""Plot a sample from the dataset.
Args:
sample: A sample returned by :meth:`RasterDataset.__getitem__`.
show_titles: Flag indicating whether to show titles above each panel.
suptitle: Optional string to use as a suptitle.
Returns:
A matplotlib Figure with the rendered sample.
"""
ncols = 2 if 'prediction' in sample else 1
fig, axs = plt.subplots(ncols=ncols, squeeze=False)

axs[0, 0].imshow(sample['mask'])
axs[0, 0].axis('off')
if show_titles:
axs[0, 0].set_title('Mask')

if 'prediction' in sample:
axs[0, 1].imshow(sample['prediction'])
axs[0, 1].axis('off')
if show_titles:
axs[0, 1].set_title('Prediction')

if suptitle is not None:
plt.suptitle(suptitle)

return fig

0 comments on commit 460556c

Please sign in to comment.