From aec22fc62f3f570e18cf6773999452853ad4c16c Mon Sep 17 00:00:00 2001 From: Farhan Ahmed Date: Wed, 8 Nov 2023 16:26:41 -0800 Subject: [PATCH] update pytorch detr unit tests Signed-off-by: Farhan Ahmed --- .../pytorch_object_detector.py | 3 +- tests/estimators/object_detection/conftest.py | 45 ++ .../test_pytorch_detection_transformer.py | 458 +++++++++--------- .../object_detection/test_pytorch_yolo.py | 1 - 4 files changed, 270 insertions(+), 237 deletions(-) diff --git a/art/estimators/object_detection/pytorch_object_detector.py b/art/estimators/object_detection/pytorch_object_detector.py index e6b1eebd23..bdaea4dbcb 100644 --- a/art/estimators/object_detection/pytorch_object_detector.py +++ b/art/estimators/object_detection/pytorch_object_detector.py @@ -544,7 +544,8 @@ def compute_losses( loss_components, _ = self._get_losses(x=x, y=y) output = {} for key, value in loss_components.items(): - output[key] = value.detach().cpu().numpy() + if key in self.attack_losses: + output[key] = value.detach().cpu().numpy() return output def compute_loss( # type: ignore diff --git a/tests/estimators/object_detection/conftest.py b/tests/estimators/object_detection/conftest.py index 5e4f600b71..8ca9c1c812 100644 --- a/tests/estimators/object_detection/conftest.py +++ b/tests/estimators/object_detection/conftest.py @@ -246,3 +246,48 @@ def forward(self, x, targets=None): ] yield object_detector, x_test, y_test + + +@pytest.fixture() +def get_pytorch_detr(get_default_cifar10_subset): + import cv2 + + from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer + + MEAN = [0.485, 0.456, 0.406] + STD = [0.229, 0.224, 0.225] + + object_detector = PyTorchDetectionTransformer( + input_shape=(3, 800, 800), + clip_values=(0, 1), + preprocessing=(MEAN, STD), + channels_first=True, + attack_losses=("loss_ce", "loss_bbox", "loss_giou"), + ) + + (_, _), (x_test_cifar10, _) = get_default_cifar10_subset + + x_test = cv2.resize( + x_test_cifar10[0].transpose((1, 2, 0)), dsize=(800, 800), interpolation=cv2.INTER_CUBIC + ).transpose((2, 0, 1)) + x_test = np.expand_dims(x_test, axis=0) + x_test = np.repeat(x_test, repeats=2, axis=0) + + # Create labels + + result = object_detector.predict(x=x_test) + + y_test = [ + { + "boxes": result[0]["boxes"], + "labels": result[0]["labels"], + "scores": np.ones_like(result[0]["labels"]), + }, + { + "boxes": result[1]["boxes"], + "labels": result[1]["labels"], + "scores": np.ones_like(result[1]["labels"]), + }, + ] + + yield object_detector, x_test, y_test diff --git a/tests/estimators/object_detection/test_pytorch_detection_transformer.py b/tests/estimators/object_detection/test_pytorch_detection_transformer.py index 308d119c1f..07ccaaa124 100644 --- a/tests/estimators/object_detection/test_pytorch_detection_transformer.py +++ b/tests/estimators/object_detection/test_pytorch_detection_transformer.py @@ -22,287 +22,275 @@ import numpy as np import pytest +from tests.utils import ARTTestException + logger = logging.getLogger(__name__) -@pytest.fixture() -@pytest.mark.skip_framework("tensorflow", "tensorflow2v1", "keras", "kerastf", "mxnet", "non_dl_frameworks") -def get_pytorch_detr(): - from art.utils import load_dataset - from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer +@pytest.mark.only_with_platform("pytorch") +def test_predict(art_warning, get_pytorch_detr): + try: + object_detector, x_test, _ = get_pytorch_detr + + result = object_detector.predict(x=x_test) + + assert list(result[0].keys()) == ["boxes", "labels", "scores"] + + assert result[0]["boxes"].shape == (100, 4) + expected_detection_boxes = np.asarray([-0.12423098, 361.80136, 82.385345, 795.50305]) + np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=1) + + assert result[0]["scores"].shape == (100,) + expected_detection_scores = np.asarray( + [ + 0.00105285, + 0.00261505, + 0.00060220, + 0.00121928, + 0.00154554, + 0.00021678, + 0.00077083, + 0.00045684, + 0.00180561, + 0.00067704, + ] + ) + np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=1) - MEAN = [0.485, 0.456, 0.406] - STD = [0.229, 0.224, 0.225] - INPUT_SHAPE = (3, 32, 32) + assert result[0]["labels"].shape == (100,) + expected_detection_classes = np.asarray([1, 23, 23, 1, 1, 23, 23, 23, 1, 1]) + np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=1) - object_detector = PyTorchDetectionTransformer( - input_shape=INPUT_SHAPE, clip_values=(0, 1), preprocessing=(MEAN, STD) - ) + except ARTTestException as e: + art_warning(e) - n_test = 2 - (_, _), (x_test, y_test), _, _ = load_dataset("cifar10") - x_test = x_test.transpose(0, 3, 1, 2).astype(np.float32) - x_test = x_test[:n_test] - # Create labels +@pytest.mark.only_with_platform("pytorch") +def test_fit(art_warning, get_pytorch_yolo): + try: + import torch - result = object_detector.predict(x=x_test) + object_detector, x_test, y_test = get_pytorch_yolo - y_test = [ - { - "boxes": result[0]["boxes"], - "labels": result[0]["labels"], - "scores": np.ones_like(result[0]["labels"]), - }, - { - "boxes": result[1]["boxes"], - "labels": result[1]["labels"], - "scores": np.ones_like(result[1]["labels"]), - }, - ] + # Create optimizer + params = [p for p in object_detector.model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(params, lr=0.01) + object_detector.set_params(optimizer=optimizer) - yield object_detector, x_test, y_test + # Compute loss before training + loss1 = object_detector.compute_loss(x=x_test, y=y_test) + # Train for one epoch + object_detector.fit(x_test, y_test, nb_epochs=1) -@pytest.mark.only_with_platform("pytorch") -def test_predict(get_pytorch_detr): - - object_detector, x_test, _ = get_pytorch_detr - - result = object_detector.predict(x=x_test) - - assert list(result[0].keys()) == ["boxes", "labels", "scores"] - - assert result[0]["boxes"].shape == (100, 4) - expected_detection_boxes = np.asarray([-5.9490204e-03, 1.1947733e01, 3.1993944e01, 3.1925127e01]) - np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=1) - - assert result[0]["scores"].shape == (100,) - expected_detection_scores = np.asarray( - [ - 0.00679839, - 0.0250559, - 0.07205943, - 0.01115368, - 0.03321039, - 0.10407761, - 0.00113309, - 0.01442852, - 0.00527624, - 0.01240906, - ] - ) - np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=1) + # Compute loss after training + loss2 = object_detector.compute_loss(x=x_test, y=y_test) - assert result[0]["labels"].shape == (100,) - expected_detection_classes = np.asarray([17, 17, 33, 17, 17, 17, 74, 17, 17, 17]) - np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=5) + assert loss1 != loss2 - -@pytest.mark.only_with_platform("pytorch") -def test_loss_gradient(get_pytorch_detr): - - object_detector, x_test, y_test = get_pytorch_detr - - grads = object_detector.loss_gradient(x=x_test, y=y_test) - - assert grads.shape == (2, 3, 800, 800) - - expected_gradients1 = np.asarray( - [ - -0.00061366, - 0.00322502, - -0.00039866, - -0.00807413, - -0.00476555, - 0.00181204, - 0.01007765, - 0.00415828, - -0.00073114, - 0.00018387, - -0.00146992, - -0.00119636, - -0.00098966, - -0.00295517, - -0.0024271, - -0.00131314, - -0.00149217, - -0.00104926, - -0.00154239, - -0.00110989, - 0.00092887, - 0.00049146, - -0.00292508, - -0.00124526, - 0.00140347, - 0.00019833, - 0.00191074, - -0.00117537, - -0.00080604, - 0.00057427, - -0.00061728, - -0.00206535, - ] - ) - - np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=2) - - expected_gradients2 = np.asarray( - [ - -1.1787530e-03, - -2.8500680e-03, - 5.0884970e-03, - 6.4504531e-04, - -6.8841036e-05, - 2.8184296e-03, - 3.0257765e-03, - 2.8565727e-04, - -1.0701057e-04, - 1.2945699e-03, - 7.3593057e-04, - 1.0177144e-03, - -2.4692707e-03, - -1.3801848e-03, - 6.3182280e-04, - -4.2305476e-04, - 4.4307750e-04, - 8.5821096e-04, - -7.1204413e-04, - -3.1404425e-03, - -1.5964351e-03, - -1.9222996e-03, - -5.3157361e-04, - -9.9202688e-04, - -1.5815455e-03, - 2.0060266e-04, - -2.0584739e-03, - 6.6960667e-04, - 9.7393827e-04, - -1.6040013e-03, - -6.9741381e-04, - 1.4657658e-04, - ] - ) - np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2) + except ARTTestException as e: + art_warning(e) @pytest.mark.only_with_platform("pytorch") -def test_errors(): - - from torch import hub +def test_loss_gradient(art_warning, get_pytorch_detr): + try: + object_detector, x_test, y_test = get_pytorch_detr + + grads = object_detector.loss_gradient(x=x_test, y=y_test) + + assert grads.shape == (2, 3, 800, 800) + + expected_gradients1 = np.asarray( + [ + -0.00757495, + -0.00101332, + 0.00368362, + 0.00283334, + -0.00096027, + 0.00873749, + 0.00546095, + -0.00823532, + -0.00710872, + 0.00389713, + -0.00966289, + 0.00448294, + 0.00754991, + -0.00934104, + -0.00350194, + -0.00541577, + -0.00395624, + 0.00147651, + 0.0105616, + 0.01231265, + -0.00148831, + -0.0043609, + 0.00093031, + 0.00884939, + -0.00356749, + 0.00093475, + -0.00353712, + -0.0060132, + -0.00067899, + -0.00886974, + 0.00108483, + -0.00052412, + ] + ) - from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer + np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=2) + + expected_gradients2 = np.asarray( + [ + -0.00757495, + -0.00101332, + 0.00368362, + 0.00283334, + -0.00096027, + 0.00873749, + 0.00546095, + -0.00823532, + -0.00710872, + 0.00389713, + -0.00966289, + 0.00448294, + 0.00754991, + -0.00934104, + -0.00350194, + -0.00541577, + -0.00395624, + 0.00147651, + 0.0105616, + 0.01231265, + -0.00148831, + -0.0043609, + 0.00093031, + 0.00884939, + -0.00356749, + 0.00093475, + -0.00353712, + -0.0060132, + -0.00067899, + -0.00886974, + 0.00108483, + -0.00052412, + ] + ) + np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2) - model = hub.load("facebookresearch/detr", "detr_resnet50", pretrained=True) + except ARTTestException as e: + art_warning(e) - with pytest.raises(ValueError): - PyTorchDetectionTransformer( - model=model, - clip_values=(1, 2), - attack_losses=("loss_ce", "loss_bbox", "loss_giou"), - ) - with pytest.raises(ValueError): - PyTorchDetectionTransformer( - model=model, - clip_values=(-1, 1), - attack_losses=("loss_ce", "loss_bbox", "loss_giou"), - ) +@pytest.mark.only_with_platform("pytorch") +def test_errors(art_warning): + try: + from torch import hub - from art.defences.postprocessor.rounded import Rounded + from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer - post_def = Rounded() - with pytest.raises(ValueError): - PyTorchDetectionTransformer( - model=model, - clip_values=(0, 1), - attack_losses=("loss_ce", "loss_bbox", "loss_giou"), - postprocessing_defences=post_def, - ) + model = hub.load("facebookresearch/detr", "detr_resnet50", pretrained=True) + with pytest.raises(ValueError): + PyTorchDetectionTransformer( + model=model, + clip_values=(1, 2), + attack_losses=("loss_ce", "loss_bbox", "loss_giou"), + ) -@pytest.mark.only_with_platform("pytorch") -def test_preprocessing_defences(get_pytorch_detr): + with pytest.raises(ValueError): + PyTorchDetectionTransformer( + model=model, + clip_values=(-1, 1), + attack_losses=("loss_ce", "loss_bbox", "loss_giou"), + ) - object_detector, x_test, _ = get_pytorch_detr + from art.defences.postprocessor.rounded import Rounded - from art.defences.preprocessor.spatial_smoothing_pytorch import SpatialSmoothingPyTorch + post_def = Rounded() + with pytest.raises(ValueError): + PyTorchDetectionTransformer( + model=model, + clip_values=(0, 1), + attack_losses=("loss_ce", "loss_bbox", "loss_giou"), + postprocessing_defences=post_def, + ) - pre_def = SpatialSmoothingPyTorch() + except ARTTestException as e: + art_warning(e) - object_detector.set_params(preprocessing_defences=pre_def) - # Create labels - result = object_detector.predict(x=x_test) +@pytest.mark.only_with_platform("pytorch") +def test_preprocessing_defences(art_warning, get_pytorch_detr): + try: + object_detector, x_test, _ = get_pytorch_detr + + from art.defences.preprocessor.spatial_smoothing_pytorch import SpatialSmoothingPyTorch + + pre_def = SpatialSmoothingPyTorch() + + object_detector.set_params(preprocessing_defences=pre_def) + + # Create labels + result = object_detector.predict(x=x_test) + + y = [ + { + "boxes": result[0]["boxes"], + "labels": result[0]["labels"], + "scores": np.ones_like(result[0]["labels"]), + }, + { + "boxes": result[1]["boxes"], + "labels": result[1]["labels"], + "scores": np.ones_like(result[1]["labels"]), + }, + ] - y = [ - { - "boxes": result[0]["boxes"], - "labels": result[0]["labels"], - "scores": np.ones_like(result[0]["labels"]), - }, - { - "boxes": result[1]["boxes"], - "labels": result[1]["labels"], - "scores": np.ones_like(result[1]["labels"]), - }, - ] + # Compute gradients + grads = object_detector.loss_gradient(x=x_test, y=y) - # Compute gradients - grads = object_detector.loss_gradient(x=x_test, y=y) + assert grads.shape == (2, 3, 800, 800) - assert grads.shape == (2, 3, 800, 800) + except ARTTestException as e: + art_warning(e) @pytest.mark.only_with_platform("pytorch") -def test_compute_losses(get_pytorch_detr): +def test_compute_losses(art_warning, get_pytorch_detr): + try: + object_detector, x_test, y_test = get_pytorch_detr + losses = object_detector.compute_losses(x=x_test, y=y_test) + assert len(losses) == 3 - object_detector, x_test, y_test = get_pytorch_detr - losses = object_detector.compute_losses(x=x_test, y=y_test) - assert len(losses) == 3 + except ARTTestException as e: + art_warning(e) @pytest.mark.only_with_platform("pytorch") -def test_compute_loss(get_pytorch_detr): +def test_compute_loss(art_warning, get_pytorch_detr): + try: + object_detector, x_test, y_test = get_pytorch_detr - object_detector, x_test, _ = get_pytorch_detr - # Create labels - result = object_detector.predict(x_test) + # Compute loss + loss = object_detector.compute_loss(x=x_test, y=y_test) - y = [ - { - "boxes": result[0]["boxes"], - "labels": result[0]["labels"], - "scores": np.ones_like(result[0]["labels"]), - }, - { - "boxes": result[1]["boxes"], - "labels": result[1]["labels"], - "scores": np.ones_like(result[1]["labels"]), - }, - ] + assert pytest.approx(6.7767677, abs=0.1) == float(loss) - # Compute loss - loss = object_detector.compute_loss(x=x_test, y=y) - - assert pytest.approx(3.9634, abs=0.01) == float(loss) + except ARTTestException as e: + art_warning(e) @pytest.mark.only_with_platform("pytorch") -def test_pgd(get_pytorch_detr): - - object_detector, x_test, y_test = get_pytorch_detr +def test_pgd(art_warning, get_pytorch_detr): + try: + from art.attacks.evasion import ProjectedGradientDescent - from art.attacks.evasion import ProjectedGradientDescent - from PIL import Image + object_detector, x_test, y_test = get_pytorch_detr - imgs = [] - for i in x_test: - img = Image.fromarray((i * 255).astype(np.uint8).transpose(1, 2, 0)) - img = img.resize(size=(800, 800)) - imgs.append(np.array(img)) - x_test = np.array(imgs).transpose(0, 3, 1, 2) + attack = ProjectedGradientDescent(estimator=object_detector, max_iter=2) + x_test_adv = attack.generate(x=x_test, y=y_test) + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, x_test_adv, x_test) - attack = ProjectedGradientDescent(estimator=object_detector, max_iter=2) - x_test_adv = attack.generate(x=x_test, y=y_test) - np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, x_test_adv, x_test) + except ARTTestException as e: + art_warning(e) diff --git a/tests/estimators/object_detection/test_pytorch_yolo.py b/tests/estimators/object_detection/test_pytorch_yolo.py index a4d88e11bf..13c70ba92f 100644 --- a/tests/estimators/object_detection/test_pytorch_yolo.py +++ b/tests/estimators/object_detection/test_pytorch_yolo.py @@ -67,7 +67,6 @@ def test_predict(art_warning, get_pytorch_yolo): @pytest.mark.only_with_platform("pytorch") def test_fit(art_warning, get_pytorch_yolo): - try: object_detector, x_test, y_test = get_pytorch_yolo