Skip to content

Commit

Permalink
Merge pull request #2557 from Trusted-AI/development_typing_docs
Browse files Browse the repository at this point in the history
Improve typing and documentation of AdversarialPatchPyTorch
  • Loading branch information
beat-buesser authored Jan 21, 2025
2 parents d3280eb + a9b698e commit e20d091
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import logging
import math
from packaging.version import parse
from typing import Any, TYPE_CHECKING
from typing import Any, cast, TYPE_CHECKING

import numpy as np
from tqdm.auto import trange
Expand All @@ -41,7 +41,7 @@

import torch

from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE, PYTORCH_OBJECT_DETECTOR_TYPE

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,7 +72,7 @@ class AdversarialPatchPyTorch(EvasionAttack):

def __init__(
self,
estimator: "CLASSIFIER_NEURALNETWORK_TYPE",
estimator: "CLASSIFIER_NEURALNETWORK_TYPE | PYTORCH_OBJECT_DETECTOR_TYPE",
rotation_max: float = 22.5,
scale_min: float = 0.1,
scale_max: float = 1.0,
Expand All @@ -91,7 +91,7 @@ def __init__(
"""
Create an instance of the :class:`.AdversarialPatchPyTorch`.
:param estimator: A trained estimator.
:param estimator: A trained PyTorch estimator for classification or object detection.
:param rotation_max: The maximum rotation applied to random patches. The value is expected to be in the
range `[0, 180]`.
:param scale_min: The minimum scaling applied to random patches. The value should be in the range `[0, 1]`,
Expand Down Expand Up @@ -183,7 +183,10 @@ def __init__(
self._optimizer = torch.optim.Adam([self._patch], lr=self.learning_rate)

def _train_step(
self, images: "torch.Tensor", target: "torch.Tensor", mask: "torch.Tensor" | None = None
self,
images: "torch.Tensor",
target: "torch.Tensor" | list[dict[str, "torch.Tensor"]],
mask: "torch.Tensor" | None = None,
) -> "torch.Tensor":
import torch

Expand Down Expand Up @@ -227,7 +230,12 @@ def _predictions(

return predictions, target

def _loss(self, images: "torch.Tensor", target: "torch.Tensor", mask: "torch.Tensor" | None) -> "torch.Tensor":
def _loss(
self,
images: "torch.Tensor",
target: "torch.Tensor" | list[dict[str, "torch.Tensor"]],
mask: "torch.Tensor" | None,
) -> "torch.Tensor":
import torch

if isinstance(target, torch.Tensor):
Expand Down Expand Up @@ -475,13 +483,17 @@ def _random_overlay(
return patched_images

def generate( # type: ignore
self, x: np.ndarray, y: np.ndarray | None = None, **kwargs
self, x: np.ndarray, y: np.ndarray | list[dict[str, np.ndarray | "torch.Tensor"]] | None = None, **kwargs
) -> tuple[np.ndarray, np.ndarray]:
"""
Generate an adversarial patch and return the patch and its mask in arrays.
:param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
:param y: An array with the original true labels.
:param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
input image. The fields of the dict are as follows:
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
:param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
(N, H, W) without their channel dimensions. Any features for which the mask is True can be the
center location of the patch during sampling.
Expand All @@ -499,12 +511,19 @@ def generate( # type: ignore
if self.patch_location is not None and mask is not None:
raise ValueError("Masks can only be used if the `patch_location` is `None`.")

if y is None: # pragma: no cover
logger.info("Setting labels to estimator predictions and running untargeted attack because `y=None`.")
y = to_categorical(np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes)

if hasattr(self.estimator, "nb_classes"):
y = check_and_transform_label_format(labels=y, nb_classes=self.estimator.nb_classes)

y_array: np.ndarray

if y is None: # pragma: no cover
logger.info("Setting labels to estimator classification predictions.")
y_array = to_categorical(
np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes
)
else:
y_array = cast(np.ndarray, y)

Check warning on line 524 in art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

View check run for this annotation

Codecov / codecov/patch

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py#L524

Added line #L524 was not covered by tests

y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)

Check warning on line 526 in art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

View check run for this annotation

Codecov / codecov/patch

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py#L526

Added line #L526 was not covered by tests

# check if logits or probabilities
y_pred = self.estimator.predict(x=x[[0]])
Expand All @@ -513,6 +532,10 @@ def generate( # type: ignore
self.use_logits = False
else:
self.use_logits = True
else:
if y is None: # pragma: no cover
logger.info("Setting labels to estimator object detection predictions.")
y = self.estimator.predict(x=x)

if isinstance(y, np.ndarray):
x_tensor = torch.Tensor(x)
Expand Down

0 comments on commit e20d091

Please sign in to comment.