Skip to content

Commit

Permalink
Support OV IR / ONNX model file for Explainer (#47)
Browse files Browse the repository at this point in the history
* Support IR/ONNX file for Explainer

* Add intg test for ir/onnx file support

* Fix intg test

* Update documentaton

* Fix pre-commit

* Fix typo

* Update docs/source/user-guide.md

Co-authored-by: Galina Zalesskaya <[email protected]>

* Update README.md

Co-authored-by: Galina Zalesskaya <[email protected]>

* Update README.md

Co-authored-by: Galina Zalesskaya <[email protected]>

---------

Co-authored-by: Galina Zalesskaya <[email protected]>
  • Loading branch information
goodsong81 and GalyaZalesskaya authored Jul 22, 2024
1 parent e18e809 commit b1fd326
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 55 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,20 +158,20 @@ In this example, we are trying to know the reason why the model outputs a `cheet
```python
import cv2
import numpy as np
import openvino.runtime as ov
import openvino as ov
import openvino_xai as xai

# Load the model
# Load the model: IR or ONNX
ov_model: ov.Model = ov.Core().read_model("mobilenet_v3.xml")

# Load the image to be analyzed
image: np.ndarray = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=(224, 224))
image = np.expand_dims(image, 0)

# Create the Explainer object
# Create the Explainer for the model
explainer = xai.Explainer(
model=ov_model,
model=ov_model, # accepts path arguments "mobilenet_v3.xml" or "mobilenet_v3.onnx" as well
task=xai.Task.CLASSIFICATION,
)

Expand Down
50 changes: 41 additions & 9 deletions docs/source/user-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Content:

- [OpenVINO™ Explainable AI Toolkit User Guide](#openvino-explainable-ai-toolkit-user-guide)
- [OpenVINO XAI Architecture](#openvino-xai-architecture)
- [Explainer - interface to XAI algorithms](#explainer---interface-to-xai-algorithms)
- [`Explainer`: the main interface to XAI algorithms](#explainer-the-main-interface-to-xai-algorithms)
- [Basic usage: Auto mode](#basic-usage-auto-mode)
- [Running without `preprocess_fn`](#running-without-preprocess_fn)
- [Specifying `preprocess_fn`](#specifying-preprocess_fn)
Expand All @@ -31,21 +31,53 @@ OpenVINO XAI provides the API to explain models, using two types of methods:
- **White-box** - treats the model as a white box, making inner modifications and adding an extra XAI branch. This results in additional output from the model and relatively fast explanations.
- **Black-box** - treats the model as a black box, working on a wide range of models. However, it requires many more inference runs.

## Explainer - interface to XAI algorithms
## `Explainer`: the main interface to XAI algorithms

In a nutshell, the explanation call looks like this:

```python
import openvino_xai as xai

explainer = xai.Explainer(model=model, task=xai.Task.CLASSIFICATION)
explanation = explainer(data)
```

There are a few options for the model formats. The major use-case is to load OpenVINO IR model from file and pass `ov.Model` instance to explainer.

### Create Explainer for OpenVINO Model instance

```python
import openvino as ov
model = ov.Core().read_model("model.xml")

explainer = xai.Explainer(
model,
task=xai.Task.CLASSIFICATION,
preprocess_fn=preprocess_fn,
model=model,
task=xai.Task.CLASSIFICATION
)
explanation = explainer(data, explanation_parameters)
```

### Create Explainer from OpenVINO IR file

The Explainer also supports the OpenVINO IR (Intermediate Representation) file format (.xml) directly like follows:

```python
explainer = xai.Explainer(
model="model.xml",
task=xai.Task.CLASSIFICATION
)
```

### Create Explainer from ONNX model file

[ONNX](https://onnx.ai/) is an open format built to represent machine learning models.
The OpenVINO Runtime supports loading and inference of the ONNX models, and so does OpenVINO XAI.

```python
explainer = xai.Explainer(
model="model.onnx",
task=xai.Task.CLASSIFICATION
)
```

## Basic usage: Auto mode

Expand Down Expand Up @@ -202,7 +234,7 @@ explanation = explainer(
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
label_names=voc_labels,
target_explain_labels=[11, 14], # target classes to explain, also ['dog', 'person'] is a valid input, since label_names are provided
overlay=True, # False by default
overlay=True, # False by default
)

# Save saliency maps
Expand Down Expand Up @@ -253,7 +285,7 @@ explanation = explainer(
explain_mode=ExplainMode.BLACKBOX,
target_explain_labels=[11, 14], # target classes to explain
# target_explain_labels=-1, # explain all classes
overlay=True, # False by default
overlay=True, # False by default
num_masks=1000, # kwargs for the RISE algorithm
)

Expand Down Expand Up @@ -303,6 +335,6 @@ pytest tests/test_classification.py

# Run a bunch of classification examples
# All outputs will be stored in the corresponding output directory
python examples/run_classification.py .data/otx_models/mlc_mobilenetv3_large_voc.xml
python examples/run_classification.py .data/otx_models/mlc_mobilenetv3_large_voc.xml
tests/assets/cheetah_person.jpg --output output
```
19 changes: 12 additions & 7 deletions openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from enum import Enum
from os import PathLike
from typing import Callable, List, Mapping, Tuple

import numpy as np
Expand Down Expand Up @@ -42,10 +43,11 @@ class Explainer:
Explainer creates methods and uses them to generate explanations.
Usage:
explanation = explainer_object(data, explanation_parameters)
>>> explainer = Explainer("model.xml", Task.CLASSIFICATION)
>>> explanation = explainer(data)
:param model: Original model.
:type model: ov.Model
:param model: Original model object, OpenVINO IR file (.xml) or ONNX file (.onnx).
:type model: ov.Model | str | PathLike
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param preprocess_fn: Preprocessing function, identity function by default
Expand All @@ -67,7 +69,7 @@ class Explainer:

def __init__(
self,
model: ov.Model,
model: ov.Model | str | PathLike,
task: Task,
preprocess_fn: Callable[[np.ndarray], np.ndarray] = IdentityPreprocessFN(),
postprocess_fn: Callable[[Mapping], np.ndarray] = None,
Expand All @@ -78,6 +80,9 @@ def __init__(
device_name: str = "CPU",
**kwargs,
) -> None:
if isinstance(model, (str, PathLike)):
model = ov.Core().read_model(model)

self.model = model
self.compiled_model: ov.CompiledModel | None = None
self.task = task
Expand Down Expand Up @@ -131,7 +136,7 @@ def create_method(self, explain_mode: ExplainMode, task: Task) -> MethodBase:
def __call__(
self,
data: np.ndarray,
targets: np.ndarray | List[int | str] | int | str,
targets: np.ndarray | List[int | str] | int | str = -1,
original_input_image: np.ndarray | None = None,
label_names: List[str] | None = None,
output_size: Tuple[int, int] | None = None,
Expand Down Expand Up @@ -159,7 +164,7 @@ def __call__(
def explain(
self,
data: np.ndarray,
targets: np.ndarray | List[int | str] | int | str,
targets: np.ndarray | List[int | str] | int | str = -1,
original_input_image: np.ndarray | None = None,
label_names: List[str] | None = None,
output_size: Tuple[int, int] | None = None,
Expand All @@ -176,7 +181,7 @@ def explain(
:param data: Input image.
:type data: np.ndarray
:param targets: List of custom labels to explain, optional. Can be list of integer indices (int),
or list of names (str) from label_names.
or list of names (str) from label_names. Defaults to -1, which means all class labels.
:type targets: np.ndarray | List[int | str] | int | str
:param label_names: List of all label names.
:type label_names: List[str] | None
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-csv",
"pytest-mock",
"pytest-xdist",
"pre-commit==3.7.0",
"addict",
Expand Down
139 changes: 104 additions & 35 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,25 +157,11 @@ def test_classification_white_box(self, model_id, dump_maps=False):
if model_id in NON_SUPPORTED_BY_WB_MODELS:
pytest.skip(reason="Not supported yet")

timm_model, model_cfg = self.get_timm_model(model_id)
model_dir = self.data_dir / "timm_models" / "converted_models"
timm_model, model_cfg = self.get_timm_model(model_id, model_dir)
self.update_report("report_wb.csv", model_id)

ir_path = self.data_dir / "timm_models" / "converted_models" / model_id / "model_fp32.xml"
if not ir_path.is_file():
output_model_dir = self.output_dir / "timm_models" / "converted_models" / model_id
output_model_dir.mkdir(parents=True, exist_ok=True)
ir_path = output_model_dir / "model_fp32.xml"
input_size = [1] + list(timm_model.default_cfg["input_size"])
dummy_tensor = torch.rand(input_size)
onnx_path = output_model_dir / "model_fp32.onnx"
set_dynamic_batch = model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS
export_to_onnx(timm_model, onnx_path, dummy_tensor, set_dynamic_batch)
self.update_report("report_wb.csv", model_id, "True")
export_to_ir(onnx_path, output_model_dir / "model_fp32.xml")
self.update_report("report_wb.csv", model_id, "True", "True")
else:
self.update_report("report_wb.csv", model_id, "True", "True")

ir_path = model_dir / model_id / "model_fp32.xml"
model = ov.Core().read_model(ir_path)

if model_id in LIMITED_DIVERSE_SET_OF_CNN_MODELS:
Expand Down Expand Up @@ -254,24 +240,17 @@ def test_classification_white_box(self, model_id, dump_maps=False):
@pytest.mark.parametrize("model_id", TEST_MODELS)
def test_classification_black_box(self, model_id, dump_maps=False):
# self.check_for_saved_map(model_id, "timm_models/maps_bb/")
if model_id == "nest_tiny_jx.goog_in1k":
pytest.xfail(
"[cpu]reshape: the shape of input data (1.27.27.192) conflicts with the reshape pattern (0.2.14.2.14.192)"
)

timm_model, model_cfg = self.get_timm_model(model_id)
model_dir = self.data_dir / "timm_models" / "converted_models"
timm_model, model_cfg = self.get_timm_model(model_id, model_dir)
self.update_report("report_bb.csv", model_id)

onnx_path = self.data_dir / "timm_models" / "converted_models" / model_id / "model_fp32.onnx"
if not onnx_path.is_file():
output_model_dir = self.output_dir / "timm_models" / "converted_models" / model_id
output_model_dir.mkdir(parents=True, exist_ok=True)
onnx_path = output_model_dir / "model_fp32.onnx"
input_size = [1] + list(timm_model.default_cfg["input_size"])
dummy_tensor = torch.rand(input_size)
onnx_path = output_model_dir / "model_fp32.onnx"
export_to_onnx(timm_model, onnx_path, dummy_tensor, False)
self.update_report("report_bb.csv", model_id, "True", "True")
else:
self.update_report("report_bb.csv", model_id, "True", "True")

model = ov.Core().read_model(onnx_path)
ir_path = model_dir / model_id / "model_fp32.xml"
model = ov.Core().read_model(ir_path)

mean_values = [(item * 255) for item in model_cfg["mean"]]
scale_values = [(item * 255) for item in model_cfg["std"]]
Expand Down Expand Up @@ -330,7 +309,7 @@ def test_classification_black_box(self, model_id, dump_maps=False):
],
)
# @pytest.mark.parametrize("model_id", TEST_MODELS)
def test_ovc_ir_insertion(self, model_id):
def test_ovc_model_white_box(self, model_id):
if model_id in NON_SUPPORTED_BY_WB_MODELS:
pytest.skip(reason="Not supported yet")

Expand All @@ -339,7 +318,8 @@ def test_ovc_ir_insertion(self, model_id):
reason="RuntimeError: Couldn't get TorchScript module by tracing."
) # Torch -> OV conversion error

timm_model, model_cfg = self.get_timm_model(model_id)
model_dir = self.data_dir / "timm_models" / "converted_models"
timm_model, model_cfg = self.get_timm_model(model_id, model_dir)
input_size = list(timm_model.default_cfg["input_size"])
dummy_tensor = torch.rand([1] + input_size)
model = ov.convert_model(timm_model, example_input=dummy_tensor, input=(ov.PartialShape([-1] + input_size),))
Expand Down Expand Up @@ -383,6 +363,84 @@ def test_ovc_ir_insertion(self, model_id):
assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1
print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.")

@pytest.mark.parametrize(
"model_id",
[
"resnet18.a1_in1k",
"vit_tiny_patch16_224.augreg_in21k", # Downloads last month 15,345
],
)
@pytest.mark.parametrize(
"explain_mode",
[
ExplainMode.WHITEBOX,
ExplainMode.BLACKBOX,
],
)
@pytest.mark.parametrize(
"model_format",
["xml", "onnx"],
)
def test_model_format(self, model_id, explain_mode, model_format):
if (
model_id == "vit_tiny_patch16_224.augreg_in21k"
and explain_mode == ExplainMode.WHITEBOX
and model_format == "onnx"
):
pytest.xfail(
"RuntimeError: Failed to insert XAI into the model -> Only two outputs of the between block Add node supported, but got 3."
)

model_dir = self.data_dir / "timm_models" / "converted_models"
timm_model, model_cfg = self.get_timm_model(model_id, model_dir)
model_path = model_dir / model_id / ("model_fp32." + model_format)

mean_values = [(item * 255) for item in model_cfg["mean"]]
scale_values = [(item * 255) for item in model_cfg["std"]]
preprocess_fn = get_preprocess_fn(
change_channel_order=True,
input_size=model_cfg["input_size"][1:],
mean=mean_values,
std=scale_values,
hwc_to_chw=True,
)

explain_method = None
postprocess_fn = None
if explain_mode == ExplainMode.WHITEBOX:
if model_id in LIMITED_DIVERSE_SET_OF_CNN_MODELS:
explain_method = Method.RECIPROCAM
elif model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS:
explain_method = Method.VITRECIPROCAM
else:
raise ValueError
else: # explain_mode == ExplainMode.BLACKBOX:
postprocess_fn = get_postprocess_fn()

explainer = Explainer(
model=model_path,
task=Task.CLASSIFICATION,
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
explain_mode=explain_mode,
explain_method=explain_method,
embed_scaling=False,
)

target_class = self.supported_num_classes[model_cfg["num_classes"]]
image = cv2.imread("tests/assets/cheetah_person.jpg")
explanation = explainer(
image,
targets=[target_class],
resize=False,
colormap=False,
num_masks=2, # minimal iterations for feature test
)

assert explanation is not None
assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1
print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.")

def check_for_saved_map(self, model_id, directory):
for target in self.supported_num_classes.values():
map_name = model_id + "_target_" + str(target) + ".jpg"
Expand All @@ -396,14 +454,25 @@ def check_for_saved_map(self, model_id, directory):
self.clear_cache()
pytest.skip(f"Model {model_id} is already explained.")

def get_timm_model(self, model_id):
def get_timm_model(self, model_id: str, model_dir: Path):
timm_model = timm.create_model(model_id, in_chans=3, pretrained=True, checkpoint_path="")
timm_model.eval()
model_cfg = timm_model.default_cfg
num_classes = model_cfg["num_classes"]
if num_classes not in self.supported_num_classes:
self.clear_cache()
pytest.skip(f"Number of model classes {num_classes} unknown")
model_dir = model_dir / model_id
ir_path = model_dir / "model_fp32.xml"
if not ir_path.is_file():
model_dir.mkdir(parents=True, exist_ok=True)
ir_path = model_dir / "model_fp32.xml"
input_size = [1] + list(timm_model.default_cfg["input_size"])
dummy_tensor = torch.rand(input_size)
onnx_path = model_dir / "model_fp32.onnx"
set_dynamic_batch = model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS
export_to_onnx(timm_model, onnx_path, dummy_tensor, set_dynamic_batch)
export_to_ir(onnx_path, model_dir / "model_fp32.xml")
return timm_model, model_cfg

@staticmethod
Expand Down
Loading

0 comments on commit b1fd326

Please sign in to comment.