Skip to content

Commit

Permalink
Merge branch 'dev_1.17.0' into dev_1.17.0_no_labels_mi
Browse files Browse the repository at this point in the history
  • Loading branch information
beat-buesser authored Dec 18, 2023
2 parents b2deaee + cc03386 commit f4a4fa6
Show file tree
Hide file tree
Showing 7 changed files with 1,186 additions and 17 deletions.
1 change: 1 addition & 0 deletions art/attacks/evasion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from art.attacks.evasion.brendel_bethge import BrendelBethgeAttack

from art.attacks.evasion.boundary import BoundaryAttack
from art.attacks.evasion.composite_adversarial_attack import CompositeAdversarialAttackPyTorch
from art.attacks.evasion.carlini import CarliniL2Method, CarliniLInfMethod, CarliniL0Method
from art.attacks.evasion.decision_tree_attack import DecisionTreeAttack
from art.attacks.evasion.deepfool import DeepFool
Expand Down
673 changes: 673 additions & 0 deletions art/attacks/evasion/composite_adversarial_attack.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions art/estimators/classification/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def _make_model_wrapper(self, model: "torch.nn.Module") -> "torch.nn.Module":

input_shape = self._input_shape
input_for_hook = torch.rand(input_shape)
# self.device may not match the device the raw model was passed into ART.
# Check if the model is on cuda, if so set the hook input accordingly
if next(model.parameters()).is_cuda:
cuda_idx = torch.cuda.current_device()
input_for_hook = input_for_hook.to(torch.device(f"cuda:{cuda_idx}"))

input_for_hook = torch.unsqueeze(input_for_hook, dim=0)

if self.processor is not None:
Expand Down
5 changes: 5 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ demonstrates a MembershipInferenceBlackBox membership inference attack using sha
[label_only_membership_inference.ipynb](label_only_membership_inference.ipynb) [[on nbviewer](https://nbviewer.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/label_only_membership_inference.ipynb)]
demonstrates a LabelOnlyDecisionBoundary membership inference attack on a PyTorch classifier for the MNIST dataset.

[composite-adversarial-attack.ipynb](composite-adversarial-attack.ipynb)[[on nbviewer](https://nbviewer.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/composite-adversarial-attack.ipynb)]
shows how to launch Composite Adversarial Attack (CAA) on Pytorch-based model ([Hsiung et al., 2023](https://arxiv.org/abs/2202.04235)).
CAA composites the perturbations in Lp-ball and semantic space (i.e., hue, saturation, rotation, brightness, and contrast),
and is able to optimize the attack sequence and each attack component, thereby enhancing the efficiency and efficacy of adversarial examples.

## Metrics

[privacy_metric.ipynb](privacy_metric.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/privacy_metric.ipynb)]
Expand Down
290 changes: 290 additions & 0 deletions notebooks/composite-adversarial-attack.ipynb

Large diffs are not rendered by default.

36 changes: 19 additions & 17 deletions notebooks/huggingface_notebook.ipynb

Large diffs are not rendered by default.

192 changes: 192 additions & 0 deletions tests/attacks/evasion/test_composite_adversarial_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import logging

import numpy as np
import pytest

from art.attacks.evasion import CompositeAdversarialAttackPyTorch
from art.estimators.estimator import BaseEstimator, LossGradientsMixin
from art.estimators.classification.classifier import ClassifierMixin

from tests.attacks.utils import backend_test_classifier_type_check_fail
from tests.utils import ARTTestException, get_cifar10_image_classifier_pt

logger = logging.getLogger(__name__)


@pytest.fixture()
def fix_get_cifar10_subset(get_cifar10_dataset):
(x_train_cifar10, y_train_cifar10), (x_test_cifar10, y_test_cifar10) = get_cifar10_dataset
n_train = 100
n_test = 11
yield x_train_cifar10[:n_train], y_train_cifar10[:n_train], x_test_cifar10[:n_test], y_test_cifar10[:n_test]


@pytest.mark.skip_framework(
"tensorflow1", "tensorflow2", "tensorflow2v1", "keras", "non_dl_frameworks", "mxnet", "kerastf", "huggingface"
)
def test_generate(art_warning, fix_get_cifar10_subset):
try:
(x_train, y_train, x_test, y_test) = fix_get_cifar10_subset

classifier = get_cifar10_image_classifier_pt(from_logits=False, load_init=True)
attack = CompositeAdversarialAttackPyTorch(classifier)

x_train_adv = attack.generate(x=x_train, y=y_train)
x_test_adv = attack.generate(x=x_test, y=y_test)

assert x_train.shape == x_train_adv.shape
assert np.min(x_train_adv) >= 0.0
assert np.max(x_train_adv) <= 1.0
assert x_test.shape == x_test_adv.shape
assert np.min(x_test_adv) >= 0.0
assert np.max(x_test_adv) <= 1.0

except ARTTestException as e:
art_warning(e)


@pytest.mark.skip_framework(
"tensorflow1", "tensorflow2", "tensorflow2v1", "keras", "non_dl_frameworks", "mxnet", "kerastf"
)
def test_check_params(art_warning):
try:
classifier = get_cifar10_image_classifier_pt(from_logits=False, load_init=True)

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, enabled_attack=(0, 1, 2, 3, 4, 5, 6, 7))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=(-10.0, 0.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=(0.0, 10.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=(-1, 2.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=3.14)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, hue_epsilon=("1.0", 2.0))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=(-10.0, 0.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=(0.0, -10.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=(1, 2.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=2.0)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, sat_epsilon=("1.0", 2.0))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=(-450.0, 359.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=(10.0, -10.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=(1.0, 2))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=10)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, rot_epsilon=("10", 20.0))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=(-10.0, 0.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=(0.0, 10.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=(-1, 1.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=1.0)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, bri_epsilon=("1.0", 2.0))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=(-10.0, 10.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=(0.0, -10.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=(1, 2.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=2.0)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, con_epsilon=("1.0", 2.0))

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=(-0.5, 2.0))
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=(8 / 255, -8 / 255))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=(-2, 1))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=8 / 255)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=(0.0, 10.0, 20.0))
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, pgd_epsilon=("2/255", 3 / 255))

with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, early_stop="true")
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, early_stop=1)

with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_iter="max")
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_iter=-5)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_iter=2.5)

with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_inner_iter="max")
with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_inner_iter=-5)
with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, max_inner_iter=2.5)

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, attack_order="schedule")

with pytest.raises(ValueError):
_ = CompositeAdversarialAttackPyTorch(classifier, batch_size=-1)

with pytest.raises(TypeError):
_ = CompositeAdversarialAttackPyTorch(classifier, verbose="true")

except ARTTestException as e:
art_warning(e)


@pytest.mark.framework_agnostic
def test_classifier_type_check_fail(art_warning):
try:
backend_test_classifier_type_check_fail(
CompositeAdversarialAttackPyTorch, [BaseEstimator, LossGradientsMixin, ClassifierMixin]
)
except ARTTestException as e:
art_warning(e)

0 comments on commit f4a4fa6

Please sign in to comment.