Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insertion Deletion AUC metric #56

Merged
merged 9 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
* Upgrade OpenVINO to 2024.3.0 by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/52
* Add saliency map visualization with explanation.plot() by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/53
* Enable flexible naming for saved saliency maps and include confidence scores by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/51
* Add [Pointing Game](https://arxiv.org/abs/1608.00507) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
* Add [Pointing Game](https://link.springer.com/article/10.1007/s11263-017-1059-x) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/54
* Add [Insertion Deletion AUC](https://arxiv.org/abs/1806.07421) saliency map quality metric by @GalyaZalesskaya in https://github.com/openvinotoolkit/openvino_xai/pull/56

### Known Issues

Expand Down
27 changes: 17 additions & 10 deletions openvino_xai/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,29 @@
)


def postprocess_fn(x: Mapping, logit_name="logits") -> np.ndarray:
"""Postprocess function."""
return x.get(logit_name, x[0]) # Models from OVC has no output names at times


def get_postprocess_fn(logit_name="logits") -> Callable[[], np.ndarray]:
"""Returns partially initialized postprocess_fn."""
return partial(postprocess_fn, logit_name=logit_name)


class ActivationType(Enum):
SIGMOID = "sigmoid"
SOFTMAX = "softmax"
NONE = "none"


def postprocess_fn(x: Mapping, logit_name="logits", activation: ActivationType = ActivationType.NONE) -> np.ndarray:
"""Postprocess function."""
x = x.get(logit_name, x[0]) # Models from OVC has no output names at times
if activation == ActivationType.SOFTMAX:
return softmax(x)

Check warning on line 132 in openvino_xai/explainer/utils.py

View check run for this annotation

Codecov / codecov/patch

openvino_xai/explainer/utils.py#L132

Added line #L132 was not covered by tests
if activation == ActivationType.SIGMOID:
return sigmoid(x)
return x


def get_postprocess_fn(
logit_name="logits", activation: ActivationType = ActivationType.NONE
) -> Callable[[], np.ndarray]:
"""Returns partially initialized postprocess_fn."""
return partial(postprocess_fn, logit_name=logit_name, activation=activation)


def get_score(x: np.ndarray, index: int, activation: ActivationType = ActivationType.NONE):
"""Returns activated score at index."""
if activation == ActivationType.SOFTMAX:
Expand Down
33 changes: 33 additions & 0 deletions openvino_xai/metrics/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Tuple

import numpy as np
import openvino.runtime as ov
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved

from openvino_xai.common.utils import IdentityPreprocessFN
from openvino_xai.explainer.explanation import Explanation


class BaseMetric(ABC):
"""Base class for XAI quality metric."""

def __init__(
self,
model_compiled: ov.ie_api.CompiledModel = None,
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
postprocess_fn: Callable[[np.ndarray], np.ndarray] = None,
):
self.model_compiled = model_compiled
self.preprocess_fn = preprocess_fn
self.postprocess_fn = postprocess_fn

def model_predict(self, input: np.ndarray) -> np.ndarray:
logits = self.model_compiled([self.preprocess_fn(input)])
logits = self.postprocess_fn(logits)[0]
return logits

@abstractmethod
def evaluate(
self, explanations: List[Explanation], *args: Any, **kwargs: Any
) -> float | Tuple[float, float, float]:
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved
"""Evaluate the quality of saliency maps over the list of images"""
99 changes: 99 additions & 0 deletions openvino_xai/metrics/insertion_deletion_auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, List, Tuple

import numpy as np

from openvino_xai.explainer.explanation import Explanation, Layout
from openvino_xai.metrics.base import BaseMetric


def auc(arr):
"""Returns normalized Area Under Curve of the array."""
return np.abs((arr.sum() - arr[0] / 2 - arr[-1] / 2) / (arr.shape[0] - 1))
GalyaZalesskaya marked this conversation as resolved.
Show resolved Hide resolved


class InsertionDeletionAUC(BaseMetric):
"""
Implementation of the Insertion and Deletion AUC by Petsiuk et al. 2018.

References:
Petsiuk, Vitali, Abir Das, and Kate Saenko. "Rise: Randomized input sampling
for explanation of black-box models." arXiv preprint arXiv:1806.07421 (2018).
"""

def insertion_deletion_auc(
self, saliency_map: np.ndarray, class_idx: int, input_image: np.ndarray, steps: int = 100
) -> Tuple[float, float]:
"""
Calculate the Insertion and Deletion AUC metrics for one saliency map for one class.

Parameters:
:param saliency_map: Importance scores for each pixel (H, W).
:type saliency_map: np.ndarray
:param class_idx: The class of saliency map to evaluate.
:type class_idx: int
:param input_image: The input image to the model (H, W, C).
:type input_image: np.ndarray
:param steps: Number of steps for inserting pixels.
:type steps: int

Returns:
- insertion_auc_score: Saliency map AUC for insertion.
- deletion_auc_score: Saliency map AUC for deletion.
"""
# Values to start
baseline_insertion = np.full_like(input_image, 0)
baseline_deletion = input_image

# Sort pixels by descending importance to find the most important pixels
sorted_indices = np.argsort(-saliency_map.flatten())
sorted_indices = np.unravel_index(sorted_indices, saliency_map.shape)

insertion_scores, deletion_scores = [], []
for i in range(steps + 1):
temp_image_insertion, temp_image_deletion = baseline_insertion.copy(), baseline_deletion.copy()

num_pixels = int(i * len(sorted_indices[0]) / steps)
x_indices = sorted_indices[0][:num_pixels]
y_indices = sorted_indices[1][:num_pixels]

# Insert the image on the places of the important pixels
temp_image_insertion[x_indices, y_indices] = input_image[x_indices, y_indices]
# Remove image pixels on the places of the important pixels
temp_image_deletion[x_indices, y_indices] = 0

# Predict on masked image
temp_logits_insertion = self.model_predict(temp_image_insertion)
temp_logits_deletion = self.model_predict(temp_image_deletion)

insertion_scores.append(temp_logits_insertion[class_idx])
deletion_scores.append(temp_logits_deletion[class_idx])
return auc(np.array(insertion_scores)), auc(np.array(deletion_scores))

def evaluate(
self, explanations: List[Explanation], input_images: List[np.ndarray], steps: int, **kwargs: Any
) -> Tuple[float, float, float]:
"""
Evaluate the insertion and deletion AUC over the list of images and its saliency maps.

:param explanations: List of explanation objects containing saliency maps.
:type explanations: List[Explanation]
:param input_images: List of input images as numpy arrays.
:type input_images: List[np.ndarray]
:param steps: Number of steps for the insertion and deletion process.
:type steps: int

:return: A tuple containing the mean insertion AUC, mean deletion AUC, and their difference (delta).
:rtype: float
"""
for explanation in explanations:
assert explanation.layout in [Layout.MULTIPLE_MAPS_PER_IMAGE_GRAY, Layout.MULTIPLE_MAPS_PER_IMAGE_COLOR]

results = []
for input_image, explanation in zip(input_images, explanations):
for class_idx, saliency_map in explanation.saliency_map.items():
insertion, deletion = self.insertion_deletion_auc(saliency_map, int(class_idx), input_image, steps)
results.append([insertion, deletion])

insertion, deletion = np.mean(np.array(results), axis=0)
delta = insertion - deletion
return insertion, deletion, delta
10 changes: 7 additions & 3 deletions openvino_xai/metrics/pointing_game.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple

import numpy as np

from openvino_xai.common.utils import logger
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.metrics.base import BaseMetric


class PointingGame:
class PointingGame(BaseMetric):
"""
Implementation of the Pointing Game by Zhang et al., 2018.

Expand Down Expand Up @@ -55,7 +56,10 @@ def pointing_game(saliency_map: np.ndarray, image_gt_bboxes: List[Tuple[int, int
return False

def evaluate(
self, explanations: List[Explanation], gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]]
self,
explanations: List[Explanation],
gt_bboxes: List[Dict[str, List[Tuple[int, int, int, int]]]],
**kwargs: Any,
) -> float:
"""
Evaluates the Pointing Game metric over a set of images. Skips saliency maps if the gt bboxes for this class are absent.
Expand Down
53 changes: 37 additions & 16 deletions tests/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from openvino_xai import Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.utils import get_preprocess_fn
from openvino_xai.explainer.utils import (
ActivationType,
get_postprocess_fn,
get_preprocess_fn,
)
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC
from openvino_xai.metrics.pointing_game import PointingGame
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

Expand All @@ -23,7 +28,6 @@
def load_gt_bboxes(json_coco_path: str) -> List[Dict[str, List[Tuple[int, int, int, int]]]]:
"""
Loads ground truth bounding boxes from a COCO format JSON file.

Returns a list of dictionaries, where each dictionary corresponds to an image.
The key is the label name and the value is a list of bounding boxes for certain image.
"""
Expand Down Expand Up @@ -61,12 +65,16 @@ class TestDummyRegression:
hwc_to_chw=True,
)

postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
data_dir = fxt_data_root
retrieve_otx_model(data_dir, MODEL_NAME)
model_path = data_dir / "otx_models" / (MODEL_NAME + ".xml")
model = ov.Core().read_model(model_path)
compiled_model = ov.Core().compile_model(model, "CPU")
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)

self.explainer = Explainer(
model=model,
Expand All @@ -76,28 +84,41 @@ def setup(self, fxt_data_root):
)

def test_explainer_image(self):
explanation = self.explainer(
self.image,
targets=["person"],
label_names=VOC_NAMES,
colormap=False,
)
explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 1
score = self.pointing_game.evaluate([explanation], self.gt_bboxes)
assert score == 1.0

explanation = self.explainer(self.image, targets=["person"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 1
auc_score = self.auc.evaluate([explanation], [self.image], steps=10)
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
assert insertion_auc_score >= 0.9
assert deletion_auc_score >= 0.2
assert delta_auc_score >= 0.7

# Two classes for saliency maps
explanation = self.explainer(self.image, targets=["person", "cat"], label_names=VOC_NAMES, colormap=False)
assert len(explanation.saliency_map) == 2
auc_score = self.auc.evaluate([explanation], [self.image], steps=10)
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
assert insertion_auc_score >= 0.5
assert deletion_auc_score >= 0.1
assert delta_auc_score >= 0.35

def test_explainer_images(self):
images = [self.image, self.image]
explanations = []
for image in images:
explanation = self.explainer(
image,
targets=["person"],
label_names=VOC_NAMES,
colormap=False,
)
explanation = self.explainer(image, targets=["person"], label_names=VOC_NAMES, colormap=False)
explanations.append(explanation)
dataset_gt_bboxes = self.gt_bboxes * 2

score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes)
assert score == 1.0
pointing_game_score = self.pointing_game.evaluate(explanations, dataset_gt_bboxes)
assert pointing_game_score == 1.0

auc_score = self.auc.evaluate(explanations, images, steps=10)
insertion_auc_score, deletion_auc_score, delta_auc_score = auc_score
assert insertion_auc_score >= 0.9
assert deletion_auc_score >= 0.2
assert delta_auc_score >= 0.7
69 changes: 69 additions & 0 deletions tests/unit/metrics/test_auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import cv2
import numpy as np
import openvino as ov
import pytest

from openvino_xai import Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.explainer.utils import (
ActivationType,
get_postprocess_fn,
get_preprocess_fn,
)
from openvino_xai.metrics.insertion_deletion_auc import InsertionDeletionAUC

MODEL_NAME = "mlc_mobilenetv3_large_voc"


class TestAUC:
image = cv2.imread("tests/assets/cheetah_person.jpg")
preprocess_fn = get_preprocess_fn(
change_channel_order=True,
input_size=(224, 224),
hwc_to_chw=True,
)
postprocess_fn = get_postprocess_fn(activation=ActivationType.SIGMOID)
steps = 10

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
self.data_dir = fxt_data_root
retrieve_otx_model(self.data_dir, MODEL_NAME)
model_path = self.data_dir / "otx_models" / (MODEL_NAME + ".xml")
core = ov.Core()
model = core.read_model(model_path)
compiled_model = core.compile_model(model=model, device_name="AUTO")
self.auc = InsertionDeletionAUC(compiled_model, self.preprocess_fn, self.postprocess_fn)

self.explainer = Explainer(
model=model,
task=Task.CLASSIFICATION,
preprocess_fn=self.preprocess_fn,
explain_mode=ExplainMode.WHITEBOX,
)

def test_insertion_deletion_auc(self):
class_idx = 1
input_image = np.random.rand(224, 224, 3)
saliency_map = np.random.rand(224, 224)

insertion_auc, deletion_auc = self.auc.insertion_deletion_auc(saliency_map, class_idx, input_image, self.steps)

for value in [insertion_auc, deletion_auc]:
assert isinstance(value, float)
assert 0 <= value <= 1

def test_evaluate(self):
input_images = [np.random.rand(224, 224, 3) for _ in range(5)]
explanations = [
Explanation({0: np.random.rand(224, 224), 1: np.random.rand(224, 224)}, targets=[0, 1]) for _ in range(5)
]

insertion, deletion, delta = self.auc.evaluate(explanations, input_images, self.steps)

for value in [insertion, deletion]:
assert isinstance(value, float)
assert 0 <= value <= 1
assert isinstance(delta, float)