Skip to content

Commit

Permalink
Address code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lei Hsiung <[email protected]>
  • Loading branch information
twweeb committed Sep 14, 2023
1 parent 54c5c79 commit cd803c3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 69 deletions.
2 changes: 1 addition & 1 deletion art/attacks/evasion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from art.attacks.evasion.brendel_bethge import BrendelBethgeAttack

from art.attacks.evasion.boundary import BoundaryAttack
from art.attacks.evasion.composite_adversarial_attack import CompositeAdversarialAttack
from art.attacks.evasion.composite_adversarial_attack import CompositeAdversarialAttackPyTorch
from art.attacks.evasion.carlini import CarliniL2Method, CarliniLInfMethod, CarliniL0Method
from art.attacks.evasion.decision_tree_attack import DecisionTreeAttack
from art.attacks.evasion.deepfool import DeepFool
Expand Down
186 changes: 144 additions & 42 deletions art/attacks/evasion/composite_adversarial_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import logging

from typing import Optional, Tuple, TYPE_CHECKING
from typing import Optional, Tuple, List, TYPE_CHECKING

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -45,19 +45,19 @@
if TYPE_CHECKING:
# pylint: disable=C0412
import torch
import torch.nn.functional as F
from art.estimators.classification.pytorch import PyTorchClassifier
from math import pi

from math import pi

logger = logging.getLogger(__name__)


class CompositeAdversarialAttackPyTorch(EvasionAttack):
"""
Implementation of the composite adversarial attack on image classifiers in PyTorch. The attack is constructed by adversarially
perturbing the hue component of the inputs. It uses the iterative gradient sign method to optimise the semantic
perturbations (see `FastGradientMethod` and `BasicIterativeMethod`). This implementation extends the original
optimisation method to other norms as well.
Implementation of the composite adversarial attack on image classifiers in PyTorch. The attack is constructed by
adversarially perturbing the hue component of the inputs. It uses order scheduling to search for the attack sequence
and uses the iterative gradient sign method to optimize the perturbations in semantic space and Lp-ball (see
`FastGradientMethod` and `BasicIterativeMethod`).
Note that this attack is intended for only PyTorch image classifiers with RGB images in the range [0, 1] as inputs.
Expand Down Expand Up @@ -86,12 +86,12 @@ def __init__(
classifier: "PyTorchClassifier",
enabled_attack: Tuple = (0, 1, 2, 3, 4, 5),
# Default: Full Attacks; 0: Hue, 1: Saturation, 2: Rotation, 3: Brightness, 4: Contrast, 5: PGD (L-infinity)
hue_epsilon: Tuple = (-pi, pi),
sat_epsilon: Tuple = (0.7, 1.3),
rot_epsilon: Tuple = (-10, 10),
bri_epsilon: Tuple = (-0.2, 0.2),
con_epsilon: Tuple = (0.7, 1.3),
pgd_epsilon: Tuple = (-8 / 255, 8 / 255), # L-infinity
hue_epsilon: List = [-pi, pi],
sat_epsilon: List = [0.7, 1.3],
rot_epsilon: List = [-10, 10],
bri_epsilon: List = [-0.2, 0.2],
con_epsilon: List = [0.7, 1.3],
pgd_epsilon: List = [-8 / 255, 8 / 255], # L-infinity
early_stop: bool = True,
max_iter: int = 5,
max_inner_iter: int = 10,
Expand All @@ -103,7 +103,11 @@ def __init__(
Create an instance of the :class:`.CompositeAdversarialAttackPyTorch`.
:param classifier: A trained PyTorch classifier.
:param enabled_attack: The norm of the adversarial perturbation. Possible values: `"inf"`, `np.inf`, `1` or `2`.
:param enabled_attack: Attack pool selection, and attack order designation for fixed order. For simplicity,
we use the following abbreviations to specify each attack types. 0: Hue, 1: Saturation,
2: Rotation, 3: Brightness, 4: Contrast, 5: PGD(L-infinity). Therefore, `(0,1,2)` means
that the attack combines hue, saturation, and rotation; `(0,1,2,3,4)` means the
semantic attacks; `(0,1,2,3,4,5)` means the full attacks.
:param hue_epsilon: The boundary of the hue perturbation. The value is expected to be in the interval
`[-pi, pi]`. Perturbation of `0` means no shift and `-pi` and `pi` give a complete reversal
of the hue channel in the HSV colour space in the positive and negative directions,
Expand Down Expand Up @@ -142,39 +146,86 @@ def __init__(
self.device = next(self.model.parameters()).device
self.fixed_order = enabled_attack
self.enabled_attack = tuple(sorted(enabled_attack))
self.seq_num = len(enabled_attack) # attack_num
self.epsilons = [hue_epsilon, sat_epsilon, rot_epsilon, bri_epsilon, con_epsilon, pgd_epsilon]
self.early_stop = early_stop
self.linf_idx = self.enabled_attack.index(5) if 5 in self.enabled_attack else None
self.eps_pool = torch.tensor(
[hue_epsilon, sat_epsilon, rot_epsilon, bri_epsilon, con_epsilon, pgd_epsilon], device=self.device)
self.attack_order = attack_order
self.max_inner_iter = max_inner_iter
self.max_iter = max_iter if self.attack_order == 'scheduled' else 1
self.max_inner_iter = max_inner_iter
self.targeted = False
self.batch_size = batch_size
self.verbose = verbose
self.attack_pool = (
self.caa_hue, self.caa_saturation, self.caa_rotation, self.caa_brightness, self.caa_contrast, self.caa_linf)
self._check_params()

import kornia
self.seq_num = len(self.enabled_attack) # attack_num
self.linf_idx = self.enabled_attack.index(5) if 5 in self.enabled_attack else None
self.attack_pool = (
self.caa_hue, self.caa_saturation, self.caa_rotation, self.caa_brightness, self.caa_contrast, self.caa_linf)
self.eps_pool = torch.tensor(self.epsilons, device=self.device)
self.attack_pool_base = (
kornia.enhance.adjust_hue, kornia.enhance.adjust_saturation, kornia.geometry.transform.rotate,
kornia.enhance.adjust_brightness, kornia.enhance.adjust_contrast, self.get_linf_perturbation)
self.attack_dict = tuple([self.attack_pool[i] for i in self.enabled_attack])
self.step_size_pool = [2.5 * ((eps[1] - eps[0]) / 2) / self.max_inner_iter for eps in
self.eps_pool] # 2.5 * ε-test / num_steps

self._check_params()
self._description = "Composite Adversarial Attack"
self._is_scheduling = False
self.adv_val_pool = self.eps_space = self.adv_val_space = self.curr_dsm = \
self.curr_seq = self.is_attacked = self.is_not_attacked = None

def _check_params(self) -> None:
super()._check_params()
if not isinstance(self.enabled_attack, tuple) or not all(
value in [0, 1, 2, 3, 4, 5] for value in self.enabled_attack):
raise ValueError(
"The parameter `enabled_attack` must be a tuple specifying the attack to launch. For simplicity, we use"
+ " the following abbreviations to specify each attack types. 0: Hue, 1: Saturation, 2: Rotation, 3: Br"
+ "ightness, 4: Contrast, 5: PGD(L-infinity). Therefore, `(0,1,2)` means that the attack combines hue, "
+ "saturation, and rotation; `(0,1,2,3,4)` means the all semantic attacks; `(0,1,2,3,4,5)` means the fu"
+ "ll attacks.")
_epsilons_range = [["hue_epsilon", [-np.pi, np.pi], "[-np.pi, np.pi]"],
["sat_epsilon", [0, np.inf], "[0, np.inf]"], ["rot_epsilon", [-360, 360], "[-360, 360]"],
["bri_epsilon", [-1, 1], "[-1, 1]"], ["con_epsilon", [0, np.inf], "[0, np.inf]"],
["pgd_epsilon", [-1, 1], "[-1, 1]"]]
for i in range(6):
if (not isinstance(self.epsilons[i], list) or
not len(self.epsilons[i]) == 2 or
not _epsilons_range[i][1][0] <= self.epsilons[i][0] <= self.epsilons[i][1] <= _epsilons_range[i][1][1]):
logger.info(
"The argument `" + _epsilons_range[i][0] + "` must be an interval within " + _epsilons_range[i][2]
+ " of type list.")
raise ValueError(
"The argument `" + _epsilons_range[i][0] + "` must be an interval within " + _epsilons_range[i][2]
+ " of type list.")

if not isinstance(self.early_stop, bool):
logger.info("The flag `early_stop` has to be of type bool.")
raise ValueError("The flag `early_stop` has to be of type bool.")

if not isinstance(self.targeted, bool):
logger.info("The flag `targeted` has to be of type bool.")
raise ValueError("The flag `targeted` has to be of type bool.")

if not isinstance(self.max_iter, int) or self.max_iter <= 0:
logger.info("The argument `max_iter` must be positive of type int.")
raise ValueError("The argument `max_iter` must be positive of type int.")

if not isinstance(self.max_inner_iter, int):
logger.info("The argument `max_inner_iter` must be positive of type int.")
raise TypeError("The argument `max_inner_iter` must be positive of type int.")

if self.attack_order not in ('fixed', 'random', 'scheduled'):
logger.info("attack_order: {}, should be either 'fixed', 'random', or 'scheduled'.".format(self.attack_order))
raise ValueError
logger.info("The argument `attack_order` should be either `fixed`, `random`, or `scheduled`.")
raise ValueError("The argument `attack_order` should be either `fixed`, `random`, or `scheduled`.")

if self.batch_size <= 0:
logger.info("The batch size has to be positive.")
raise ValueError("The batch size has to be positive.")

if not isinstance(self.verbose, bool):
logger.info("The argument `verbose` has to be a Boolean.")
raise ValueError("The argument `verbose` has to be a Boolean.")

def _set_targets(
self,
Expand Down Expand Up @@ -214,7 +265,7 @@ def _set_targets(
return targets

def _setup_attack(self):
import torch
import torch

hue_space = torch.rand(self.batch_size, device=self.device) * (
self.eps_pool[0][1] - self.eps_pool[0][0]) + self.eps_pool[0][0]
Expand All @@ -238,7 +289,7 @@ def generate(
y: Optional[np.ndarray] = None,
**kwargs
) -> np.ndarray:
import torch
import torch

targets = self._set_targets(x, y)
dataset = torch.utils.data.TensorDataset(
Expand Down Expand Up @@ -298,9 +349,16 @@ def _generate_batch(
x, y = x.to(self.device), y.to(self.device)

return self.caa_attack(x, y).cpu().detach().numpy()

def _comp_pgd(self, data, labels, attack_idx, attack_parameter, ori_is_attacked):
import torch
def _comp_pgd(
self,
data: "torch.Tensor",
labels: "torch.Tensor",
attack_idx: "torch.Tensor",
attack_parameter: "torch.Tensor",
ori_is_attacked: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
import torch
import torch.nn.functional as F

adv_data = self.attack_pool_base[attack_idx](data, attack_parameter)
for _ in range(self.max_inner_iter):
Expand All @@ -323,7 +381,12 @@ def _comp_pgd(self, data, labels, attack_idx, attack_parameter, ori_is_attacked)

return adv_data, attack_parameter

def caa_hue(self, data, hue, labels):
def caa_hue(
self,
data: "torch.Tensor",
hue: "torch.Tensor",
labels: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
hue = hue.detach().clone()
hue[self.is_attacked] = 0
hue.requires_grad_()
Expand All @@ -332,7 +395,12 @@ def caa_hue(self, data, hue, labels):
return self._comp_pgd(data=sur_data, labels=labels, attack_idx=0, attack_parameter=hue,
ori_is_attacked=self.is_attacked.clone())

def caa_saturation(self, data, saturation, labels):
def caa_saturation(
self,
data: "torch.Tensor",
saturation: "torch.Tensor",
labels: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
saturation = saturation.detach().clone()
saturation[self.is_attacked] = 1
saturation.requires_grad_()
Expand All @@ -341,7 +409,12 @@ def caa_saturation(self, data, saturation, labels):
return self._comp_pgd(data=sur_data, labels=labels, attack_idx=1, attack_parameter=saturation,
ori_is_attacked=self.is_attacked.clone())

def caa_rotation(self, data, theta, labels):
def caa_rotation(
self,
data: "torch.Tensor",
theta: "torch.Tensor",
labels: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
theta = theta.detach().clone()
theta[self.is_attacked] = 0
theta.requires_grad_()
Expand All @@ -350,7 +423,12 @@ def caa_rotation(self, data, theta, labels):
return self._comp_pgd(data=sur_data, labels=labels, attack_idx=2, attack_parameter=theta,
ori_is_attacked=self.is_attacked.clone())

def caa_brightness(self, data, brightness, labels):
def caa_brightness(
self,
data: "torch.Tensor",
brightness: "torch.Tensor",
labels: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
brightness = brightness.detach().clone()
brightness[self.is_attacked] = 0
brightness.requires_grad_()
Expand All @@ -359,7 +437,12 @@ def caa_brightness(self, data, brightness, labels):
return self._comp_pgd(data=sur_data, labels=labels, attack_idx=3, attack_parameter=brightness,
ori_is_attacked=self.is_attacked.clone())

def caa_contrast(self, data, contrast, labels):
def caa_contrast(
self,
data: "torch.Tensor",
contrast: "torch.Tensor",
labels: "torch.Tensor"
) -> Tuple["torch.Tensor", "torch.Tensor"]:
contrast = contrast.detach().clone()
contrast[self.is_attacked] = 1
contrast.requires_grad_()
Expand All @@ -368,8 +451,13 @@ def caa_contrast(self, data, contrast, labels):
return self._comp_pgd(data=sur_data, labels=labels, attack_idx=4, attack_parameter=contrast,
ori_is_attacked=self.is_attacked.clone())

def caa_linf(self, data, labels):
import torch
def caa_linf(
self,
data: "torch.Tensor",
labels: "torch.Tensor"
) -> "torch.Tensor":
import torch
import torch.nn.functional as F

sur_data = data.detach()
adv_data = data.detach().requires_grad_()
Expand All @@ -393,13 +481,23 @@ def caa_linf(self, data, labels):

return adv_data

def get_linf_perturbation(self, data, noise):
import torch
def get_linf_perturbation(
self,
data: "torch.Tensor",
noise: "torch.Tensor"
) -> "torch.Tensor":
import torch

return torch.clamp(data + noise, 0.0, 1.0)

def update_attack_order(self, images, labels, adv_val=None):
import torch
def update_attack_order(
self,
images: "torch.Tensor",
labels: "torch.Tensor",
adv_val: Optional["torch.Tensor"] = None
) -> None:
import torch
import torch.nn.functional as F

def hungarian(matrix_batch):
sol = torch.tensor([-i for i in range(1, matrix_batch.shape[0] + 1)], dtype=torch.int32)
Expand Down Expand Up @@ -458,8 +556,12 @@ def sinkhorn_normalization(ori_dsm, n_iters=20):
else:
raise ValueError()

def caa_attack(self, images, labels):
import torch
def caa_attack(
self,
images: "torch.Tensor",
labels: "torch.Tensor"
) -> "torch.Tensor":
import torch

attack = self.attack_dict
adv_img = images.detach().clone()
Expand Down
Loading

0 comments on commit cd803c3

Please sign in to comment.