-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
60ba6cb
commit 460556c
Showing
7 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import numpy as np | ||
import rasterio | ||
from rasterio import Affine | ||
from rasterio.crs import CRS | ||
|
||
SIZE = 36 | ||
|
||
np.random.seed(0) | ||
|
||
profile = { | ||
'driver': 'GTiff', | ||
'dtype': 'uint8', | ||
'width': SIZE, | ||
'height': SIZE, | ||
'count': 1, | ||
'crs': CRS.from_epsg(3857), | ||
'transform': Affine(3.0, 0.0, -1333.7802539161332, 0.0, -3.0, 1120234.423089223), | ||
} | ||
|
||
Z = np.random.choice([0, 255], size=(SIZE, SIZE)) | ||
|
||
with rasterio.open('GBM_v1_e000_n10_e005_n05.tif', 'w', **profile) as src: | ||
src.write(Z, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import pytest | ||
import torch | ||
from rasterio.crs import CRS | ||
|
||
from torchgeo.datasets import ( | ||
BoundingBox, | ||
DatasetNotFoundError, | ||
GlobalBuildingMap, | ||
IntersectionDataset, | ||
UnionDataset, | ||
) | ||
|
||
|
||
class TestGlobalBuildingMap: | ||
@pytest.fixture | ||
def dataset(self) -> GlobalBuildingMap: | ||
paths = os.path.join('tests', 'data', 'gbm') | ||
return GlobalBuildingMap(paths) | ||
|
||
def test_getitem(self, dataset: GlobalBuildingMap) -> None: | ||
x = dataset[dataset.bounds] | ||
assert isinstance(x, dict) | ||
assert isinstance(x['crs'], CRS) | ||
assert isinstance(x['mask'], torch.Tensor) | ||
|
||
def test_len(self, dataset: GlobalBuildingMap) -> None: | ||
assert len(dataset) == 1 | ||
|
||
def test_and(self, dataset: GlobalBuildingMap) -> None: | ||
ds = dataset & dataset | ||
assert isinstance(ds, IntersectionDataset) | ||
|
||
def test_or(self, dataset: GlobalBuildingMap) -> None: | ||
ds = dataset | dataset | ||
assert isinstance(ds, UnionDataset) | ||
|
||
def test_plot(self, dataset: GlobalBuildingMap) -> None: | ||
sample = dataset[dataset.bounds] | ||
sample['prediction'] = sample['mask'] | ||
dataset.plot(sample, suptitle='Test') | ||
plt.close() | ||
|
||
def test_no_data(self, tmp_path: Path) -> None: | ||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'): | ||
GlobalBuildingMap(tmp_path) | ||
|
||
def test_invalid_query(self, dataset: GlobalBuildingMap) -> None: | ||
query = BoundingBox(0, 0, 0, 0, 0, 0) | ||
with pytest.raises( | ||
IndexError, match='query: .* not found in index with bounds:' | ||
): | ||
dataset[query] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
"""GlobalBuildingMap.""" | ||
|
||
from typing import Any | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.figure import Figure | ||
|
||
from .geo import RasterDataset | ||
|
||
|
||
class GlobalBuildingMap(RasterDataset): | ||
"""GlobalBuildingMap dataset. | ||
The GlobalBuildingMap (GBM) dataset provides the highest resolution and highest | ||
accuracy building footprint map on a global scale ever created. GBM was generated | ||
by training and applying modern deep neural networks on nearly 800,000 satellite | ||
images. The dataset is stored in 5 by 5 degree tiles in geotiff format. | ||
The GlobalBuildingMap is generated by applying an ensemble of deep neural networks | ||
on nearly 800,000 satellite images of about 3m resolution. The deep neural networks | ||
were trained with manually inspected training samples generated from OpenStreetMap. | ||
If you use this dataset in your research, please cite the following paper: | ||
* https://arxiv.org/abs/2404.13911 | ||
.. versionadded:: 0.7 | ||
""" | ||
|
||
filename_glob = 'GBM_v1_*' | ||
is_image = False | ||
|
||
def plot( | ||
self, | ||
sample: dict[str, Any], | ||
show_titles: bool = True, | ||
suptitle: str | None = None, | ||
) -> Figure: | ||
"""Plot a sample from the dataset. | ||
Args: | ||
sample: A sample returned by :meth:`RasterDataset.__getitem__`. | ||
show_titles: Flag indicating whether to show titles above each panel. | ||
suptitle: Optional string to use as a suptitle. | ||
Returns: | ||
A matplotlib Figure with the rendered sample. | ||
""" | ||
ncols = 2 if 'prediction' in sample else 1 | ||
fig, axs = plt.subplots(ncols=ncols, squeeze=False) | ||
|
||
axs[0, 0].imshow(sample['mask']) | ||
axs[0, 0].axis('off') | ||
if show_titles: | ||
axs[0, 0].set_title('Mask') | ||
|
||
if 'prediction' in sample: | ||
axs[0, 1].imshow(sample['prediction']) | ||
axs[0, 1].axis('off') | ||
if show_titles: | ||
axs[0, 1].set_title('Prediction') | ||
|
||
if suptitle is not None: | ||
plt.suptitle(suptitle) | ||
|
||
return fig |