Skip to content

Commit

Permalink
Switch class -> label
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jan 6, 2025
1 parent cdcf552 commit abc1d05
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_getitem(self, dataset: VHR10) -> None:
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
if dataset.split == 'positive':
assert isinstance(x['class'], torch.Tensor)
assert isinstance(x['label'], torch.Tensor)
assert isinstance(x['bbox_xyxy'], torch.Tensor)
if 'mask' in x:
assert isinstance(x['mask'], torch.Tensor)
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_plot(self, dataset: VHR10) -> None:
scores = [0.7, 0.3, 0.7]
for i in range(3):
x = dataset[i]
x['prediction_labels'] = x['class']
x['prediction_labels'] = x['label']
x['prediction_boxes'] = x['bbox_xyxy']
x['prediction_scores'] = torch.Tensor([scores[i]])
if 'mask' in x:
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def collate_fn_detection(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
output: dict[str, Any] = {}
output['image'] = torch.stack([sample['image'] for sample in batch])
output['bbox_xyxy'] = [sample['bbox_xyxy'].float() for sample in batch]
if 'class' in batch[0].keys():
output['class'] = [sample['class'] for sample in batch]
if 'label' in batch[0].keys():
output['label'] = [sample['label'] for sample in batch]
else:
output['class'] = [
output['label'] = [
torch.tensor([1] * len(sample['bbox_xyxy'])) for sample in batch
]

Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
sample['class'] = sample['label']['labels']
sample['bbox_xyxy'] = sample['label']['boxes']
sample['mask'] = sample['label']['masks'].float()
del sample['label']
sample['label'] = sample.pop('class')

if self.transforms is not None:
sample = self.transforms(sample)
Expand Down Expand Up @@ -402,7 +402,7 @@ def plot(

image = sample['image'].permute(1, 2, 0).numpy()
boxes = sample['bbox_xyxy'].cpu().numpy()
labels = sample['class'].cpu().numpy()
labels = sample['label'].cpu().numpy()
if 'mask' in sample:
masks = [mask.squeeze().cpu().numpy() for mask in sample['mask']]

Expand Down
6 changes: 3 additions & 3 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def training_step(
batch_size = x.shape[0]
assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.'
y = [
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]}
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]}
for i in range(batch_size)
]
loss_dict = self(x, y)
Expand All @@ -262,7 +262,7 @@ def validation_step(
batch_size = x.shape[0]
assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.'
y = [
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]}
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]}
for i in range(batch_size)
]
y_hat = self(x)
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_size = x.shape[0]
assert 'bbox_xyxy' in batch, 'bbox_xyxy is required for object detection.'
y = [
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['class'][i]}
{'boxes': batch['bbox_xyxy'][i], 'labels': batch['label'][i]}
for i in range(batch_size)
]
y_hat = self(x)
Expand Down

0 comments on commit abc1d05

Please sign in to comment.