Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for BEYOND [ICML-2024] #2489

Closed
wants to merge 16 commits into from
Closed
2 changes: 2 additions & 0 deletions art/defences/detector/evasion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
from art.defences.detector.evasion.binary_activation_detector import BinaryActivationDetector
from art.defences.detector.evasion.subsetscanning.detector import SubsetScanningDetector
from art.defences.detector.evasion.beyond_detector import BeyondDetector

167 changes: 167 additions & 0 deletions art/defences/detector/evasion/beyond_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements the BEYOND detector for adversarial examples detection.

| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
"""
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from art.defences.detector.evasion.evasion_detector import EvasionDetector

if TYPE_CHECKING:
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE


class BeyondDetector(EvasionDetector):
"""
BEYOND detector for adversarial samples detection.
This detector uses a combination of SSL and target model predictions to detect adversarial examples.

| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
"""

defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"]

def __init__(self,
target_model: "CLASSIFIER_NEURALNETWORK_TYPE",
ssl_model: "CLASSIFIER_NEURALNETWORK_TYPE",
augmentations: Callable | None,
aug_num: int=50,
alpha: float=0.8,
K:int=20,
percentile:int=5) -> None:
"""
Initialize the BEYOND detector.

:param target_model: The target model to be protected
:param ssl_model: The self-supervised learning model used for feature extraction
:param augmentation: data augmentations for generating neighborhoods
:param aug_num: Number of augmentations to apply to each sample (default: 50)
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
:param K: Number of top similarities to consider (default: 20)
:param percentile: using to calculate the threshold
"""
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.target_model = target_model.to(self.device)
self.ssl_model = ssl_model.to(self.device)
self.aug_num = aug_num
self.alpha = alpha
self.K = K

self.backbone = ssl_model.backbone
self.classifier = ssl_model.classifier
self.projector = ssl_model.projector

self.img_augmentations = augmentations

self.percentile = percentile # determinate the threshold
self.threshold = None



def _multi_transform(self, img: torch.Tensor) -> torch.Tensor:
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1)

def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]:
"""
Calculate similarities that combining label consistency and representation similarity for given samples

:param x: Input samples
:param batch_size: Batch size for processing
:return: A report similarities
"""
samples = torch.from_numpy(x).to(self.device)

self.target_model.eval()
self.backbone.eval()
self.classifier.eval()
self.projector.eval()

number_batch = int(math.ceil(len(samples) / batch_size))

similarities = []

with torch.no_grad():
for index in range(number_batch):
start = index * batch_size
end = min((index + 1) * batch_size, len(samples))

batch_samples = samples[start:end]
b, c, h, w = batch_samples.shape

trans_images = self._multi_transform(batch_samples).to(self.device)
ssl_backbone_out = self.backbone(batch_samples)

ssl_repre = self.projector(ssl_backbone_out)
ssl_pred = self.classifier(ssl_backbone_out)
ssl_label = torch.max(ssl_pred, -1)[1]

aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w))
aug_repre = self.projector(aug_backbone_out)
aug_pred = self.classifier(aug_backbone_out)
aug_pred = aug_pred.reshape(b, self.aug_num, -1)

sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2)
sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2)

similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy())

similarities = np.concatenate(similarities, axis=0)

return similarities


def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
"""
Determine a threshold that covers 95% of clean samples.

:param x: Clean sample data
:param y: Clean sample labels (not used in this method)
:param batch_size: Batch size for processing
:param nb_epochs: Number of training epochs (not used in this method)
"""
k_minus_one_metrics = clean_metrics[:, self.K-1]

self.threshold = np.percentile(k_minus_one_metrics, self.threshold)

def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]:
"""
Detect whether given samples are adversarial

:param x: Input samples
:param batch_size: Batch size for processing
:return: (report, is_adversarial):
where report containing detection results
where is_adversarial is a boolean list indicating whether samples are adversarial or not
"""
if self.threshold is None:
raise ValueError("Detector has not been fitted. Call fit() before detect().")

similarities = self._get_metrics(x, batch_size)

report = similarities[:, self.K-1]
is_adversarial = report < self.threshold

return report, is_adversarial
157 changes: 157 additions & 0 deletions tests/defences/detector/evasion/test_beyond_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import, division, print_function, unicode_literals

import pytest
import numpy as np

import torch.nn as nn
from torchvision import models, transforms

from art.attacks.evasion.fast_gradient import FastGradientMethod
from art.defences.detector.evasion import BeyondDetector

from tests.utils import ARTTestException


class SimSiamWithCls(nn.Module):
'''
SimSiam with Classifier
'''
def __init__(self, arch='resnet18', feat_dim=2048, num_proj_layers=2):

super(SimSiamWithCls, self).__init__()
self.backbone = models.resnet18()
out_dim = self.backbone.fc.weight.shape[1]
self.backbone.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=2, bias=False
)
self.backbone.maxpool = nn.Identity()
self.backbone.fc = nn.Identity()
self.classifier = nn.Linear(out_dim, 10)

pred_hidden_dim = int(feat_dim / 4)

self.projector = nn.Sequential(
nn.Linear(out_dim, feat_dim, bias=False),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, feat_dim, bias=False),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, feat_dim),
nn.BatchNorm1d(feat_dim, affine=False),
)
self.projector[6].bias.requires_grad = False

self.predictor = nn.Sequential(
nn.Linear(feat_dim, pred_hidden_dim, bias=False),
nn.BatchNorm1d(pred_hidden_dim),
nn.ReLU(),
nn.Linear(pred_hidden_dim, feat_dim),
)

def forward(self, img, im_aug1=None, im_aug2=None):

r_ori = self.backbone(img)
if im_aug1 is None and im_aug2 is None:
cls = self.classifier(r_ori)
rep = self.projector(r_ori)
return {'cls': cls, 'rep':rep}
else:

r1 = self.backbone(im_aug1)
r2 = self.backbone(im_aug2)

z1 = self.projector(r1)
z2 = self.projector(r2)

p1 = self.predictor(z1)
p2 = self.predictor(z2)

return {'z1': z1, 'z2': z2, 'p1': p1, 'p2': p2}


@pytest.fixture
def get_ssl_model(weights_path):
"""
Loads the SSL model (SimSiamWithCls).
"""
model = SimSiamWithCls()
model.load_state_dict(torch.load(weights_path))
return model

@pytest.mark.only_with_platform("pytorch")
def test_beyond_detector(art_warning, load_cifar10, get_ssl_model):
try:
# Load CIFAR10 data
(x_train, y_train), (x_test, y_test), _, _ = load_cifar10()

# Load models
# Download pretrained weights from https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How large are the downloaded files? Can we store them in the ART repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pre-trained model has over 100 M

target_model = models.resnet18()
target_model.load_state_dict(torch.load("./resnet_c10.pth"))
ssl_model = get_ssl_model()
ssl_model.load_state_dict(torch.load("./simsiam_c10.pth"))


# Generate adversarial samples
attack = FastGradientMethod(estimator=target_model, eps=0.05)
x_test_adv = attack.generate(x_test)

img_augmentations = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Initialize BeyondDetector
detector = BeyondDetector(
target_model=target_model,
ssl_model=ssl_model,
augmentations=img_augmentations,
aug_num=50,
alpha=0.8,
K=20,
percentile=5
)

# Fit the detector
detector.fit(x_train, y_train, batch_size=128)

# Apply detector on clean and adversarial test data
_, test_detection = detector.detect(x_test)
_, test_adv_detection = detector.detect(x_test_adv)

# Assert there is at least one true positive and negative
nb_true_positives = np.sum(test_adv_detection)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nb_true_positives is not used.
nb_true_negatives = len(test_detection) - np.sum(test_detection)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable nb_true_negatives is not used.

clean_accuracy = 1 - np.mean(test_detection)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable clean_accuracy is not used.
adv_accuracy = np.mean(test_adv_detection)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable adv_accuracy is not used.

except ARTTestException as e:
art_warning(e)

if __name__ == "__main__":

test_beyond_detector()