diff --git a/openvino_xai/__init__.py b/openvino_xai/__init__.py index 7664848e..1982a2c6 100644 --- a/openvino_xai/__init__.py +++ b/openvino_xai/__init__.py @@ -5,8 +5,8 @@ """ +from .api.api import insert_xai from .common.parameters import Method, Task from .explainer.explainer import Explainer -from .inserter import insert_xai __all__ = ["Explainer", "insert_xai", "Method", "Task"] diff --git a/openvino_xai/api/__init__.py b/openvino_xai/api/__init__.py new file mode 100644 index 00000000..e391405c --- /dev/null +++ b/openvino_xai/api/__init__.py @@ -0,0 +1,10 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +""" +Finctional API. +""" +from openvino_xai.api.api import insert_xai + +__all__ = [ + "insert_xai", +] diff --git a/openvino_xai/api/api.py b/openvino_xai/api/api.py new file mode 100644 index 00000000..d1e89924 --- /dev/null +++ b/openvino_xai/api/api.py @@ -0,0 +1,51 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import openvino.runtime as ov + +from openvino_xai.common.parameters import Task +from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger +from openvino_xai.inserter.parameters import InsertionParameters +from openvino_xai.methods.factory import WhiteBoxMethodFactory + + +def insert_xai( + model: ov.Model, + task: Task, + insertion_parameters: InsertionParameters | None = None, +) -> ov.Model: + """ + Function that inserts XAI branch into IR. + + Usage: + model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION) + + :param model: Original IR. + :type model: ov.Model | str + :param task: Type of the task: CLASSIFICATION or DETECTION. + :type task: Task + :param insertion_parameters: Insertion parameters that parametrize white-box method, + that will be inserted into the model graph (optional). + :type insertion_parameters: InsertionParameters + :return: IR with XAI branch. + """ + + if has_xai(model): + logger.info("Provided IR model already contains XAI branch, return it as-is.") + return model + + method = WhiteBoxMethodFactory.create_method( + task=task, + model=model, + preprocess_fn=IdentityPreprocessFN(), + insertion_parameters=insertion_parameters, + prepare_model=False, + ) + + model_xai = method.prepare_model(load_model=False) + + if not has_xai(model_xai): + raise RuntimeError("Insertion of the XAI branch into the model was not successful.") + logger.info("Insertion of the XAI branch into the model was successful.") + + return model_xai diff --git a/openvino_xai/inserter/__init__.py b/openvino_xai/inserter/__init__.py index 109e8a05..1a5fca3a 100644 --- a/openvino_xai/inserter/__init__.py +++ b/openvino_xai/inserter/__init__.py @@ -3,7 +3,6 @@ """ Interface for inserting XAI branch into OV IR. """ -from openvino_xai.inserter.inserter import insert_xai from openvino_xai.inserter.parameters import ( ClassificationInsertionParameters, DetectionInsertionParameters, @@ -11,7 +10,6 @@ ) __all__ = [ - "insert_xai", "InsertionParameters", "ClassificationInsertionParameters", "DetectionInsertionParameters", diff --git a/openvino_xai/inserter/inserter.py b/openvino_xai/inserter/inserter.py index f34d0fee..b7f5c1d4 100644 --- a/openvino_xai/inserter/inserter.py +++ b/openvino_xai/inserter/inserter.py @@ -4,57 +4,7 @@ import openvino.runtime as ov from openvino.preprocess import PrePostProcessor -from openvino_xai import Task -from openvino_xai.common.utils import ( - SALIENCY_MAP_OUTPUT_NAME, - IdentityPreprocessFN, - has_xai, - logger, -) -from openvino_xai.inserter.parameters import InsertionParameters - - -def insert_xai( - model: ov.Model, - task: Task, - insertion_parameters: InsertionParameters | None = None, -) -> ov.Model: - """ - Function that inserts XAI branch into IR. - - Usage: - model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION) - - :param model: Original IR. - :type model: ov.Model | str - :param task: Type of the task: CLASSIFICATION or DETECTION. - :type task: Task - :param insertion_parameters: Insertion parameters that parametrize white-box method, - that will be inserted into the model graph (optional). - :type insertion_parameters: InsertionParameters - :return: IR with XAI branch. - """ - from openvino_xai.methods.factory import WhiteBoxMethodFactory - - if has_xai(model): - logger.info("Provided IR model already contains XAI branch, return it as-is.") - return model - - method = WhiteBoxMethodFactory.create_method( - task=task, - model=model, - preprocess_fn=IdentityPreprocessFN(), - insertion_parameters=insertion_parameters, - prepare_model=False, - ) - - model_xai = method.prepare_model(load_model=False) - - if not has_xai(model_xai): - raise RuntimeError("Insertion of the XAI branch into the model was not successful.") - logger.info("Insertion of the XAI branch into the model was successful.") - - return model_xai +from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME def insert_xai_branch_into_model( diff --git a/tests/integration/test_classification.py b/tests/integration/test_classification.py index cbbfe972..f6ac8a67 100644 --- a/tests/integration/test_classification.py +++ b/tests/integration/test_classification.py @@ -8,7 +8,7 @@ import openvino.runtime as ov import pytest -import openvino_xai as xai +import openvino_xai.api.api as xai from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import has_xai, retrieve_otx_model from openvino_xai.explainer.explainer import Explainer diff --git a/tests/unit/common/test_utils.py b/tests/unit/common/test_utils.py index 70f2fdc1..d5bec2a6 100644 --- a/tests/unit/common/test_utils.py +++ b/tests/unit/common/test_utils.py @@ -5,9 +5,9 @@ import openvino.runtime as ov +from openvino_xai.api.api import insert_xai from openvino_xai.common.parameters import Task from openvino_xai.common.utils import has_xai, retrieve_otx_model -from openvino_xai.inserter.inserter import insert_xai from tests.integration.test_classification import DEFAULT_CLS_MODEL DARA_DIR = Path(".data") diff --git a/tests/unit/explanation/test_explainer.py b/tests/unit/explanation/test_explainer.py index a1407bb7..ea2119b0 100644 --- a/tests/unit/explanation/test_explainer.py +++ b/tests/unit/explanation/test_explainer.py @@ -7,6 +7,7 @@ import openvino.runtime as ov import pytest +from openvino_xai.api.api import insert_xai from openvino_xai.common.parameters import Task from openvino_xai.common.utils import retrieve_otx_model from openvino_xai.explainer.explainer import Explainer @@ -16,7 +17,6 @@ TargetExplainGroup, ) from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn -from openvino_xai.inserter.inserter import insert_xai from openvino_xai.inserter.parameters import ClassificationInsertionParameters from tests.unit.explanation.test_explanation_utils import VOC_NAMES diff --git a/tests/unit/insertion/test_insertion.py b/tests/unit/insertion/test_insertion.py index ca46fc2a..c0a3e722 100644 --- a/tests/unit/insertion/test_insertion.py +++ b/tests/unit/insertion/test_insertion.py @@ -4,7 +4,7 @@ import pytest from openvino import runtime as ov -from openvino_xai import insert_xai +from openvino_xai.api.api import insert_xai from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import has_xai, retrieve_otx_model from openvino_xai.inserter.parameters import DetectionInsertionParameters