Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to ART 1.15.0 #2207

Merged
merged 127 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 122 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
22ad7e8
add optimizer parameter to tensorflow v2 classifier
f4str Apr 24, 2023
a45603c
add default train step for tensorflow classifiers
f4str Apr 24, 2023
b8ad7d0
refactor tensorflow tests to use default train step
f4str Apr 25, 2023
57c51d4
refactor tensorflow notebooks to use default train step
f4str Apr 25, 2023
7e0d212
bug fix with new tensorflow optimizer parameter
f4str Apr 25, 2023
f190b9f
fix tensorflow model repr test case
f4str Apr 25, 2023
343a21a
implement resizing for numpy images
f4str Apr 27, 2023
ec032db
implement square padding for numpy images
f4str Apr 27, 2023
4a9ef03
Remove incorrectly duplicated queries in SignOPT attack
Lodour May 2, 2023
1280fd4
unit tests for numpy resizing
f4str May 3, 2023
1dd90fc
unit tests for numpy square padding
f4str May 3, 2023
2eaff48
implement square padding for pytorch images
f4str May 4, 2023
af791d2
implement square padding for tensorflow images
f4str May 4, 2023
8fa00ed
update square pad type checking and docstrings
f4str May 4, 2023
92aa1be
implement resizing for pytorch images
f4str May 5, 2023
dc20075
unit tests for pytorch resizing
f4str May 5, 2023
61fb591
implement resizing for tensorflow images
f4str May 6, 2023
65520d6
fix poison backdoor image trigger bug
f4str May 8, 2023
90350ef
minor backdoor image insert optimizations
f4str May 8, 2023
0535c79
Merge pull request #2129 from Lodour/fix_signopt_duplicated_queries
beat-buesser May 9, 2023
1f784ed
Merge pull request #2143 from f4str/image-trigger-bug
beat-buesser May 9, 2023
a87de7c
Merge branch 'dev_1.15.0' into object-detector-resize
beat-buesser May 9, 2023
8ccefa2
fix Tensor Device Inconsistencies in pgd
May 5, 2023
92dead2
Merge branch 'dev_1.15.0' into tf-default-fit
beat-buesser May 10, 2023
f1d5e0f
adding trades adversarial training
May 10, 2023
092d59c
address review comments
f4str May 10, 2023
614a436
address review comments
f4str May 13, 2023
9120a44
Merge remote-tracking branch 'origin/main' into dev_1.15.0
May 15, 2023
012b940
Merge branch 'main' into signedqiu
beat-buesser May 15, 2023
bed6945
Merge branch 'dev_1.15.0' into signedqiu
beat-buesser May 15, 2023
d6f190f
Merge remote-tracking branch 'origin/main' into dev_1.15.0
May 15, 2023
60744cf
Merge branch 'dev_1.15.0' into signedqiu
beat-buesser May 15, 2023
5adf58f
Merge branch 'dev_1.15.0' into tf-default-fit
f4str May 15, 2023
ae4ab18
Merge branch 'dev_1.15.0' into object-detector-resize
f4str May 15, 2023
37a976f
changes after review
May 19, 2023
bbb92cf
Merge pull request #2135 from SignedQiu/signedqiu
beat-buesser May 23, 2023
51b78df
changes for type checking style checks
May 23, 2023
4d6f9c1
correcting parameters description in TRADES trainer base class
May 24, 2023
6678060
Merge branch 'dev_1.15.0' into tf-default-fit
f4str May 24, 2023
b1d12cf
Merge branch 'dev_1.15.0' into object-detector-resize
f4str May 24, 2023
3fe948d
Merge branch 'dev_1.15.0' into trades_adv
beat-buesser May 25, 2023
36bc03d
Merge pull request #2131 from Zaid-Hameed/trades_adv
beat-buesser May 26, 2023
ede42ff
Merge branch 'dev_1.15.0' into tf-default-fit
f4str May 26, 2023
da64aa6
Merge branch 'dev_1.15.0' into object-detector-resize
f4str May 26, 2023
76ae22a
Merge pull request #2124 from f4str/tf-default-fit
beat-buesser May 29, 2023
658a22d
Merge branch 'dev_1.15.0' into object-detector-resize
beat-buesser May 29, 2023
3e9defa
extend bad det gma for arbitrary sizes
f4str May 27, 2023
f7eaa37
extend bad det oga for arbitrary sizes
f4str May 27, 2023
4ee410b
extend bad det rma for arbitrary sizes
f4str May 27, 2023
a4f403c
extend bad det oda for arbitrary sizes
f4str May 27, 2023
9c69398
update bad det demo notebook
f4str May 27, 2023
9c23192
optimized tensorflow predictions
f4str May 30, 2023
d1fb071
optimize pytorch classifier
f4str May 30, 2023
292910b
optimize pytorch classification and regression
f4str May 30, 2023
b2c8c2a
bug fix for prediction
f4str May 30, 2023
a95b42b
optimize pytorch yolo loops
f4str May 31, 2023
5fe8e36
optimize pytorch object detector
f4str Jun 1, 2023
67e89b3
linting and style checks
f4str Jun 1, 2023
50e0be1
use torch dataloader for randomized smoothing
f4str Jun 6, 2023
2819e27
fix style checks and unit tests
f4str Jun 6, 2023
a698532
fix mypy typings
f4str Jun 8, 2023
43d13fc
Merge remote-tracking branch 'origin/main' into dev_1.15.0
Jun 9, 2023
7a6eac3
Merge branch 'dev_1.15.0' into object-detector-resize
f4str Jun 9, 2023
f1b6a50
Merge pull request #2138 from f4str/object-detector-resize
beat-buesser Jun 13, 2023
cb9fc56
Merge branch 'dev_1.15.0' into torch-dataloaders
f4str Jun 13, 2023
d8b837d
Merge branch 'dev_1.15.0' into baddet-extension
f4str Jun 13, 2023
312c9f6
fix typo in baddet demo notebook
f4str Jun 13, 2023
ada0830
revert tensorflow classifier changes
f4str Jun 14, 2023
442bada
Fixing compatibility issue between PyTorch YOLO and AdversarialPatchP…
kieranfraser May 29, 2023
b8f7c74
Fixing compatibility issue between PyTorch YOLO and AdversarialPatchP…
kieranfraser May 29, 2023
0319e0a
Removing fix for spurious YOLO predictions as generates nans due to a…
kieranfraser Jun 14, 2023
7e46923
Formatting fix
kieranfraser Jun 14, 2023
790dea2
Fixing failing attack test as targets not defined
kieranfraser Jun 14, 2023
cb610a8
ViT backbone object detector with fasterRCNN example
kieranfraser Mar 2, 2023
dc2987a
ViT backbone object detector with fasterRCNN example
kieranfraser Mar 2, 2023
36e5e6e
training pipeline for fasterrcnn with vit backbone working. requires …
kieranfraser Mar 6, 2023
cc4894e
Adding pytorch detr. Working example demonstrating object detection a…
kieranfraser Mar 15, 2023
2c5bf42
DETR with original source methods attributed
kieranfraser Mar 20, 2023
4b97069
DETR with changes to original src
kieranfraser Mar 20, 2023
fd09793
Removed unused misc files. Updated example notebook demonstrating ViT…
kieranfraser Mar 20, 2023
5fa5001
Completed tests. Added method to freeze multihead-attention module. U…
kieranfraser Apr 13, 2023
f943d67
Adding constructor for detr
kieranfraser Apr 13, 2023
61a814e
Moved notebook to correct folder for adversarial patch attack. Update…
kieranfraser Apr 13, 2023
1627d37
Updated formatting
kieranfraser Apr 19, 2023
b9d7fc7
Refactored loss classes to prevent tests for other frameworks failing
kieranfraser Apr 19, 2023
bedcc31
Refactored loss classes to prevent tests for other frameworks failing
kieranfraser Apr 19, 2023
c639bbd
Fix for static methods and styling
kieranfraser Apr 20, 2023
317aada
Framework check for detr tests
kieranfraser Apr 20, 2023
b48dff0
Updated class name, added typing and other minor fixes.
kieranfraser May 11, 2023
46f3958
Updated class name, added typing and other minor fixes.
kieranfraser May 11, 2023
8f62b12
Added test call to github workflow
kieranfraser May 11, 2023
478d9a7
fix Tensor Device Inconsistencies in pgd
May 5, 2023
64db977
Updates to DETR: cleaned up resizing; correct clipping. Updates to no…
kieranfraser Jun 13, 2023
8e4c89d
Fixing formatting
kieranfraser Jun 14, 2023
8248092
updated detection transformer notebook
kieranfraser Jun 14, 2023
a1757e0
Remove irrelevant PGD
kieranfraser Jun 14, 2023
cacc829
Merge remote-tracking branch 'upstream/dev_1.15.0' into dev_detection…
kieranfraser Jun 14, 2023
ef88ed2
Fixed pylint, mypy issues
kieranfraser Jun 14, 2023
07b195e
Merge remote-tracking branch 'upstream/dev_1.15.0' into dev_issue_214…
kieranfraser Jun 14, 2023
a51b614
Remove print line
kieranfraser Jun 14, 2023
ddc09f7
fix type checking for object detection utils
f4str Jun 14, 2023
0ab98d0
Adding Apache License to original DETR functions
kieranfraser Jun 15, 2023
7a96e2c
Updated notebook with stronger adversarial patch attacks - targeted a…
kieranfraser Jun 15, 2023
0d15d2f
Removing comments to fix pylint test
kieranfraser Jun 15, 2023
40070ea
Adding missing license to functions
kieranfraser Jun 15, 2023
6db3e5c
Fix ragged nested sequence warning
Foxglove144 Jun 16, 2023
5a41925
Remove old code
Foxglove144 Jun 16, 2023
011ab1e
Merge pull request #2189 from f4str/baddet-extension
beat-buesser Jun 20, 2023
f3fcf19
Merge branch 'dev_1.15.0' into torch-dataloaders
beat-buesser Jun 20, 2023
4bfed67
Merge pull request #2180 from f4str/torch-dataloaders
beat-buesser Jun 20, 2023
04cb8aa
Merge branch 'dev_1.15.0' into dev_issue_2148_yolo
beat-buesser Jun 20, 2023
d91b3e0
Merge branch 'main' into patch-1
Foxglove144 Jun 26, 2023
28a3f8c
Merge remote-tracking branch 'origin/main' into dev_1.15.0
Jun 27, 2023
df3e298
Merge branch 'dev_1.15.0' into dev_detection_transformer
beat-buesser Jun 27, 2023
c67cd57
Merge branch 'dev_1.15.0' into dev_issue_2148_yolo
beat-buesser Jun 27, 2023
8cb3607
Merge pull request #2169 from kieranfraser/dev_issue_2148_yolo
beat-buesser Jun 27, 2023
3e250a1
Standalone detr.py file for utility code from FB repo
kieranfraser Jun 28, 2023
496fcd3
Merge remote-tracking branch 'origin/dev_detection_transformer' into …
kieranfraser Jun 28, 2023
482b277
Removing duplicate license reference
kieranfraser Jun 28, 2023
d6ed99b
Updated reference to adapted detr functions under Apache 2.0
kieranfraser Jun 28, 2023
35f1d5a
Updated detr.py docstring with list of changes to Apache 2.0 code
kieranfraser Jun 28, 2023
3a97e66
Updated device in pytorch_detection_transformer.py and detr.py. Updat…
kieranfraser Jun 28, 2023
84c9e2b
mypy fix - .to should not be called if np.array
kieranfraser Jun 28, 2023
81408e5
Fix for black formatting
kieranfraser Jun 28, 2023
82f8fa2
Merge pull request #2192 from kieranfraser/dev_detection_transformer
beat-buesser Jun 28, 2023
6250986
Merge branch 'dev_1.15.0' into patch-1
beat-buesser Jun 28, 2023
e3708eb
Merge pull request #2195 from Foxglove144/patch-1
beat-buesser Jun 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-pytorch-object-detectors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_object_detector.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_faster_rcnn
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_faster_rcnn.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_detection_transformer
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_detection_transformer.py --framework=pytorch --durations=0
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions art/attacks/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def __init__(self):
@abc.abstractmethod
def poison(
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples and return them as an array. This method should be overridden by all concrete
poisoning attack implementations.
Expand Down
32 changes: 23 additions & 9 deletions art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,9 @@ def __getitem__(self, idx):
img = torch.from_numpy(self.x[idx])

target = {}
target["boxes"] = torch.from_numpy(y[idx]["boxes"])
target["labels"] = torch.from_numpy(y[idx]["labels"])
target["scores"] = torch.from_numpy(y[idx]["scores"])
target["boxes"] = torch.from_numpy(self.y[idx]["boxes"])
target["labels"] = torch.from_numpy(self.y[idx]["labels"])
target["scores"] = torch.from_numpy(self.y[idx]["scores"])
mask_i = torch.from_numpy(self.mask[idx])

return img, target, mask_i
Expand All @@ -602,19 +602,33 @@ def __getitem__(self, idx):
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
targets = []
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
target = targets
_ = self._train_step(images=images, target=target, mask=None)
else:
for images, target, mask_i in data_loader:
images = images.to(self.estimator.device)
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
targets = []
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
target = targets
mask_i = mask_i.to(self.estimator.device)
_ = self._train_step(images=images, target=target, mask=mask_i)

Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/auto_conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, *args, **kwargs) -> tf.
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=_loss_object_tf,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/auto_projected_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, *args, **kwargs) -> tf.
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=_loss_object_tf,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/brendel_bethge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,8 @@ def logits_difference(y_true, y_pred):
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=self._loss_object,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _generate_batch(
inputs = x.to(self.estimator.device)
targets = targets.to(self.estimator.device)
adv_x = torch.clone(inputs)
momentum = torch.zeros(inputs.shape)
momentum = torch.zeros(inputs.shape).to(self.estimator.device)

if mask is not None:
mask = mask.to(self.estimator.device)
Expand Down
16 changes: 8 additions & 8 deletions art/attacks/evasion/sign_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,14 @@ def _fine_grained_binary_search_local(
lbd = initial_lbd
# For targeted: we want to expand(x1.01) boundary away from targeted dataset
# For untargeted, we want to slim(x0.99) the boundary toward the original dataset
if (not self._is_label(x_0 + lbd * theta, target) and self.targeted) or (
self._is_label(x_0 + lbd * theta, y_0) and not self.targeted
if (self.targeted and not self._is_label(x_0 + lbd * theta, target)) or (
not self.targeted and self._is_label(x_0 + lbd * theta, y_0)
):
lbd_lo = lbd
lbd_hi = lbd * 1.01
nquery += 1
while (not self._is_label(x_0 + lbd_hi * theta, target) and self.targeted) or (
self._is_label(x_0 + lbd_hi * theta, y_0) and not self.targeted
while (self.targeted and not self._is_label(x_0 + lbd_hi * theta, target)) or (
not self.targeted and self._is_label(x_0 + lbd_hi * theta, y_0)
):
lbd_hi = lbd_hi * 1.01
nquery += 1
Expand All @@ -323,17 +323,17 @@ def _fine_grained_binary_search_local(
lbd_hi = lbd
lbd_lo = lbd * 0.99
nquery += 1
while (self._is_label(x_0 + lbd_lo * theta, target) and self.targeted) or (
not self._is_label(x_0 + lbd_lo * theta, y_0) and not self.targeted
while (self.targeted and self._is_label(x_0 + lbd_lo * theta, target)) or (
not self.targeted and not self._is_label(x_0 + lbd_lo * theta, y_0)
):
lbd_lo = lbd_lo * 0.99
nquery += 1

while (lbd_hi - lbd_lo) > tol:
lbd_mid = (lbd_lo + lbd_hi) / 2.0
nquery += 1
if (self._is_label(x_0 + lbd_mid * theta, target) and self.targeted) or (
not self._is_label(x_0 + lbd_mid * theta, y_0) and not self.targeted
if (self.targeted and self._is_label(x_0 + lbd_mid * theta, target)) or (
not self.targeted and not self._is_label(x_0 + lbd_mid * theta, y_0)
):
lbd_hi = lbd_mid
else:
Expand Down
41 changes: 24 additions & 17 deletions art/attacks/poisoning/bad_det/bad_det_gma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -77,36 +77,39 @@ def __init__(

def poison( # pylint: disable=W0221
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
for labels `y`.

:param x: Sample images of shape `NCHW` or `NHWC`.
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
of the dictionary are:

- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores or each prediction.
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
x_ndim = len(x.shape)
if isinstance(x, np.ndarray):
x_ndim = len(x.shape)
else:
x_ndim = len(x[0].shape) + 1

if x_ndim != 4:
raise ValueError("Unrecognized input dimension. BadDet GMA can only be applied to image data.")

if self.channels_first:
# NCHW --> NHWC
x = np.transpose(x, (0, 2, 3, 1))

x_poison = x.copy()
y_poison: List[Dict[str, np.ndarray]] = []
# copy images
x_poison: Union[np.ndarray, List[np.ndarray]]
if isinstance(x, np.ndarray):
x_poison = x.copy()
else:
x_poison = [x_i.copy() for x_i in x]

# copy labels
y_poison: List[Dict[str, np.ndarray]] = []
for y_i in y:
target_dict = {k: v.copy() for k, v in y_i.items()}
y_poison.append(target_dict)
Expand All @@ -120,18 +123,22 @@ def poison( # pylint: disable=W0221
image = x_poison[i]
labels = y_poison[i]["labels"]

if self.channels_first:
image = np.transpose(image, (1, 2, 0))

# insert backdoor into the image
# add an additional dimension to create a batch of size 1
poisoned_input, _ = self.backdoor.poison(image[np.newaxis], labels)
x_poison[i] = poisoned_input[0]
image = poisoned_input[0]

# replace the original image with the poisoned image
if self.channels_first:
image = np.transpose(image, (2, 0, 1))
x_poison[i] = image

# change all labels to the target label
y_poison[i]["labels"] = np.full(labels.shape, self.class_target)

if self.channels_first:
# NHWC --> NCHW
x_poison = np.transpose(x_poison, (0, 3, 1, 2))

return x_poison, y_poison

def _check_params(self) -> None:
Expand Down
40 changes: 23 additions & 17 deletions art/attacks/poisoning/bad_det/bad_det_oda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -77,36 +77,39 @@ def __init__(

def poison( # pylint: disable=W0221
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
for labels `y`.

:param x: Sample images of shape `NCHW` or `NHWC`.
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
of the dictionary are:

- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores or each prediction.
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
x_ndim = len(x.shape)
if isinstance(x, np.ndarray):
x_ndim = len(x.shape)
else:
x_ndim = len(x[0].shape) + 1

if x_ndim != 4:
raise ValueError("Unrecognized input dimension. BadDet ODA can only be applied to image data.")

if self.channels_first:
# NCHW --> NHWC
x = np.transpose(x, (0, 2, 3, 1))

x_poison = x.copy()
y_poison: List[Dict[str, np.ndarray]] = []
# copy images
x_poison: Union[np.ndarray, List[np.ndarray]]
if isinstance(x, np.ndarray):
x_poison = x.copy()
else:
x_poison = [x_i.copy() for x_i in x]

# copy labels and find indices of the source class
y_poison: List[Dict[str, np.ndarray]] = []
source_indices = []
for i, y_i in enumerate(y):
target_dict = {k: v.copy() for k, v in y_i.items()}
Expand All @@ -121,10 +124,12 @@ def poison( # pylint: disable=W0221

for i in tqdm(selected_indices, desc="BadDet ODA iteration", disable=not self.verbose):
image = x_poison[i]

boxes = y_poison[i]["boxes"]
labels = y_poison[i]["labels"]

if self.channels_first:
image = np.transpose(image, (1, 2, 0))

keep_indices = []

for j, (box, label) in enumerate(zip(boxes, labels)):
Expand All @@ -140,13 +145,14 @@ def poison( # pylint: disable=W0221
else:
keep_indices.append(j)

# replace the original image with the poisoned image
if self.channels_first:
image = np.transpose(image, (2, 0, 1))
x_poison[i] = image

# remove labels for poisoned bounding boxes
y_poison[i] = {k: v[keep_indices] for k, v in y_poison[i].items()}

if self.channels_first:
# NHWC --> NCHW
x_poison = np.transpose(x_poison, (0, 3, 1, 2))

return x_poison, y_poison

def _check_params(self) -> None:
Expand Down
Loading