diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b92a60e..9ab6120a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,21 +4,26 @@ ### Summary +* Support OpenVINO IR (.xml) / ONNX (.onnx) model file for `Explainer` model * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models. ### What's Changed +* Use OVC converted models in func tests by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/44 +* Update CodeCov action by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/46 +* Refactor OpenVINO imports by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/45 +* Support OV IR / ONNX model file for Explainer by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/47 +* Try CNN -> ViT assumption for IR insertion by @goodsong81 in https://github.com/openvinotoolkit/openvino_xai/pull/48 * Enable AISE: Adaptive Input Sampling for Explanation of Black-box Models by @negvet in https://github.com/openvinotoolkit/openvino_xai/pull/49 ### Known Issues -* OpenVINO IR branch insertion not working for models converted directly from torch models in https://github.com/openvinotoolkit/openvino_xai/issues/26 * Runtime error from ONNX / OpenVINO IR models while conversion or inference for XAI in https://github.com/openvinotoolkit/openvino_xai/issues/29 * Models not supported by white box XAI methods in https://github.com/openvinotoolkit/openvino_xai/issues/30 ### New Contributors -* +* N/A --- diff --git a/openvino_xai/methods/factory.py b/openvino_xai/methods/factory.py index 199729d3..07c6680d 100644 --- a/openvino_xai/methods/factory.py +++ b/openvino_xai/methods/factory.py @@ -107,14 +107,21 @@ def create_classification_method( if explain_method is None or explain_method == Method.RECIPROCAM: logger.info("Using ReciproCAM method (for CNNs).") - return ReciproCAM( - model, - preprocess_fn, - target_layer, - embed_scaling, - device_name, - **kwargs, - ) + try: + return ReciproCAM( + model, + preprocess_fn, + target_layer, + embed_scaling, + device_name, + **kwargs, + ) + except Exception as e: + if explain_method is None: + logger.info(f"Not successfull due to '{e}'. Trying another methods.") + explain_method = Method.VITRECIPROCAM + else: + raise e if explain_method == Method.VITRECIPROCAM: logger.info("Using ViTReciproCAM method (for vision transformers).") return ViTReciproCAM( diff --git a/tests/func/test_classification_timm_full.py b/tests/func/test_classification_timm_full.py index 6baa7913..e7bc6a5f 100644 --- a/tests/func/test_classification_timm_full.py +++ b/tests/func/test_classification_timm_full.py @@ -29,48 +29,6 @@ TEST_MODELS = timm.list_models(pretrained=True) -CNN_MODELS = [ - "bat_resnext", - "convnext", - "cs3", - "cs3darknet", - "darknet", - "densenet", - "dla", - "dpn", - "efficientnet", - "ese_vovnet", - "fbnet", - "gernet", - "ghostnet", - "hardcorenas", - "hrnet", - "inception", - "lcnet", - "legacy_", - "mixnet", - "mnasnet", - "mobilenet", - "nasnet", - "regnet", - "repvgg", - "res2net", - "res2next", - "resnest", - "resnet", - "resnext", - "rexnet", - "selecsls", - "semnasnet", - "senet", - "seresnext", - "spnasnet", - "tinynet", - "tresnet", - "vgg", - "xception", -] - SUPPORTED_BUT_FAILED_BY_BB_MODELS = {} NOT_SUPPORTED_BY_BB_MODELS = { @@ -82,7 +40,7 @@ "dm_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144", "eca_nfnet": "openvino._pyopenvino.GeneralFailure: Check 'false' failed at src/frontends/onnx/frontend/src/frontend.cpp:144", "eva_giant": "RuntimeError: The serialized model is larger than the 2GiB limit imposed by the protobuf library.", - "halo": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.", + # "halo": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.", "nf_regnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", "nf_resnet": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", "nfnet_l0": "RuntimeError: Exception from src/inference/src/cpp/core.cpp:90: Training mode of BatchNormalization is not supported.", @@ -110,6 +68,7 @@ **NOT_SUPPORTED_BY_BB_MODELS, # Killed on WB "beit_large_patch16_512": "Failed to allocate 94652825600 bytes of memory", + "convmixer_1536_20": "OOM Killed", "eva_large_patch14_336": "OOM Killed", "eva02_base_patch14_448": "OOM Killed", "eva02_large_patch14_448": "OOM Killed", @@ -127,32 +86,23 @@ "xcit_small_12_p8_384": "OOM Killed", "xcit_small_24_p8_384": "OOM Killed", # Not expected to work for now - "botnet26t_256": "Only two outputs of the between block Add node supported, but got 1", - "caformer": "One (and only one) of the nodes has to be Add type. But got MVN and Multiply.", "cait_": "Cannot create an empty Constant. Please provide valid data.", "coat_": "Only two outputs of the between block Add node supported, but got 1.", - "coatn": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "convmixer": "Cannot find output backbone_node in auto mode, please provide target_layer.", "crossvit": "One (and only one) of the nodes has to be Add type. But got StridedSlice and StridedSlice.", - "davit": "Only two outputs of the between block Add node supported, but got 1.", - "eca_botnext": "Only two outputs of the between block Add node supported, but got 1.", - "edgenext": "Only two outputs of the between block Add node supported, but got 1", - "efficientformer": "Cannot find output backbone_node in auto mode.", - "focalnet": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "gcvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "levit_": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "maxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "maxxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "mobilevitv2": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "nest_": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "poolformer": "Cannot find output backbone_node in auto mode, please provide target_layer.", - "sebotnet": "Only two outputs of the between block Add node supported, but got 1.", + # work in CNN mode -> "davit": "Only two outputs of the between block Add node supported, but got 1.", + # work in CNN mode -> "efficientformer": "Cannot find output backbone_node in auto mode.", + # work in CNN mode -> "focalnet": "Cannot find output backbone_node in auto mode, please provide target_layer.", + # work in CNN mode -> "gcvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", + "levit_": "Check 'TRShape::merge_into(output_shape, in_copy)' failed", + # work in CNN mode -> "maxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", + # work in CNN mode -> "maxxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.", + # work in CNN mode -> "mobilevitv2": "Cannot find output backbone_node in auto mode, please provide target_layer.", + # work in CNN mode -> "nest_": "Cannot find output backbone_node in auto mode, please provide target_layer.", + # work in CNN mode -> "poolformer": "Cannot find output backbone_node in auto mode, please provide target_layer.", "sequencer2d": "Cannot find output backbone_node in auto mode, please provide target_layer.", "tnt_s_patch16_224": "Only two outputs of the between block Add node supported, but got 1.", - "tresnet": "Batch shape of the output should be dynamic, but it is static.", "twins": "One (and only one) of the nodes has to be Add type. But got ShapeOf and Transpose.", - "visformer": "Cannot find output backbone_node in auto mode, please provide target_layer", - "vit_relpos_base_patch32_plus_rpn_256": "Check 'TRShape::merge_into(output_shape, in_copy)' failed", + # work in CNN mode -> "visformer": "Cannot find output backbone_node in auto mode, please provide target_layer", "vit_relpos_medium_patch16_rpn_224": "ValueError in openvino_xai/methods/white_box/recipro_cam.py:215", } @@ -184,11 +134,7 @@ def test_classification_white_box(self, model_id, dump_maps=False): if failed_model in model_id: pytest.xfail(reason=SUPPORTED_BUT_FAILED_BY_WB_MODELS[failed_model]) - explain_method = Method.VITRECIPROCAM - for cnn_model in CNN_MODELS: - if cnn_model in model_id: - explain_method = Method.RECIPROCAM - break + explain_method = None timm_model, model_cfg = self.get_timm_model(model_id) input_size = list(timm_model.default_cfg["input_size"]) diff --git a/tests/unit/methods/white_box/test_create_method.py b/tests/unit/methods/test_factory.py similarity index 70% rename from tests/unit/methods/white_box/test_create_method.py rename to tests/unit/methods/test_factory.py index acb9d7a7..987b4bda 100644 --- a/tests/unit/methods/white_box/test_create_method.py +++ b/tests/unit/methods/test_factory.py @@ -5,11 +5,13 @@ import openvino as ov import pytest +from pytest_mock import MockerFixture from openvino_xai.common.parameters import Method, Task from openvino_xai.common.utils import retrieve_otx_model -from openvino_xai.explainer.utils import get_preprocess_fn -from openvino_xai.methods.factory import WhiteBoxMethodFactory +from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn +from openvino_xai.methods.black_box.aise import AISE +from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory from openvino_xai.methods.white_box.activation_map import ActivationMap from openvino_xai.methods.white_box.det_class_probability_map import ( DetClassProbabilityMap, @@ -75,6 +77,40 @@ def test_create_wb_cls_vit_method(fxt_data_root: Path): assert isinstance(explain_method, ViTReciproCAM) +def test_create_wb_cls_guess_method(mocker: MockerFixture): + model = mocker.MagicMock() + # method=None -> ReciproCAM fail -> ViTReciproCAM + recipro_cam = mocker.patch("openvino_xai.methods.factory.ReciproCAM", side_effect=Exception("DUMMY REASON")) + vit_recipro_cam = mocker.patch("openvino_xai.methods.factory.ViTReciproCAM") + explain_method = WhiteBoxMethodFactory.create_method( + task=Task.CLASSIFICATION, + model=model, + preprocess_fn=PREPROCESS_FN, + explain_method=None, + ) + vit_recipro_cam.assert_called() + # method=ReciproCAM -> ReciproCAM fail -> Exception + recipro_cam = mocker.patch("openvino_xai.methods.factory.ReciproCAM", side_effect=Exception("DUMMY REASON")) + vit_recipro_cam = mocker.patch("openvino_xai.methods.factory.ViTReciproCAM") + with pytest.raises(Exception) as exc_info: + explain_method = WhiteBoxMethodFactory.create_method( + task=Task.CLASSIFICATION, + model=model, + preprocess_fn=PREPROCESS_FN, + explain_method=Method.RECIPROCAM, + ) + vit_recipro_cam.assert_not_called() + assert str(exc_info.value) == "DUMMY REASON" + + +def test_create_bb_cls_vit_method(fxt_data_root: Path): + retrieve_otx_model(fxt_data_root, VIT_MODEL) + model_path = fxt_data_root / "otx_models" / (VIT_MODEL + ".xml") + model_vit = ov.Core().read_model(model_path) + explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model_vit, get_postprocess_fn()) + assert isinstance(explain_method, AISE) + + def test_create_wb_det_cnn_method(fxt_data_root: Path): retrieve_otx_model(fxt_data_root, DEFAULT_DET_MODEL) model_path = fxt_data_root / "otx_models" / (DEFAULT_DET_MODEL + ".xml")