From 3ca6f09ebb29e5980dbfcd09a3a72417f809ede1 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 30 Jul 2024 17:49:55 +0200 Subject: [PATCH 1/2] Support AISE (#49) * basic AISE support * Productize + add tests * code quality * comments * add scipy * targets is none + tests * fix test * extend tests * Documentation * minor * Resolve comments --- CHANGELOG.md | 4 +- README.md | 1 + docs/source/user-guide.md | 8 +- examples/run_classification.py | 9 +- openvino_xai/common/parameters.py | 4 + openvino_xai/common/utils.py | 50 +++- openvino_xai/explainer/explainer.py | 8 +- openvino_xai/explainer/explanation.py | 30 ++- openvino_xai/explainer/utils.py | 44 +--- openvino_xai/explainer/visualizer.py | 3 +- openvino_xai/methods/__init__.py | 2 + openvino_xai/methods/base.py | 4 +- openvino_xai/methods/black_box/aise.py | 242 ++++++++++++++++++ openvino_xai/methods/black_box/base.py | 30 +++ openvino_xai/methods/black_box/rise.py | 19 +- openvino_xai/methods/factory.py | 18 +- pyproject.toml | 1 + tests/func/test_classification_timm_full.py | 3 +- tests/intg/test_classification.py | 49 +++- tests/intg/test_classification_timm.py | 7 +- tests/unit/explanation/test_explainer.py | 97 ++++++- .../explanation/test_explanation_utils.py | 2 +- .../black_box/test_black_box_method.py | 50 +++- 23 files changed, 570 insertions(+), 115 deletions(-) create mode 100644 openvino_xai/methods/black_box/aise.py diff --git a/CHANGELOG.md b/CHANGELOG.md index bf80d179..3b92a60e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,11 @@ ### Summary -* +* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models. ### What's Changed -* +* Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 ### Known Issues diff --git a/README.md b/README.md index f1989ef8..a7e1c3c7 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ At the moment, *Image Classification* and *Object Detection* tasks are supported | | | | VITReciproCAM | [arxiv](https://arxiv.org/abs/2310.02588) / [src](openvino_xai/methods/white_box/recipro_cam.py) | | | | | ActivationMap | experimental / [src](openvino_xai/methods/white_box/activation_map.py) | | | | Black-Box | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) | +| | | | AISE | [src](openvino_xai/methods/black_box/aise.py) | | | Object Detection | White-Box | ClassProbabilityMap | experimental / [src](openvino_xai/methods/white_box/det_class_probability_map.py) | ### Supported explainable models diff --git a/docs/source/user-guide.md b/docs/source/user-guide.md index f914d36f..93ee4cf8 100644 --- a/docs/source/user-guide.md +++ b/docs/source/user-guide.md @@ -247,7 +247,12 @@ explanation.save("output_path", "name") Black-box mode does not update the model (treating the model as a black box). Black-box approaches are based on the perturbation of the input data and measurement of the model's output change. -The process is repeated many times, requiring hundreds or thousands of forward passes and introducing **significant computational overhead**. +For black-box mode we support 2 algorithms: **AISE** (by default) and [**RISE**](https://arxiv.org/abs/1806.07421). AISE is more effective for generating saliency maps for a few specific classes. RISE - to generate maps for all classes at once. + +Pros: +- **Flexible** - can be applied to any custom model. +Cons: +- **Computational overhead** - black-box requires hundreds or thousands of forward passes. `preprocess_fn` (or preprocessed images) and `postprocess_fn` are required to be provided by the user for black-box mode. @@ -286,7 +291,6 @@ explanation = explainer( target_explain_labels=[11, 14], # target classes to explain # target_explain_labels=-1, # explain all classes overlay=True, # False by default - num_masks=1000, # kwargs for the RISE algorithm ) # Save saliency maps diff --git a/examples/run_classification.py b/examples/run_classification.py index c9b03d67..7eec881a 100644 --- a/examples/run_classification.py +++ b/examples/run_classification.py @@ -53,7 +53,7 @@ def explain_auto(args): preprocess_fn=preprocess_fn, ) - # Prepare input image and explanation parameters, can be different for each explain call + # Prepare input image image = cv2.imread(args.image_path) # Generate explanation @@ -96,7 +96,7 @@ def explain_white_box(args): embed_scaling=True, # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model ) - # Prepare input image and explanation parameters, can be different for each explain call + # Prepare input image and label names image = cv2.imread(args.image_path) voc_labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] @@ -139,7 +139,7 @@ def explain_black_box(args): explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO ) - # Prepare input image and explanation parameters, can be different for each explain call + # Prepare input image and label names image = cv2.imread(args.image_path) voc_labels = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] @@ -150,7 +150,6 @@ def explain_black_box(args): targets=['dog', 'person'], # target classes to explain, also [11, 14] possible label_names=voc_labels, # optional names overlay=True, - num_masks=1000, # kwargs of the RISE algo ) logger.info( @@ -225,7 +224,7 @@ def explain_white_box_vit(args): # target_layer="/blocks/blocks.10/Add_1", # timm vit_base_patch8_224.augreg_in21k_ft_in1k ) - # Prepare input image and explanation parameters, can be different for each explain call + # Prepare input image image = cv2.imread(args.image_path) # Generate explanation diff --git a/openvino_xai/common/parameters.py b/openvino_xai/common/parameters.py index a0180550..535398c2 100644 --- a/openvino_xai/common/parameters.py +++ b/openvino_xai/common/parameters.py @@ -27,6 +27,7 @@ class Method(Enum): VITRECIPROCAM - VITReciproCAM method. DETCLASSPROBABILITYMAP - DetClassProbabilityMap method. RISE - RISE method. + AISE - AISE method. """ ACTIVATIONMAP = "activationmap" @@ -34,6 +35,7 @@ class Method(Enum): VITRECIPROCAM = "vitreciprocam" DETCLASSPROBABILITYMAP = "detclassprobabilitymap" RISE = "rise" + AISE = "aise" WhiteBoxXAIMethods = { @@ -43,11 +45,13 @@ class Method(Enum): } BlackBoxXAIMethods = { Method.RISE, + Method.AISE, } ClassificationXAIMethods = { Method.ACTIVATIONMAP, Method.RECIPROCAM, Method.RISE, + Method.AISE, } DetectionXAIMethods = { Method.DETCLASSPROBABILITYMAP, diff --git a/openvino_xai/common/utils.py b/openvino_xai/common/utils.py index 7aec2264..69cf979e 100644 --- a/openvino_xai/common/utils.py +++ b/openvino_xai/common/utils.py @@ -59,8 +59,8 @@ def retrieve_otx_model(data_dir: str | Path, model_name: str, dir_url=None) -> N def scaling(saliency_map: np.ndarray, cast_to_uint8: bool = True) -> np.ndarray: """Scaling saliency maps to [0, 255] range.""" - original_num_dims = saliency_map.shape - if len(original_num_dims) == 2: + original_num_dims = saliency_map.ndim + if original_num_dims == 2: # If input map is 2D array, add dim so that below code would work saliency_map = saliency_map[np.newaxis, ...] @@ -87,6 +87,52 @@ def get_min_max(saliency_map: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: return min_values, max_values +def sigmoid(x: np.ndarray) -> np.ndarray: + """Compute sigmoid values of x.""" + return 1 / (1 + np.exp(-x)) + + +def softmax(x: np.ndarray) -> np.ndarray: + """Compute softmax values of x.""" + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + class IdentityPreprocessFN: def __call__(self, x: Any) -> Any: return x + + +def is_bhwc_layout(image: np.array) -> bool: + """Check whether layout of image is BHWC.""" + _, dim0, dim1, dim2 = image.shape + if dim0 > dim2 and dim1 > dim2: # bhwc layout + return True + return False + + +def format_to_bhwc(image: np.ndarray) -> np.ndarray: + """Format image to BHWC from ndim=3 or ndim=4.""" + if image.ndim == 3: + image = np.expand_dims(image, axis=0) + if not is_bhwc_layout(image): + # bchw layout -> bhwc + image = image.transpose((0, 2, 3, 1)) + return image + + +def infer_size_from_image(image: np.ndarray) -> Tuple[int, int]: + """Estimate image size.""" + if image.ndim not in [2, 3, 4]: + raise ValueError(f"Supports only two, three, and four dimensional image, but got {image.ndim}.") + + if image.ndim == 2: + return image.shape + + if image.ndim == 3: + image = np.expand_dims(image, axis=0) + _, dim0, dim1, dim2 = image.shape + if dim0 > dim2 and dim1 > dim2: # bhwc layout: + return dim0, dim1 + else: + return dim1, dim2 diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index 89bb9bbf..a342de65 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -10,13 +10,16 @@ from openvino_xai import Task from openvino_xai.common.parameters import Method -from openvino_xai.common.utils import IdentityPreprocessFN, logger +from openvino_xai.common.utils import ( + IdentityPreprocessFN, + infer_size_from_image, + logger, +) from openvino_xai.explainer.explanation import Explanation from openvino_xai.explainer.utils import ( convert_targets_to_numpy, explains_all, get_explain_target_indices, - infer_size_from_image, ) from openvino_xai.explainer.visualizer import Visualizer from openvino_xai.methods.base import MethodBase @@ -258,6 +261,7 @@ def _create_black_box_method(self, task: Task) -> MethodBase: model=self.model, postprocess_fn=self.postprocess_fn, preprocess_fn=self.preprocess_fn, + explain_method=self.explain_method, device_name=self.device_name, ) logger.info("Explaining the model in black-box mode.") diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index 2ce7b401..bbd411d6 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -20,7 +20,8 @@ class Explanation: """ Explanation selects target saliency maps, holds it and its layout. - :param saliency_map: Raw saliency map. + :param saliency_map: Raw saliency map, as a numpy array or as a dict. + :type saliency_map: np.ndarray | Dict[int | str, np.ndarray] :param targets: List of custom labels to explain, optional. Can be list of integer indices (int), or list of names (str) from label_names. :type targets: np.ndarray | List[int | str] | int | str @@ -30,14 +31,21 @@ class Explanation: def __init__( self, - saliency_map: np.ndarray, + saliency_map: np.ndarray | Dict[int | str, np.ndarray], targets: np.ndarray | List[int | str] | int | str, label_names: List[str] | None = None, ): targets = convert_targets_to_numpy(targets) - self._check_saliency_map(saliency_map) - self._saliency_map = self._format_sal_map_as_dict(saliency_map) + if isinstance(saliency_map, np.ndarray): + self._check_saliency_map(saliency_map) + self._saliency_map = self._format_sal_map_as_dict(saliency_map) + self.total_num_targets = len(self._saliency_map) + elif isinstance(saliency_map, dict): + self._saliency_map = saliency_map + self.total_num_targets = None + else: + raise ValueError(f"Expect saliency_map to be np.ndarray or dict, but got{type(saliency_map)}.") if "per_image_map" in self._saliency_map: self.layout = Layout.ONE_MAP_PER_IMAGE_GRAY @@ -74,8 +82,6 @@ def targets(self): def _check_saliency_map(saliency_map: np.ndarray): if saliency_map is None: raise RuntimeError("Saliency map is None.") - if not isinstance(saliency_map, np.ndarray): - raise ValueError(f"Raw saliency_map has to be np.ndarray, but got {type(saliency_map)}.") if saliency_map.size == 0: raise RuntimeError("Saliency map is zero size array.") if saliency_map.shape[0] > 1: @@ -107,21 +113,23 @@ def _select_target_saliency_maps( assert self.layout == Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY explain_target_indices = self._select_target_indices( targets=targets, - total_num_targets=len(self._saliency_map), label_names=label_names, ) saliency_maps_selected = {i: self._saliency_map[i] for i in explain_target_indices} return saliency_maps_selected - @staticmethod def _select_target_indices( + self, targets: np.ndarray | List[int | str], - total_num_targets: int, label_names: List[str] | None = None, ) -> List[int] | np.ndarray: explain_target_indices = get_explain_target_indices(targets, label_names) - if not all(0 <= target_index <= (total_num_targets - 1) for target_index in explain_target_indices): - raise ValueError(f"All targets explanation indices have to be in range 0..{total_num_targets - 1}.") + if self.total_num_targets is not None: + if not all(0 <= target_index <= (self.total_num_targets - 1) for target_index in explain_target_indices): + raise ValueError(f"All targets indices have to be in range 0..{self.total_num_targets - 1}.") + else: + if not all(target_index in self.saliency_map for target_index in explain_target_indices): + raise ValueError("Provided targer index {targer_index} is not available among saliency maps.") return explain_target_indices def save(self, dir_path: Path | str, name: str | None = None) -> None: diff --git a/openvino_xai/explainer/utils.py b/openvino_xai/explainer/utils.py index 8aceb35f..b133f426 100644 --- a/openvino_xai/explainer/utils.py +++ b/openvino_xai/explainer/utils.py @@ -7,6 +7,8 @@ import cv2 import numpy as np +from openvino_xai.common.utils import sigmoid, softmax + def convert_targets_to_numpy(targets): targets = np.asarray(targets) @@ -127,17 +129,6 @@ def get_postprocess_fn(logit_name="logits") -> Callable[[], np.ndarray]: return partial(postprocess_fn, logit_name=logit_name) -def softmax(x: np.ndarray) -> np.ndarray: - """Compute softmax values of x.""" - e_x = np.exp(x - np.max(x)) - return e_x / e_x.sum() - - -def sigmoid(x: np.ndarray) -> np.ndarray: - """Compute sigmoid values of x.""" - return 1 / (1 + np.exp(-x)) - - class ActivationType(Enum): SIGMOID = "sigmoid" SOFTMAX = "softmax" @@ -154,34 +145,3 @@ def get_score(x: np.ndarray, index: int, activation: ActivationType = Activation assert x.shape[0] == 1 return x[0, index] return x[index] - - -def format_to_bhwc(image: np.ndarray) -> np.ndarray: - """Format image to BHWC from ndim=3 or ndim=4.""" - if image.ndim == 3: - image = np.expand_dims(image, axis=0) - if not is_bhwc_layout(image): - # bchw layout -> bhwc - image = image.transpose((0, 2, 3, 1)) - return image - - -def is_bhwc_layout(image: np.array) -> bool: - """Check whether layout of image is BHWC.""" - _, dim0, dim1, dim2 = image.shape - if dim0 > dim2 and dim1 > dim2: # bhwc layout - return True - return False - - -def infer_size_from_image(image: np.ndarray) -> Tuple[int, int]: - """Estimate image size.""" - if image.ndim == 2: - return image.shape - - image = format_to_bhwc(image) - if image.ndim == 4: - _, h, w, _ = image.shape - else: - raise ValueError(f"Supports only two, three, and four dimensional image, but got {image.ndim}.") - return h, w diff --git a/openvino_xai/explainer/visualizer.py b/openvino_xai/explainer/visualizer.py index b85a310f..3c7dda3a 100644 --- a/openvino_xai/explainer/visualizer.py +++ b/openvino_xai/explainer/visualizer.py @@ -6,7 +6,7 @@ import cv2 import numpy as np -from openvino_xai.common.utils import scaling +from openvino_xai.common.utils import format_to_bhwc, infer_size_from_image, scaling from openvino_xai.explainer.explanation import ( COLOR_MAPPED_LAYOUTS, GRAY_LAYOUTS, @@ -15,7 +15,6 @@ Explanation, Layout, ) -from openvino_xai.explainer.utils import format_to_bhwc, infer_size_from_image def resize(saliency_map: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: diff --git a/openvino_xai/methods/__init__.py b/openvino_xai/methods/__init__.py index 0e63b285..58c3a114 100644 --- a/openvino_xai/methods/__init__.py +++ b/openvino_xai/methods/__init__.py @@ -3,6 +3,7 @@ """ XAI algorithms. """ +from openvino_xai.methods.black_box.aise import AISE from openvino_xai.methods.black_box.rise import RISE from openvino_xai.methods.white_box.activation_map import ActivationMap from openvino_xai.methods.white_box.base import WhiteBoxMethod @@ -23,4 +24,5 @@ "ViTReciproCAM", "DetClassProbabilityMap", "RISE", + "AISE", ] diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index efe8cc1d..6343c14e 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Callable, Mapping +from typing import Callable, Dict, Mapping import numpy as np import openvino.runtime as ov @@ -41,7 +41,7 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: return self._model_compiled(x) @abstractmethod - def generate_saliency_map(self, data: np.ndarray) -> np.ndarray: + def generate_saliency_map(self, data: np.ndarray) -> Dict[int, np.ndarray] | np.ndarray: """Saliency map generation.""" def load_model(self) -> None: diff --git a/openvino_xai/methods/black_box/aise.py b/openvino_xai/methods/black_box/aise.py new file mode 100644 index 00000000..bd89f0c4 --- /dev/null +++ b/openvino_xai/methods/black_box/aise.py @@ -0,0 +1,242 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import collections +import math +from typing import Callable, Dict, List, Tuple + +import numpy as np +import openvino.runtime as ov +from openvino.runtime.utils.data_helpers.wrappers import OVDict +from scipy.optimize import Bounds, direct + +from openvino_xai.common.utils import ( + IdentityPreprocessFN, + infer_size_from_image, + logger, + scaling, + sigmoid, +) +from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset + + +class AISE(BlackBoxXAIMethod): + """AISE explains classification models in black-box mode using + AISE: Adaptive Input Sampling for Explanation of Black-box Models. + (TODO (negvet): add link to the paper.) + + :param model: OpenVINO model. + :type model: ov.Model + :param postprocess_fn: Post-processing function that extract scores from IR model output. + :type postprocess_fn: Callable[[OVDict], np.ndarray] + :param preprocess_fn: Pre-processing function, identity function by default + (assume input images are already preprocessed by user). + :type preprocess_fn: Callable[[np.ndarray], np.ndarray] + :param device_name: Device type name. + :type device_name: str + :param prepare_model: Loading (compiling) the model prior to inference. + :type prepare_model: bool + """ + + def __init__( + self, + model: ov.Model, + postprocess_fn: Callable[[OVDict], np.ndarray], + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + device_name: str = "CPU", + prepare_model: bool = True, + ): + super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name) + self.postprocess_fn = postprocess_fn + + self.data_preprocessed = None + self.target: int | None = None + self.num_iterations_per_kernel: int | None = None + self.kernel_widths: List[float] | np.ndarray | None = None + self._current_kernel_width: float | None = None + self.solver_epsilon = 0.1 + self.locally_biased = False + self.kernel_params_hist: Dict = collections.defaultdict(list) + self.pred_score_hist: Dict = collections.defaultdict(list) + self.input_size: Tuple[int, int] | None = None + self._mask_generator: GaussianPerturbationMask | None = None + + if prepare_model: + self.prepare_model() + + def generate_saliency_map( # type: ignore + self, + data: np.ndarray, + explain_target_indices: List[int] | None, + preset: Preset = Preset.BALANCE, + num_iterations_per_kernel: int | None = None, + kernel_widths: List[float] | np.ndarray | None = None, + solver_epsilon: float = 0.1, + locally_biased: bool = False, + scale_output: bool = True, + ) -> Dict[int, np.ndarray]: + """ + Generates inference result of the AISE algorithm. + Optimized for per class saliency map generation. Not effcient for large number of classes. + + :param data: Input image. + :type data: np.ndarray + :param explain_target_indices: List of target indices to explain. + :type explain_target_indices: List[int] + :param preset: Speed-Quality preset, defines predefined configurations that manage speed-quality tradeoff. + :type preset: Preset + :param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget. + :type num_iterations_per_kernel: int + :param kernel_widths: Kernel bandwidths. + :type kernel_widths: List[float] | np.ndarray + :param solver_epsilon: Solver epsilon of DIRECT optimizer. + :type solver_epsilon: float + :param locally_biased: Locally biased flag of DIRECT optimizer. + :type locally_biased: bool + :param scale_output: Whether to scale output or not. + :type scale_output: bool + """ + self.data_preprocessed = self.preprocess_fn(data) + + if explain_target_indices is None: + num_classes = self.get_num_classes(self.data_preprocessed) + if num_classes > 10: + logger.info(f"num_classes = {num_classes}, which might take significant time to process.") + explain_target_indices = list(range(num_classes)) + + self._preset_parameters(preset, num_iterations_per_kernel, kernel_widths) + + self.solver_epsilon = solver_epsilon + self.locally_biased = locally_biased + + self.input_size = infer_size_from_image(self.data_preprocessed) + self._mask_generator = GaussianPerturbationMask(self.input_size) + + saliency_maps = {} + for target in explain_target_indices: + self.kernel_params_hist = collections.defaultdict(list) + self.pred_score_hist = collections.defaultdict(list) + + self.target = target + saliency_map_per_target = self._run_synchronous_explanation() + if scale_output: + saliency_map_per_target = scaling(saliency_map_per_target) + saliency_maps[target] = saliency_map_per_target + return saliency_maps + + def _preset_parameters( + self, + preset: Preset, + num_iterations_per_kernel: int | None, + kernel_widths: List[float] | np.ndarray | None, + ) -> None: + if preset == Preset.SPEED: + self.num_iterations_per_kernel = 25 + self.kernel_widths = np.linspace(0.1, 0.25, 3) + elif preset == Preset.BALANCE: + self.num_iterations_per_kernel = 50 + self.kernel_widths = np.linspace(0.1, 0.25, 3) + elif preset == Preset.QUALITY: + self.num_iterations_per_kernel = 85 + self.kernel_widths = np.linspace(0.075, 0.25, 4) + else: + raise ValueError(f"Preset {preset} is not supported.") + + if num_iterations_per_kernel is not None: + self.num_iterations_per_kernel = num_iterations_per_kernel + if kernel_widths is not None: + self.kernel_widths = kernel_widths + + def _run_synchronous_explanation(self) -> np.ndarray: + for kernel_width in self.kernel_widths: + self._current_kernel_width = kernel_width + self._run_optimization() + return self._kernel_density_estimation() + + def _run_optimization(self): + """Run DIRECT optimizer by default.""" + _ = direct( + func=self._objective_function, + bounds=Bounds([0.0, 0.0], [1.0, 1.0]), + eps=self.solver_epsilon, + maxfun=self.num_iterations_per_kernel, + locally_biased=self.locally_biased, + ) + + def _objective_function(self, args) -> float: + """ + Objective function to optimize (to find a global minimum). + Hybrid (dual) paradigm adopted with two sub-objectives: + - preservation + - deletion + """ + mh, mw = args + kernel_params = (mh, mw, self._current_kernel_width) + self.kernel_params_hist[self._current_kernel_width].append(kernel_params) + + kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params) + kernel_mask = np.clip(kernel_mask, 0, 1) + + data_perturbed_preserve = self.data_preprocessed * kernel_mask + pred_score_preserve = self._get_score(data_perturbed_preserve) + + data_perturbed_delete = self.data_preprocessed * (1 - kernel_mask) + pred_score_delete = self._get_score(data_perturbed_delete) + + loss = pred_score_preserve - pred_score_delete + + self.pred_score_hist[self._current_kernel_width].append(pred_score_preserve - pred_score_delete) + + loss *= -1 # Objective: minimize + return loss + + def _get_score(self, data_perturbed: np.array) -> float: + """Get model prediction score for perturbed input.""" + x = self.model_forward(data_perturbed, preprocess=False) + x = self.postprocess_fn(x) + if np.max(x) > 1 or np.min(x) < 0: + x = sigmoid(x) + pred_scores = x.squeeze() # type: ignore + return pred_scores[self.target] + + def _kernel_density_estimation(self) -> np.ndarray: + """Aggregate the result per kernel with KDE.""" + saliency_map_per_kernel = np.zeros((len(self.kernel_widths), self.input_size[0], self.input_size[1])) + for kernel_index, kernel_width in enumerate(self.kernel_widths): + kernel_masks_weighted = np.zeros(self.input_size) + for i in range(self.num_iterations_per_kernel): + kernel_params = self.kernel_params_hist[kernel_width][i] + kernel_mask = self._mask_generator.generate_kernel_mask(kernel_params) + score = self.pred_score_hist[kernel_width][i] + kernel_masks_weighted += kernel_mask * score + kernel_masks_weighted = kernel_masks_weighted / kernel_masks_weighted.max() + saliency_map_per_kernel[kernel_index] = kernel_masks_weighted + + saliency_map = saliency_map_per_kernel.sum(axis=0) + saliency_map /= saliency_map.max() + return saliency_map + + +class GaussianPerturbationMask: + """ + Perturbation mask generator. + """ + + def __init__(self, input_size: Tuple[int, int]): + h = np.linspace(0, 1, input_size[0]) + w = np.linspace(0, 1, input_size[1]) + self.h, self.w = np.meshgrid(w, h) + + def _get_2d_gaussian(self, gauss_params: Tuple[float, float, float]) -> np.ndarray: + mh, mw, sigma = gauss_params + A = 1 / (2 * math.pi * sigma * sigma) + B = (self.h - mh) ** 2 / (2 * sigma**2) + C = (self.w - mw) ** 2 / (2 * sigma**2) + return A * np.exp(-(B + C)) + + def generate_kernel_mask(self, gauss_param: Tuple[float, float, float], scale: float = 1.0): + """ + Generates 2D gaussian mask. + """ + gaussian = self._get_2d_gaussian(gauss_param) + return (gaussian / gaussian.max()) * scale diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index b3aabfd8..7e8c888b 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -1,8 +1,38 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from enum import Enum + +import openvino.runtime as ov + from openvino_xai.methods.base import MethodBase class BlackBoxXAIMethod(MethodBase): """Base class for methods that explain model in Black-Box mode.""" + + def prepare_model(self, load_model: bool = True) -> ov.Model: + if load_model: + self.load_model() + return self._model + + def get_num_classes(self, data_preprocessed): + forward_output = self.model_forward(data_preprocessed, preprocess=False) + logits = self.postprocess_fn(forward_output) + _, num_classes = logits.shape + return num_classes + + +class Preset(Enum): + """ + Enum representing the different presets: + + Contains the following values: + SPEED - Prioritize getting results faster. + BALANCE - Balance between speed and quality. + QUALITY - Prioritize high quality results. + """ + + SPEED = "speed" + BALANCE = "balance" + QUALITY = "quality" diff --git a/openvino_xai/methods/black_box/rise.py b/openvino_xai/methods/black_box/rise.py index 1dc3588a..a8b04eb2 100644 --- a/openvino_xai/methods/black_box/rise.py +++ b/openvino_xai/methods/black_box/rise.py @@ -17,9 +17,9 @@ class RISE(BlackBoxXAIMethod): :param model: OpenVINO model. :type model: ov.Model - :param postprocess_fn: Preprocessing function that extract scores from IR model output. - :type postprocess_fn: Callable[[Mapping], np.ndarray] - :param preprocess_fn: Preprocessing function, identity function by default + :param postprocess_fn: Post-processing function that extract scores from IR model output. + :type postprocess_fn: Callable[[OVDict], np.ndarray] + :param preprocess_fn: Pre-processing function, identity function by default (assume input images are already preprocessed by user). :type preprocess_fn: Callable[[np.ndarray], np.ndarray] :param device_name: Device type name. @@ -42,11 +42,6 @@ def __init__( if prepare_model: self.prepare_model() - def prepare_model(self, load_model: bool = True) -> ov.Model: - if load_model: - self.load_model() - return self._model - def generate_saliency_map( self, data: np.ndarray, @@ -56,7 +51,7 @@ def generate_saliency_map( prob: float = 0.5, seed: int = 0, scale_output: bool = True, - ): + ) -> np.ndarray: """ Generates inference result of the RISE algorithm. @@ -102,13 +97,11 @@ def _run_synchronous_explanation( prob: float, seed: int, ) -> np.ndarray: - from openvino_xai.explainer.utils import is_bhwc_layout + from openvino_xai.common.utils import is_bhwc_layout input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4] - forward_output = self.model_forward(data_preprocessed, preprocess=False) - logits = self.postprocess_fn(forward_output) - _, num_classes = logits.shape + num_classes = self.get_num_classes(data_preprocessed) if target_classes is None: num_targets = num_classes diff --git a/openvino_xai/methods/factory.py b/openvino_xai/methods/factory.py index f59aecfd..199729d3 100644 --- a/openvino_xai/methods/factory.py +++ b/openvino_xai/methods/factory.py @@ -10,6 +10,7 @@ from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import IdentityPreprocessFN, logger from openvino_xai.methods.base import MethodBase +from openvino_xai.methods.black_box.aise import AISE from openvino_xai.methods.black_box.base import BlackBoxXAIMethod from openvino_xai.methods.black_box.rise import RISE from openvino_xai.methods.white_box.activation_map import ActivationMap @@ -184,13 +185,18 @@ def create_method( model: ov.Model, postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + explain_method: Method | None = None, device_name: str = "CPU", **kwargs, ) -> MethodBase: if task == Task.CLASSIFICATION: - return cls.create_classification_method(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + return cls.create_classification_method( + model, postprocess_fn, preprocess_fn, explain_method, device_name, **kwargs + ) if task == Task.DETECTION: - return cls.create_detection_method(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + return cls.create_detection_method( + model, postprocess_fn, preprocess_fn, explain_method, device_name, **kwargs + ) raise ValueError(f"Model type {task} is not supported in black-box mode.") @staticmethod @@ -198,10 +204,12 @@ def create_classification_method( model: ov.Model, postprocess_fn: Callable[[Mapping], np.ndarray], preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + explain_method: Method | None = None, device_name: str = "CPU", **kwargs, ) -> BlackBoxXAIMethod: """Generates instance of the classification black-box method class. + Using AISE as a default method. :param model: OV IR model. :type model: ov.Model @@ -213,7 +221,11 @@ def create_classification_method( :param device_name: Device type name. :type device_name: str """ - return RISE(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + if explain_method is None or explain_method == Method.AISE: + return AISE(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + elif explain_method == Method.RISE: + return RISE(model, postprocess_fn, preprocess_fn, device_name, **kwargs) + raise ValueError(f"Requested explanation method {explain_method} is not implemented.") @staticmethod def create_detection_method(*args, **kwargs) -> BlackBoxXAIMethod: diff --git a/pyproject.toml b/pyproject.toml index 789e6788..c852cea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ version = "1.1.0rc0" dependencies = [ "openvino-dev==2024.2", "opencv-python", + "scipy", "numpy==1.*", "tqdm", ] diff --git a/tests/func/test_classification_timm_full.py b/tests/func/test_classification_timm_full.py index 46834529..6baa7913 100644 --- a/tests/func/test_classification_timm_full.py +++ b/tests/func/test_classification_timm_full.py @@ -268,8 +268,7 @@ def test_classification_black_box(self, model_id, dump_maps=False): explanation = explainer( image, targets=[target_class], - # num_masks=2000, # kwargs of the RISE algo - num_masks=2, # minimal iterations for feature test + num_iterations_per_kernel=2, # minimal iterations for feature test ) assert explanation is not None diff --git a/tests/intg/test_classification.py b/tests/intg/test_classification.py index 39a12c28..7242f758 100644 --- a/tests/intg/test_classification.py +++ b/tests/intg/test_classification.py @@ -14,6 +14,7 @@ from openvino_xai.common.utils import has_xai, retrieve_otx_model from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn +from openvino_xai.methods.black_box.base import Preset MODELS = [ "mlc_mobilenetv3_large_voc", # verified @@ -309,6 +310,44 @@ class TestClsBB: def setup(self, fxt_data_root): self.data_dir = fxt_data_root + @pytest.mark.parametrize("model_name", MODELS) + @pytest.mark.parametrize("overlay", [True, False]) + @pytest.mark.parametrize("scaling", [True, False]) + def test_aise( + self, + model_name: str, + overlay: bool, + scaling: bool, + ): + retrieve_otx_model(self.data_dir, model_name) + model_path = self.data_dir / "otx_models" / (model_name + ".xml") + model = ov.Core().read_model(model_path) + + explainer = Explainer( + model=model, + task=Task.CLASSIFICATION, + preprocess_fn=self.preprocess_fn, # type: ignore + postprocess_fn=get_postprocess_fn(), # type: ignore + explain_mode=ExplainMode.BLACKBOX, + ) + target_class = 1 + explanation = explainer( + self.image, + targets=[target_class], + scaling=scaling, + overlay=overlay, + resize=False, + colormap=False, + num_iterations_per_kernel=2, + kernel_widths=[0.1], + ) + assert target_class in explanation.saliency_map + assert len(explanation.saliency_map) == len([target_class]) + if overlay: + assert explanation.saliency_map[target_class].ndim == 3 + else: + assert explanation.saliency_map[target_class].ndim == 2 + @pytest.mark.parametrize("model_name", MODELS) @pytest.mark.parametrize("overlay", [True, False]) @pytest.mark.parametrize( @@ -319,7 +358,7 @@ def setup(self, fxt_data_root): ], ) @pytest.mark.parametrize("scaling", [True, False]) - def test_classification_black_box_visualizing( + def test_rise( self, model_name: str, overlay: bool, @@ -336,6 +375,7 @@ def test_classification_black_box_visualizing( preprocess_fn=self.preprocess_fn, # type: ignore postprocess_fn=get_postprocess_fn(), # type: ignore explain_mode=ExplainMode.BLACKBOX, + explain_method=Method.RISE, ) if not explain_all_classes: @@ -380,12 +420,8 @@ def test_classification_black_box_visualizing( for map_ in explanation.saliency_map.values(): assert map_.min() == 0, f"{map_.min()}" assert map_.max() in {254, 255}, f"{map_.max()}" - if model_name in self._ref_sal_maps: - actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16) - ref_sal_vals = self._ref_sal_maps[DEFAULT_CLS_MODEL].astype(np.uint8) - assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) - def test_classification_black_box_xai_model_as_input(self): + def test_rise_xai_model_as_input(self): retrieve_otx_model(self.data_dir, DEFAULT_CLS_MODEL) model_path = self.data_dir / "otx_models" / (DEFAULT_CLS_MODEL + ".xml") model = ov.Core().read_model(model_path) @@ -401,6 +437,7 @@ def test_classification_black_box_xai_model_as_input(self): preprocess_fn=self.preprocess_fn, postprocess_fn=get_postprocess_fn(), explain_mode=ExplainMode.BLACKBOX, + explain_method=Method.RISE, ) explanation = explainer( self.image, diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index ac0a123b..d2fb3d47 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -277,8 +277,9 @@ def test_classification_black_box(self, model_id, dump_maps=False): explanation = explainer( image, targets=[target_class], - # num_masks=2000, # kwargs of the RISE algo - num_masks=2, # minimal iterations for feature test + overlay=True, + num_iterations_per_kernel=2, + kernel_widths=[0.1], ) assert explanation is not None @@ -434,7 +435,7 @@ def test_model_format(self, model_id, explain_mode, model_format): targets=[target_class], resize=False, colormap=False, - num_masks=2, # minimal iterations for feature test + num_iterations_per_kernel=2, # minimal iterations for feature test ) assert explanation is not None diff --git a/tests/unit/explanation/test_explainer.py b/tests/unit/explanation/test_explainer.py index ab90169a..bdeceea6 100644 --- a/tests/unit/explanation/test_explainer.py +++ b/tests/unit/explanation/test_explainer.py @@ -10,10 +10,11 @@ from pytest_mock import MockerFixture from openvino_xai.api.api import insert_xai -from openvino_xai.common.parameters import Task +from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.explainer import Explainer, ExplainMode from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn +from openvino_xai.methods.black_box.base import Preset from tests.unit.explanation.test_explanation_utils import VOC_NAMES MODEL_NAME = "mlc_mobilenetv3_large_voc" @@ -39,15 +40,8 @@ def setup(self, fxt_data_root): ExplainMode.AUTO, ], ) - @pytest.mark.parametrize( - "explain_all_classes", - [ - True, - False, - ], - ) @pytest.mark.parametrize("with_xai_originally", [True, False]) - def test_explainer(self, explain_mode, explain_all_classes, with_xai_originally): + def test_explainer(self, explain_mode, with_xai_originally): retrieve_otx_model(self.data_dir, MODEL_NAME) model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml") model = ov.Core().read_model(model_path) @@ -66,19 +60,96 @@ def test_explainer(self, explain_mode, explain_all_classes, with_xai_originally) explain_mode=explain_mode, ) + explanation = explainer( + self.image, + targets=[11, 14], + label_names=VOC_NAMES, # optional + preset=Preset.SPEED, + ) + + assert len(explanation.saliency_map) == 2 + + @pytest.mark.parametrize( + "explain_all_classes", + [ + True, + False, + ], + ) + def test_explainer_all_classes(self, explain_all_classes): + retrieve_otx_model(self.data_dir, MODEL_NAME) + model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml") + model = ov.Core().read_model(model_path) + + explainer = Explainer( + model=model, + task=Task.CLASSIFICATION, + preprocess_fn=self.preprocess_fn, + postprocess_fn=get_postprocess_fn(), + ) + if explain_all_classes: explanation = explainer( self.image, targets=-1, label_names=VOC_NAMES, # optional - num_masks=10, ) else: explanation = explainer( self.image, targets=[11, 14], label_names=VOC_NAMES, # optional - num_masks=10, + ) + + if explain_all_classes: + assert len(explanation.saliency_map) == 20 + else: + assert len(explanation.saliency_map) == 2 + + @pytest.mark.parametrize( + "explain_all_classes", + [ + True, + False, + ], + ) + @pytest.mark.parametrize( + "explain_method", + [ + Method.AISE, + Method.RISE, + ], + ) + def test_explainer_all_classes_black_box(self, explain_all_classes, explain_method): + retrieve_otx_model(self.data_dir, MODEL_NAME) + model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml") + model = ov.Core().read_model(model_path) + + explainer = Explainer( + model=model, + task=Task.CLASSIFICATION, + preprocess_fn=self.preprocess_fn, + postprocess_fn=get_postprocess_fn(), + explain_mode=ExplainMode.BLACKBOX, + explain_method=explain_method, + ) + + if explain_method == Method.AISE: + kwargs = {"num_iterations_per_kernel": 2} + else: + kwargs = {"num_masks": 2} + + if explain_all_classes: + explanation = explainer( + self.image, + targets=-1, + **kwargs, + ) + else: + explanation = explainer( + self.image, + targets=[11, 14], + **kwargs, ) if explain_all_classes: @@ -113,7 +184,7 @@ def test_explainer_wo_preprocessing(self): ) processed_data = self.preprocess_fn(self.image) - explanation = explainer(processed_data, targets=[11, 14], num_masks=10) + explanation = explainer(processed_data, targets=[11, 14], preset=Preset.SPEED) assert len(explanation.saliency_map) == 2 @@ -140,7 +211,7 @@ def test_auto_black_box_fallback(self): explain_mode=ExplainMode.AUTO, target_layer="some_wrong_name", ) - explanation = explainer(self.image, targets=-1, num_masks=10) + explanation = explainer(self.image, targets=-1, num_iterations_per_kernel=2) assert len(explanation.saliency_map) == 20 def test_two_explainers_with_same_model(self): diff --git a/tests/unit/explanation/test_explanation_utils.py b/tests/unit/explanation/test_explanation_utils.py index 3c38ae5a..71eed6d6 100644 --- a/tests/unit/explanation/test_explanation_utils.py +++ b/tests/unit/explanation/test_explanation_utils.py @@ -4,11 +4,11 @@ import numpy as np import pytest +from openvino_xai.common.utils import is_bhwc_layout from openvino_xai.explainer.utils import ( ActivationType, get_explain_target_indices, get_score, - is_bhwc_layout, ) VOC_NAMES = [ diff --git a/tests/unit/methods/black_box/test_black_box_method.py b/tests/unit/methods/black_box/test_black_box_method.py index 99e635c4..db8e0f22 100644 --- a/tests/unit/methods/black_box/test_black_box_method.py +++ b/tests/unit/methods/black_box/test_black_box_method.py @@ -10,11 +10,13 @@ from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn +from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.black_box.base import Preset from openvino_xai.methods.black_box.rise import RISE from tests.intg.test_classification import DEFAULT_CLS_MODEL -class TestRISE: +class InputSampling: image = cv2.imread("tests/assets/cheetah_person.jpg") preprocess_fn = get_preprocess_fn( change_channel_order=True, @@ -23,11 +25,47 @@ class TestRISE: ) postprocess_fn = get_postprocess_fn() - @pytest.mark.parametrize("explain_target_indices", [[0], None]) - def test_run(self, explain_target_indices, fxt_data_root: Path): + def get_model(self, fxt_data_root): retrieve_otx_model(fxt_data_root, DEFAULT_CLS_MODEL) model_path = fxt_data_root / "otx_models" / (DEFAULT_CLS_MODEL + ".xml") - model = ov.Core().read_model(model_path) + return ov.Core().read_model(model_path) + + +class TestAISE(InputSampling): + @pytest.mark.parametrize("explain_target_indices", [[0], [0, 1]]) + def test_run(self, explain_target_indices, fxt_data_root: Path): + model = self.get_model(fxt_data_root) + + aise_method = AISE(model, self.postprocess_fn, self.preprocess_fn) + saliency_map = aise_method.generate_saliency_map( + data=self.image, + explain_target_indices=explain_target_indices, + preset=Preset.SPEED, + num_iterations_per_kernel=10, + kernel_widths=[0.1], + ) + assert aise_method.num_iterations_per_kernel == 10 + assert aise_method.kernel_widths == [0.1] + + assert isinstance(saliency_map, dict) + assert len(saliency_map) == len(explain_target_indices) + for target in explain_target_indices: + assert target in saliency_map + + ref_target = 0 + assert saliency_map[ref_target].dtype == np.uint8 + assert saliency_map[ref_target].shape == (224, 224) + assert (saliency_map[ref_target] >= 0).all() and (saliency_map[ref_target] <= 255).all() + + actual_sal_vals = saliency_map[0][0, :10].astype(np.int16) + ref_sal_vals = np.array([68, 69, 69, 69, 70, 70, 71, 71, 72, 72], dtype=np.uint8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + + +class TestRISE(InputSampling): + @pytest.mark.parametrize("explain_target_indices", [[0], None]) + def test_run(self, explain_target_indices, fxt_data_root: Path): + model = self.get_model(fxt_data_root) rise_method = RISE(model, self.postprocess_fn, self.preprocess_fn) saliency_map = rise_method.generate_saliency_map( @@ -39,3 +77,7 @@ def test_run(self, explain_target_indices, fxt_data_root: Path): assert saliency_map.dtype == np.uint8 assert saliency_map.shape == (1, 20, 224, 224) assert (saliency_map >= 0).all() and (saliency_map <= 255).all() + + actual_sal_vals = saliency_map[0][0][0, :10].astype(np.int16) + ref_sal_vals = np.array([246, 241, 236, 231, 226, 221, 216, 211, 205, 197], dtype=np.uint8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) From c497c01dbe1386700aa1e11a834e5792907f1904 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 1 Aug 2024 16:45:44 +0200 Subject: [PATCH 2/2] Align RISE with AISE (#50) * Preset for RISE + renaming * align aise and rise * rise to output dict * test quality preset * comments * preset num_cells --- README.md | 4 +- openvino_xai/explainer/explainer.py | 8 +- openvino_xai/explainer/explanation.py | 14 +-- openvino_xai/explainer/utils.py | 2 +- openvino_xai/methods/black_box/aise.py | 47 +++++----- openvino_xai/methods/black_box/base.py | 4 +- openvino_xai/methods/black_box/rise.py | 78 ++++++++++------ .../explanation/test_explanation_utils.py | 28 +++--- .../black_box/test_black_box_method.py | 89 ++++++++++++++++--- 9 files changed, 180 insertions(+), 94 deletions(-) diff --git a/README.md b/README.md index a7e1c3c7..db0b9fd3 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ At the moment, *Image Classification* and *Object Detection* tasks are supported | Computer Vision | Image Classification | White-Box | ReciproCAM | [arxiv](https://arxiv.org/abs/2209.14074) / [src](openvino_xai/methods/white_box/recipro_cam.py) | | | | | VITReciproCAM | [arxiv](https://arxiv.org/abs/2310.02588) / [src](openvino_xai/methods/white_box/recipro_cam.py) | | | | | ActivationMap | experimental / [src](openvino_xai/methods/white_box/activation_map.py) | -| | | Black-Box | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) | -| | | | AISE | [src](openvino_xai/methods/black_box/aise.py) | +| | | Black-Box | AISE | [src](openvino_xai/methods/black_box/aise.py) | +| | | | RISE | [arxiv](https://arxiv.org/abs/1806.07421v3) / [src](openvino_xai/methods/black_box/rise.py) | | | Object Detection | White-Box | ClassProbabilityMap | experimental / [src](openvino_xai/methods/white_box/det_class_probability_map.py) | ### Supported explainable models diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index a342de65..4b7cfd22 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -19,7 +19,7 @@ from openvino_xai.explainer.utils import ( convert_targets_to_numpy, explains_all, - get_explain_target_indices, + get_target_indices, ) from openvino_xai.explainer.visualizer import Visualizer from openvino_xai.methods.base import MethodBase @@ -205,16 +205,16 @@ def explain( """ targets = convert_targets_to_numpy(targets) - explain_target_indices = None + target_indices = None if isinstance(self.method, BlackBoxXAIMethod) and not explains_all(targets): - explain_target_indices = get_explain_target_indices( + target_indices = get_target_indices( targets, label_names, ) saliency_map = self.method.generate_saliency_map( data, - explain_target_indices=explain_target_indices, # type: ignore + target_indices=target_indices, # type: ignore **kwargs, ) diff --git a/openvino_xai/explainer/explanation.py b/openvino_xai/explainer/explanation.py index bbd411d6..bd63b676 100644 --- a/openvino_xai/explainer/explanation.py +++ b/openvino_xai/explainer/explanation.py @@ -12,7 +12,7 @@ from openvino_xai.explainer.utils import ( convert_targets_to_numpy, explains_all, - get_explain_target_indices, + get_target_indices, ) @@ -111,11 +111,11 @@ def _select_target_saliency_maps( label_names: List[str] | None = None, ) -> Dict[int | str, np.ndarray]: assert self.layout == Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY - explain_target_indices = self._select_target_indices( + target_indices = self._select_target_indices( targets=targets, label_names=label_names, ) - saliency_maps_selected = {i: self._saliency_map[i] for i in explain_target_indices} + saliency_maps_selected = {i: self._saliency_map[i] for i in target_indices} return saliency_maps_selected def _select_target_indices( @@ -123,14 +123,14 @@ def _select_target_indices( targets: np.ndarray | List[int | str], label_names: List[str] | None = None, ) -> List[int] | np.ndarray: - explain_target_indices = get_explain_target_indices(targets, label_names) + target_indices = get_target_indices(targets, label_names) if self.total_num_targets is not None: - if not all(0 <= target_index <= (self.total_num_targets - 1) for target_index in explain_target_indices): + if not all(0 <= target_index <= (self.total_num_targets - 1) for target_index in target_indices): raise ValueError(f"All targets indices have to be in range 0..{self.total_num_targets - 1}.") else: - if not all(target_index in self.saliency_map for target_index in explain_target_indices): + if not all(target_index in self.saliency_map for target_index in target_indices): raise ValueError("Provided targer index {targer_index} is not available among saliency maps.") - return explain_target_indices + return target_indices def save(self, dir_path: Path | str, name: str | None = None) -> None: """Dumps saliency map.""" diff --git a/openvino_xai/explainer/utils.py b/openvino_xai/explainer/utils.py index b133f426..6251c6b3 100644 --- a/openvino_xai/explainer/utils.py +++ b/openvino_xai/explainer/utils.py @@ -17,7 +17,7 @@ def convert_targets_to_numpy(targets): return np.atleast_1d(targets) -def get_explain_target_indices( +def get_target_indices( targets: np.ndarray | List[int | str], label_names: List[str] | None = None, ) -> List[int]: diff --git a/openvino_xai/methods/black_box/aise.py b/openvino_xai/methods/black_box/aise.py index bd89f0c4..722e9547 100644 --- a/openvino_xai/methods/black_box/aise.py +++ b/openvino_xai/methods/black_box/aise.py @@ -22,7 +22,7 @@ class AISE(BlackBoxXAIMethod): """AISE explains classification models in black-box mode using - AISE: Adaptive Input Sampling for Explanation of Black-box Models. + AISE: Adaptive Input Sampling for Explanation of Black-box Models (TODO (negvet): add link to the paper.) :param model: OpenVINO model. @@ -67,7 +67,7 @@ def __init__( def generate_saliency_map( # type: ignore self, data: np.ndarray, - explain_target_indices: List[int] | None, + target_indices: List[int] | None, preset: Preset = Preset.BALANCE, num_iterations_per_kernel: int | None = None, kernel_widths: List[float] | np.ndarray | None = None, @@ -81,9 +81,9 @@ def generate_saliency_map( # type: ignore :param data: Input image. :type data: np.ndarray - :param explain_target_indices: List of target indices to explain. - :type explain_target_indices: List[int] - :param preset: Speed-Quality preset, defines predefined configurations that manage speed-quality tradeoff. + :param target_indices: List of target indices to explain. + :type target_indices: List[int] + :param preset: Speed-Quality preset, defines predefined configurations that manage the speed-quality tradeoff. :type preset: Preset :param num_iterations_per_kernel: Number of iterations per kernel, defines compute budget. :type num_iterations_per_kernel: int @@ -98,13 +98,17 @@ def generate_saliency_map( # type: ignore """ self.data_preprocessed = self.preprocess_fn(data) - if explain_target_indices is None: + if target_indices is None: num_classes = self.get_num_classes(self.data_preprocessed) if num_classes > 10: logger.info(f"num_classes = {num_classes}, which might take significant time to process.") - explain_target_indices = list(range(num_classes)) + target_indices = list(range(num_classes)) - self._preset_parameters(preset, num_iterations_per_kernel, kernel_widths) + self.num_iterations_per_kernel, self.kernel_widths = self._preset_parameters( + preset, + num_iterations_per_kernel, + kernel_widths, + ) self.solver_epsilon = solver_epsilon self.locally_biased = locally_biased @@ -113,7 +117,7 @@ def generate_saliency_map( # type: ignore self._mask_generator = GaussianPerturbationMask(self.input_size) saliency_maps = {} - for target in explain_target_indices: + for target in target_indices: self.kernel_params_hist = collections.defaultdict(list) self.pred_score_hist = collections.defaultdict(list) @@ -124,28 +128,29 @@ def generate_saliency_map( # type: ignore saliency_maps[target] = saliency_map_per_target return saliency_maps + @staticmethod def _preset_parameters( - self, preset: Preset, num_iterations_per_kernel: int | None, kernel_widths: List[float] | np.ndarray | None, - ) -> None: + ) -> Tuple[int, np.ndarray]: if preset == Preset.SPEED: - self.num_iterations_per_kernel = 25 - self.kernel_widths = np.linspace(0.1, 0.25, 3) + iterations = 25 + widths = np.linspace(0.1, 0.25, 3) elif preset == Preset.BALANCE: - self.num_iterations_per_kernel = 50 - self.kernel_widths = np.linspace(0.1, 0.25, 3) + iterations = 50 + widths = np.linspace(0.1, 0.25, 3) elif preset == Preset.QUALITY: - self.num_iterations_per_kernel = 85 - self.kernel_widths = np.linspace(0.075, 0.25, 4) + iterations = 85 + widths = np.linspace(0.075, 0.25, 4) else: raise ValueError(f"Preset {preset} is not supported.") - if num_iterations_per_kernel is not None: - self.num_iterations_per_kernel = num_iterations_per_kernel - if kernel_widths is not None: - self.kernel_widths = kernel_widths + if num_iterations_per_kernel is None: + num_iterations_per_kernel = iterations + if kernel_widths is None: + kernel_widths = widths + return num_iterations_per_kernel, kernel_widths def _run_synchronous_explanation(self) -> np.ndarray: for kernel_width in self.kernel_widths: diff --git a/openvino_xai/methods/black_box/base.py b/openvino_xai/methods/black_box/base.py index 7e8c888b..8d26dbba 100644 --- a/openvino_xai/methods/black_box/base.py +++ b/openvino_xai/methods/black_box/base.py @@ -28,9 +28,9 @@ class Preset(Enum): Enum representing the different presets: Contains the following values: - SPEED - Prioritize getting results faster. + SPEED - Prioritizes getting results faster. BALANCE - Balance between speed and quality. - QUALITY - Prioritize high quality results. + QUALITY - Prioritizes getting high quality results. """ SPEED = "speed" diff --git a/openvino_xai/methods/black_box/rise.py b/openvino_xai/methods/black_box/rise.py index a8b04eb2..afb3beaa 100644 --- a/openvino_xai/methods/black_box/rise.py +++ b/openvino_xai/methods/black_box/rise.py @@ -1,19 +1,21 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Mapping, Tuple +from typing import Callable, Dict, List, Mapping, Tuple import cv2 import numpy as np import openvino.runtime as ov from tqdm import tqdm -from openvino_xai.common.utils import IdentityPreprocessFN, scaling -from openvino_xai.methods.black_box.base import BlackBoxXAIMethod +from openvino_xai.common.utils import IdentityPreprocessFN, is_bhwc_layout, scaling +from openvino_xai.methods.black_box.base import BlackBoxXAIMethod, Preset class RISE(BlackBoxXAIMethod): - """RISE explains classification models in black-box mode using RISE (https://arxiv.org/abs/1806.07421). + """RISE explains classification models in black-box mode using + 'RISE: Randomized Input Sampling for Explanation of Black-box Models' paper + (https://arxiv.org/abs/1806.07421). :param model: OpenVINO model. :type model: ov.Model @@ -45,20 +47,23 @@ def __init__( def generate_saliency_map( self, data: np.ndarray, - explain_target_indices: List[int] | None = None, - num_masks: int = 5000, + target_indices: List[int] | None = None, + preset: Preset = Preset.BALANCE, + num_masks: int | None = None, num_cells: int = 8, prob: float = 0.5, seed: int = 0, scale_output: bool = True, - ) -> np.ndarray: + ) -> np.ndarray | Dict[int, np.ndarray]: """ Generates inference result of the RISE algorithm. :param data: Input image. :type data: np.ndarray - :param explain_target_indices: List of target indices to explain. - :type explain_target_indices: List[int] + :param target_indices: List of target indices to explain. + :type target_indices: List[int] + :param preset: Speed-Quality preset, defines predefined configurations that manage speed-quality tradeoff. + :type preset: Preset :param num_masks: Number of generated masks to aggregate. :type num_masks: int :param num_cells: Number of cells for low-dimensional RISE @@ -74,20 +79,45 @@ def generate_saliency_map( """ data_preprocessed = self.preprocess_fn(data) + num_masks = self._preset_parameters(preset, num_masks) + saliency_maps = self._run_synchronous_explanation( data_preprocessed, - explain_target_indices, + target_indices, num_masks, num_cells, prob, seed, ) - if scale_output: - saliency_maps = scaling(saliency_maps) - saliency_maps = np.expand_dims(saliency_maps, axis=0) + if isinstance(saliency_maps, np.ndarray): + if scale_output: + saliency_maps = scaling(saliency_maps) + saliency_maps = np.expand_dims(saliency_maps, axis=0) + elif isinstance(saliency_maps, dict): + for target in saliency_maps: + if scale_output: + saliency_maps[target] = scaling(saliency_maps[target]) return saliency_maps + @staticmethod + def _preset_parameters( + preset: Preset, + num_masks: int | None = None, + ) -> int: + # TODO (negvet): preset num_cells + if num_masks is not None: + return num_masks + + if preset == Preset.SPEED: + return 2000 + elif preset == Preset.BALANCE: + return 5000 + elif preset == Preset.QUALITY: + return 8000 + else: + raise ValueError(f"Preset {preset} is not supported.") + def _run_synchronous_explanation( self, data_preprocessed: np.ndarray, @@ -97,8 +127,6 @@ def _run_synchronous_explanation( prob: float, seed: int, ) -> np.ndarray: - from openvino_xai.common.utils import is_bhwc_layout - input_size = data_preprocessed.shape[1:3] if is_bhwc_layout(data_preprocessed) else data_preprocessed.shape[2:4] num_classes = self.get_num_classes(data_preprocessed) @@ -110,7 +138,7 @@ def _run_synchronous_explanation( rand_generator = np.random.default_rng(seed=seed) - sal_maps = np.zeros((num_targets, input_size[0], input_size[1])) + saliency_maps = np.zeros((num_targets, input_size[0], input_size[1])) for _ in tqdm(range(0, num_masks), desc="Explaining in synchronous mode"): mask = self._generate_mask(input_size, num_cells, prob, rand_generator) # Add channel dimensions for masks @@ -123,11 +151,11 @@ def _run_synchronous_explanation( raw_scores = self.postprocess_fn(forward_output) sal = self._get_scored_mask(raw_scores, mask, target_classes) - sal_maps += sal + saliency_maps += sal if target_classes is not None: - sal_maps = self._reconstruct_sparce_saliency_map(sal_maps, num_classes, input_size, target_classes) - return sal_maps + saliency_maps = self._reformat_as_dict(saliency_maps, target_classes) + return saliency_maps @staticmethod def _get_scored_mask(raw_scores: np.ndarray, mask: np.ndarray, target_classes: List[int] | None) -> np.ndarray: @@ -137,15 +165,11 @@ def _get_scored_mask(raw_scores: np.ndarray, mask: np.ndarray, target_classes: L return raw_scores.reshape(-1, 1, 1) * mask @staticmethod - def _reconstruct_sparce_saliency_map( - sal_maps: np.ndarray, num_classes: int, input_size, target_classes: List[int] | None + def _reformat_as_dict( + saliency_maps: np.ndarray, + target_classes: List[int] | None, ) -> np.ndarray: - # TODO: see if np.put() or other alternatives works faster (requires flatten array) - sal_maps_tmp = sal_maps - sal_maps = np.zeros((num_classes, input_size[0], input_size[1])) - for i, sal in enumerate(sal_maps_tmp): - sal_maps[target_classes[i]] = sal - return sal_maps + return {target_class: saliency_map for target_class, saliency_map in zip(target_classes, saliency_maps)} @staticmethod def _generate_mask(input_size: Tuple[int, int], num_cells: int, prob: float, rand_generator) -> np.ndarray: diff --git a/tests/unit/explanation/test_explanation_utils.py b/tests/unit/explanation/test_explanation_utils.py index 71eed6d6..6460f98c 100644 --- a/tests/unit/explanation/test_explanation_utils.py +++ b/tests/unit/explanation/test_explanation_utils.py @@ -5,11 +5,7 @@ import pytest from openvino_xai.common.utils import is_bhwc_layout -from openvino_xai.explainer.utils import ( - ActivationType, - get_explain_target_indices, - get_score, -) +from openvino_xai.explainer.utils import ActivationType, get_score, get_target_indices VOC_NAMES = [ "aeroplane", @@ -38,25 +34,25 @@ LABELS_STR = ["bicycle", "bottle"] -def test_get_explain_target_indices_int(): - explain_target_indices = get_explain_target_indices(LABELS_INT, VOC_NAMES) - assert np.all(explain_target_indices == LABELS_INT) +def test_get_target_indices_int(): + target_indices = get_target_indices(LABELS_INT, VOC_NAMES) + assert np.all(target_indices == LABELS_INT) -def test_get_explain_target_indices_int_wo_names(): - explain_target_indices = get_explain_target_indices(LABELS_INT) - assert np.all(explain_target_indices == LABELS_INT) +def test_get_target_indices_int_wo_names(): + target_indices = get_target_indices(LABELS_INT) + assert np.all(target_indices == LABELS_INT) -def test_get_explain_target_indices_str(): - explain_target_indices = get_explain_target_indices(LABELS_STR, VOC_NAMES) - assert explain_target_indices == [1, 4] +def test_get_target_indices_str(): + target_indices = get_target_indices(LABELS_STR, VOC_NAMES) + assert target_indices == [1, 4] -def test_get_explain_target_indices_str_spelling(): +def test_get_target_indices_str_spelling(): LABELS_STR[0] = "bicycle_" with pytest.raises(Exception) as exc_info: - _ = get_explain_target_indices(LABELS_STR, VOC_NAMES) + _ = get_target_indices(LABELS_STR, VOC_NAMES) assert str(exc_info.value) == "No all label names found in label_names. Check spelling." diff --git a/tests/unit/methods/black_box/test_black_box_method.py b/tests/unit/methods/black_box/test_black_box_method.py index db8e0f22..32c5504e 100644 --- a/tests/unit/methods/black_box/test_black_box_method.py +++ b/tests/unit/methods/black_box/test_black_box_method.py @@ -1,6 +1,7 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import time from pathlib import Path import cv2 @@ -30,16 +31,23 @@ def get_model(self, fxt_data_root): model_path = fxt_data_root / "otx_models" / (DEFAULT_CLS_MODEL + ".xml") return ov.Core().read_model(model_path) + def _generate_with_preset(self, method, preset): + _ = method.generate_saliency_map( + data=self.image, + target_indices=[1], + preset=preset, + ) + class TestAISE(InputSampling): - @pytest.mark.parametrize("explain_target_indices", [[0], [0, 1]]) - def test_run(self, explain_target_indices, fxt_data_root: Path): + @pytest.mark.parametrize("target_indices", [[0], [0, 1]]) + def test_run(self, target_indices, fxt_data_root: Path): model = self.get_model(fxt_data_root) aise_method = AISE(model, self.postprocess_fn, self.preprocess_fn) saliency_map = aise_method.generate_saliency_map( data=self.image, - explain_target_indices=explain_target_indices, + target_indices=target_indices, preset=Preset.SPEED, num_iterations_per_kernel=10, kernel_widths=[0.1], @@ -48,8 +56,8 @@ def test_run(self, explain_target_indices, fxt_data_root: Path): assert aise_method.kernel_widths == [0.1] assert isinstance(saliency_map, dict) - assert len(saliency_map) == len(explain_target_indices) - for target in explain_target_indices: + assert len(saliency_map) == len(target_indices) + for target in target_indices: assert target in saliency_map ref_target = 0 @@ -61,23 +69,76 @@ def test_run(self, explain_target_indices, fxt_data_root: Path): ref_sal_vals = np.array([68, 69, 69, 69, 70, 70, 71, 71, 72, 72], dtype=np.uint8) assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + def test_preset(self, fxt_data_root: Path): + model = self.get_model(fxt_data_root) + method = AISE(model, self.postprocess_fn, self.preprocess_fn) + + tic = time.time() + self._generate_with_preset(method, Preset.SPEED) + toc = time.time() + time_speed = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.BALANCE) + toc = time.time() + time_balance = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.QUALITY) + toc = time.time() + time_quality = toc - tic + + assert time_speed < time_balance < time_quality + class TestRISE(InputSampling): - @pytest.mark.parametrize("explain_target_indices", [[0], None]) - def test_run(self, explain_target_indices, fxt_data_root: Path): + @pytest.mark.parametrize("target_indices", [[0], None]) + def test_run(self, target_indices, fxt_data_root: Path): model = self.get_model(fxt_data_root) rise_method = RISE(model, self.postprocess_fn, self.preprocess_fn) saliency_map = rise_method.generate_saliency_map( self.image, - explain_target_indices, + target_indices, num_masks=5, ) - assert saliency_map.dtype == np.uint8 - assert saliency_map.shape == (1, 20, 224, 224) - assert (saliency_map >= 0).all() and (saliency_map <= 255).all() + if target_indices == [0]: + assert isinstance(saliency_map, dict) + assert saliency_map[0].dtype == np.uint8 + assert saliency_map[0].shape == (224, 224) + assert (saliency_map[0] >= 0).all() and (saliency_map[0] <= 255).all() + + actual_sal_vals = saliency_map[0][0, :10].astype(np.int16) + ref_sal_vals = np.array([246, 241, 236, 231, 226, 221, 216, 211, 205, 197], dtype=np.uint8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + else: + isinstance(saliency_map, np.ndarray) + assert saliency_map.dtype == np.uint8 + assert saliency_map.shape == (1, 20, 224, 224) + assert (saliency_map >= 0).all() and (saliency_map <= 255).all() + + actual_sal_vals = saliency_map[0][0][0, :10].astype(np.int16) + ref_sal_vals = np.array([246, 241, 236, 231, 226, 221, 216, 211, 205, 197], dtype=np.uint8) + assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + + def test_preset(self, fxt_data_root: Path): + model = self.get_model(fxt_data_root) + method = RISE(model, self.postprocess_fn, self.preprocess_fn) - actual_sal_vals = saliency_map[0][0][0, :10].astype(np.int16) - ref_sal_vals = np.array([246, 241, 236, 231, 226, 221, 216, 211, 205, 197], dtype=np.uint8) - assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1) + tic = time.time() + self._generate_with_preset(method, Preset.SPEED) + toc = time.time() + time_speed = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.BALANCE) + toc = time.time() + time_balance = toc - tic + + tic = time.time() + self._generate_with_preset(method, Preset.QUALITY) + toc = time.time() + time_quality = toc - tic + + assert time_speed < time_balance < time_quality