Skip to content

Commit

Permalink
Switch prediction_boxes -> prediction_bbox_xyxy
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jan 6, 2025
1 parent abc1d05 commit 2e15665
Show file tree
Hide file tree
Showing 13 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion tests/datasets/test_fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,6 @@ def test_plot(self, dataset: FAIR1M) -> None:
plt.close()

if dataset.split != 'test':
x['prediction_boxes'] = x['bbox_xyxy'].clone()
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
dataset.plot(x)
plt.close()
2 changes: 1 addition & 1 deletion tests/datasets/test_forestdamage.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def test_plot(self, dataset: ForestDamage) -> None:

def test_plot_prediction(self, dataset: ForestDamage) -> None:
x = dataset[0].copy()
x['prediction_boxes'] = x['bbox_xyxy'].clone()
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()
2 changes: 1 addition & 1 deletion tests/datasets/test_idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_plot(self, dataset: IDTReeS) -> None:
plt.close()

if 'bbox_xyxy' in x:
x['prediction_boxes'] = x['bbox_xyxy']
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
dataset.plot(x, show_titles=True)
plt.close()
if 'label' in x:
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None:
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
x['prediction_boxes'] = x['bbox_xyxy'].clone()
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
dataset.plot(x)
plt.close()
2 changes: 1 addition & 1 deletion tests/datasets/test_reforestree.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def test_plot(self, dataset: ReforesTree) -> None:

def test_plot_prediction(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()
x['prediction_boxes'] = x['bbox_xyxy'].clone()
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
dataset.plot(x, suptitle='Prediction')
plt.close()
2 changes: 1 addition & 1 deletion tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_plot(self, dataset: VHR10) -> None:
for i in range(3):
x = dataset[i]
x['prediction_labels'] = x['label']
x['prediction_boxes'] = x['bbox_xyxy']
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
x['prediction_scores'] = torch.Tensor([scores[i]])
if 'mask' in x:
x['prediction_masks'] = x['mask']
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def plot(
image = sample['image'].permute((1, 2, 0)).numpy()

ncols = 1
if 'prediction_boxes' in sample:
if 'prediction_bbox_xyxy' in sample:
ncols += 1

fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
Expand All @@ -427,7 +427,7 @@ def plot(
axs[1].axis('off')
polygons = [
patches.Polygon(points, color='r', fill=False)
for points in sample['prediction_boxes'].numpy()
for points in sample['prediction_bbox_xyxy'].numpy()
]
for polygon in polygons:
axs[0].add_patch(polygon)
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/forestdamage.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def plot(
image = sample['image'].permute((1, 2, 0)).numpy()

ncols = 1
showing_predictions = 'prediction_boxes' in sample
showing_predictions = 'prediction_bbox_xyxy' in sample
if showing_predictions:
ncols += 1

Expand Down Expand Up @@ -305,7 +305,7 @@ def plot(
edgecolor='r',
facecolor='none',
)
for bbox in sample['prediction_boxes'].numpy()
for bbox in sample['prediction_bbox_xyxy'].numpy()
]
for bbox in pred_bboxes:
axs[1].add_patch(bbox)
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,15 +521,15 @@ def normalize(x: Tensor) -> Tensor:
else:
image = sample['image'].permute((1, 2, 0)).numpy()

if 'prediction_boxes' in sample and len(sample['prediction_boxes']):
if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']):
ncols += 1
labels = (
[self.idx2class[int(i)] for i in sample['prediction_label']]
if 'prediction_label' in sample
else None
)
preds = draw_bounding_boxes(
image=sample['image'], boxes=sample['prediction_boxes'], labels=labels
image=sample['image'], boxes=sample['prediction_bbox_xyxy'], labels=labels
)
preds = preds.permute((1, 2, 0)).numpy()

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/nasa_marine_debris.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def plot(
)
image_arr = image.permute((1, 2, 0)).numpy()

if 'prediction_boxes' in sample and len(sample['prediction_boxes']):
if 'prediction_bbox_xyxy' in sample and len(sample['prediction_bbox_xyxy']):
ncols += 1
preds = draw_bounding_boxes(
image=sample['image'], boxes=sample['prediction_boxes']
image=sample['image'], boxes=sample['prediction_bbox_xyxy']
)
preds_arr = preds.permute((1, 2, 0)).numpy()

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/reforestree.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def plot(
"""
image = sample['image'].permute((1, 2, 0)).numpy()
ncols = 1
showing_predictions = 'prediction_boxes' in sample
showing_predictions = 'prediction_bbox_xyxy' in sample
if showing_predictions:
ncols += 1

Expand Down Expand Up @@ -260,7 +260,7 @@ def plot(
edgecolor='r',
facecolor='none',
)
for bbox in sample['prediction_boxes'].numpy()
for bbox in sample['prediction_bbox_xyxy'].numpy()
]
for bbox in pred_bboxes:
axs[1].add_patch(bbox)
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def plot(
show_pred_masks = False
prediction_labels = sample['prediction_labels'].numpy()
prediction_scores = sample['prediction_scores'].numpy()
if 'prediction_boxes' in sample:
prediction_boxes = sample['prediction_boxes'].numpy()
if 'prediction_bbox_xyxy' in sample:
prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy()
show_pred_boxes = True
if 'prediction_masks' in sample:
prediction_masks = sample['prediction_masks'].numpy()
Expand Down Expand Up @@ -485,7 +485,7 @@ def plot(

if show_pred_boxes:
# Add bounding boxes
x1, y1, x2, y2 = prediction_boxes[i]
x1, y1, x2, y2 = prediction_bbox_xyxy[i]
r = patches.Rectangle(
(x1, y1),
x2 - x1,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def validation_step(
and hasattr(self.logger.experiment, 'add_figure')
):
datamodule = self.trainer.datamodule
batch['prediction_boxes'] = [b['boxes'].cpu() for b in y_hat]
batch['prediction_bbox_xyxy'] = [b['boxes'].cpu() for b in y_hat]
batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat]
batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat]
batch['image'] = batch['image'].cpu()
Expand Down

0 comments on commit 2e15665

Please sign in to comment.