Skip to content

Commit

Permalink
Fix style check
Browse files Browse the repository at this point in the history
Signed-off-by: Lei Hsiung <[email protected]>
  • Loading branch information
twweeb committed Oct 16, 2023
1 parent 3528185 commit cf31929
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions art/attacks/evasion/composite_adversarial_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,15 @@ def _check_params(self) -> None:
_epsilons_range[i][1][0] <= self.epsilons[i][0] <= self.epsilons[i][1] <= _epsilons_range[i][1][1]
)
):
logger.info("The argument `%s` must be an interval within %s of type tuple.", _epsilons_range[i][0], _epsilons_range[i][2])
raise ValueError("The argument `{}` must be an interval within {} of type tuple.".format(_epsilons_range[i][0], _epsilons_range[i][2]))
logger.info(
"The argument `%s` must be an interval within %s of type tuple.",
_epsilons_range[i][0],
_epsilons_range[i][2],
)
raise ValueError(
f"The argument `{_epsilons_range[i][0]}` must be an interval "
f"within {_epsilons_range[i][2]} of type tuple."
)

if not isinstance(self.early_stop, bool):
logger.info("The flag `early_stop` has to be of type bool.")
Expand Down Expand Up @@ -525,9 +532,7 @@ def caa_linf(

return adv_data, eta

def update_attack_order(
self, images: "torch.Tensor", labels: "torch.Tensor", adv_val: List
) -> None:
def update_attack_order(self, images: "torch.Tensor", labels: "torch.Tensor", adv_val: List) -> None:
"""
Update the specified attack ordering.
:param images: A tensor of a batch of original inputs to be attacked.
Expand Down Expand Up @@ -600,7 +605,6 @@ def caa_attack(self, images: "torch.Tensor", labels: "torch.Tensor") -> "torch.T
"""
import torch

attack = self.attack_dict
adv_img = images.detach().clone()
adv_val_saved = torch.zeros((self.seq_num, self.batch_size), device=self.device)
adv_val = [self.adv_val_space[idx] for idx in range(self.seq_num)]
Expand All @@ -623,7 +627,7 @@ def caa_attack(self, images: "torch.Tensor", labels: "torch.Tensor") -> "torch.T

for tdx in range(self.seq_num):
idx = self.curr_seq[tdx]
adv_img, adv_val_updated = attack[idx](adv_img, adv_val[idx], labels)
adv_img, adv_val_updated = self.attack_dict[idx](adv_img, adv_val[idx], labels)
if idx != self.linf_idx:
adv_val[idx] = adv_val_updated

Expand Down

0 comments on commit cf31929

Please sign in to comment.