diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index b4158a796d7..5d11f4da869 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -52,10 +52,10 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: def test_getitem(self, dataset: NASAMarineDebris) -> None: x = dataset[0] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - assert isinstance(x["bbox_xyxy"], torch.Tensor) - assert x["image"].shape[0] == 3 - assert x["bbox_xyxy"].shape[-1] == 4 + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert x['bbox_xyxy'].shape[-1] == 4 def test_len(self, dataset: NASAMarineDebris) -> None: assert len(dataset) == 4 @@ -99,6 +99,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x["prediction_boxes"] = x["bbox_xyxy"].clone() + x['prediction_boxes'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 28d22f91772..898d1296c23 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -48,12 +48,12 @@ def test_getitem(self, dataset: VHR10) -> None: for i in range(2): x = dataset[i] assert isinstance(x, dict) - assert isinstance(x["image"], torch.Tensor) - if dataset.split == "positive": - assert isinstance(x["class"], torch.Tensor) - assert isinstance(x["bbox_xyxy"], torch.Tensor) - if "mask" in x: - assert isinstance(x["mask"], torch.Tensor) + assert isinstance(x['image'], torch.Tensor) + if dataset.split == 'positive': + assert isinstance(x['class'], torch.Tensor) + assert isinstance(x['bbox_xyxy'], torch.Tensor) + if 'mask' in x: + assert isinstance(x['mask'], torch.Tensor) def test_len(self, dataset: VHR10) -> None: if dataset.split == 'positive': @@ -91,10 +91,10 @@ def test_plot(self, dataset: VHR10) -> None: scores = [0.7, 0.3, 0.7] for i in range(3): x = dataset[i] - x["prediction_labels"] = x["class"] - x["prediction_boxes"] = x["bbox_xyxy"] - x["prediction_scores"] = torch.Tensor([scores[i]]) - if "masks" in x: - x["prediction_masks"] = x["mask"] - dataset.plot(x, show_feats="masks") + x['prediction_labels'] = x['class'] + x['prediction_boxes'] = x['bbox_xyxy'] + x['prediction_scores'] = torch.Tensor([scores[i]]) + if 'masks' in x: + x['prediction_masks'] = x['mask'] + dataset.plot(x, show_feats='masks') plt.close() diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 05b2bd62537..c5c86d75756 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -20,7 +20,6 @@ GridGeoSampler, RandomBatchGeoSampler, ) -from ..transforms import AugmentationSequential from .utils import MisconfigurationException diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 73c18c43dee..2fcba2236c3 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -11,7 +11,7 @@ from ..datasets import NASAMarineDebris from .geo import NonGeoDataModule -from .utils import collate_fn_detection, dataset_split +from .utils import collate_fn_detection class NASAMarineDebrisDataModule(NonGeoDataModule): diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index 95ac01e2ded..5ae7127ab68 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -5,14 +5,11 @@ import math from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any import numpy as np import torch -from torch import Generator, Tensor -from torch.utils.data import Subset, TensorDataset, random_split - -from ..datasets import NonGeoDataset +from torch import Tensor # Based on lightning_lite.utilities.exceptions @@ -32,23 +29,23 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]: .. versionadded:: 0.6 """ output: dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) + output['image'] = torch.stack([sample['image'] for sample in batch]) # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} - bbox_key = "boxes" + bbox_key = 'boxes' for key in batch[0].keys(): - if key in {"bbox", "bbox_xyxy", "bbox_xywh"}: + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: bbox_key = key output[bbox_key] = [sample[bbox_key].float() for sample in batch] - if "class" in batch[0].keys(): - output["class"] = [sample["class"] for sample in batch] + if 'class' in batch[0].keys(): + output['class'] = [sample['class'] for sample in batch] else: - output["class"] = [ + output['class'] = [ torch.tensor([1] * len(sample[bbox_key])) for sample in batch ] - if "mask" in batch[0]: - output["mask"] = [sample["mask"] for sample in batch] + if 'mask' in batch[0]: + output['mask'] = [sample['mask'] for sample in batch] return output diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 1c35834b7e5..49e7b90241b 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -72,10 +72,10 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.kwargs["transforms"] = K.AugmentationSequential( + self.kwargs['transforms'] = K.AugmentationSequential( K.Resize(self.patch_size), data_keys=None, keepdim=True ) - self.kwargs["transforms"].keepdim = True + self.kwargs['transforms'].keepdim = True self.dataset = VHR10(**self.kwargs) generator = torch.Generator().manual_seed(0) self.train_dataset, self.val_dataset, self.test_dataset = random_split( diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index d38a221e531..e3738963508 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -100,15 +100,15 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and labels at that index """ - image = self._load_image(self.files[index]["image"]) - boxes = self._load_target(self.files[index]["target"]) - sample = {"image": image, "bbox_xyxy": boxes} + image = self._load_image(self.files[index]['image']) + boxes = self._load_target(self.files[index]['target']) + sample = {'image': image, 'bbox_xyxy': boxes} # Filter invalid boxes - w_check = (sample["bbox_xyxy"][:, 2] - sample["bbox_xyxy"][:, 0]) > 0 - h_check = (sample["bbox_xyxy"][:, 3] - sample["bbox_xyxy"][:, 1]) > 0 + w_check = (sample['bbox_xyxy'][:, 2] - sample['bbox_xyxy'][:, 0]) > 0 + h_check = (sample['bbox_xyxy'][:, 3] - sample['bbox_xyxy'][:, 1]) > 0 indices = w_check & h_check - sample["bbox_xyxy"] = sample["bbox_xyxy"][indices] + sample['bbox_xyxy'] = sample['bbox_xyxy'][indices] if self.transforms is not None: sample = self.transforms(sample) @@ -234,11 +234,11 @@ def plot( """ ncols = 1 - sample["image"] = sample["image"].byte() - image = sample["image"] - if "bbox_xyxy" in sample and len(sample["bbox_xyxy"]): + sample['image'] = sample['image'].byte() + image = sample['image'] + if 'bbox_xyxy' in sample and len(sample['bbox_xyxy']): image = draw_bounding_boxes( - image=sample["image"], boxes=sample["bbox_xyxy"] + image=sample['image'], boxes=sample['bbox_xyxy'] ) image = image.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index aab66dbe7ee..078a20912a5 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -249,10 +249,10 @@ def __getitem__(self, index: int) -> dict[str, Any]: if sample['label']['annotations']: sample = self.coco_convert(sample) - sample["class"] = sample["label"]["labels"] - sample["bbox_xyxy"] = sample["label"]["boxes"] - sample["mask"] = sample["label"]["masks"].float() - del sample["label"] + sample['class'] = sample['label']['labels'] + sample['bbox_xyxy'] = sample['label']['boxes'] + sample['mask'] = sample['label']['masks'].float() + del sample['label'] if self.transforms is not None: sample = self.transforms(sample) @@ -401,11 +401,11 @@ def plot( if show_feats != 'boxes': skimage = lazy_import('skimage') - image = sample["image"].permute(1, 2, 0).numpy() - boxes = sample["bbox_xyxy"].cpu().numpy() - labels = sample["class"].cpu().numpy() - if "mask" in sample: - masks = [mask.squeeze().cpu().numpy() for mask in sample["mask"]] + boxes = sample['bbox_xyxy'].cpu().numpy() + labels = sample['class'].cpu().numpy() + + if 'mask' in sample: + masks = [mask.squeeze().cpu().numpy() for mask in sample['mask']] n_gt = len(boxes) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 4269da6a582..673d266de26 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -239,10 +239,10 @@ def training_step( batch_size = x.shape[0] # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} for key in batch.keys(): - if key in {"bbox", "bbox_xyxy", "bbox_xywh"}: + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: bbox_key = key y = [ - {"boxes": batch[bbox_key][i], "labels": batch["class"][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] loss_dict = self(x, y) @@ -264,10 +264,10 @@ def validation_step( batch_size = x.shape[0] # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} for key in batch.keys(): - if key in {"bbox", "bbox_xyxy", "bbox_xywh"}: + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: bbox_key = key y = [ - {"boxes": batch[bbox_key][i], "labels": batch["class"][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x) @@ -322,10 +322,10 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None batch_size = x.shape[0] # Get bbox key as it can be one of {"bbox", "bbox_xyxy", "bbox_xywh"} for key in batch.keys(): - if key in {"bbox", "bbox_xyxy", "bbox_xywh"}: + if key in {'bbox', 'bbox_xyxy', 'bbox_xywh'}: bbox_key = key y = [ - {"boxes": batch[bbox_key][i], "labels": batch["class"][i]} + {'boxes': batch[bbox_key][i], 'labels': batch['class'][i]} for i in range(batch_size) ] y_hat = self(x)