diff --git a/CHANGELOG.md b/CHANGELOG.md index e623fb61..c0ead68e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,8 @@ * Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52 * Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53 * Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51 -* Add [Pointing Game](https://arxiv.org/abs/1608.00507) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54 +* Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54 +* Add [Insertion Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56 ### Known Issues diff --git a/openvino_xai/explainer/utils.py b/openvino_xai/explainer/utils.py index 6251c6b3..e4ec69c6 100644 --- a/openvino_xai/explainer/utils.py +++ b/openvino_xai/explainer/utils.py @@ -119,22 +119,29 @@ def get_preprocess_fn( ) -def postprocess_fn(x: Mapping, logit_name="logits") -> np.ndarray: - """Postprocess function.""" - return x.get(logit_name, x[0]) # Models from OVC has no output names at times - - -def get_postprocess_fn(logit_name="logits") -> Callable[[], np.ndarray]: - """Returns partially initialized postprocess_fn.""" - return partial(postprocess_fn, logit_name=logit_name) - - class ActivationType(Enum): SIGMOID = "sigmoid" SOFTMAX = "softmax" NONE = "none" +def postprocess_fn(x: Mapping, logit_name="logits", activation: ActivationType = ActivationType.NONE) -> np.ndarray: + """Postprocess function.""" + x = x.get(logit_name, x[0]) # Models from OVC has no output names at times + if activation == ActivationType.SOFTMAX: + return softmax(x) + if activation == ActivationType.SIGMOID: + return sigmoid(x) + return x + + +def get_postprocess_fn( + logit_name="logits", activation: ActivationType = ActivationType.NONE +) -> Callable[[], np.ndarray]: + """Returns partially initialized postprocess_fn.""" + return partial(postprocess_fn, logit_name=logit_name, activation=activation) + + def get_score(x: np.ndarray, index: int, activation: ActivationType = ActivationType.NONE): """Returns activated score at index.""" if activation == ActivationType.SOFTMAX: diff --git a/openvino_xai/methods/base.py b/openvino_xai/methods/base.py index 6343c14e..2673491d 100644 --- a/openvino_xai/methods/base.py +++ b/openvino_xai/methods/base.py @@ -5,7 +5,7 @@ from typing import Callable, Dict, Mapping import numpy as np -import openvino.runtime as ov +import openvino as ov from openvino_xai.common.utils import IdentityPreprocessFN @@ -25,7 +25,7 @@ def __init__( self._device_name = device_name @property - def model_compiled(self) -> ov.ie_api.CompiledModel | None: + def model_compiled(self) -> ov.CompiledModel | None: return self._model_compiled @abstractmethod diff --git a/openvino_xai/metrics/base.py b/openvino_xai/metrics/base.py new file mode 100644 index 00000000..a26a8cbb --- /dev/null +++ b/openvino_xai/metrics/base.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List + +import numpy as np +import openvino as ov + +from openvino_xai.common.utils import IdentityPreprocessFN +from openvino_xai.explainer.explanation import Explanation + + +class BaseMetric(ABC): + """Base class for XAI quality metric.""" + + def __init__( + self, + model_compiled: ov.CompiledModel = None, + preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), + postprocess_fn: Callable[[np.ndarray], np.ndarray] = None, + ): + # Pass model_predict to class initialization directly? + self.model_compiled = model_compiled + self.preprocess_fn = preprocess_fn + self.postprocess_fn = postprocess_fn + + def model_predict(self, input: np.ndarray) -> np.ndarray: + logits = self.model_compiled([self.preprocess_fn(input)]) + logits = self.postprocess_fn(logits)[0] + return logits + + @abstractmethod + def __call__(self, saliency_map, *args: Any, **kwargs: Any) -> Dict[str, float]: + """Calculate the metric for the single saliency map""" + + @abstractmethod + def evaluate(self, explanations: List[Explanation], *args: Any, **kwargs: Any) -> Dict[str, float]: + """Evaluate the quality of saliency maps over the list of images""" diff --git a/openvino_xai/metrics/insertion_deletion_auc.py b/openvino_xai/metrics/insertion_deletion_auc.py new file mode 100644 index 00000000..2d69b2e1 --- /dev/null +++ b/openvino_xai/metrics/insertion_deletion_auc.py @@ -0,0 +1,109 @@ +from typing import Any, Dict, List, Tuple + +import numpy as np + +from openvino_xai.explainer.explanation import Explanation, Layout +from openvino_xai.metrics.base import BaseMetric + + +def AUC(arr: np.array) -> float: + """ + Returns normalized Area Under Curve of the array. + """ + return np.abs((arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1)) + + +class InsertionDeletionAUC(BaseMetric): + """ + Implementation of the Insertion and Deletion AUC by Petsiuk et al. 2018. + + References: + Petsiuk, Vitali, Abir Das, and Kate Saenko. "Rise: Randomized input sampling + for explanation of black-box models." arXiv preprint arXiv:1806.07421 (2018). + """ + + @staticmethod + def step_image_insertion_deletion( + num_pixels: int, sorted_indices: Tuple[np.ndarray, np.ndarray], input_image: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Return insertion/deletion image based on number of pixels to add/delete on this step. + """ + # Values to start + image_insertion = np.full_like(input_image, 0) + image_deletion = input_image.copy() + + x_indices = sorted_indices[0][:num_pixels] + y_indices = sorted_indices[1][:num_pixels] + + # Insert the image on the places of the important pixels + image_insertion[x_indices, y_indices] = input_image[x_indices, y_indices] + # Remove image pixels on the places of the important pixels + image_deletion[x_indices, y_indices] = 0 + return image_insertion, image_deletion + + def __call__( + self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray, steps: int = 100, **kwargs: Any + ) -> Dict[str, float]: + """ + Calculate the Insertion and Deletion AUC metrics for one saliency map for one class. + + Parameters: + :param saliency_map: Importance scores for each pixel (H, W). + :type saliency_map: np.ndarray + :param class_idx: The class of saliency map to evaluate. + :type class_idx: int + :param input_image: The input image to the model (H, W, C). + :type input_image: np.ndarray + :param steps: Number of steps for inserting pixels. + :type steps: int + + Returns: + :return: A dictionary containing the AUC scores for insertion and deletion scores. + :rtype: Dict[str, float] + """ + # Sort pixels by descending importance to find the most important pixels + sorted_indices = np.argsort(-saliency_map.flatten()) + sorted_indices = np.unravel_index(sorted_indices, saliency_map.shape) + + insertion_scores, deletion_scores = [], [] + for i in range(steps + 1): + num_pixels = int(i * len(sorted_indices[0]) / steps) + step_image_insertion, step_image_deletion = self.step_image_insertion_deletion( + num_pixels, sorted_indices, input_image + ) + # Predict on masked image + insertion_scores.append(self.model_predict(step_image_insertion)[class_idx]) + deletion_scores.append(self.model_predict(step_image_deletion)[class_idx]) + insertion = AUC(np.array(insertion_scores)) + deletion = AUC(np.array(deletion_scores)) + return {"insertion": insertion, "deletion": deletion} + + def evaluate( + self, explanations: List[Explanation], input_images: List[np.ndarray], steps: int, **kwargs: Any + ) -> Dict[str, float]: + """ + Evaluate the insertion and deletion AUC over the list of images and its saliency maps. + + :param explanations: List of explanation objects containing saliency maps. + :type explanations: List[Explanation] + :param input_images: List of input images as numpy arrays. + :type input_images: List[np.ndarray] + :param steps: Number of steps for the insertion and deletion process. + :type steps: int + + :return: A Dict containing the mean insertion AUC, mean deletion AUC, and their difference (delta) as values. + :rtype: float + """ + for explanation in explanations: + assert explanation.layout in [Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY, Layout.MULTIPLE_MAPS_PER_IMAGE_COLOR] + + results = [] + for input_image, explanation in zip(input_images, explanations): + for class_idx, saliency_map in explanation.saliency_map.items(): + metric_dict = self(saliency_map, int(class_idx), input_image, steps) + results.append([metric_dict["insertion"], metric_dict["deletion"]]) + + insertion, deletion = np.mean(np.array(results), axis=0) + delta = insertion - deletion + return {"insertion": insertion, "deletion": deletion, "delta": delta} diff --git a/openvino_xai/metrics/pointing_game.py b/openvino_xai/metrics/pointing_game.py index 855e107c..80b073c3 100644 --- a/openvino_xai/metrics/pointing_game.py +++ b/openvino_xai/metrics/pointing_game.py @@ -1,15 +1,16 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import numpy as np from openvino_xai.common.utils import logger from openvino_xai.explainer.explanation import Explanation +from openvino_xai.metrics.base import BaseMetric -class PointingGame: +class PointingGame(BaseMetric): """ Implementation of the Pointing Game by Zhang et al., 2018. @@ -29,18 +30,21 @@ class PointingGame: """ @staticmethod - def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> bool: + def __call__(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int, int, int]]) -> Dict[str, float]: """ - Implements the Pointing Game metric using a saliency map and bounding boxes of the same image and class. - Returns a boolean indicating if any of the most salient points fall within the ground truth bounding boxes. + Calculate the Pointing Game metric for one saliency map for one class. + + This implementation uses a saliency map and bounding boxes of the same image and class. + Returns a dictionary with the result of the Pointing Game metric. + 1.0 if any of the most salient points fall within the ground truth bounding boxes, 0.0 otherwise. :param saliency_map: A 2D numpy array representing the saliency map for the image. :type saliency_map: np.ndarray :param image_gt_bboxes: A list of tuples (x, y, w, h) representing the bounding boxes of the ground truth objects. :type image_gt_bboxes: List[Tuple[int, int, int, int]] - :return: True if any of the most salient points fall within any of the ground truth bounding boxes, False otherwise. - :rtype: bool + :return: A dictionary with the result of the Pointing Game metric. + :rtype: Dict[str, float] """ # TODO: Optimize calculation by generating a mask from annotation and finding the intersection # Find the most salient points in the saliency map @@ -51,12 +55,15 @@ def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int for max_point_y, max_point_x in max_indices: # Check if this point is within the ground truth bounding box if x <= max_point_x <= x + w and y <= max_point_y <= y + h: - return True - return False + return {"pointing_game": 1.0} + return {"pointing_game": 0.0} def evaluate( - self, explanations: List[Explanation], gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]] - ) -> float: + self, + explanations: List[Explanation], + gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]], + **kwargs: Any, + ) -> Dict[str, float]: """ Evaluates the Pointing Game metric over a set of images. Skips saliency maps if the gt bboxes for this class are absent. @@ -65,15 +72,15 @@ def evaluate( :param gt_bboxes: A list of dictionaries {label_name: lists of bounding boxes} for each image. :type gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]] - :return: Pointing game score over a list of images - :rtype: float + :return: Dict with "Pointing game" as a key and score over a list of images as as a value. + :rtype: Dict[str, float] """ assert len(explanations) == len( gt_bboxes ), "Number of explanations and ground truth bounding boxes must match and equal to number of images." - hits = 0 + hits = 0.0 num_sal_maps = 0 for explanation, image_gt_bboxes in zip(explanations, gt_bboxes): label_names = explanation.label_names @@ -90,7 +97,8 @@ def evaluate( continue class_gt_bboxes = image_gt_bboxes[label_name] - hits += self.pointing_game(class_sal_map, class_gt_bboxes) + hits += self(class_sal_map, class_gt_bboxes)["pointing_game"] num_sal_maps += 1 - return hits / num_sal_maps if num_sal_maps > 0 else 0.0 + score = hits / num_sal_maps if num_sal_maps > 0 else 0.0 + return {"pointing_game": score} diff --git a/tests/regression/test_regression.py b/tests/regression/test_regression.py index eef04816..072ff965 100644 --- a/tests/regression/test_regression.py +++ b/tests/regression/test_regression.py @@ -11,7 +11,12 @@ from openvino_xai import Task from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.explainer import Explainer, ExplainMode -from openvino_xai.explainer.utils import get_preprocess_fn +from openvino_xai.explainer.utils import ( + ActivationType, + get_postprocess_fn, + get_preprocess_fn, +) +from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC from openvino_xai.metrics.pointing_game import PointingGame from tests.unit.explanation.test_explanation_utils import VOC_NAMES @@ -23,7 +28,6 @@ def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, int, int]]]]: """ Loads ground truth bounding boxes from a COCO format JSON file. - Returns a list of dictionaries, where each dictionary corresponds to an image. The key is the label name and the value is a list of bounding boxes for certain image. """ @@ -61,12 +65,16 @@ class TestDummyRegression: hwc_to_chw=True, ) + postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID) + @pytest.fixture(autouse=True) def setup(self, fxt_data_root): data_dir = fxt_data_root retrieve_otx_model(data_dir, MODEL_NAME) model_path = data_dir / "otx_models" / (MODEL_NAME + ".xml") model = ov.Core().read_model(model_path) + compiled_model = ov.Core().compile_model(model, "CPU") + self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn) self.explainer = Explainer( model=model, @@ -76,28 +84,41 @@ def setup(self, fxt_data_root): ) def test_explainer_image(self): - explanation = self.explainer( - self.image, - targets=["person"], - label_names=VOC_NAMES, - colormap=False, - ) + explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False) assert len(explanation.saliency_map) == 1 - score = self.pointing_game.evaluate([explanation], self.gt_bboxes) - assert score == 1.0 + pointing_game_score = self.pointing_game.evaluate([explanation], self.gt_bboxes)["pointing_game"] + assert pointing_game_score == 1.0 + + explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False) + assert len(explanation.saliency_map) == 1 + auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values() + insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score + assert insertion_auc_score >= 0.9 + assert deletion_auc_score >= 0.2 + assert delta_auc_score >= 0.7 + + # Two classes for saliency maps + explanation = self.explainer(self.image, targets=["person", "cat"], label_names=VOC_NAMES, colormap=False) + assert len(explanation.saliency_map) == 2 + auc_score = self.auc.evaluate([explanation], [self.image], steps=10).values() + insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score + assert insertion_auc_score >= 0.5 + assert deletion_auc_score >= 0.1 + assert delta_auc_score >= 0.35 def test_explainer_images(self): images = [self.image, self.image] explanations = [] for image in images: - explanation = self.explainer( - image, - targets=["person"], - label_names=VOC_NAMES, - colormap=False, - ) + explanation = self.explainer(image, targets=["person"], label_names=VOC_NAMES, colormap=False) explanations.append(explanation) dataset_gt_bboxes = self.gt_bboxes * 2 - score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes) - assert score == 1.0 + pointing_game_score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes)["pointing_game"] + assert pointing_game_score == 1.0 + + auc_score = self.auc.evaluate(explanations, images, steps=10).values() + insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score + assert insertion_auc_score >= 0.9 + assert deletion_auc_score >= 0.2 + assert delta_auc_score >= 0.7 diff --git a/tests/unit/metrics/test_auc.py b/tests/unit/metrics/test_auc.py new file mode 100644 index 00000000..02d51435 --- /dev/null +++ b/tests/unit/metrics/test_auc.py @@ -0,0 +1,69 @@ +import cv2 +import numpy as np +import openvino as ov +import pytest + +from openvino_xai import Task +from openvino_xai.common.utils import retrieve_otx_model +from openvino_xai.explainer.explainer import Explainer, ExplainMode +from openvino_xai.explainer.explanation import Explanation +from openvino_xai.explainer.utils import ( + ActivationType, + get_postprocess_fn, + get_preprocess_fn, +) +from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC + +MODEL_NAME = "mlc_mobilenetv3_large_voc" + + +class TestAUC: + image = cv2.imread("tests/assets/cheetah_person.jpg") + preprocess_fn = get_preprocess_fn( + change_channel_order=True, + input_size=(224, 224), + hwc_to_chw=True, + ) + postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID) + steps = 10 + + @pytest.fixture(autouse=True) + def setup(self, fxt_data_root): + self.data_dir = fxt_data_root + retrieve_otx_model(self.data_dir, MODEL_NAME) + model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml") + core = ov.Core() + model = core.read_model(model_path) + compiled_model = core.compile_model(model=model, device_name="AUTO") + self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn) + + self.explainer = Explainer( + model=model, + task=Task.CLASSIFICATION, + preprocess_fn=self.preprocess_fn, + explain_mode=ExplainMode.WHITEBOX, + ) + + def test_insertion_deletion_auc(self): + class_idx = 1 + input_image = np.random.rand(224, 224, 3) + saliency_map = np.random.rand(224, 224) + + insertion_auc, deletion_auc = self.auc(saliency_map, class_idx, input_image, self.steps).values() + + for value in [insertion_auc, deletion_auc]: + assert isinstance(value, float) + assert 0 <= value <= 1 + + def test_evaluate(self): + input_images = [np.random.rand(224, 224, 3) for _ in range(5)] + explanations = [ + Explanation({0: np.random.rand(224, 224), 1: np.random.rand(224, 224)}, targets=[0, 1]) for _ in range(5) + ] + + insertion, deletion, delta = self.auc.evaluate(explanations, input_images, self.steps).values() + + for value in [insertion, deletion]: + assert isinstance(value, float) + assert 0 <= value <= 1 + assert isinstance(delta, float) diff --git a/tests/unit/metrics/test_pointing_game.py b/tests/unit/metrics/test_pointing_game.py index 37c09393..b6de5403 100644 --- a/tests/unit/metrics/test_pointing_game.py +++ b/tests/unit/metrics/test_pointing_game.py @@ -17,12 +17,12 @@ def test_pointing_game(self): saliency_map[1, 1] = 1 ground_truth_bbox = [(1, 1, 1, 1)] - score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox) - assert score == 1 + score_result = self.pointing_game(saliency_map, ground_truth_bbox) + assert score_result["pointing_game"] == 1 ground_truth_bbox = [(0, 0, 0, 0)] - score = self.pointing_game.pointing_game(saliency_map, ground_truth_bbox) - assert score == 0 + score_result = self.pointing_game(saliency_map, ground_truth_bbox) + assert score_result["pointing_game"] == 0 def test_pointing_game_evaluate(self, caplog): pointing_game = PointingGame() @@ -34,25 +34,25 @@ def test_pointing_game_evaluate(self, caplog): explanations = [explanation] gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}] - score = pointing_game.evaluate(explanations, gt_bboxes) - assert score == 1.0 + score_result = pointing_game.evaluate(explanations, gt_bboxes) + assert score_result["pointing_game"] == 1.0 # No hit for dog class saliency map, hit for cat class saliency map gt_bboxes = [{"cat": [(0, 0, 2, 2), (0, 0, 1, 1)], "dog": [(0, 0, 0, 0)]}] - score = pointing_game.evaluate(explanations, gt_bboxes) - assert score == 0.5 + score_result = pointing_game.evaluate(explanations, gt_bboxes) + assert score_result["pointing_game"] == 0.5 # No ground truth bboxes for available saliency map classes gt_bboxes = [{"not-cat": [(0, 0, 2, 2)], "not-dog": [(0, 0, 0, 0)]}] with caplog.at_level(logging.INFO): - score = pointing_game.evaluate(explanations, gt_bboxes) + score_result = pointing_game.evaluate(explanations, gt_bboxes) assert "Skip pointing game evaluation for this saliency map." in caplog.text - assert score == 0.0 + assert score_result["pointing_game"] == 0.0 # Ground truth bboxes / saliency maps number mismatch gt_bboxes = [] with pytest.raises(AssertionError): - score = pointing_game.evaluate(explanations, gt_bboxes) + score_result = pointing_game.evaluate(explanations, gt_bboxes) # No label names explanation = Explanation( @@ -63,4 +63,4 @@ def test_pointing_game_evaluate(self, caplog): explanations = [explanation] gt_bboxes = [{"cat": [(0, 0, 2, 2)], "dog": [(0, 0, 1, 1)]}] with pytest.raises(AssertionError): - score = pointing_game.evaluate(explanations, gt_bboxes) + score_result = pointing_game.evaluate(explanations, gt_bboxes)