Skip to content

Commit

Permalink
update pytorch detr unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Farhan Ahmed <[email protected]>
  • Loading branch information
f4str committed Nov 9, 2023
1 parent 7256f9c commit aec22fc
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 237 deletions.
3 changes: 2 additions & 1 deletion art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def compute_losses(
loss_components, _ = self._get_losses(x=x, y=y)
output = {}
for key, value in loss_components.items():
output[key] = value.detach().cpu().numpy()
if key in self.attack_losses:
output[key] = value.detach().cpu().numpy()
return output

def compute_loss( # type: ignore
Expand Down
45 changes: 45 additions & 0 deletions tests/estimators/object_detection/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,48 @@ def forward(self, x, targets=None):
]

yield object_detector, x_test, y_test


@pytest.fixture()
def get_pytorch_detr(get_default_cifar10_subset):
import cv2

from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

object_detector = PyTorchDetectionTransformer(
input_shape=(3, 800, 800),
clip_values=(0, 1),
preprocessing=(MEAN, STD),
channels_first=True,
attack_losses=("loss_ce", "loss_bbox", "loss_giou"),
)

(_, _), (x_test_cifar10, _) = get_default_cifar10_subset

x_test = cv2.resize(
x_test_cifar10[0].transpose((1, 2, 0)), dsize=(800, 800), interpolation=cv2.INTER_CUBIC
).transpose((2, 0, 1))
x_test = np.expand_dims(x_test, axis=0)
x_test = np.repeat(x_test, repeats=2, axis=0)

# Create labels

result = object_detector.predict(x=x_test)

y_test = [
{
"boxes": result[0]["boxes"],
"labels": result[0]["labels"],
"scores": np.ones_like(result[0]["labels"]),
},
{
"boxes": result[1]["boxes"],
"labels": result[1]["labels"],
"scores": np.ones_like(result[1]["labels"]),
},
]

yield object_detector, x_test, y_test
Loading

0 comments on commit aec22fc

Please sign in to comment.