From 934872ab5c52bbe4609cde1ee3d18c010aa1e130 Mon Sep 17 00:00:00 2001 From: Beat Buesser Date: Mon, 30 Sep 2024 22:26:03 +0200 Subject: [PATCH] Apply package.version.parse Signed-off-by: Beat Buesser --- .../evasion/adversarial_patch/adversarial_patch_pytorch.py | 5 +++-- art/attacks/evasion/pixel_threshold.py | 3 ++- art/estimators/object_detection/pytorch_object_detector.py | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py index 891e872481..3d314b6c83 100644 --- a/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py +++ b/art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py @@ -25,6 +25,7 @@ import logging import math +from packaging.version import parse from typing import Any, TYPE_CHECKING import numpy as np @@ -121,8 +122,8 @@ def __init__( import torch import torchvision - torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split("."))) - torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split("."))) + torch_version = list(parse(torch.__version__.lower()).release) + torchvision_version = list(parse(torchvision.__version__.lower()).release) assert ( torch_version[0] >= 1 and torch_version[1] >= 7 or (torch_version[0] >= 2) ), "AdversarialPatchPyTorch requires torch>=1.7.0" diff --git a/art/attacks/evasion/pixel_threshold.py b/art/attacks/evasion/pixel_threshold.py index d994dcee4c..820af9151e 100644 --- a/art/attacks/evasion/pixel_threshold.py +++ b/art/attacks/evasion/pixel_threshold.py @@ -27,6 +27,7 @@ import logging from itertools import product +from packaging.version import parse from typing import TYPE_CHECKING import numpy as np @@ -42,7 +43,7 @@ import scipy from scipy._lib._util import check_random_state -scipy_version = list(map(int, scipy.__version__.lower().split("."))) +scipy_version = list(parse(scipy.__version__.lower()).release) if scipy_version[1] >= 8: from scipy.optimize._optimize import _status_message else: diff --git a/art/estimators/object_detection/pytorch_object_detector.py b/art/estimators/object_detection/pytorch_object_detector.py index 49bb14c15d..6f9ccd05c0 100644 --- a/art/estimators/object_detection/pytorch_object_detector.py +++ b/art/estimators/object_detection/pytorch_object_detector.py @@ -21,6 +21,7 @@ from __future__ import annotations import logging +from packaging.version import parse from typing import Any, TYPE_CHECKING import numpy as np @@ -96,8 +97,8 @@ def __init__( import torch import torchvision - torch_version = list(map(int, torch.__version__.lower().split("+", maxsplit=1)[0].split("."))) - torchvision_version = list(map(int, torchvision.__version__.lower().split("+", maxsplit=1)[0].split("."))) + torch_version = list(parse(torch.__version__.lower()).release) + torchvision_version = list(parse(torchvision.__version__.lower()).release) assert not (torch_version[0] == 1 and (torch_version[1] == 8 or torch_version[1] == 9)), ( "PyTorchObjectDetector does not support torch==1.8 and torch==1.9 because of " "https://github.com/pytorch/vision/issues/4153. Support will return for torch==1.10."