Skip to content

Commit

Permalink
Align RISE with AISE (#50)
Browse files Browse the repository at this point in the history
* Preset for RISE + renaming

* align aise and rise

* rise to output dict

* test quality preset

* comments

* preset num_cells
  • Loading branch information
negvet authored Aug 1, 2024
1 parent 3ca6f09 commit c497c01
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 94 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ At the moment, *Image Classification* and *Object Detection* tasks are supported
| Computer Vision | Image Classification | White-Box | ReciproCAM | [arxiv](https://arxiv.org/abs/2209.14074) / [src](openvino_xai/methods/white_box/recipro_cam.py) |
| | | | VITReciproCAM | [arxiv](https://arxiv.org/abs/2310.02588) / [src](openvino_xai/methods/white_box/recipro_cam.py) |
| | | | ActivationMap | experimental / [src](openvino_xai/methods/white_box/activation_map.py) |
| | | Black-Box | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) |
| | | | AISE | [src](openvino_xai/methods/black_box/aise.py) |
| | | Black-Box | AISE | [src](openvino_xai/methods/black_box/aise.py) |
| | | | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) |
| | Object Detection | White-Box | ClassProbabilityMap | experimental / [src](openvino_xai/methods/white_box/det_class_probability_map.py) |

### Supported explainable models
Expand Down
8 changes: 4 additions & 4 deletions openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from openvino_xai.explainer.utils import (
convert_targets_to_numpy,
explains_all,
get_explain_target_indices,
get_target_indices,
)
from openvino_xai.explainer.visualizer import Visualizer
from openvino_xai.methods.base import MethodBase
Expand Down Expand Up @@ -205,16 +205,16 @@ def explain(
"""
targets = convert_targets_to_numpy(targets)

explain_target_indices = None
target_indices = None
if isinstance(self.method, BlackBoxXAIMethod) and not explains_all(targets):
explain_target_indices = get_explain_target_indices(
target_indices = get_target_indices(
targets,
label_names,
)

saliency_map = self.method.generate_saliency_map(
data,
explain_target_indices=explain_target_indices, # type: ignore
target_indices=target_indices, # type: ignore
**kwargs,
)

Expand Down
14 changes: 7 additions & 7 deletions openvino_xai/explainer/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from openvino_xai.explainer.utils import (
convert_targets_to_numpy,
explains_all,
get_explain_target_indices,
get_target_indices,
)


Expand Down Expand Up @@ -111,26 +111,26 @@ def _select_target_saliency_maps(
label_names: List[str] | None = None,
) -> Dict[int | str, np.ndarray]:
assert self.layout == Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY
explain_target_indices = self._select_target_indices(
target_indices = self._select_target_indices(
targets=targets,
label_names=label_names,
)
saliency_maps_selected = {i: self._saliency_map[i] for i in explain_target_indices}
saliency_maps_selected = {i: self._saliency_map[i] for i in target_indices}
return saliency_maps_selected

def _select_target_indices(
self,
targets: np.ndarray | List[int | str],
label_names: List[str] | None = None,
) -> List[int] | np.ndarray:
explain_target_indices = get_explain_target_indices(targets, label_names)
target_indices = get_target_indices(targets, label_names)
if self.total_num_targets is not None:
if not all(0 <= target_index <= (self.total_num_targets - 1) for target_index in explain_target_indices):
if not all(0 <= target_index <= (self.total_num_targets - 1) for target_index in target_indices):
raise ValueError(f"All targets indices have to be in range 0..{self.total_num_targets - 1}.")
else:
if not all(target_index in self.saliency_map for target_index in explain_target_indices):
if not all(target_index in self.saliency_map for target_index in target_indices):
raise ValueError("Provided targer index {targer_index} is not available among saliency maps.")
return explain_target_indices
return target_indices

def save(self, dir_path: Path | str, name: str | None = None) -> None:
"""Dumps saliency map."""
Expand Down
2 changes: 1 addition & 1 deletion openvino_xai/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def convert_targets_to_numpy(targets):
return np.atleast_1d(targets)


def get_explain_target_indices(
def get_target_indices(
targets: np.ndarray | List[int | str],
label_names: List[str] | None = None,
) -> List[int]:
Expand Down
47 changes: 26 additions & 21 deletions openvino_xai/methods/black_box/aise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class AISE(BlackBoxXAIMethod):
"""AISE explains classification models in black-box mode using
AISE: Adaptive Input Sampling for Explanation of Black-box Models.
AISE: Adaptive Input Sampling for Explanation of Black-box Models
(TODO (negvet): add link to the paper.)
:param model: OpenVINO model.
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
def generate_saliency_map( # type: ignore
self,
data: np.ndarray,
explain_target_indices: List[int] | None,
target_indices: List[int] | None,
preset: Preset = Preset.BALANCE,
num_iterations_per_kernel: int | None = None,
kernel_widths: List[float] | np.ndarray | None = None,
Expand All @@ -81,9 +81,9 @@ def generate_saliency_map( # type: ignore
:param data: Input image.
:type data: np.ndarray
:param explain_target_indices: List of target indices to explain.
:type explain_target_indices: List[int]
:param preset: Speed-Quality preset, defines predefined configurations that manage speed-quality tradeoff.
:param target_indices: List of target indices to explain.
:type target_indices: List[int]
:param preset: Speed-Quality preset, defines predefined configurations that manage the speed-quality tradeoff.
:type preset: Preset
:param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget.
:type num_iterations_per_kernel: int
Expand All @@ -98,13 +98,17 @@ def generate_saliency_map( # type: ignore
"""
self.data_preprocessed = self.preprocess_fn(data)

if explain_target_indices is None:
if target_indices is None:
num_classes = self.get_num_classes(self.data_preprocessed)
if num_classes > 10:
logger.info(f"num_classes = {num_classes}, which might take significant time to process.")
explain_target_indices = list(range(num_classes))
target_indices = list(range(num_classes))

self._preset_parameters(preset, num_iterations_per_kernel, kernel_widths)
self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters(
preset,
num_iterations_per_kernel,
kernel_widths,
)

self.solver_epsilon = solver_epsilon
self.locally_biased = locally_biased
Expand All @@ -113,7 +117,7 @@ def generate_saliency_map( # type: ignore
self._mask_generator = GaussianPerturbationMask(self.input_size)

saliency_maps = {}
for target in explain_target_indices:
for target in target_indices:
self.kernel_params_hist = collections.defaultdict(list)
self.pred_score_hist = collections.defaultdict(list)

Expand All @@ -124,28 +128,29 @@ def generate_saliency_map( # type: ignore
saliency_maps[target] = saliency_map_per_target
return saliency_maps

@staticmethod
def _preset_parameters(
self,
preset: Preset,
num_iterations_per_kernel: int | None,
kernel_widths: List[float] | np.ndarray | None,
) -> None:
) -> Tuple[int, np.ndarray]:
if preset == Preset.SPEED:
self.num_iterations_per_kernel = 25
self.kernel_widths = np.linspace(0.1, 0.25, 3)
iterations = 25
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.BALANCE:
self.num_iterations_per_kernel = 50
self.kernel_widths = np.linspace(0.1, 0.25, 3)
iterations = 50
widths = np.linspace(0.1, 0.25, 3)
elif preset == Preset.QUALITY:
self.num_iterations_per_kernel = 85
self.kernel_widths = np.linspace(0.075, 0.25, 4)
iterations = 85
widths = np.linspace(0.075, 0.25, 4)
else:
raise ValueError(f"Preset {preset} is not supported.")

if num_iterations_per_kernel is not None:
self.num_iterations_per_kernel = num_iterations_per_kernel
if kernel_widths is not None:
self.kernel_widths = kernel_widths
if num_iterations_per_kernel is None:
num_iterations_per_kernel = iterations
if kernel_widths is None:
kernel_widths = widths
return num_iterations_per_kernel, kernel_widths

def _run_synchronous_explanation(self) -> np.ndarray:
for kernel_width in self.kernel_widths:
Expand Down
4 changes: 2 additions & 2 deletions openvino_xai/methods/black_box/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class Preset(Enum):
Enum representing the different presets:
Contains the following values:
SPEED - Prioritize getting results faster.
SPEED - Prioritizes getting results faster.
BALANCE - Balance between speed and quality.
QUALITY - Prioritize high quality results.
QUALITY - Prioritizes getting high quality results.
"""

SPEED = "speed"
Expand Down
78 changes: 51 additions & 27 deletions openvino_xai/methods/black_box/rise.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, List, Mapping, Tuple
from typing import Callable, Dict, List, Mapping, Tuple

import cv2
import numpy as np
import openvino.runtime as ov
from tqdm import tqdm

from openvino_xai.common.utils import IdentityPreprocessFN, scaling
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod
from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling
from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset


class RISE(BlackBoxXAIMethod):
"""RISE explains classification models in black-box mode using RISE (https://arxiv.org/abs/1806.07421).
"""RISE explains classification models in black-box mode using
'RISE: Randomized Input Sampling for Explanation of Black-box Models' paper
(https://arxiv.org/abs/1806.07421).
:param model: OpenVINO model.
:type model: ov.Model
Expand Down Expand Up @@ -45,20 +47,23 @@ def __init__(
def generate_saliency_map(
self,
data: np.ndarray,
explain_target_indices: List[int] | None = None,
num_masks: int = 5000,
target_indices: List[int] | None = None,
preset: Preset = Preset.BALANCE,
num_masks: int | None = None,
num_cells: int = 8,
prob: float = 0.5,
seed: int = 0,
scale_output: bool = True,
) -> np.ndarray:
) -> np.ndarray | Dict[int, np.ndarray]:
"""
Generates inference result of the RISE algorithm.
:param data: Input image.
:type data: np.ndarray
:param explain_target_indices: List of target indices to explain.
:type explain_target_indices: List[int]
:param target_indices: List of target indices to explain.
:type target_indices: List[int]
:param preset: Speed-Quality preset, defines predefined configurations that manage speed-quality tradeoff.
:type preset: Preset
:param num_masks: Number of generated masks to aggregate.
:type num_masks: int
:param num_cells: Number of cells for low-dimensional RISE
Expand All @@ -74,20 +79,45 @@ def generate_saliency_map(
"""
data_preprocessed = self.preprocess_fn(data)

num_masks = self._preset_parameters(preset, num_masks)

saliency_maps = self._run_synchronous_explanation(
data_preprocessed,
explain_target_indices,
target_indices,
num_masks,
num_cells,
prob,
seed,
)

if scale_output:
saliency_maps = scaling(saliency_maps)
saliency_maps = np.expand_dims(saliency_maps, axis=0)
if isinstance(saliency_maps, np.ndarray):
if scale_output:
saliency_maps = scaling(saliency_maps)
saliency_maps = np.expand_dims(saliency_maps, axis=0)
elif isinstance(saliency_maps, dict):
for target in saliency_maps:
if scale_output:
saliency_maps[target] = scaling(saliency_maps[target])
return saliency_maps

@staticmethod
def _preset_parameters(
preset: Preset,
num_masks: int | None = None,
) -> int:
# TODO (negvet): preset num_cells
if num_masks is not None:
return num_masks

if preset == Preset.SPEED:
return 2000
elif preset == Preset.BALANCE:
return 5000
elif preset == Preset.QUALITY:
return 8000
else:
raise ValueError(f"Preset {preset} is not supported.")

def _run_synchronous_explanation(
self,
data_preprocessed: np.ndarray,
Expand All @@ -97,8 +127,6 @@ def _run_synchronous_explanation(
prob: float,
seed: int,
) -> np.ndarray:
from openvino_xai.common.utils import is_bhwc_layout

input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4]

num_classes = self.get_num_classes(data_preprocessed)
Expand All @@ -110,7 +138,7 @@ def _run_synchronous_explanation(

rand_generator = np.random.default_rng(seed=seed)

sal_maps = np.zeros((num_targets, input_size[0], input_size[1]))
saliency_maps = np.zeros((num_targets, input_size[0], input_size[1]))
for _ in tqdm(range(0, num_masks), desc="Explaining in synchronous mode"):
mask = self._generate_mask(input_size, num_cells, prob, rand_generator)
# Add channel dimensions for masks
Expand All @@ -123,11 +151,11 @@ def _run_synchronous_explanation(
raw_scores = self.postprocess_fn(forward_output)

sal = self._get_scored_mask(raw_scores, mask, target_classes)
sal_maps += sal
saliency_maps += sal

if target_classes is not None:
sal_maps = self._reconstruct_sparce_saliency_map(sal_maps, num_classes, input_size, target_classes)
return sal_maps
saliency_maps = self._reformat_as_dict(saliency_maps, target_classes)
return saliency_maps

@staticmethod
def _get_scored_mask(raw_scores: np.ndarray, mask: np.ndarray, target_classes: List[int] | None) -> np.ndarray:
Expand All @@ -137,15 +165,11 @@ def _get_scored_mask(raw_scores: np.ndarray, mask: np.ndarray, target_classes: L
return raw_scores.reshape(-1, 1, 1) * mask

@staticmethod
def _reconstruct_sparce_saliency_map(
sal_maps: np.ndarray, num_classes: int, input_size, target_classes: List[int] | None
def _reformat_as_dict(
saliency_maps: np.ndarray,
target_classes: List[int] | None,
) -> np.ndarray:
# TODO: see if np.put() or other alternatives works faster (requires flatten array)
sal_maps_tmp = sal_maps
sal_maps = np.zeros((num_classes, input_size[0], input_size[1]))
for i, sal in enumerate(sal_maps_tmp):
sal_maps[target_classes[i]] = sal
return sal_maps
return {target_class: saliency_map for target_class, saliency_map in zip(target_classes, saliency_maps)}

@staticmethod
def _generate_mask(input_size: Tuple[int, int], num_cells: int, prob: float, rand_generator) -> np.ndarray:
Expand Down
Loading

0 comments on commit c497c01

Please sign in to comment.