Skip to content

Commit

Permalink
Move insert_xai into separate functional api module (#11)
Browse files Browse the repository at this point in the history
* Move insert_xai into separate functional api module

* fix import in samples

* fix isort

* Fix import

* fix classification sample
  • Loading branch information
negvet authored Jun 11, 2024
1 parent 36fb14f commit bed7e4c
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 58 deletions.
2 changes: 1 addition & 1 deletion openvino_xai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"""


from .api.api import insert_xai
from .common.parameters import Method, Task
from .explainer.explainer import Explainer
from .inserter import insert_xai

__all__ = ["Explainer", "insert_xai", "Method", "Task"]
10 changes: 10 additions & 0 deletions openvino_xai/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
"""
Finctional API.
"""
from openvino_xai.api.api import insert_xai

__all__ = [
"insert_xai",
]
51 changes: 51 additions & 0 deletions openvino_xai/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import openvino.runtime as ov

from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger
from openvino_xai.inserter.parameters import InsertionParameters
from openvino_xai.methods.factory import WhiteBoxMethodFactory


def insert_xai(
model: ov.Model,
task: Task,
insertion_parameters: InsertionParameters | None = None,
) -> ov.Model:
"""
Function that inserts XAI branch into IR.
Usage:
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION)
:param model: Original IR.
:type model: ov.Model | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param insertion_parameters: Insertion parameters that parametrize white-box method,
that will be inserted into the model graph (optional).
:type insertion_parameters: InsertionParameters
:return: IR with XAI branch.
"""

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
return model

method = WhiteBoxMethodFactory.create_method(
task=task,
model=model,
preprocess_fn=IdentityPreprocessFN(),
insertion_parameters=insertion_parameters,
prepare_model=False,
)

model_xai = method.prepare_model(load_model=False)

if not has_xai(model_xai):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
logger.info("Insertion of the XAI branch into the model was successful.")

return model_xai
2 changes: 0 additions & 2 deletions openvino_xai/inserter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
"""
Interface for inserting XAI branch into OV IR.
"""
from openvino_xai.inserter.inserter import insert_xai
from openvino_xai.inserter.parameters import (
ClassificationInsertionParameters,
DetectionInsertionParameters,
InsertionParameters,
)

__all__ = [
"insert_xai",
"InsertionParameters",
"ClassificationInsertionParameters",
"DetectionInsertionParameters",
Expand Down
52 changes: 1 addition & 51 deletions openvino_xai/inserter/inserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,7 @@
import openvino.runtime as ov
from openvino.preprocess import PrePostProcessor

from openvino_xai import Task
from openvino_xai.common.utils import (
SALIENCY_MAP_OUTPUT_NAME,
IdentityPreprocessFN,
has_xai,
logger,
)
from openvino_xai.inserter.parameters import InsertionParameters


def insert_xai(
model: ov.Model,
task: Task,
insertion_parameters: InsertionParameters | None = None,
) -> ov.Model:
"""
Function that inserts XAI branch into IR.
Usage:
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION)
:param model: Original IR.
:type model: ov.Model | str
:param task: Type of the task: CLASSIFICATION or DETECTION.
:type task: Task
:param insertion_parameters: Insertion parameters that parametrize white-box method,
that will be inserted into the model graph (optional).
:type insertion_parameters: InsertionParameters
:return: IR with XAI branch.
"""
from openvino_xai.methods.factory import WhiteBoxMethodFactory

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
return model

method = WhiteBoxMethodFactory.create_method(
task=task,
model=model,
preprocess_fn=IdentityPreprocessFN(),
insertion_parameters=insertion_parameters,
prepare_model=False,
)

model_xai = method.prepare_model(load_model=False)

if not has_xai(model_xai):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
logger.info("Insertion of the XAI branch into the model was successful.")

return model_xai
from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME


def insert_xai_branch_into_model(
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import openvino.runtime as ov
import pytest

import openvino_xai as xai
import openvino_xai.api.api as xai
from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import openvino.runtime as ov

from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.inserter.inserter import insert_xai
from tests.integration.test_classification import DEFAULT_CLS_MODEL

DARA_DIR = Path(".data")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/explanation/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import openvino.runtime as ov
import pytest

from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Task
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer
Expand All @@ -16,7 +17,6 @@
TargetExplainGroup,
)
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.inserter.inserter import insert_xai
from openvino_xai.inserter.parameters import ClassificationInsertionParameters
from tests.unit.explanation.test_explanation_utils import VOC_NAMES

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/insertion/test_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from openvino import runtime as ov

from openvino_xai import insert_xai
from openvino_xai.api.api import insert_xai
from openvino_xai.common.parameters import Method, Task
from openvino_xai.common.utils import has_xai, retrieve_otx_model
from openvino_xai.inserter.parameters import DetectionInsertionParameters
Expand Down

0 comments on commit bed7e4c

Please sign in to comment.