-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'allenhzy-main' into dev_1.19.0
- Loading branch information
Showing
6 changed files
with
368 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# MIT License | ||
# | ||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||
# Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
""" | ||
This module implements the BEYOND detector for adversarial examples detection. | ||
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 | ||
""" | ||
from __future__ import annotations | ||
|
||
import math | ||
from typing import TYPE_CHECKING, Callable | ||
|
||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
import torch | ||
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE | ||
|
||
|
||
from art.defences.detector.evasion.evasion_detector import EvasionDetector | ||
|
||
|
||
class BeyondDetectorPyTorch(EvasionDetector): | ||
""" | ||
BEYOND detector for adversarial samples detection. | ||
This detector uses a combination of SSL and target model predictions to detect adversarial examples. | ||
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3 | ||
""" | ||
|
||
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"] | ||
|
||
def __init__( | ||
self, | ||
target_classifier: "CLASSIFIER_NEURALNETWORK_TYPE", | ||
ssl_classifier: "CLASSIFIER_NEURALNETWORK_TYPE", | ||
augmentations: Callable | None, | ||
aug_num: int = 50, | ||
alpha: float = 0.8, | ||
K: int = 20, | ||
percentile: int = 5, | ||
) -> None: | ||
""" | ||
Initialize the BEYOND detector. | ||
:param target_classifier: The target model to be protected | ||
:param ssl_classifier: The self-supervised learning model used for feature extraction | ||
:param augmentations: data augmentations for generating neighborhoods | ||
:param aug_num: Number of augmentations to apply to each sample (default: 50) | ||
:param alpha: Weight factor for combining label and representation similarities (default: 0.8) | ||
:param K: Number of top similarities to consider (default: 20) | ||
:param percentile: using to calculate the threshold | ||
""" | ||
import torch | ||
|
||
super().__init__() | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
self.target_model = target_classifier.model.to(self.device) | ||
self.ssl_model = ssl_classifier.model.to(self.device) | ||
self.aug_num = aug_num | ||
self.alpha = alpha | ||
self.K = K | ||
|
||
self.backbone = self.ssl_model.backbone | ||
self.model_classifier = self.ssl_model.classifier | ||
self.projector = self.ssl_model.projector | ||
|
||
self.img_augmentations = augmentations | ||
|
||
self.percentile = percentile # determine the threshold | ||
self.threshold: float | None = None | ||
|
||
def _multi_transform(self, img: "torch.Tensor") -> "torch.Tensor": | ||
import torch | ||
|
||
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1) | ||
|
||
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> np.ndarray: | ||
""" | ||
Calculate similarities that combining label consistency and representation similarity for given samples | ||
:param x: Input samples | ||
:param batch_size: Batch size for processing | ||
:return: A report similarities | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
samples = torch.from_numpy(x).to(self.device) | ||
|
||
self.target_model.eval() | ||
self.backbone.eval() | ||
self.model_classifier.eval() | ||
self.projector.eval() | ||
|
||
number_batch = int(math.ceil(len(samples) / batch_size)) | ||
|
||
similarities = [] | ||
|
||
with torch.no_grad(): | ||
for index in range(number_batch): | ||
start = index * batch_size | ||
end = min((index + 1) * batch_size, len(samples)) | ||
|
||
batch_samples = samples[start:end] | ||
b, c, h, w = batch_samples.shape | ||
|
||
trans_images = self._multi_transform(batch_samples).to(self.device) | ||
ssl_backbone_out = self.backbone(batch_samples) | ||
|
||
ssl_repre = self.projector(ssl_backbone_out) | ||
ssl_pred = self.model_classifier(ssl_backbone_out) | ||
ssl_label = torch.max(ssl_pred, -1)[1] | ||
|
||
aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w)) | ||
aug_repre = self.projector(aug_backbone_out) | ||
aug_pred = self.model_classifier(aug_backbone_out) | ||
aug_pred = aug_pred.reshape(b, self.aug_num, -1) | ||
|
||
sim_repre = F.cosine_similarity( | ||
ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2 | ||
) | ||
|
||
sim_preds = F.cosine_similarity( | ||
F.one_hot(ssl_label, num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), | ||
aug_pred, | ||
dim=2, | ||
) | ||
|
||
similarities.append( | ||
(self.alpha * sim_preds + (1 - self.alpha) * sim_repre).sort(descending=True)[0].cpu().numpy() | ||
) | ||
|
||
similarities = np.concatenate(similarities, axis=0) | ||
|
||
return similarities | ||
|
||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None: | ||
""" | ||
Determine a threshold that covers 95% of clean samples. | ||
:param x: Clean sample data | ||
:param y: Clean sample labels (not used in this method) | ||
:param batch_size: Batch size for processing | ||
:param nb_epochs: Number of training epochs (not used in this method) | ||
""" | ||
clean_metrics = self._get_metrics(x=x, batch_size=batch_size) | ||
k_minus_one_metrics = clean_metrics[:, self.K - 1] | ||
self.threshold = np.percentile(k_minus_one_metrics, q=self.percentile) | ||
|
||
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]: | ||
""" | ||
Detect whether given samples are adversarial | ||
:param x: Input samples | ||
:param batch_size: Batch size for processing | ||
:return: (report, is_adversarial): | ||
where report containing detection results | ||
where is_adversarial is a boolean list indicating whether samples are adversarial or not | ||
""" | ||
if self.threshold is None: | ||
raise ValueError("Detector has not been fitted. Call fit() before detect().") | ||
|
||
similarities = self._get_metrics(x, batch_size) | ||
|
||
report = similarities[:, self.K - 1] | ||
is_adversarial = report < self.threshold | ||
|
||
return report, is_adversarial |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
178 changes: 178 additions & 0 deletions
178
tests/defences/detector/evasion/test_beyond_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# MIT License | ||
# | ||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||
# Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import pytest | ||
import numpy as np | ||
|
||
from art.attacks.evasion.fast_gradient import FastGradientMethod | ||
from art.defences.detector.evasion import BeyondDetectorPyTorch | ||
from art.estimators.classification import PyTorchClassifier | ||
from tests.utils import ARTTestException | ||
|
||
|
||
def get_ssl_model(weights_path): | ||
""" | ||
Loads the SSL model (SimSiamWithCls). | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
|
||
class SimSiamWithCls(nn.Module): | ||
""" | ||
SimSiam with Classifier | ||
""" | ||
|
||
def __init__(self, arch="resnet18", feat_dim=2048, num_proj_layers=2): | ||
from torchvision import models | ||
|
||
super(SimSiamWithCls, self).__init__() | ||
self.backbone = models.resnet18() | ||
out_dim = self.backbone.fc.weight.shape[1] | ||
self.backbone.conv1 = nn.Conv2d( | ||
in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=2, bias=False | ||
) | ||
self.backbone.maxpool = nn.Identity() | ||
self.backbone.fc = nn.Identity() | ||
self.classifier = nn.Linear(out_dim, out_features=10) | ||
|
||
pred_hidden_dim = int(feat_dim / 4) | ||
|
||
self.projector = nn.Sequential( | ||
nn.Linear(out_dim, feat_dim, bias=False), | ||
nn.BatchNorm1d(feat_dim), | ||
nn.ReLU(), | ||
nn.Linear(feat_dim, feat_dim, bias=False), | ||
nn.BatchNorm1d(feat_dim), | ||
nn.ReLU(), | ||
nn.Linear(feat_dim, feat_dim), | ||
nn.BatchNorm1d(feat_dim, affine=False), | ||
) | ||
self.projector[6].bias.requires_grad = False | ||
|
||
self.predictor = nn.Sequential( | ||
nn.Linear(feat_dim, pred_hidden_dim, bias=False), | ||
nn.BatchNorm1d(pred_hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(pred_hidden_dim, feat_dim), | ||
) | ||
|
||
def forward(self, img, im_aug1=None, im_aug2=None): | ||
|
||
r_ori = self.backbone(img) | ||
if im_aug1 is None and im_aug2 is None: | ||
cls = self.classifier(r_ori) | ||
rep = self.projector(r_ori) | ||
return {"cls": cls, "rep": rep} | ||
else: | ||
|
||
r1 = self.backbone(im_aug1) | ||
r2 = self.backbone(im_aug2) | ||
|
||
z1 = self.projector(r1) | ||
z2 = self.projector(r2) | ||
|
||
p1 = self.predictor(z1) | ||
p2 = self.predictor(z2) | ||
|
||
return {"z1": z1, "z2": z2, "p1": p1, "p2": p2} | ||
|
||
model = SimSiamWithCls() | ||
model.load_state_dict(torch.load(weights_path)) | ||
return model | ||
|
||
|
||
@pytest.mark.only_with_platform("pytorch") | ||
def test_beyond_detector(art_warning, get_default_cifar10_subset): | ||
try: | ||
import torch | ||
from torchvision import models, transforms | ||
|
||
# Load CIFAR10 data | ||
(x_train, y_train), (x_test, _) = get_default_cifar10_subset | ||
|
||
x_train = x_train[0:100] | ||
y_train = y_train[0:100] | ||
x_test = x_test[0:100] | ||
|
||
# Load models | ||
# Download pretrained weights from | ||
# https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing | ||
target_model = models.resnet18() | ||
# target_model.load_state_dict(torch.load("../../../../utils/resources/models/resnet_c10.pth", map_location=torch.device('cpu'))) | ||
ssl_model = get_ssl_model(weights_path="../../../../utils/resources/models/simsiam_c10.pth") | ||
|
||
target_classifier = PyTorchClassifier( | ||
model=target_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss() | ||
) | ||
ssl_classifier = PyTorchClassifier( | ||
model=ssl_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss() | ||
) | ||
|
||
# Generate adversarial samples | ||
attack = FastGradientMethod(estimator=target_classifier, eps=0.05) | ||
x_test_adv = attack.generate(x_test) | ||
|
||
img_augmentations = transforms.Compose( | ||
[ | ||
transforms.RandomResizedCrop(32, scale=(0.2, 1.0)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), # not strengthened | ||
transforms.RandomGrayscale(p=0.2), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
] | ||
) | ||
|
||
# Initialize BeyondDetector | ||
detector = BeyondDetectorPyTorch( | ||
target_classifier=target_classifier, | ||
ssl_classifier=ssl_classifier, | ||
augmentations=img_augmentations, | ||
aug_num=50, | ||
alpha=0.8, | ||
K=20, | ||
percentile=5, | ||
) | ||
|
||
# Fit the detector | ||
detector.fit(x_train, y_train, batch_size=128) | ||
|
||
# Apply detector on clean and adversarial test data | ||
_, test_detection = detector.detect(x_test) | ||
_, test_adv_detection = detector.detect(x_test_adv) | ||
|
||
# Assert there is at least one true positive and negative | ||
nb_true_positives = np.sum(test_adv_detection) | ||
nb_true_negatives = len(test_detection) - np.sum(test_detection) | ||
|
||
assert nb_true_positives > 0 | ||
assert nb_true_negatives > 0 | ||
|
||
clean_accuracy = 1 - np.mean(test_detection) | ||
adv_accuracy = np.mean(test_adv_detection) | ||
|
||
assert clean_accuracy > 0.0 | ||
assert adv_accuracy > 0.0 | ||
|
||
except ARTTestException as e: | ||
art_warning(e) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
test_beyond_detector() |
Binary file not shown.
Binary file not shown.