Skip to content

Commit

Permalink
Merge pull request #2534 from Trusted-AI/development_issue_2529
Browse files Browse the repository at this point in the history
Add option for pool size in AutoAttack
  • Loading branch information
beat-buesser authored Dec 13, 2024
2 parents 21f1923 + 7903726 commit bf196eb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
23 changes: 13 additions & 10 deletions art/attacks/evasion/auto_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
batch_size: int = 32,
estimator_orig: "CLASSIFIER_TYPE" | None = None,
targeted: bool = False,
parallel: bool = False,
parallel_pool_size: int = 0,
):
"""
Create a :class:`.AutoAttack` instance.
Expand All @@ -93,7 +93,8 @@ def __init__(
:param estimator_orig: Original estimator to be attacked by adversarial examples.
:param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible
target.
:param parallel: If True run attacks in parallel.
:param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0
computation runs without multiprocessing.
"""
super().__init__(estimator=estimator)

Expand Down Expand Up @@ -151,7 +152,7 @@ def __init__(
self.estimator_orig = estimator

self._targeted = targeted
self.parallel = parallel
self.parallel_pool_size = parallel_pool_size
self.best_attacks: np.ndarray = np.array([])
self._check_params()

Expand Down Expand Up @@ -199,7 +200,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
if attack.targeted:
attack.set_params(targeted=False)

if self.parallel:
if self.parallel_pool_size > 0:
args.append(
(
deepcopy(x_adv),
Expand Down Expand Up @@ -253,7 +254,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
targeted_labels[:, i], nb_classes=self.estimator.nb_classes
)

if self.parallel:
if self.parallel_pool_size > 0:
args.append(
(
deepcopy(x_adv),
Expand Down Expand Up @@ -287,8 +288,8 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n
except ValueError as error:
logger.warning("Error completing attack: %s}", str(error))

if self.parallel:
with multiprocess.get_context("spawn").Pool() as pool:
if self.parallel_pool_size > 0:
with multiprocess.get_context("spawn").Pool(processes=self.parallel_pool_size) as pool:
# Results come back in the order that they were issued
results = pool.starmap(run_attack, args)
perturbations = []
Expand Down Expand Up @@ -320,15 +321,16 @@ def __repr__(self) -> str:
This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks
per image passed to the AutoAttack class.
"""
if self.parallel:
if self.parallel_pool_size > 0:
best_attack_meta = "\n".join(
[
f"image {i+1}: {str(self.args[idx][3])}" if idx != 0 else f"image {i+1}: n/a"
for i, idx in enumerate(self.best_attacks)
]
)
auto_attack_meta = (
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.args)})"
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
+ "num_attacks={len(self.args)})"
)
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"

Expand All @@ -339,7 +341,8 @@ def __repr__(self) -> str:
]
)
auto_attack_meta = (
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.attacks)})"
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
+ "num_attacks={len(self.attacks)})"
)
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"

Expand Down
6 changes: 3 additions & 3 deletions tests/attacks/evasion/test_auto_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=False,
parallel=True,
parallel_pool_size=3,
)

attack_noparallel = AutoAttack(
Expand All @@ -285,7 +285,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=False,
parallel=False,
parallel_pool_size=0,
)

x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
Expand All @@ -310,7 +310,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=True,
parallel=True,
parallel_pool_size=3,
)

x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
Expand Down

0 comments on commit bf196eb

Please sign in to comment.