From baeec58d1c822b87dc6c9a8a8e554e04ebcc4852 Mon Sep 17 00:00:00 2001 From: Farhan Ahmed Date: Wed, 8 Nov 2023 16:53:05 -0800 Subject: [PATCH] fix style checks --- .../object_detection/pytorch_detection_transformer.py | 6 ++++-- .../object_detection/pytorch_object_detector.py | 8 ++++---- art/estimators/object_detection/pytorch_yolo.py | 6 ++++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/art/estimators/object_detection/pytorch_detection_transformer.py b/art/estimators/object_detection/pytorch_detection_transformer.py index be7672350c..91ffbd3b1e 100644 --- a/art/estimators/object_detection/pytorch_detection_transformer.py +++ b/art/estimators/object_detection/pytorch_detection_transformer.py @@ -23,6 +23,8 @@ import logging from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +import numpy as np + from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector if TYPE_CHECKING: @@ -161,7 +163,7 @@ def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> List[Any return labels_translated - def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List[Dict[str, "torch.Tensor"]]: + def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List[Dict[str, np.ndarray]]: """ Translate object detection predictions from the model format (DETR) to ART format (torchvision) and convert tensors to numpy arrays. @@ -181,7 +183,7 @@ def _translate_predictions(self, predictions: Dict[str, "torch.Tensor"]) -> List pred_boxes = predictions["pred_boxes"] pred_logits = predictions["pred_logits"] - predictions_x1y1x2y2 = [] + predictions_x1y1x2y2: List[Dict[str, np.ndarray]] = [] for pred_box, pred_logit in zip(pred_boxes, pred_logits): boxes = rescale_bboxes(pred_box.detach().cpu(), (height, width)).numpy() diff --git a/art/estimators/object_detection/pytorch_object_detector.py b/art/estimators/object_detection/pytorch_object_detector.py index bdaea4dbcb..b2cf3027ae 100644 --- a/art/estimators/object_detection/pytorch_object_detector.py +++ b/art/estimators/object_detection/pytorch_object_detector.py @@ -120,8 +120,8 @@ def __init__( self._attack_losses = attack_losses # Parameters used for subclasses - self.weight_dict = None - self.criterion = None + self.weight_dict: Optional[Dict[str, float]] = None + self.criterion: Optional[torch.nn.Module] = None if self.clip_values is not None: if self.clip_values[0] != 0: @@ -577,6 +577,6 @@ def compute_loss( # type: ignore ) if isinstance(x, torch.Tensor): - return loss + return loss # type: ignore - return loss.detach().cpu().numpy() + return loss.detach().cpu().numpy() # type: ignore diff --git a/art/estimators/object_detection/pytorch_yolo.py b/art/estimators/object_detection/pytorch_yolo.py index 10b7ec0e11..976d601465 100644 --- a/art/estimators/object_detection/pytorch_yolo.py +++ b/art/estimators/object_detection/pytorch_yolo.py @@ -23,6 +23,8 @@ import logging from typing import List, Dict, Optional, Tuple, Union, TYPE_CHECKING +import numpy as np + from art.estimators.object_detection.pytorch_object_detector import PyTorchObjectDetector if TYPE_CHECKING: @@ -142,7 +144,7 @@ def _translate_labels(self, labels: List[Dict[str, "torch.Tensor"]]) -> "torch.T labels_xcycwh = torch.vstack(labels_xcycwh_list) return labels_xcycwh - def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str, "torch.Tensor"]]: + def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str, np.ndarray]]: """ Translate object detection predictions from the model format (YOLO) to ART format (torchvision) and convert tensors to numpy arrays. @@ -159,7 +161,7 @@ def _translate_predictions(self, predictions: "torch.Tensor") -> List[Dict[str, height = self.input_shape[0] width = self.input_shape[1] - predictions_x1y1x2y2 = [] + predictions_x1y1x2y2: List[Dict[str, np.ndarray]] = [] for pred in predictions: boxes = torch.vstack(