diff --git a/art/defences/detector/evasion/__init__.py b/art/defences/detector/evasion/__init__.py index 26112a2afe..3546431c38 100644 --- a/art/defences/detector/evasion/__init__.py +++ b/art/defences/detector/evasion/__init__.py @@ -6,3 +6,5 @@ from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector from art.defences.detector.evasion.binary_activation_detector import BinaryActivationDetector from art.defences.detector.evasion.subsetscanning.detector import SubsetScanningDetector +from art.defences.detector.evasion.beyond_detector import BeyondDetector + diff --git a/art/defences/detector/evasion/beyond_detector.py b/art/defences/detector/evasion/beyond_detector.py new file mode 100644 index 0000000000..66b73abab3 --- /dev/null +++ b/art/defences/detector/evasion/beyond_detector.py @@ -0,0 +1,167 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements the BEYOND detector for adversarial examples detection. + +| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from art.defences.detector.evasion.evasion_detector import EvasionDetector + +if TYPE_CHECKING: + from art.utils import CLASSIFIER_NEURALNETWORK_TYPE + + +class BeyondDetector(EvasionDetector): + """ + BEYOND detector for adversarial samples detection. + This detector uses a combination of SSL and target model predictions to detect adversarial examples. + + | Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 + """ + + defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"] + + def __init__(self, + target_model: "CLASSIFIER_NEURALNETWORK_TYPE", + ssl_model: "CLASSIFIER_NEURALNETWORK_TYPE", + augmentations: Callable | None, + aug_num: int=50, + alpha: float=0.8, + K:int=20, + percentile:int=5) -> None: + """ + Initialize the BEYOND detector. + + :param target_model: The target model to be protected + :param ssl_model: The self-supervised learning model used for feature extraction + :param augmentation: data augmentations for generating neighborhoods + :param aug_num: Number of augmentations to apply to each sample (default: 50) + :param alpha: Weight factor for combining label and representation similarities (default: 0.8) + :param K: Number of top similarities to consider (default: 20) + :param percentile: using to calculate the threshold + """ + super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.target_model = target_model.to(self.device) + self.ssl_model = ssl_model.to(self.device) + self.aug_num = aug_num + self.alpha = alpha + self.K = K + + self.backbone = ssl_model.backbone + self.classifier = ssl_model.classifier + self.projector = ssl_model.projector + + self.img_augmentations = augmentations + + self.percentile = percentile # determinate the threshold + self.threshold = None + + + + def _multi_transform(self, img: torch.Tensor) -> torch.Tensor: + return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1) + + def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]: + """ + Calculate similarities that combining label consistency and representation similarity for given samples + + :param x: Input samples + :param batch_size: Batch size for processing + :return: A report similarities + """ + samples = torch.from_numpy(x).to(self.device) + + self.target_model.eval() + self.backbone.eval() + self.classifier.eval() + self.projector.eval() + + number_batch = int(math.ceil(len(samples) / batch_size)) + + similarities = [] + + with torch.no_grad(): + for index in range(number_batch): + start = index * batch_size + end = min((index + 1) * batch_size, len(samples)) + + batch_samples = samples[start:end] + b, c, h, w = batch_samples.shape + + trans_images = self._multi_transform(batch_samples).to(self.device) + ssl_backbone_out = self.backbone(batch_samples) + + ssl_repre = self.projector(ssl_backbone_out) + ssl_pred = self.classifier(ssl_backbone_out) + ssl_label = torch.max(ssl_pred, -1)[1] + + aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w)) + aug_repre = self.projector(aug_backbone_out) + aug_pred = self.classifier(aug_backbone_out) + aug_pred = aug_pred.reshape(b, self.aug_num, -1) + + sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2) + sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2) + + similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy()) + + similarities = np.concatenate(similarities, axis=0) + + return similarities + + + def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None: + """ + Determine a threshold that covers 95% of clean samples. + + :param x: Clean sample data + :param y: Clean sample labels (not used in this method) + :param batch_size: Batch size for processing + :param nb_epochs: Number of training epochs (not used in this method) + """ + k_minus_one_metrics = clean_metrics[:, self.K-1] + + self.threshold = np.percentile(k_minus_one_metrics, self.threshold) + + def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]: + """ + Detect whether given samples are adversarial + + :param x: Input samples + :param batch_size: Batch size for processing + :return: (report, is_adversarial): + where report containing detection results + where is_adversarial is a boolean list indicating whether samples are adversarial or not + """ + if self.threshold is None: + raise ValueError("Detector has not been fitted. Call fit() before detect().") + + similarities = self._get_metrics(x, batch_size) + + report = similarities[:, self.K-1] + is_adversarial = report < self.threshold + + return report, is_adversarial diff --git a/tests/defences/detector/evasion/test_beyond_detector.py b/tests/defences/detector/evasion/test_beyond_detector.py new file mode 100644 index 0000000000..b4f2ced48d --- /dev/null +++ b/tests/defences/detector/evasion/test_beyond_detector.py @@ -0,0 +1,157 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from __future__ import absolute_import, division, print_function, unicode_literals + +import pytest +import numpy as np + +import torch.nn as nn +from torchvision import models, transforms + +from art.attacks.evasion.fast_gradient import FastGradientMethod +from art.defences.detector.evasion import BeyondDetector + +from tests.utils import ARTTestException + + +class SimSiamWithCls(nn.Module): + ''' + SimSiam with Classifier + ''' + def __init__(self, arch='resnet18', feat_dim=2048, num_proj_layers=2): + + super(SimSiamWithCls, self).__init__() + self.backbone = models.resnet18() + out_dim = self.backbone.fc.weight.shape[1] + self.backbone.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=2, bias=False + ) + self.backbone.maxpool = nn.Identity() + self.backbone.fc = nn.Identity() + self.classifier = nn.Linear(out_dim, 10) + + pred_hidden_dim = int(feat_dim / 4) + + self.projector = nn.Sequential( + nn.Linear(out_dim, feat_dim, bias=False), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + nn.Linear(feat_dim, feat_dim, bias=False), + nn.BatchNorm1d(feat_dim), + nn.ReLU(), + nn.Linear(feat_dim, feat_dim), + nn.BatchNorm1d(feat_dim, affine=False), + ) + self.projector[6].bias.requires_grad = False + + self.predictor = nn.Sequential( + nn.Linear(feat_dim, pred_hidden_dim, bias=False), + nn.BatchNorm1d(pred_hidden_dim), + nn.ReLU(), + nn.Linear(pred_hidden_dim, feat_dim), + ) + + def forward(self, img, im_aug1=None, im_aug2=None): + + r_ori = self.backbone(img) + if im_aug1 is None and im_aug2 is None: + cls = self.classifier(r_ori) + rep = self.projector(r_ori) + return {'cls': cls, 'rep':rep} + else: + + r1 = self.backbone(im_aug1) + r2 = self.backbone(im_aug2) + + z1 = self.projector(r1) + z2 = self.projector(r2) + + p1 = self.predictor(z1) + p2 = self.predictor(z2) + + return {'z1': z1, 'z2': z2, 'p1': p1, 'p2': p2} + + +@pytest.fixture +def get_ssl_model(weights_path): + """ + Loads the SSL model (SimSiamWithCls). + """ + model = SimSiamWithCls() + model.load_state_dict(torch.load(weights_path)) + return model + +@pytest.mark.only_with_platform("pytorch") +def test_beyond_detector(art_warning, load_cifar10, get_ssl_model): + try: + # Load CIFAR10 data + (x_train, y_train), (x_test, y_test), _, _ = load_cifar10() + + # Load models + # Download pretrained weights from https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing + target_model = models.resnet18() + target_model.load_state_dict(torch.load("./resnet_c10.pth")) + ssl_model = get_ssl_model() + ssl_model.load_state_dict(torch.load("./simsiam_c10.pth")) + + + # Generate adversarial samples + attack = FastGradientMethod(estimator=target_model, eps=0.05) + x_test_adv = attack.generate(x_test) + + img_augmentations = transforms.Compose([ + transforms.RandomResizedCrop(32, scale=(0.2, 1.)), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + + # Initialize BeyondDetector + detector = BeyondDetector( + target_model=target_model, + ssl_model=ssl_model, + augmentations=img_augmentations, + aug_num=50, + alpha=0.8, + K=20, + percentile=5 + ) + + # Fit the detector + detector.fit(x_train, y_train, batch_size=128) + + # Apply detector on clean and adversarial test data + _, test_detection = detector.detect(x_test) + _, test_adv_detection = detector.detect(x_test_adv) + + # Assert there is at least one true positive and negative + nb_true_positives = np.sum(test_adv_detection) + nb_true_negatives = len(test_detection) - np.sum(test_detection) + + clean_accuracy = 1 - np.mean(test_detection) + adv_accuracy = np.mean(test_adv_detection) + + except ARTTestException as e: + art_warning(e) + +if __name__ == "__main__": + + test_beyond_detector() \ No newline at end of file