Skip to content

Commit

Permalink
Add a flag to be used for marking the YOLOv8 model
Browse files Browse the repository at this point in the history
  • Loading branch information
CNOCycle committed Sep 14, 2024
1 parent 693e545 commit d27f500
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
11 changes: 7 additions & 4 deletions art/estimators/object_detection/pytorch_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
"loss_rpn_box_reg",
),
device_type: str = "gpu",
is_yolov8: bool = False,
):
"""
Initialization.
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
'loss_objectness', and 'loss_rpn_box_reg'.
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
if available otherwise run on CPU.
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
"""
import re
import torch
Expand Down Expand Up @@ -140,9 +142,10 @@ def __init__(

self._model: torch.nn.Module
self._model.to(self._device)
try:
self.is_yolov8 = is_yolov8
if self.is_yolov8:
self._model.model.eval()
except AttributeError:
else:
self._model.eval()

@property
Expand Down Expand Up @@ -406,9 +409,9 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> list[dict[s
from torch.utils.data import TensorDataset, DataLoader

# Set model to evaluation mode
try:
if self.is_yolov8:
self._model.model.eval()

Check warning on line 413 in art/estimators/object_detection/pytorch_object_detector.py

View check run for this annotation

Codecov / codecov/patch

art/estimators/object_detection/pytorch_object_detector.py#L413

Added line #L413 was not covered by tests
except AttributeError:
else:
self._model.eval()

# Apply preprocessing and convert to tensors
Expand Down
3 changes: 3 additions & 0 deletions art/estimators/object_detection/pytorch_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
"loss_rpn_box_reg",
),
device_type: str = "gpu",
is_yolov8: bool = False,
):
"""
Initialization.
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
'loss_objectness', and 'loss_rpn_box_reg'.
:param device_type: Type of device to be used for model and tensors, if `cpu` run on CPU, if `gpu` run on GPU
if available otherwise run on CPU.
:param is_yolov8: The flag to be used for marking the YOLOv8 model.
"""
super().__init__(
model=model,
Expand All @@ -104,6 +106,7 @@ def __init__(
preprocessing=preprocessing,
attack_losses=attack_losses,
device_type=device_type,
is_yolov8=is_yolov8,
)

def _translate_labels(self, labels: list[dict[str, "torch.Tensor"]]) -> "torch.Tensor":
Expand Down
3 changes: 2 additions & 1 deletion notebooks/snal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
"model = YOLO('yolov8m')\n",
"py_model = PyTorchYolo(model=model,\n",
" input_shape=(3, 640, 640),\n",
" channels_first=True)\n",
" channels_first=True,\n",
" is_yolov8=True)\n",
"\n",
"# Define a custom function to collect patches from images\n",
"def collect_patches_from_images(model: \"torch.nn.Module\",\n",
Expand Down
4 changes: 2 additions & 2 deletions tests/attacks/evasion/test_steal_now_attack_later.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_generate(art_warning):
import requests

model = YOLO("yolov8m")
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)

# Define a custom function to collect patches from images
def collect_patches_from_images(model, imgs):
Expand Down Expand Up @@ -192,7 +192,7 @@ def _loader(self, path):
def test_check_params(art_warning):
try:
model = YOLO("yolov8m")
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True, is_yolov8=True)

def dummy_func(model, imags):
candidates_patch = []
Expand Down

0 comments on commit d27f500

Please sign in to comment.