From 6b678894e6a13d2e052ac3d88544e2aeb7885551 Mon Sep 17 00:00:00 2001 From: Beat Buesser Date: Fri, 13 Dec 2024 01:19:41 +0100 Subject: [PATCH] Add option for pool size in AutoAttack Signed-off-by: Beat Buesser --- art/attacks/evasion/auto_attack.py | 23 +++++++++++++---------- tests/attacks/evasion/test_auto_attack.py | 6 +++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/art/attacks/evasion/auto_attack.py b/art/attacks/evasion/auto_attack.py index 01a4046ec7..a148e43b20 100644 --- a/art/attacks/evasion/auto_attack.py +++ b/art/attacks/evasion/auto_attack.py @@ -78,7 +78,7 @@ def __init__( batch_size: int = 32, estimator_orig: "CLASSIFIER_TYPE" | None = None, targeted: bool = False, - parallel: bool = False, + parallel_pool_size: int = 0, ): """ Create a :class:`.AutoAttack` instance. @@ -93,7 +93,8 @@ def __init__( :param estimator_orig: Original estimator to be attacked by adversarial examples. :param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible target. - :param parallel: If True run attacks in parallel. + :param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0 + computation runs without multiprocessing. """ super().__init__(estimator=estimator) @@ -151,7 +152,7 @@ def __init__( self.estimator_orig = estimator self._targeted = targeted - self.parallel = parallel + self.parallel_pool_size = parallel_pool_size self.best_attacks: np.ndarray = np.array([]) self._check_params() @@ -199,7 +200,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n if attack.targeted: attack.set_params(targeted=False) - if self.parallel: + if self.parallel_pool_size > 0: args.append( ( deepcopy(x_adv), @@ -253,7 +254,7 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n targeted_labels[:, i], nb_classes=self.estimator.nb_classes ) - if self.parallel: + if self.parallel_pool_size > 0: args.append( ( deepcopy(x_adv), @@ -287,8 +288,8 @@ def generate(self, x: np.ndarray, y: np.ndarray | None = None, **kwargs) -> np.n except ValueError as error: logger.warning("Error completing attack: %s}", str(error)) - if self.parallel: - with multiprocess.get_context("spawn").Pool() as pool: + if self.parallel_pool_size > 0: + with multiprocess.get_context("spawn").Pool(processes=self.parallel_pool_size) as pool: # Results come back in the order that they were issued results = pool.starmap(run_attack, args) perturbations = [] @@ -320,7 +321,7 @@ def __repr__(self) -> str: This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks per image passed to the AutoAttack class. """ - if self.parallel: + if self.parallel_pool_size > 0: best_attack_meta = "\n".join( [ f"image {i+1}: {str(self.args[idx][3])}" if idx != 0 else f"image {i+1}: n/a" @@ -328,7 +329,8 @@ def __repr__(self) -> str: ] ) auto_attack_meta = ( - f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.args)})" + f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, " + + "num_attacks={len(self.args)})" ) return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}" @@ -339,7 +341,8 @@ def __repr__(self) -> str: ] ) auto_attack_meta = ( - f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.attacks)})" + f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, " + + "num_attacks={len(self.attacks)})" ) return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}" diff --git a/tests/attacks/evasion/test_auto_attack.py b/tests/attacks/evasion/test_auto_attack.py index 52e76274a6..f98847bfd0 100644 --- a/tests/attacks/evasion/test_auto_attack.py +++ b/tests/attacks/evasion/test_auto_attack.py @@ -273,7 +273,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator batch_size=batch_size, estimator_orig=None, targeted=False, - parallel=True, + parallel_pool_size=3, ) attack_noparallel = AutoAttack( @@ -285,7 +285,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator batch_size=batch_size, estimator_orig=None, targeted=False, - parallel=False, + parallel_pool_size=0, ) x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist) @@ -310,7 +310,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator batch_size=batch_size, estimator_orig=None, targeted=True, - parallel=True, + parallel_pool_size=3, ) x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)