-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for BEYOND [ICML-2024] #2489
Changes from all commits
f5ec52c
d9130cf
c2ec510
32cb990
4628d64
2e7b1a4
5a92ea3
7902e01
b138168
d5a0cd9
f197bbe
a7afbd1
94c6ced
318f9de
17e6e60
f6e3371
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# MIT License | ||
# | ||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||
# Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
""" | ||
This module implements the BEYOND detector for adversarial examples detection. | ||
|
||
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 | ||
""" | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from art.defences.detector.evasion.evasion_detector import EvasionDetector | ||
|
||
if TYPE_CHECKING: | ||
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE | ||
|
||
|
||
class BeyondDetector(EvasionDetector): | ||
""" | ||
BEYOND detector for adversarial samples detection. | ||
This detector uses a combination of SSL and target model predictions to detect adversarial examples. | ||
|
||
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 | ||
""" | ||
|
||
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"] | ||
|
||
def __init__(self, | ||
target_model: "CLASSIFIER_NEURALNETWORK_TYPE", | ||
ssl_model: "CLASSIFIER_NEURALNETWORK_TYPE", | ||
augmentations: Callable | None, | ||
aug_num: int=50, | ||
alpha: float=0.8, | ||
K:int=20, | ||
percentile:int=5) -> None: | ||
""" | ||
Initialize the BEYOND detector. | ||
|
||
:param target_model: The target model to be protected | ||
:param ssl_model: The self-supervised learning model used for feature extraction | ||
:param augmentation: data augmentations for generating neighborhoods | ||
:param aug_num: Number of augmentations to apply to each sample (default: 50) | ||
:param alpha: Weight factor for combining label and representation similarities (default: 0.8) | ||
:param K: Number of top similarities to consider (default: 20) | ||
:param percentile: using to calculate the threshold | ||
""" | ||
super().__init__() | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
self.target_model = target_model.to(self.device) | ||
self.ssl_model = ssl_model.to(self.device) | ||
self.aug_num = aug_num | ||
self.alpha = alpha | ||
self.K = K | ||
|
||
self.backbone = ssl_model.backbone | ||
self.classifier = ssl_model.classifier | ||
self.projector = ssl_model.projector | ||
|
||
self.img_augmentations = augmentations | ||
|
||
self.percentile = percentile # determinate the threshold | ||
self.threshold = None | ||
|
||
|
||
|
||
def _multi_transform(self, img: torch.Tensor) -> torch.Tensor: | ||
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1) | ||
|
||
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]: | ||
""" | ||
Calculate similarities that combining label consistency and representation similarity for given samples | ||
|
||
:param x: Input samples | ||
:param batch_size: Batch size for processing | ||
:return: A report similarities | ||
""" | ||
samples = torch.from_numpy(x).to(self.device) | ||
|
||
self.target_model.eval() | ||
self.backbone.eval() | ||
self.classifier.eval() | ||
self.projector.eval() | ||
|
||
number_batch = int(math.ceil(len(samples) / batch_size)) | ||
|
||
similarities = [] | ||
|
||
with torch.no_grad(): | ||
for index in range(number_batch): | ||
start = index * batch_size | ||
end = min((index + 1) * batch_size, len(samples)) | ||
|
||
batch_samples = samples[start:end] | ||
b, c, h, w = batch_samples.shape | ||
|
||
trans_images = self._multi_transform(batch_samples).to(self.device) | ||
ssl_backbone_out = self.backbone(batch_samples) | ||
|
||
ssl_repre = self.projector(ssl_backbone_out) | ||
ssl_pred = self.classifier(ssl_backbone_out) | ||
ssl_label = torch.max(ssl_pred, -1)[1] | ||
|
||
aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w)) | ||
aug_repre = self.projector(aug_backbone_out) | ||
aug_pred = self.classifier(aug_backbone_out) | ||
aug_pred = aug_pred.reshape(b, self.aug_num, -1) | ||
|
||
sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2) | ||
sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2) | ||
|
||
similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy()) | ||
|
||
similarities = np.concatenate(similarities, axis=0) | ||
|
||
return similarities | ||
|
||
|
||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None: | ||
""" | ||
Determine a threshold that covers 95% of clean samples. | ||
|
||
:param x: Clean sample data | ||
:param y: Clean sample labels (not used in this method) | ||
:param batch_size: Batch size for processing | ||
:param nb_epochs: Number of training epochs (not used in this method) | ||
""" | ||
k_minus_one_metrics = clean_metrics[:, self.K-1] | ||
|
||
self.threshold = np.percentile(k_minus_one_metrics, self.threshold) | ||
|
||
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]: | ||
""" | ||
Detect whether given samples are adversarial | ||
|
||
:param x: Input samples | ||
:param batch_size: Batch size for processing | ||
:return: (report, is_adversarial): | ||
where report containing detection results | ||
where is_adversarial is a boolean list indicating whether samples are adversarial or not | ||
""" | ||
if self.threshold is None: | ||
raise ValueError("Detector has not been fitted. Call fit() before detect().") | ||
|
||
similarities = self._get_metrics(x, batch_size) | ||
|
||
report = similarities[:, self.K-1] | ||
is_adversarial = report < self.threshold | ||
|
||
return report, is_adversarial |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# MIT License | ||
# | ||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||
# Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
import torch.nn as nn | ||
from torchvision import models, transforms | ||
|
||
from art.attacks.evasion.fast_gradient import FastGradientMethod | ||
from art.defences.detector.evasion import BeyondDetector | ||
|
||
from tests.utils import ARTTestException | ||
|
||
|
||
class SimSiamWithCls(nn.Module): | ||
''' | ||
SimSiam with Classifier | ||
''' | ||
def __init__(self, arch='resnet18', feat_dim=2048, num_proj_layers=2): | ||
|
||
super(SimSiamWithCls, self).__init__() | ||
self.backbone = models.resnet18() | ||
out_dim = self.backbone.fc.weight.shape[1] | ||
self.backbone.conv1 = nn.Conv2d( | ||
3, 64, kernel_size=3, stride=1, padding=2, bias=False | ||
) | ||
self.backbone.maxpool = nn.Identity() | ||
self.backbone.fc = nn.Identity() | ||
self.classifier = nn.Linear(out_dim, 10) | ||
|
||
pred_hidden_dim = int(feat_dim / 4) | ||
|
||
self.projector = nn.Sequential( | ||
nn.Linear(out_dim, feat_dim, bias=False), | ||
nn.BatchNorm1d(feat_dim), | ||
nn.ReLU(), | ||
nn.Linear(feat_dim, feat_dim, bias=False), | ||
nn.BatchNorm1d(feat_dim), | ||
nn.ReLU(), | ||
nn.Linear(feat_dim, feat_dim), | ||
nn.BatchNorm1d(feat_dim, affine=False), | ||
) | ||
self.projector[6].bias.requires_grad = False | ||
|
||
self.predictor = nn.Sequential( | ||
nn.Linear(feat_dim, pred_hidden_dim, bias=False), | ||
nn.BatchNorm1d(pred_hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(pred_hidden_dim, feat_dim), | ||
) | ||
|
||
def forward(self, img, im_aug1=None, im_aug2=None): | ||
|
||
r_ori = self.backbone(img) | ||
if im_aug1 is None and im_aug2 is None: | ||
cls = self.classifier(r_ori) | ||
rep = self.projector(r_ori) | ||
return {'cls': cls, 'rep':rep} | ||
else: | ||
|
||
r1 = self.backbone(im_aug1) | ||
r2 = self.backbone(im_aug2) | ||
|
||
z1 = self.projector(r1) | ||
z2 = self.projector(r2) | ||
|
||
p1 = self.predictor(z1) | ||
p2 = self.predictor(z2) | ||
|
||
return {'z1': z1, 'z2': z2, 'p1': p1, 'p2': p2} | ||
|
||
|
||
@pytest.fixture | ||
def get_ssl_model(weights_path): | ||
""" | ||
Loads the SSL model (SimSiamWithCls). | ||
""" | ||
model = SimSiamWithCls() | ||
model.load_state_dict(torch.load(weights_path)) | ||
return model | ||
|
||
@pytest.mark.only_with_platform("pytorch") | ||
def test_beyond_detector(art_warning, load_cifar10, get_ssl_model): | ||
try: | ||
# Load CIFAR10 data | ||
(x_train, y_train), (x_test, y_test), _, _ = load_cifar10() | ||
|
||
# Load models | ||
# Download pretrained weights from https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing | ||
target_model = models.resnet18() | ||
target_model.load_state_dict(torch.load("./resnet_c10.pth")) | ||
ssl_model = get_ssl_model() | ||
ssl_model.load_state_dict(torch.load("./simsiam_c10.pth")) | ||
|
||
|
||
# Generate adversarial samples | ||
attack = FastGradientMethod(estimator=target_model, eps=0.05) | ||
x_test_adv = attack.generate(x_test) | ||
|
||
img_augmentations = transforms.Compose([ | ||
transforms.RandomResizedCrop(32, scale=(0.2, 1.)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomApply([ | ||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened | ||
], p=0.8), | ||
transforms.RandomGrayscale(p=0.2), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||
]) | ||
|
||
# Initialize BeyondDetector | ||
detector = BeyondDetector( | ||
target_model=target_model, | ||
ssl_model=ssl_model, | ||
augmentations=img_augmentations, | ||
aug_num=50, | ||
alpha=0.8, | ||
K=20, | ||
percentile=5 | ||
) | ||
|
||
# Fit the detector | ||
detector.fit(x_train, y_train, batch_size=128) | ||
|
||
# Apply detector on clean and adversarial test data | ||
_, test_detection = detector.detect(x_test) | ||
_, test_adv_detection = detector.detect(x_test_adv) | ||
|
||
# Assert there is at least one true positive and negative | ||
nb_true_positives = np.sum(test_adv_detection) | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable nb_true_positives is not used.
|
||
nb_true_negatives = len(test_detection) - np.sum(test_detection) | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable nb_true_negatives is not used.
|
||
|
||
clean_accuracy = 1 - np.mean(test_detection) | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable clean_accuracy is not used.
|
||
adv_accuracy = np.mean(test_adv_detection) | ||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable adv_accuracy is not used.
|
||
|
||
except ARTTestException as e: | ||
art_warning(e) | ||
|
||
if __name__ == "__main__": | ||
|
||
test_beyond_detector() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How large are the downloaded files? Can we store them in the ART repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.