Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for pool size in AutoAttack #2534

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions art/attacks/evasion/auto_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
batch_size: int = 32,
estimator_orig: "CLASSIFIER_TYPE" | None = None,
targeted: bool = False,
parallel: bool = False,
parallel_pool_size: int = 0,
):
"""
Create a :class:`.AutoAttack` instance.
Expand All @@ -93,7 +93,8 @@
:param estimator_orig: Original estimator to be attacked by adversarial examples.
:param targeted: If False run only untargeted attacks, if True also run targeted attacks against each possible
target.
:param parallel: If True run attacks in parallel.
:param parallel_pool_size: Number of parallel threads / pool size in multiprocessing. If parallel_pool_size=0
computation runs without multiprocessing.
"""
super().__init__(estimator=estimator)

Expand Down Expand Up @@ -151,7 +152,7 @@
self.estimator_orig = estimator

self._targeted = targeted
self.parallel = parallel
self.parallel_pool_size = parallel_pool_size

Check warning on line 155 in art/attacks/evasion/auto_attack.py

View check run for this annotation

Codecov / codecov/patch

art/attacks/evasion/auto_attack.py#L155

Added line #L155 was not covered by tests
self.best_attacks: np.ndarray = np.array([])
self._check_params()

Expand Down Expand Up @@ -199,7 +200,7 @@
if attack.targeted:
attack.set_params(targeted=False)

if self.parallel:
if self.parallel_pool_size > 0:
args.append(
(
deepcopy(x_adv),
Expand Down Expand Up @@ -253,7 +254,7 @@
targeted_labels[:, i], nb_classes=self.estimator.nb_classes
)

if self.parallel:
if self.parallel_pool_size > 0:
args.append(
(
deepcopy(x_adv),
Expand Down Expand Up @@ -287,8 +288,8 @@
except ValueError as error:
logger.warning("Error completing attack: %s}", str(error))

if self.parallel:
with multiprocess.get_context("spawn").Pool() as pool:
if self.parallel_pool_size > 0:
with multiprocess.get_context("spawn").Pool(processes=self.parallel_pool_size) as pool:
# Results come back in the order that they were issued
results = pool.starmap(run_attack, args)
perturbations = []
Expand Down Expand Up @@ -320,15 +321,16 @@
This method returns a summary of the best performing (lowest perturbation in the parallel case) attacks
per image passed to the AutoAttack class.
"""
if self.parallel:
if self.parallel_pool_size > 0:
best_attack_meta = "\n".join(
[
f"image {i+1}: {str(self.args[idx][3])}" if idx != 0 else f"image {i+1}: n/a"
for i, idx in enumerate(self.best_attacks)
]
)
auto_attack_meta = (
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.args)})"
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
+ "num_attacks={len(self.args)})"
)
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"

Expand All @@ -339,7 +341,8 @@
]
)
auto_attack_meta = (
f"AutoAttack(targeted={self.targeted}, parallel={self.parallel}, num_attacks={len(self.attacks)})"
f"AutoAttack(targeted={self.targeted}, parallel_pool_size={self.parallel_pool_size}, "
+ "num_attacks={len(self.attacks)})"
)
return f"{auto_attack_meta}\nBestAttacks:\n{best_attack_meta}"

Expand Down
6 changes: 3 additions & 3 deletions tests/attacks/evasion/test_auto_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=False,
parallel=True,
parallel_pool_size=3,
)

attack_noparallel = AutoAttack(
Expand All @@ -285,7 +285,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=False,
parallel=False,
parallel_pool_size=0,
)

x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
Expand All @@ -310,7 +310,7 @@ def test_generate_parallel(art_warning, fix_get_mnist_subset, image_dl_estimator
batch_size=batch_size,
estimator_orig=None,
targeted=True,
parallel=True,
parallel_pool_size=3,
)

x_train_mnist_adv = attack.generate(x=x_train_mnist, y=y_train_mnist)
Expand Down
Loading