-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move insert_xai into separate functional api module (#11)
* Move insert_xai into separate functional api module * fix import in samples * fix isort * Fix import * fix classification sample
- Loading branch information
Showing
9 changed files
with
67 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (C) 2023-2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
""" | ||
Finctional API. | ||
""" | ||
from openvino_xai.api.api import insert_xai | ||
|
||
__all__ = [ | ||
"insert_xai", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import openvino.runtime as ov | ||
|
||
from openvino_xai.common.parameters import Task | ||
from openvino_xai.common.utils import IdentityPreprocessFN, has_xai, logger | ||
from openvino_xai.inserter.parameters import InsertionParameters | ||
from openvino_xai.methods.factory import WhiteBoxMethodFactory | ||
|
||
|
||
def insert_xai( | ||
model: ov.Model, | ||
task: Task, | ||
insertion_parameters: InsertionParameters | None = None, | ||
) -> ov.Model: | ||
""" | ||
Function that inserts XAI branch into IR. | ||
Usage: | ||
model_xai = openvino_xai.insert_xai(model, task=Task.CLASSIFICATION) | ||
:param model: Original IR. | ||
:type model: ov.Model | str | ||
:param task: Type of the task: CLASSIFICATION or DETECTION. | ||
:type task: Task | ||
:param insertion_parameters: Insertion parameters that parametrize white-box method, | ||
that will be inserted into the model graph (optional). | ||
:type insertion_parameters: InsertionParameters | ||
:return: IR with XAI branch. | ||
""" | ||
|
||
if has_xai(model): | ||
logger.info("Provided IR model already contains XAI branch, return it as-is.") | ||
return model | ||
|
||
method = WhiteBoxMethodFactory.create_method( | ||
task=task, | ||
model=model, | ||
preprocess_fn=IdentityPreprocessFN(), | ||
insertion_parameters=insertion_parameters, | ||
prepare_model=False, | ||
) | ||
|
||
model_xai = method.prepare_model(load_model=False) | ||
|
||
if not has_xai(model_xai): | ||
raise RuntimeError("Insertion of the XAI branch into the model was not successful.") | ||
logger.info("Insertion of the XAI branch into the model was successful.") | ||
|
||
return model_xai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters