diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index fec6d309ab1..ac0a3048bb0 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -124,6 +124,6 @@ def test_plot(self, dataset: FAIR1M) -> None: plt.close() if dataset.split != 'test': - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 8e8fd7c1ab9..7ae2ebb9eba 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -67,6 +67,6 @@ def test_plot(self, dataset: ForestDamage) -> None: def test_plot_prediction(self, dataset: ForestDamage) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index 6e378c81c0c..39350bc630c 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -88,7 +88,7 @@ def test_plot(self, dataset: IDTReeS) -> None: plt.close() if 'bbox_xyxy' in x: - x['prediction_boxes'] = x['bbox_xyxy'] + x['prediction_bbox_xyxy'] = x['bbox_xyxy'] dataset.plot(x, show_titles=True) plt.close() if 'label' in x: diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index e2544951c87..d5f94a20816 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -50,6 +50,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index e3cf9948d82..e7c1079b0a3 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -67,6 +67,6 @@ def test_plot(self, dataset: ReforesTree) -> None: def test_plot_prediction(self, dataset: ReforesTree) -> None: x = dataset[0].copy() - x['prediction_boxes'] = x['bbox_xyxy'].clone() + x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone() dataset.plot(x, suptitle='Prediction') plt.close() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index 7ac91c31b20..1d3b8eac440 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -83,7 +83,7 @@ def test_plot(self, dataset: VHR10) -> None: for i in range(3): x = dataset[i] x['prediction_labels'] = x['label'] - x['prediction_boxes'] = x['bbox_xyxy'] + x['prediction_bbox_xyxy'] = x['bbox_xyxy'] x['prediction_scores'] = torch.Tensor([scores[i]]) if 'mask' in x: x['prediction_masks'] = x['mask'] diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index d01267d9d85..d8cb5b56d05 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -401,7 +401,7 @@ def plot( image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - if 'prediction_boxes' in sample: + if 'prediction_bbox_xyxy' in sample: ncols += 1 fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) @@ -427,7 +427,7 @@ def plot( axs[1].axis('off') polygons = [ patches.Polygon(points, color='r', fill=False) - for points in sample['prediction_boxes'].numpy() + for points in sample['prediction_bbox_xyxy'].numpy() ] for polygon in polygons: axs[0].add_patch(polygon) diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index b7ea83504aa..d4f8e487b81 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -264,7 +264,7 @@ def plot( image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = 'prediction_boxes' in sample + showing_predictions = 'prediction_bbox_xyxy' in sample if showing_predictions: ncols += 1 @@ -305,7 +305,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['prediction_boxes'].numpy() + for bbox in sample['prediction_bbox_xyxy'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 46e1a9084d4..0b2f0283805 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -521,7 +521,7 @@ def normalize(x: Tensor) -> Tensor: else: image = sample['image'].permute((1, 2, 0)).numpy() - if 'prediction_boxes' in sample and len(sample['prediction_boxes']): + if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']): ncols += 1 labels = ( [self.idx2class[int(i)] for i in sample['prediction_label']] @@ -529,7 +529,7 @@ def normalize(x: Tensor) -> Tensor: else None ) preds = draw_bounding_boxes( - image=sample['image'], boxes=sample['prediction_boxes'], labels=labels + image=sample['image'], boxes=sample['prediction_bbox_xyxy'], labels=labels ) preds = preds.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 791c3cae81c..62f8df37fc2 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -167,10 +167,10 @@ def plot( ) image_arr = image.permute((1, 2, 0)).numpy() - if 'prediction_boxes' in sample and len(sample['prediction_boxes']): + if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']): ncols += 1 preds = draw_bounding_boxes( - image=sample['image'], boxes=sample['prediction_boxes'] + image=sample['image'], boxes=sample['prediction_bbox_xyxy'] ) preds_arr = preds.permute((1, 2, 0)).numpy() diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 4caced4d477..0973442791a 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -219,7 +219,7 @@ def plot( """ image = sample['image'].permute((1, 2, 0)).numpy() ncols = 1 - showing_predictions = 'prediction_boxes' in sample + showing_predictions = 'prediction_bbox_xyxy' in sample if showing_predictions: ncols += 1 @@ -260,7 +260,7 @@ def plot( edgecolor='r', facecolor='none', ) - for bbox in sample['prediction_boxes'].numpy() + for bbox in sample['prediction_bbox_xyxy'].numpy() ] for bbox in pred_bboxes: axs[1].add_patch(bbox) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index f0a73723e80..5628bcd16d6 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -416,8 +416,8 @@ def plot( show_pred_masks = False prediction_labels = sample['prediction_labels'].numpy() prediction_scores = sample['prediction_scores'].numpy() - if 'prediction_boxes' in sample: - prediction_boxes = sample['prediction_boxes'].numpy() + if 'prediction_bbox_xyxy' in sample: + prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy() show_pred_boxes = True if 'prediction_masks' in sample: prediction_masks = sample['prediction_masks'].numpy() @@ -485,7 +485,7 @@ def plot( if show_pred_boxes: # Add bounding boxes - x1, y1, x2, y2 = prediction_boxes[i] + x1, y1, x2, y2 = prediction_bbox_xyxy[i] r = patches.Rectangle( (x1, y1), x2 - x1, diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 290b6917b04..f7f73542933 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -282,7 +282,7 @@ def validation_step( and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule - batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat] + batch['prediction_bbox_xyxy'] = [b['boxes'].cpu() for b in y_hat] batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat] batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat] batch['image'] = batch['image'].cpu()