diff --git a/.github/actions/deepspeech-v2/Dockerfile b/.github/actions/deepspeech-v2/Dockerfile deleted file mode 100644 index a9b2c13810..0000000000 --- a/.github/actions/deepspeech-v2/Dockerfile +++ /dev/null @@ -1,45 +0,0 @@ -# Get base from a pytorch image -FROM pytorch/pytorch:1.5.1-cuda10.1-cudnn7-runtime - -# Set to install things in non-interactive mode -ENV DEBIAN_FRONTEND noninteractive - -# Install system wide softwares -RUN apt-get update \ - && apt-get install -y \ - libgl1-mesa-glx \ - libx11-xcb1 \ - git \ - gcc \ - mono-mcs \ - cmake \ - libavcodec-extra \ - ffmpeg \ - curl \ - && apt-get clean all \ - && rm -r /var/lib/apt/lists/* - -RUN /opt/conda/bin/conda install --yes \ - astropy \ - matplotlib \ - pandas \ - scikit-learn \ - scikit-image - -# Install necessary libraries for deepspeech v2 -RUN pip install torch -RUN pip install tensorflow -RUN pip install torchaudio==0.5.1 - -RUN git clone https://github.com/SeanNaren/warp-ctc.git -RUN cd warp-ctc && mkdir build && cd build && cmake .. && make -RUN cd warp-ctc/pytorch_binding && python setup.py install - -RUN git clone https://github.com/SeanNaren/deepspeech.pytorch.git -RUN cd deepspeech.pytorch && git checkout V2.1 -RUN cd deepspeech.pytorch && pip install -r requirements.txt -RUN cd deepspeech.pytorch && pip install -e . - -RUN pip install numba==0.50.0 -RUN pip install pytest-cov -RUN pip install pydub==0.25.1 diff --git a/.github/actions/deepspeech-v2/action.yml b/.github/actions/deepspeech-v2/action.yml deleted file mode 100644 index fbed446b8b..0000000000 --- a/.github/actions/deepspeech-v2/action.yml +++ /dev/null @@ -1,7 +0,0 @@ -name: 'Test DeepSpeech v2' -description: 'Run tests for DeepSpeech v2' -runs: - using: 'composite' - steps: - - run: $GITHUB_ACTION_PATH/run.sh - shell: bash diff --git a/.github/actions/deepspeech-v2/run.sh b/.github/actions/deepspeech-v2/run.sh deleted file mode 100755 index e8bf57f2e9..0000000000 --- a/.github/actions/deepspeech-v2/run.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -exit_code=0 - -pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/speech_recognition/test_pytorch_deep_speech.py --framework=pytorch --durations=0 -if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed estimators/speech_recognition/test_pytorch_deep_speech tests"; fi -pytest --cov-report=xml --cov=art --cov-append -q -vv tests/attacks/evasion/test_imperceptible_asr_pytorch.py --framework=pytorch --durations=0 -if [[ $? -ne 0 ]]; then exit_code=1; echo "Failed attacks/evasion/test_imperceptible_asr_pytorch tests"; fi - -exit ${exit_code} diff --git a/.github/actions/deepspeech-v3/Dockerfile b/.github/actions/deepspeech-v3/Dockerfile index 2b83524703..89ecadb38e 100644 --- a/.github/actions/deepspeech-v3/Dockerfile +++ b/.github/actions/deepspeech-v3/Dockerfile @@ -1,5 +1,5 @@ -# Get base from a pytorch image -FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime +pod# Get base from a pytorch image +FROM pytorch/pytorch:2.1.1-cuda12.1-cudnn8-runtime # Set to install things in non-interactive mode ENV DEBIAN_FRONTEND noninteractive @@ -17,26 +17,19 @@ RUN apt-get update \ curl \ libsndfile-dev \ libsndfile1 \ + vim \ + curl \ && apt-get clean all \ && rm -r /var/lib/apt/lists/* -RUN /opt/conda/bin/conda install --yes \ - astropy \ - matplotlib \ - pandas \ - scikit-learn \ - scikit-image - # Install necessary libraries for deepspeech v3 -RUN pip install torch -RUN pip install tensorflow -RUN pip install torchaudio==0.6.0 -RUN pip install --no-build-isolation fairscale +RUN pip install --ignore-installed PyYAML torch==2.1.1 tensorflow==2.14.1 torchaudio==2.1.1 pytorch-lightning==2.1.2 scikit-learn==1.3.2 +RUN pip install --no-build-isolation fairscale==0.4.13 RUN git clone https://github.com/SeanNaren/deepspeech.pytorch.git -RUN cd deepspeech.pytorch && pip install -r requirements.txt -RUN cd deepspeech.pytorch && pip install -e . +RUN cd deepspeech.pytorch && sed -i '/^sklearn/d' requirements.txt && pip install -r requirements.txt && pip install -e . + +RUN pip install numba==0.56.4 pytest-cov==4.1.0 pydub==0.25.1 +RUN pip list -RUN pip install numba==0.50.0 -RUN pip install pytest-cov -RUN pip install pydub==0.25.1 +RUN mkdir -p /root/.art/data && cd /root/.art/data && curl -LJO "https://github.com/SeanNaren/deepspeech.pytorch/releases/download/V3.0/librispeech_pretrained_v3.ckpt" diff --git a/.github/workflows/ci-deepspeech-v2.yml b/.github/workflows/ci-deepspeech-v2.yml deleted file mode 100644 index ec8a5c78e0..0000000000 --- a/.github/workflows/ci-deepspeech-v2.yml +++ /dev/null @@ -1,37 +0,0 @@ -name: CI PyTorchDeepSpeech v2 -on: - # Run on manual trigger - workflow_dispatch: - - # Run on pull requests - pull_request: - paths-ignore: - - '*.md' - - # Run on merge queue - merge_group: - - # Run when pushing to main or dev branches - push: - branches: - - main - - dev* - - # Run scheduled CI flow daily - schedule: - - cron: '0 8 * * 0' - -jobs: - test_deepspeech_v2: - name: PyTorchDeepSpeech v2 - runs-on: ubuntu-latest - container: adversarialrobustnesstoolbox/art_testing_envs:deepspeech_v2 - steps: - - name: Checkout Repo - uses: actions/checkout@v3 - - name: Run Test Action - uses: ./.github/actions/deepspeech-v2 - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - with: - fail_ci_if_error: true diff --git a/.github/workflows/ci-deepspeech-v3.yml b/.github/workflows/ci-deepspeech-v3.yml index ff2bd88666..2ea4ecefd1 100644 --- a/.github/workflows/ci-deepspeech-v3.yml +++ b/.github/workflows/ci-deepspeech-v3.yml @@ -23,9 +23,9 @@ on: jobs: test_deepspeech_v3_torch_1_10: - name: PyTorchDeepSpeech v3 / PyTorch 1.10 + name: PyTorchDeepSpeech v3 / PyTorch 2.1.1 runs-on: ubuntu-latest - container: adversarialrobustnesstoolbox/art_testing_envs:deepspeech_v3_torch_1_10 + container: adversarialrobustnesstoolbox/art_testing_envs:deepspeech_v3_torch_2_1_1 steps: - name: Checkout Repo uses: actions/checkout@v3 diff --git a/.github/workflows/ci-pytorch.yml b/.github/workflows/ci-pytorch.yml index 1cea82cded..2376a0647e 100644 --- a/.github/workflows/ci-pytorch.yml +++ b/.github/workflows/ci-pytorch.yml @@ -28,24 +28,18 @@ jobs: fail-fast: false matrix: include: - - name: PyTorch 1.12.1 (Python 3.9) - framework: pytorch - python: 3.9 - torch: 1.12.1+cpu - torchvision: 0.13.1+cpu - torchaudio: 0.12.1 - - name: PyTorch 1.13.1 (Python 3.9) - framework: pytorch - python: 3.9 - torch: 1.13.1+cpu - torchvision: 0.14.1+cpu - torchaudio: 0.13.1 - name: PyTorch 1.13.1 (Python 3.10) framework: pytorch python: '3.10' torch: 1.13.1+cpu torchvision: 0.14.1+cpu torchaudio: 0.13.1 + - name: PyTorch 2.1.1 (Python 3.10) + framework: pytorch + python: '3.10' + torch: 2.1.1 + torchvision: 0.16.1+cpu + torchaudio: 2.1.1 name: ${{ matrix.name }} steps: diff --git a/.github/workflows/ci-style-checks.yml b/.github/workflows/ci-style-checks.yml index 3a9955e2b9..636dc6a7f0 100644 --- a/.github/workflows/ci-style-checks.yml +++ b/.github/workflows/ci-style-checks.yml @@ -31,7 +31,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: '3.10' - name: Pre-install run: | sudo apt-get update @@ -39,7 +39,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip setuptools wheel - pip install -q pylint==2.12.2 mypy==0.931 pycodestyle==2.8.0 black==21.12b0 + pip install -q pylint==2.12.2 mypy==1.7.1 pycodestyle==2.8.0 black==21.12b0 pip install -q -r <(sed '/^numpy/d;/^pluggy/d;/^tensorflow/d;/^keras/d' requirements_test.txt) pip install numpy==1.22.4 pip install pluggy==0.13.1 diff --git a/.github/workflows/ci-tensorflow-v1.yml b/.github/workflows/ci-tensorflow-v1.yml index ada4b4a19d..f1654814a9 100644 --- a/.github/workflows/ci-tensorflow-v1.yml +++ b/.github/workflows/ci-tensorflow-v1.yml @@ -48,7 +48,7 @@ jobs: sudo apt-get update sudo apt-get -y -q install ffmpeg libavcodec-extra python -m pip install --upgrade pip setuptools wheel - pip install -q -r <(sed '/^pandas/d;/^scipy/d;/^matplotlib/d;/^xgboost/d;/^tensorflow/d;/^keras/d;/^jax/d' requirements_test.txt) + pip install -q -r <(sed '/^pandas/d;/^scipy/d;/^matplotlib/d;/^xgboost/d;/^tensorflow/d;/^keras/d;/^jax/d;/^torch/d' requirements_test.txt) pip install pandas==1.3.5 pip install scipy==1.7.2 pip install matplotlib==3.5.3 @@ -57,6 +57,9 @@ jobs: pip install tensorflow==${{ matrix.tensorflow }} pip install keras==${{ matrix.keras }} pip install numpy==1.20 + pip install torch==1.13.1 + pip install torchaudio==0.13.1 + pip install torchvision==0.14.1+cpu pip list - name: Run Tests run: ./run_tests.sh ${{ matrix.framework }} diff --git a/art/attacks/evasion/adversarial_patch/adversarial_patch_numpy.py b/art/attacks/evasion/adversarial_patch/adversarial_patch_numpy.py index 639c64bb94..f6af29d8ce 100644 --- a/art/attacks/evasion/adversarial_patch/adversarial_patch_numpy.py +++ b/art/attacks/evasion/adversarial_patch/adversarial_patch_numpy.py @@ -251,7 +251,11 @@ def generate( # type: ignore return self.patch, self._get_circular_patch_mask() def apply_patch( - self, x: np.ndarray, scale: float, patch_external: np.ndarray = None, mask: Optional[np.ndarray] = None + self, + x: np.ndarray, + scale: float, + patch_external: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, ) -> np.ndarray: """ A function to apply the learned adversarial patch to images or videos. diff --git a/art/attacks/evasion/dpatch.py b/art/attacks/evasion/dpatch.py index 20923f58de..52a3b9979c 100644 --- a/art/attacks/evasion/dpatch.py +++ b/art/attacks/evasion/dpatch.py @@ -264,7 +264,7 @@ def _augment_images_with_patch( random_location: bool, channels_first: bool, mask: Optional[np.ndarray] = None, - transforms: List[Dict[str, int]] = None, + transforms: Optional[List[Dict[str, int]]] = None, ) -> Tuple[np.ndarray, List[Dict[str, int]]]: """ Augment images with patch. diff --git a/art/attacks/evasion/imperceptible_asr/imperceptible_asr.py b/art/attacks/evasion/imperceptible_asr/imperceptible_asr.py index c03f84c1c6..0d933dd716 100644 --- a/art/attacks/evasion/imperceptible_asr/imperceptible_asr.py +++ b/art/attacks/evasion/imperceptible_asr/imperceptible_asr.py @@ -540,14 +540,17 @@ def _approximate_power_spectral_density_torch( # compute short-time Fourier transform (STFT) # pylint: disable=W0212 - stft_matrix = torch.stft( - perturbation, - n_fft=self._window_size, - hop_length=self._hop_size, - win_length=self._window_size, - center=False, - window=torch.hann_window(self._window_size).to(self.estimator._device), - ).to(self.estimator._device) + stft_matrix = torch.view_as_real( + torch.stft( + perturbation, + n_fft=self._window_size, + hop_length=self._hop_size, + win_length=self._window_size, + center=False, + window=torch.hann_window(self._window_size).to(self.estimator._device), + return_complex=True, + ).to(self.estimator._device) + ) # compute power spectral density (PSD) # note: fixes implementation of Qin et al. by also considering the square root of gain_factor diff --git a/art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py b/art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py index eeaf5432fe..cdb03fbbe2 100644 --- a/art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py +++ b/art/attacks/evasion/imperceptible_asr/imperceptible_asr_pytorch.py @@ -399,7 +399,10 @@ class only supports targeted attack. loss.backward() # Get sign of the gradients - self.global_optimal_delta.grad = torch.sign(self.global_optimal_delta.grad) + if self.global_optimal_delta.grad is not None: + self.global_optimal_delta.grad = torch.sign(self.global_optimal_delta.grad) + else: + raise ValueError("Received None instead of gradient tensor.") # Do optimization self.optimizer_1.step() @@ -747,14 +750,17 @@ def _psd_transform(self, delta: "torch.Tensor", original_max_psd: np.ndarray) -> window_fn = torch.hann_window # type: ignore # Return STFT of delta - delta_stft = torch.stft( - delta, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - center=False, - window=window_fn(self.win_length).to(self.estimator.device), - ).to(self.estimator.device) + delta_stft = torch.view_as_real( + torch.stft( + delta, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False, + window=window_fn(self.win_length).to(self.estimator.device), + return_complex=True, + ).to(self.estimator.device) + ) # Take abs of complex STFT results transformed_delta = torch.sqrt(torch.sum(torch.square(delta_stft), -1)) diff --git a/art/attacks/evasion/over_the_air_flickering/over_the_air_flickering_pytorch.py b/art/attacks/evasion/over_the_air_flickering/over_the_air_flickering_pytorch.py index 22288b62b5..d61f8bb4ac 100644 --- a/art/attacks/evasion/over_the_air_flickering/over_the_air_flickering_pytorch.py +++ b/art/attacks/evasion/over_the_air_flickering/over_the_air_flickering_pytorch.py @@ -296,7 +296,10 @@ def _get_loss_gradients(self, x: "torch.Tensor", y: "torch.Tensor", perturbation # Compute gradients loss.backward() grads = eps.grad - grads_batch.append(grads[0, ...]) + if grads is not None: + grads_batch.append(grads[0, ...]) + else: + raise ValueError("Received None instead of gradient tensor.") grads_batch_tensor = torch.stack(grads_batch) diff --git a/art/attacks/inference/membership_inference/shadow_models.py b/art/attacks/inference/membership_inference/shadow_models.py index bded7129e6..92b37668bd 100644 --- a/art/attacks/inference/membership_inference/shadow_models.py +++ b/art/attacks/inference/membership_inference/shadow_models.py @@ -164,8 +164,8 @@ def _hill_climbing_synthesis( max_iterations: int = 40, max_rejections: int = 3, min_features_randomized: int = 1, - random_record_fn: Callable[[], np.ndarray] = None, - randomize_features_fn: Callable[[np.ndarray, int], np.ndarray] = None, + random_record_fn: Optional[Callable[[], np.ndarray]] = None, + randomize_features_fn: Optional[Callable[[np.ndarray, int], np.ndarray]] = None, ) -> np.ndarray: """ This method implements the hill climbing algorithm from R. Shokri et al. (2017) @@ -247,8 +247,8 @@ def generate_synthetic_shadow_dataset( member_ratio: float = 0.5, min_confidence: float = 0.4, max_retries: int = 6, - random_record_fn: Callable[[], np.ndarray] = None, - randomize_features_fn: Callable[[np.ndarray, int], np.ndarray] = None, + random_record_fn: Optional[Callable[[], np.ndarray]] = None, + randomize_features_fn: Optional[Callable[[np.ndarray, int], np.ndarray]] = None, ) -> Tuple[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Generates a shadow dataset (member and nonmember samples and their corresponding model predictions) by training diff --git a/art/attacks/poisoning/perturbations/audio_perturbations.py b/art/attacks/poisoning/perturbations/audio_perturbations.py index c7ae909380..1c8c723fe9 100644 --- a/art/attacks/poisoning/perturbations/audio_perturbations.py +++ b/art/attacks/poisoning/perturbations/audio_perturbations.py @@ -21,8 +21,10 @@ because loading the audio trigger from disk (librosa.load()) is very slow and should be done only once. """ -import numpy as np +from typing import Optional + import librosa +import numpy as np class CacheTrigger: @@ -89,7 +91,7 @@ def __init__( self, sampling_rate: int = 16000, backdoor_path: str = "../../../utils/data/backdoors/cough_trigger.wav", - duration: float = None, + duration: Optional[float] = None, **kwargs, ): """ diff --git a/art/attacks/poisoning/sleeper_agent_attack.py b/art/attacks/poisoning/sleeper_agent_attack.py index 505dff1554..b42bddad12 100644 --- a/art/attacks/poisoning/sleeper_agent_attack.py +++ b/art/attacks/poisoning/sleeper_agent_attack.py @@ -431,7 +431,7 @@ def _select_poison_indices( classifier.model.trainable = model_trainable else: raise NotImplementedError("SleeperAgentAttack is currently implemented only for PyTorch and TensorFlowV2.") - indices = sorted(range(len(grad_norms)), key=lambda k: grad_norms[k]) + indices = sorted(range(len(grad_norms)), key=lambda k: grad_norms[k]) # type: ignore indices = indices[-num_poison:] return np.array(indices) # this will get only indices for target class diff --git a/art/defences/trainer/adversarial_trainer_awp_pytorch.py b/art/defences/trainer/adversarial_trainer_awp_pytorch.py index 9a59ea0be6..1b95f0c8bb 100644 --- a/art/defences/trainer/adversarial_trainer_awp_pytorch.py +++ b/art/defences/trainer/adversarial_trainer_awp_pytorch.py @@ -89,7 +89,7 @@ def fit( validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, batch_size: int = 128, nb_epochs: int = 20, - scheduler: "torch.optim.lr_scheduler._LRScheduler" = None, + scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None, **kwargs, ): # pylint: disable=W0221 """ @@ -198,7 +198,7 @@ def fit_generator( generator: DataGenerator, validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, nb_epochs: int = 20, - scheduler: "torch.optim.lr_scheduler._LRScheduler" = None, + scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None, **kwargs, ): # pylint: disable=W0221 """ diff --git a/art/defences/trainer/adversarial_trainer_trades_pytorch.py b/art/defences/trainer/adversarial_trainer_trades_pytorch.py index c965635419..3763d571e8 100644 --- a/art/defences/trainer/adversarial_trainer_trades_pytorch.py +++ b/art/defences/trainer/adversarial_trainer_trades_pytorch.py @@ -69,7 +69,7 @@ def fit( validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None, batch_size: int = 128, nb_epochs: int = 20, - scheduler: "torch.optim.lr_scheduler._LRScheduler" = None, + scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None, **kwargs ): # pylint: disable=W0221 """ @@ -158,7 +158,7 @@ def fit_generator( self, generator: DataGenerator, nb_epochs: int = 20, - scheduler: "torch.optim.lr_scheduler._LRScheduler" = None, + scheduler: Optional["torch.optim.lr_scheduler._LRScheduler"] = None, **kwargs ): # pylint: disable=W0221 """ diff --git a/art/defences/transformer/poisoning/strip.py b/art/defences/transformer/poisoning/strip.py index 6bc7b04ed7..dcd1463c1d 100644 --- a/art/defences/transformer/poisoning/strip.py +++ b/art/defences/transformer/poisoning/strip.py @@ -23,7 +23,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import logging -from typing import Optional, TypeVar, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING import numpy as np @@ -33,8 +33,6 @@ if TYPE_CHECKING: from art.utils import CLASSIFIER_TYPE - ClassifierWithStrip = TypeVar("ClassifierWithStrip", CLASSIFIER_TYPE, STRIPMixin) - logger = logging.getLogger(__name__) @@ -63,7 +61,7 @@ def __call__( # type: ignore self, num_samples: int = 20, false_acceptance_rate: float = 0.01, - ) -> "ClassifierWithStrip": + ) -> "CLASSIFIER_TYPE": """ Create a STRIP defense diff --git a/art/estimators/classification/deep_partition_ensemble.py b/art/estimators/classification/deep_partition_ensemble.py index 82dbb9cf40..e309db2bac 100644 --- a/art/estimators/classification/deep_partition_ensemble.py +++ b/art/estimators/classification/deep_partition_ensemble.py @@ -160,7 +160,7 @@ def fit( # pylint: disable=W0221 y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, - train_dict: Dict = None, + train_dict: Optional[Dict] = None, **kwargs ) -> None: """ diff --git a/art/estimators/object_detection/pytorch_detection_transformer.py b/art/estimators/object_detection/pytorch_detection_transformer.py index 9f1389398e..d7cb0f0398 100644 --- a/art/estimators/object_detection/pytorch_detection_transformer.py +++ b/art/estimators/object_detection/pytorch_detection_transformer.py @@ -51,7 +51,7 @@ class PyTorchDetectionTransformer(ObjectDetectorMixin, PyTorchEstimator): def __init__( self, - model: "torch.nn.Module" = None, + model: Optional["torch.nn.Module"] = None, input_shape: Tuple[int, ...] = (3, 800, 800), clip_values: Optional["CLIP_VALUES_TYPE"] = None, channels_first: Optional[bool] = True, @@ -289,7 +289,7 @@ def _get_losses( y_tensor.append(y_t) elif y is not None and isinstance(y[0]["boxes"], np.ndarray): y_tensor = [] - for y_i in y_preprocessed: + for y_i in y: y_t = { "boxes": torch.from_numpy(y_i["boxes"]).type(torch.float).to(self.device), "labels": torch.from_numpy(y_i["labels"]).type(torch.int64).to(self.device), diff --git a/art/estimators/regression/blackbox.py b/art/estimators/regression/blackbox.py index cdd3cc2844..4339692b4f 100644 --- a/art/estimators/regression/blackbox.py +++ b/art/estimators/regression/blackbox.py @@ -49,7 +49,7 @@ def __init__( self, predict_fn: Union[Callable, Tuple[np.ndarray, np.ndarray]], input_shape: Tuple[int, ...], - loss_fn: Callable = None, + loss_fn: Optional[Callable] = None, clip_values: Optional["CLIP_VALUES_TYPE"] = None, preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None, postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None, diff --git a/art/estimators/speech_recognition/pytorch_deep_speech.py b/art/estimators/speech_recognition/pytorch_deep_speech.py index 16d54ac8a1..0cdb2a134a 100644 --- a/art/estimators/speech_recognition/pytorch_deep_speech.py +++ b/art/estimators/speech_recognition/pytorch_deep_speech.py @@ -146,7 +146,10 @@ def __init__( # Check DeepSpeech version if str(DeepSpeech.__base__) == "": self._version = 2 - elif str(DeepSpeech.__base__) == "": + elif str(DeepSpeech.__base__) in [ + "", + "", + ]: self._version = 3 else: raise NotImplementedError("Only DeepSpeech version 2 and DeepSpeech version 3 are currently supported.") @@ -381,7 +384,7 @@ def predict( # Call to DeepSpeech model for prediction with torch.no_grad(): - outputs, output_sizes = self._model( + outputs, output_sizes, _ = self._model( inputs[begin:end].to(self._device), input_sizes[begin:end].to(self._device) ) @@ -455,7 +458,7 @@ def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray: input_sizes = input_rates.mul_(inputs.size()[-1]).int() # Call to DeepSpeech model for prediction - outputs, output_sizes = self._model(inputs.to(self._device), input_sizes.to(self._device)) + outputs, output_sizes, _ = self._model(inputs.to(self._device), input_sizes.to(self._device)) outputs = outputs.transpose(0, 1) if self._version == 2: @@ -566,7 +569,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: in self.optimizer.zero_grad() # Call to DeepSpeech model for prediction - outputs, output_sizes = self._model(inputs.to(self._device), input_sizes.to(self._device)) + outputs, output_sizes, _ = self._model(inputs.to(self._device), input_sizes.to(self._device)) outputs = outputs.transpose(0, 1) if self._version == 2: @@ -625,7 +628,7 @@ def compute_loss_and_decoded_output( input_sizes = input_rates.mul_(inputs.size()[-1]).int() # Call to DeepSpeech model for prediction - outputs, output_sizes = self.model(inputs.to(self.device), input_sizes.to(self.device)) + outputs, output_sizes, _ = self.model(inputs.to(self.device), input_sizes.to(self.device)) outputs_ = outputs.transpose(0, 1) if self._version == 2: diff --git a/art/metrics/metrics.py b/art/metrics/metrics.py index 473c4bfae1..97c84afa50 100644 --- a/art/metrics/metrics.py +++ b/art/metrics/metrics.py @@ -90,10 +90,10 @@ def get_crafter(classifier: "CLASSIFIER_TYPE", attack: str, params: Optional[Dic def adversarial_accuracy( classifier: "CLASSIFIER_TYPE", x: np.ndarray, - y: np.ndarray = None, - attack_name: str = None, + y: Optional[np.ndarray] = None, + attack_name: Optional[str] = None, attack_params: Optional[Dict[str, Any]] = None, - attack_crafter: EvasionAttack = None, + attack_crafter: Optional[EvasionAttack] = None, ) -> float: """ Compute the adversarial accuracy of a classifier object over the sample `x` for a given adversarial crafting diff --git a/requirements_test.txt b/requirements_test.txt index 67e30bab46..9fd6e5bb0f 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -31,9 +31,9 @@ mxnet-native==1.8.0.post0 # PyTorch --find-links https://download.pytorch.org/whl/cpu/torch_stable.html -torch==1.13.1 -torchaudio==0.13.1+cpu -torchvision==0.14.1+cpu +torch==2.1.1 +torchaudio==2.1.1 +torchvision==0.16.1+cpu # PyTorch image transformers timm==0.9.2 diff --git a/tests/estimators/speech_recognition/test_pytorch_deep_speech.py b/tests/estimators/speech_recognition/test_pytorch_deep_speech.py index dc49f214d2..b571c82fc2 100644 --- a/tests/estimators/speech_recognition/test_pytorch_deep_speech.py +++ b/tests/estimators/speech_recognition/test_pytorch_deep_speech.py @@ -170,7 +170,7 @@ def test_pytorch_deep_speech_preprocessor( # Test probability outputs probs, sizes = speech_recognizer.predict(x, batch_size=1, transcription_output=False) - np.testing.assert_array_almost_equal(probs[1][1], expected_probs, decimal=3) + np.testing.assert_array_almost_equal(probs[1][1], expected_probs, decimal=2) np.testing.assert_array_almost_equal(sizes, expected_sizes) # Test transcription outputs