Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MMFlood dataset #2450

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ L8 Biome

.. autoclass:: L8BiomeDataModule

MMFlood
^^^^^^^^

.. autoclass:: MMFloodDataModule

NAIP
^^^^

Expand Down
4 changes: 4 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ Landsat
.. autoclass:: Landsat2
.. autoclass:: Landsat1

MMFlood
^^^^^^^
.. autoclass:: MMFlood

NAIP
^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/datasets/geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m)
`L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30"
`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30
`MMFlood`_,"Imagery,DEM,Masks","Sentinel, MapZen/TileZen, OpenStreetMap",MIT,"2,147x2,313",20
`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2
`NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10
`NLCD`_,Masks,Landsat,"public domain",-,30
Expand Down
19 changes: 19 additions & 0 deletions tests/conf/mmflood.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 4
num_classes: 2
num_filters: 1
ignore_index: 255
data:
class_path: MMFloodDataModule
init_args:
batch_size: 1
dict_kwargs:
root: 'tests/data/mmflood'
patch_size: 8
include_dem: True
include_hydro: True
1 change: 1 addition & 0 deletions tests/data/mmflood/activations.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"EMSR000": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR000_00"]}, "EMSR001": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "train", "delineations": ["EMSR001_00"]}, "EMSR003": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "val", "delineations": ["EMSR003_00"]}, "EMSR004": {"title": "Test flood", "type": "Flood", "country": "N/A", "start": "2014-11-06T17:57:00", "end": "2015-01-29T12:47:04", "lat": 45.82427031690563, "lon": 14.484407562009336, "subset": "test", "delineations": ["EMSR004_00"]}}
Binary file added tests/data/mmflood/activations.tar.000.gz.part
Binary file not shown.
Binary file added tests/data/mmflood/activations.tar.001.gz.part
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
135 changes: 135 additions & 0 deletions tests/data/mmflood/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import json
import os
import tarfile

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_data(
path: str, filename: str, height: int, width: int, include_hydro: bool = False
) -> None:
max_value = 1000.0
min_value = 0.0
interval = max_value - min_value
folders = ['s1_raw', 'DEM', 'mask', 'hydro']
profile = {
'driver': 'GTiff',
'dtype': 'float32',
'nodata': None,
'crs': CRS.from_epsg(4326),
'transform': Affine(
0.0001287974837883981,
0.0,
14.438064999669106,
0.0,
-8.989523639880024e-05,
45.71617928533084,
),
'blockysize': 1,
'tiled': False,
'interleave': 'pixel',
'height': height,
'width': width,
}
data = {
's1_raw': np.random.rand(2, height, width).astype(np.float32) * interval
- min_value,
'DEM': np.random.rand(1, height, width).astype(np.float32) * interval
- min_value,
'mask': np.random.randint(low=0, high=2, size=(1, height, width)).astype(
np.uint8
),
}

if include_hydro:
data['hydro'] = (
np.random.rand(1, height, width).astype(np.float32) * interval - min_value
)

for folder in folders:
folder_path = os.path.join(path, folder)
os.makedirs(folder_path, exist_ok=True)
filepath = os.path.join(folder_path, filename)
profile2 = profile.copy()
profile2['count'] = 2 if folder == 's1_raw' else 1
if folder in data:
with rasterio.open(filepath, mode='w', **profile2) as src:
src.write(data[folder])


def generate_tar_gz(src: str, dst: str) -> None:
with tarfile.open(dst, 'w:gz') as tar:
tar.add(src, arcname=src)


def split_tar(path: str, dst: str, nparts: int) -> None:
fstats = os.stat(path)
size = fstats.st_size
chunk = size // nparts

with open(path, 'rb') as fp:
for idx in range(nparts):
part_path = os.path.join(dst, f'activations.tar.{idx:03}.gz.part')

bytes_to_write = chunk if idx < nparts - 1 else size - fp.tell()
with open(part_path, 'wb') as dst_fp:
dst_fp.write(fp.read(bytes_to_write))


def generate_folders_and_metadata(datapath: str, metadatapath: str) -> None:
folders_splits = [
('EMSR000', 'train'),
('EMSR001', 'train'),
('EMSR003', 'val'),
('EMSR004', 'test'),
]
num_files = {'EMSR000': 3, 'EMSR001': 2, 'EMSR003': 2, 'EMSR004': 1}
num_hydro = {'EMSR001': 2, 'EMSR003': 1, 'EMSR004': 1}
metadata = {}
for folder, split in folders_splits:
data = {}
data['title'] = 'Test flood'
data['type'] = 'Flood'
data['country'] = 'N/A'
data['start'] = '2014-11-06T17:57:00'
data['end'] = '2015-01-29T12:47:04'
data['lat'] = 45.82427031690563
data['lon'] = 14.484407562009336
data['subset'] = split
data['delineations'] = [f'{folder}_00']

count_hydro = 0

dst_folder = os.path.join(datapath, f'{folder}-0')
for idx in range(num_files[folder]):
include_hydro = count_hydro < num_hydro.get(folder, 0)
generate_data(
dst_folder,
filename=f'{folder}-{idx}.tif',
height=16,
width=16,
include_hydro=include_hydro,
)
if include_hydro:
count_hydro += 1

metadata[folder] = data

generate_tar_gz(src='activations', dst='activations.tar.gz')
split_tar(path='activations.tar.gz', dst='.', nparts=2)
os.remove('activations.tar.gz')
with open(os.path.join(metadatapath, 'activations.json'), 'w') as fp:
json.dump(metadata, fp)


if __name__ == '__main__':
datapath = os.path.join(os.getcwd(), 'activations')
metadatapath = os.getcwd()

generate_folders_and_metadata(datapath, metadatapath)
111 changes: 111 additions & 0 deletions tests/datasets/test_mmflood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from itertools import product
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
MMFlood,
UnionDataset,
)


class TestMMFlood:
@pytest.fixture(
params=product([True, False], [True, False], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MMFlood:
url = os.path.join('tests', 'data', 'mmflood') + os.sep

monkeypatch.setattr(MMFlood, 'url', url)
monkeypatch.setattr(MMFlood, '_nparts', 2)

include_dem, include_hydro, split = request.param
root = tmp_path
return MMFlood(
root,
split=split,
include_dem=include_dem,
include_hydro=include_hydro,
transforms=nn.Identity(),
download=True,
checksum=True,
)

def test_getitem(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x['crs'], CRS)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
nchannels = 2

# If DEM is included and hydro is included, check if 4 channels are present,
# If only one between DEM or hydro is included, check if 3 channels are present
# 2 otherwise
if dataset.include_dem:
nchannels += 1
if dataset.include_hydro:
nchannels += 1
assert x['image'].size(0) == nchannels

def test_len(self, dataset: MMFlood) -> None:
if dataset.split == 'train':
if not dataset.include_hydro:
assert len(dataset) == 5
else:
assert len(dataset) == 2
elif dataset.split == 'val':
if not dataset.include_hydro:
assert len(dataset) == 2
else:
assert len(dataset) == 1
else:
assert len(dataset) == 1

def test_and(self, dataset: MMFlood) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: MMFlood) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: MMFlood) -> None:
MMFlood(root=dataset.root)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MMFlood(tmp_path)

def test_plot(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle='Test')
plt.close()

def test_plot_prediction(self, dataset: MMFlood) -> None:
x = dataset[dataset.bounds]
x['prediction'] = x['mask'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()

def test_invalid_query(self, dataset: MMFlood) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match='query: .* not found in index with bounds:'
):
dataset[query]
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TestSemanticSegmentationTask:
'landcoverai',
'landcoverai100',
'loveda',
'mmflood',
'naipchesapeake',
'potsdam2d',
'sen12ms_all',
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule
from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .mmflood import MMFloodDataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
from .oscd import OSCDDataModule
Expand Down Expand Up @@ -87,6 +88,7 @@
'LandCoverAI100DataModule',
'LandCoverAIDataModule',
'LoveDADataModule',
'MMFloodDataModule',
'MisconfigurationException',
'NAIPChesapeakeDataModule',
'NASAMarineDebrisDataModule',
Expand Down
Loading
Loading