diff --git a/README.md b/README.md index c7351d16..f1989ef8 100644 --- a/README.md +++ b/README.md @@ -158,10 +158,10 @@ In this example, we are trying to know the reason why the model outputs a `cheet ```python import cv2 import numpy as np -import openvino.runtime as ov +import openvino as ov import openvino_xai as xai -# Load the model +# Load the model: IR or ONNX ov_model: ov.Model = ov.Core().read_model("mobilenet_v3.xml") # Load the image to be analyzed @@ -169,9 +169,9 @@ image: np.ndarray = cv2.imread("tests/assets/cheetah_person.jpg") image = cv2.resize(image, dsize=(224, 224)) image = np.expand_dims(image, 0) -# Create the Explainer object +# Create the Explainer for the model explainer = xai.Explainer( - model=ov_model, + model=ov_model, # accepts path arguments "mobilenet_v3.xml" or "mobilenet_v3.onnx" as well task=xai.Task.CLASSIFICATION, ) diff --git a/docs/source/user-guide.md b/docs/source/user-guide.md index c59f9492..f914d36f 100644 --- a/docs/source/user-guide.md +++ b/docs/source/user-guide.md @@ -12,7 +12,7 @@ Content: - [OpenVINO™ Explainable AI Toolkit User Guide](#openvino-explainable-ai-toolkit-user-guide) - [OpenVINO XAI Architecture](#openvino-xai-architecture) - - [Explainer - interface to XAI algorithms](#explainer---interface-to-xai-algorithms) + - [`Explainer`: the main interface to XAI algorithms](#explainer-the-main-interface-to-xai-algorithms) - [Basic usage: Auto mode](#basic-usage-auto-mode) - [Running without `preprocess_fn`](#running-without-preprocess_fn) - [Specifying `preprocess_fn`](#specifying-preprocess_fn) @@ -31,21 +31,53 @@ OpenVINO XAI provides the API to explain models, using two types of methods: - **White-box** - treats the model as a white box, making inner modifications and adding an extra XAI branch. This results in additional output from the model and relatively fast explanations. - **Black-box** - treats the model as a black box, working on a wide range of models. However, it requires many more inference runs. -## Explainer - interface to XAI algorithms +## `Explainer`: the main interface to XAI algorithms In a nutshell, the explanation call looks like this: ```python import openvino_xai as xai +explainer = xai.Explainer(model=model, task=xai.Task.CLASSIFICATION) +explanation = explainer(data) +``` + +There are a few options for the model formats. The major use-case is to load OpenVINO IR model from file and pass `ov.Model` instance to explainer. + +### Create Explainer for OpenVINO Model instance + +```python +import openvino as ov +model = ov.Core().read_model("model.xml") + explainer = xai.Explainer( - model, - task=xai.Task.CLASSIFICATION, - preprocess_fn=preprocess_fn, + model=model, + task=xai.Task.CLASSIFICATION ) -explanation = explainer(data, explanation_parameters) ``` +### Create Explainer from OpenVINO IR file + +The Explainer also supports the OpenVINO IR (Intermediate Representation) file format (.xml) directly like follows: + +```python +explainer = xai.Explainer( + model="model.xml", + task=xai.Task.CLASSIFICATION +) +``` + +### Create Explainer from ONNX model file + +[ONNX](https://onnx.ai/) is an open format built to represent machine learning models. +The OpenVINO Runtime supports loading and inference of the ONNX models, and so does OpenVINO XAI. + +```python +explainer = xai.Explainer( + model="model.onnx", + task=xai.Task.CLASSIFICATION +) +``` ## Basic usage: Auto mode @@ -202,7 +234,7 @@ explanation = explainer( explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs label_names=voc_labels, target_explain_labels=[11, 14], # target classes to explain, also ['dog', 'person'] is a valid input, since label_names are provided - overlay=True, # False by default + overlay=True, # False by default ) # Save saliency maps @@ -253,7 +285,7 @@ explanation = explainer( explain_mode=ExplainMode.BLACKBOX, target_explain_labels=[11, 14], # target classes to explain # target_explain_labels=-1, # explain all classes - overlay=True, # False by default + overlay=True, # False by default num_masks=1000, # kwargs for the RISE algorithm ) @@ -303,6 +335,6 @@ pytest tests/test_classification.py # Run a bunch of classification examples # All outputs will be stored in the corresponding output directory -python examples/run_classification.py .data/otx_models/mlc_mobilenetv3_large_voc.xml +python examples/run_classification.py .data/otx_models/mlc_mobilenetv3_large_voc.xml tests/assets/cheetah_person.jpg --output output ``` diff --git a/openvino_xai/explainer/explainer.py b/openvino_xai/explainer/explainer.py index 5f28e4ce..89bb9bbf 100644 --- a/openvino_xai/explainer/explainer.py +++ b/openvino_xai/explainer/explainer.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from enum import Enum +from os import PathLike from typing import Callable, List, Mapping, Tuple import numpy as np @@ -42,10 +43,11 @@ class Explainer: Explainer creates methods and uses them to generate explanations. Usage: - explanation = explainer_object(data, explanation_parameters) + >>> explainer = Explainer("model.xml", Task.CLASSIFICATION) + >>> explanation = explainer(data) - :param model: Original model. - :type model: ov.Model + :param model: Original model object, OpenVINO IR file (.xml) or ONNX file (.onnx). + :type model: ov.Model | str | PathLike :param task: Type of the task: CLASSIFICATION or DETECTION. :type task: Task :param preprocess_fn: Preprocessing function, identity function by default @@ -67,7 +69,7 @@ class Explainer: def __init__( self, - model: ov.Model, + model: ov.Model | str | PathLike, task: Task, preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(), postprocess_fn: Callable[[Mapping], np.ndarray] = None, @@ -78,6 +80,9 @@ def __init__( device_name: str = "CPU", **kwargs, ) -> None: + if isinstance(model, (str, PathLike)): + model = ov.Core().read_model(model) + self.model = model self.compiled_model: ov.CompiledModel | None = None self.task = task @@ -131,7 +136,7 @@ def create_method(self, explain_mode: ExplainMode, task: Task) -> MethodBase: def __call__( self, data: np.ndarray, - targets: np.ndarray | List[int | str] | int | str, + targets: np.ndarray | List[int | str] | int | str = -1, original_input_image: np.ndarray | None = None, label_names: List[str] | None = None, output_size: Tuple[int, int] | None = None, @@ -159,7 +164,7 @@ def __call__( def explain( self, data: np.ndarray, - targets: np.ndarray | List[int | str] | int | str, + targets: np.ndarray | List[int | str] | int | str = -1, original_input_image: np.ndarray | None = None, label_names: List[str] | None = None, output_size: Tuple[int, int] | None = None, @@ -176,7 +181,7 @@ def explain( :param data: Input image. :type data: np.ndarray :param targets: List of custom labels to explain, optional. Can be list of integer indices (int), - or list of names (str) from label_names. + or list of names (str) from label_names. Defaults to -1, which means all class labels. :type targets: np.ndarray | List[int | str] | int | str :param label_names: List of all label names. :type label_names: List[str] | None diff --git a/pyproject.toml b/pyproject.toml index 3d7e47ea..789e6788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dev = [ "pytest", "pytest-cov", "pytest-csv", + "pytest-mock", "pytest-xdist", "pre-commit==3.7.0", "addict", diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index 30780b7c..ac0a123b 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -157,25 +157,11 @@ def test_classification_white_box(self, model_id, dump_maps=False): if model_id in NON_SUPPORTED_BY_WB_MODELS: pytest.skip(reason="Not supported yet") - timm_model, model_cfg = self.get_timm_model(model_id) + model_dir = self.data_dir / "timm_models" / "converted_models" + timm_model, model_cfg = self.get_timm_model(model_id, model_dir) self.update_report("report_wb.csv", model_id) - ir_path = self.data_dir / "timm_models" / "converted_models" / model_id / "model_fp32.xml" - if not ir_path.is_file(): - output_model_dir = self.output_dir / "timm_models" / "converted_models" / model_id - output_model_dir.mkdir(parents=True, exist_ok=True) - ir_path = output_model_dir / "model_fp32.xml" - input_size = [1] + list(timm_model.default_cfg["input_size"]) - dummy_tensor = torch.rand(input_size) - onnx_path = output_model_dir / "model_fp32.onnx" - set_dynamic_batch = model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS - export_to_onnx(timm_model, onnx_path, dummy_tensor, set_dynamic_batch) - self.update_report("report_wb.csv", model_id, "True") - export_to_ir(onnx_path, output_model_dir / "model_fp32.xml") - self.update_report("report_wb.csv", model_id, "True", "True") - else: - self.update_report("report_wb.csv", model_id, "True", "True") - + ir_path = model_dir / model_id / "model_fp32.xml" model = ov.Core().read_model(ir_path) if model_id in LIMITED_DIVERSE_SET_OF_CNN_MODELS: @@ -254,24 +240,17 @@ def test_classification_white_box(self, model_id, dump_maps=False): @pytest.mark.parametrize("model_id", TEST_MODELS) def test_classification_black_box(self, model_id, dump_maps=False): # self.check_for_saved_map(model_id, "timm_models/maps_bb/") + if model_id == "nest_tiny_jx.goog_in1k": + pytest.xfail( + "[cpu]reshape: the shape of input data (1.27.27.192) conflicts with the reshape pattern (0.2.14.2.14.192)" + ) - timm_model, model_cfg = self.get_timm_model(model_id) + model_dir = self.data_dir / "timm_models" / "converted_models" + timm_model, model_cfg = self.get_timm_model(model_id, model_dir) self.update_report("report_bb.csv", model_id) - onnx_path = self.data_dir / "timm_models" / "converted_models" / model_id / "model_fp32.onnx" - if not onnx_path.is_file(): - output_model_dir = self.output_dir / "timm_models" / "converted_models" / model_id - output_model_dir.mkdir(parents=True, exist_ok=True) - onnx_path = output_model_dir / "model_fp32.onnx" - input_size = [1] + list(timm_model.default_cfg["input_size"]) - dummy_tensor = torch.rand(input_size) - onnx_path = output_model_dir / "model_fp32.onnx" - export_to_onnx(timm_model, onnx_path, dummy_tensor, False) - self.update_report("report_bb.csv", model_id, "True", "True") - else: - self.update_report("report_bb.csv", model_id, "True", "True") - - model = ov.Core().read_model(onnx_path) + ir_path = model_dir / model_id / "model_fp32.xml" + model = ov.Core().read_model(ir_path) mean_values = [(item * 255) for item in model_cfg["mean"]] scale_values = [(item * 255) for item in model_cfg["std"]] @@ -330,7 +309,7 @@ def test_classification_black_box(self, model_id, dump_maps=False): ], ) # @pytest.mark.parametrize("model_id", TEST_MODELS) - def test_ovc_ir_insertion(self, model_id): + def test_ovc_model_white_box(self, model_id): if model_id in NON_SUPPORTED_BY_WB_MODELS: pytest.skip(reason="Not supported yet") @@ -339,7 +318,8 @@ def test_ovc_ir_insertion(self, model_id): reason="RuntimeError: Couldn't get TorchScript module by tracing." ) # Torch -> OV conversion error - timm_model, model_cfg = self.get_timm_model(model_id) + model_dir = self.data_dir / "timm_models" / "converted_models" + timm_model, model_cfg = self.get_timm_model(model_id, model_dir) input_size = list(timm_model.default_cfg["input_size"]) dummy_tensor = torch.rand([1] + input_size) model = ov.convert_model(timm_model, example_input=dummy_tensor, input=(ov.PartialShape([-1] + input_size),)) @@ -383,6 +363,84 @@ def test_ovc_ir_insertion(self, model_id): assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + @pytest.mark.parametrize( + "model_id", + [ + "resnet18.a1_in1k", + "vit_tiny_patch16_224.augreg_in21k", # Downloads last month 15,345 + ], + ) + @pytest.mark.parametrize( + "explain_mode", + [ + ExplainMode.WHITEBOX, + ExplainMode.BLACKBOX, + ], + ) + @pytest.mark.parametrize( + "model_format", + ["xml", "onnx"], + ) + def test_model_format(self, model_id, explain_mode, model_format): + if ( + model_id == "vit_tiny_patch16_224.augreg_in21k" + and explain_mode == ExplainMode.WHITEBOX + and model_format == "onnx" + ): + pytest.xfail( + "RuntimeError: Failed to insert XAI into the model -> Only two outputs of the between block Add node supported, but got 3." + ) + + model_dir = self.data_dir / "timm_models" / "converted_models" + timm_model, model_cfg = self.get_timm_model(model_id, model_dir) + model_path = model_dir / model_id / ("model_fp32." + model_format) + + mean_values = [(item * 255) for item in model_cfg["mean"]] + scale_values = [(item * 255) for item in model_cfg["std"]] + preprocess_fn = get_preprocess_fn( + change_channel_order=True, + input_size=model_cfg["input_size"][1:], + mean=mean_values, + std=scale_values, + hwc_to_chw=True, + ) + + explain_method = None + postprocess_fn = None + if explain_mode == ExplainMode.WHITEBOX: + if model_id in LIMITED_DIVERSE_SET_OF_CNN_MODELS: + explain_method = Method.RECIPROCAM + elif model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS: + explain_method = Method.VITRECIPROCAM + else: + raise ValueError + else: # explain_mode == ExplainMode.BLACKBOX: + postprocess_fn = get_postprocess_fn() + + explainer = Explainer( + model=model_path, + task=Task.CLASSIFICATION, + preprocess_fn=preprocess_fn, + postprocess_fn=postprocess_fn, + explain_mode=explain_mode, + explain_method=explain_method, + embed_scaling=False, + ) + + target_class = self.supported_num_classes[model_cfg["num_classes"]] + image = cv2.imread("tests/assets/cheetah_person.jpg") + explanation = explainer( + image, + targets=[target_class], + resize=False, + colormap=False, + num_masks=2, # minimal iterations for feature test + ) + + assert explanation is not None + assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1 + print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.") + def check_for_saved_map(self, model_id, directory): for target in self.supported_num_classes.values(): map_name = model_id + "_target_" + str(target) + ".jpg" @@ -396,7 +454,7 @@ def check_for_saved_map(self, model_id, directory): self.clear_cache() pytest.skip(f"Model {model_id} is already explained.") - def get_timm_model(self, model_id): + def get_timm_model(self, model_id: str, model_dir: Path): timm_model = timm.create_model(model_id, in_chans=3, pretrained=True, checkpoint_path="") timm_model.eval() model_cfg = timm_model.default_cfg @@ -404,6 +462,17 @@ def get_timm_model(self, model_id): if num_classes not in self.supported_num_classes: self.clear_cache() pytest.skip(f"Number of model classes {num_classes} unknown") + model_dir = model_dir / model_id + ir_path = model_dir / "model_fp32.xml" + if not ir_path.is_file(): + model_dir.mkdir(parents=True, exist_ok=True) + ir_path = model_dir / "model_fp32.xml" + input_size = [1] + list(timm_model.default_cfg["input_size"]) + dummy_tensor = torch.rand(input_size) + onnx_path = model_dir / "model_fp32.onnx" + set_dynamic_batch = model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS + export_to_onnx(timm_model, onnx_path, dummy_tensor, set_dynamic_batch) + export_to_ir(onnx_path, model_dir / "model_fp32.xml") return timm_model, model_cfg @staticmethod diff --git a/tests/unit/explanation/test_explainer.py b/tests/unit/explanation/test_explainer.py index c22de39b..ab90169a 100644 --- a/tests/unit/explanation/test_explainer.py +++ b/tests/unit/explanation/test_explainer.py @@ -7,6 +7,7 @@ import numpy as np import openvino as ov import pytest +from pytest_mock import MockerFixture from openvino_xai.api.api import insert_xai from openvino_xai.common.parameters import Task @@ -228,3 +229,27 @@ def test_overlay_with_original_image(self): overlay=True, ) assert explanation.saliency_map[0].shape == (100, 120, 3) + + def test_init_with_file(self, mocker: MockerFixture): + ov_core = mocker.patch("openvino.Core") + create_method = mocker.patch("openvino_xai.Explainer.create_method") + # None + explainer = Explainer( + model=None, + task=Task.CLASSIFICATION, + ) + ov_core.assert_not_called() + # str + path = "model.xml" + explainer = Explainer( + model=path, + task=Task.CLASSIFICATION, + ) + ov_core.return_value.read_model.assert_called_with(path) + # Path + path = Path("model.xml") + explainer = Explainer( + model=path, + task=Task.CLASSIFICATION, + ) + ov_core.return_value.read_model.assert_called_with(path)