Skip to content

Commit

Permalink
Fix func / perf tests & doc (#77)
Browse files Browse the repository at this point in the history
* Update arch image

* Fix perf efficiency test

* Pass optional input size for torch insert_xai()

* Fix spatial dim detection

* Update skipped models

* Fix pre-commit

* Fix minor

* Fix unit test

* Update version to 1.1.0
  • Loading branch information
goodsong81 authored Sep 30, 2024
1 parent e1f0fb8 commit 3a58970
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 85 deletions.
2 changes: 1 addition & 1 deletion docs/source/_static/ovxai-architecture.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions examples/run_torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_insert_xai_torch(args: list[str]):
logger.info(f"Torch model prediction: classes ({probs.shape[-1]}) -> label ({label}) -> prob ({probs[0, label]})")

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION, input_size=input_size) # Optional input size arg to help insertion

# Torch XAI model inference
model_xai.eval()
Expand Down Expand Up @@ -121,7 +121,7 @@ def run_insert_xai_torch_to_onnx(args: list[str]):
image_norm = image_norm[None, :] # CxHxW -> 1xCxHxW

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION, input_size=input_size)

# ONNX model conversion
model_path = Path(args.output_dir) / "model.onnx"
Expand Down Expand Up @@ -184,7 +184,7 @@ def run_insert_xai_torch_to_openvino(args: list[str]):
image_norm = image_norm[None, :] # CxHxW -> 1xCxHxW

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION, input_size=input_size)

# OpenVINO model conversion
ov_model = ov.convert_model(
Expand Down
24 changes: 19 additions & 5 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ def __init__(
embed_scaling: bool = True,
device_name: str = "CPU",
prepare_model: bool = True,
input_size: tuple[int, int] = (224, 224), # For fixed input size models like ViT
**kwargs,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self._target_layer = target_layer
self._embed_scaling = embed_scaling
self._input_size = input_size

if prepare_model:
self.prepare_model()
Expand All @@ -66,6 +68,7 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
return self._model_compiled

model = copy.deepcopy(self._model)
model.eval()

# Feature
if self._target_layer:
Expand All @@ -78,7 +81,6 @@ def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
model.register_forward_hook(self._output_hook)

setattr(model, "has_xai", True)
model.eval()

if load_model:
self._model_compiled = model
Expand Down Expand Up @@ -119,17 +121,26 @@ def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
self._feature_module = None
self._num_modules = 0

def _has_spatial_dim(shape: torch.Size):
if len(shape) != 4: # BxCxHxW
return False
if shape[2] <= 1 or shape[3] <= 1: # H > 1 and W > 1
return False
if shape[1] <= shape[2] or shape[1] <= shape[3]: # H < C and H < C for feature maps generally
return False
return True

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(output, torch.Tensor):
module.index = self._num_modules
self._num_modules += 1
shape = output.shape
if len(shape) == 4 and shape[2] > 1 and shape[3] > 1:
if _has_spatial_dim(shape):
self._feature_module = module

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(torch.zeros((1, 3, 128, 128)))
module.forward(torch.zeros((1, 3, *self._input_size)))
finally:
global_hook_handle.remove()
if self._feature_module is None:
Expand Down Expand Up @@ -274,10 +285,13 @@ def __init__(
def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model by finding the 3rd last LayerNorm module."""
self._feature_module = None
norm_modules = [m for _, m in module.named_modules() if isinstance(m, torch.nn.LayerNorm)]
norm_modules = []
for name, sub_module in module.named_modules():
if "LayerNorm" in type(sub_module).__name__ or "BatchNorm" in type(sub_module).__name__ or "norm1" in name:
norm_modules.append(sub_module)

if len(norm_modules) < 3:
raise RuntimeError("Feature modules with LayerNorm are less than 3 in the torch model")
raise RuntimeError("Feature modules with LayerNorm or BatchNorm are less than 3 in the torch model")

self._feature_module = norm_modules[-3]
return self._feature_module
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "openvino_xai"
version = "1.1.0rc0"
version = "1.1.0"
dependencies = [
"openvino-dev==2024.4",
"opencv-python",
Expand Down
65 changes: 6 additions & 59 deletions tests/func/test_torch_onnx_timm_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,12 @@

TEST_MODELS = timm.list_models(pretrained=True)

NOT_SUPPORTED_BY_BB_MODELS = {
SKIPPED_MODELS = {
"repvit": "urllib.error.HTTPError: HTTP Error 404: Not Found",
"tf_efficientnet_cc": "torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.",
"vit_base_r50_s16_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer",
"vit_gigantic_patch16_224_ijepa.in22k": "RuntimeError: shape '[1, 13, 13, -1]' is invalid for input of size 274560",
"vit_huge_patch14_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer",
"vit_large_patch32_224.orig_in21k": "RuntimeError: Error(s) in loading state_dict for VisionTransformer",
"volo_": "RuntimeError: Exception from src/core/src/dimension.cpp:227: Cannot get length of dynamic dimension",
}

SUPPORTED_BUT_FAILED_BY_WB_MODELS = {
"swin": "Only two outputs of the between block Add node supported, but got 1. Try to use black-box.",
"vit_base_patch16_rpn_224": "Number of normalization outputs > 1",
"vit_relpos_medium_patch16_rpn_224": "ValueError in openvino_xai/methods/white_box/recipro_cam.py:215",
}

NOT_SUPPORTED_BY_WB_MODELS = {
**NOT_SUPPORTED_BY_BB_MODELS,
# Killed on WB
"beit_large_patch16_512": "Failed to allocate 94652825600 bytes of memory",
"convmixer_1536_20": "OOM Killed",
"eva_large_patch14_336": "OOM Killed",
"eva02_base_patch14_448": "OOM Killed",
"eva02_large_patch14_448": "OOM Killed",
"mobilevit_": "Segmentation fault",
"mobilevit_xxs": "Segmentation fault",
"mvitv2_base.fb_in1k": "Segmentation fault",
"mvitv2_large": "OOM Killed",
"mvitv2_small": "Segmentation fault",
"mvitv2_tiny": "Segmentation fault",
"pit_": "Segmentation fault",
"pvt_": "Segmentation fault",
"tf_efficientnet_l2.ns_jft_in1k": "OOM Killed",
"xcit_large": "Failed to allocate 81581875200 bytes of memory",
"xcit_medium_24_p8_384": "OOM Killed",
"xcit_small_12_p8_384": "OOM Killed",
"xcit_small_24_p8_384": "OOM Killed",
# Not expected to work for now
"cait_": "Cannot create an empty Constant. Please provide valid data.",
"coat_": "Only two outputs of the between block Add node supported, but got 1.",
"crossvit": "One (and only one) of the nodes has to be Add type. But got StridedSlice and StridedSlice.",
# work in CNN mode -> "davit": "Only two outputs of the between block Add node supported, but got 1.",
# work in CNN mode -> "efficientformer": "Cannot find output backbone_node in auto mode.",
# work in CNN mode -> "focalnet": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "gcvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"levit_": "Check 'TRShape::merge_into(output_shape, in_copy)' failed",
# work in CNN mode -> "maxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "maxxvit": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "mobilevitv2": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "nest_": "Cannot find output backbone_node in auto mode, please provide target_layer.",
# work in CNN mode -> "poolformer": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"sequencer2d": "Cannot find output backbone_node in auto mode, please provide target_layer.",
"tnt_s_patch16_224": "Only two outputs of the between block Add node supported, but got 1.",
"twins": "One (and only one) of the nodes has to be Add type. But got ShapeOf and Transpose.",
# work in CNN mode -> "visformer": "Cannot find output backbone_node in auto mode, please provide target_layer",
}


Expand All @@ -89,15 +40,10 @@ def setup(self, fxt_clear_cache):
self.clear_cache_converted_models = fxt_clear_cache

@pytest.mark.parametrize("model_id", TEST_MODELS)
# @pytest.mark.parametrize("model_id", ["resnet18.a1_in1k"])
def test_insert_xai(self, model_id, fxt_output_root: Path):
# for skipped_model in NOT_SUPPORTED_BY_WB_MODELS.keys():
# if skipped_model in model_id:
# pytest.skip(reason=NOT_SUPPORTED_BY_WB_MODELS[skipped_model])

# for failed_model in SUPPORTED_BUT_FAILED_BY_WB_MODELS.keys():
# if failed_model in model_id:
# pytest.xfail(reason=SUPPORTED_BUT_FAILED_BY_WB_MODELS[failed_model])
for skipped_model in SKIPPED_MODELS.keys():
if skipped_model in model_id:
pytest.skip(reason=SKIPPED_MODELS[skipped_model])

# Load Torch model from timm
model = timm.create_model(model_id, in_chans=3, pretrained=True)
Expand All @@ -114,7 +60,7 @@ def test_insert_xai(self, model_id, fxt_output_root: Path):
image_norm = image_norm[None, :] # CxHxW -> 1xCxHxW

# Insert XAI head
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION)
model_xai: torch.nn.Module = insert_xai(model, Task.CLASSIFICATION, input_size=input_size)

# Torch XAI model inference
model_xai.eval()
Expand Down Expand Up @@ -164,6 +110,7 @@ def test_insert_xai(self, model_id, fxt_output_root: Path):
assert saliency_map.dtype == np.uint8

# Clean up
model_path.unlink()
self.clear_cache()

def clear_cache(self):
Expand Down
18 changes: 10 additions & 8 deletions tests/perf/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def pytest_addoption(parser: pytest.Parser):
"Defaults to 10.",
)
parser.addoption(
"--num-masks",
"--preset",
action="store",
default=5000,
help="Number of masks for black box methods." "Defaults to 5000.",
default="speed",
choices=("speed", "balance", "quality"),
help="Efficiency preset for blackbox methods. Defaults to 'speed'.",
)
parser.addoption(
"--dataset-root",
Expand All @@ -57,13 +58,13 @@ def fxt_num_repeat(request: pytest.FixtureRequest) -> int:


@pytest.fixture(scope="session")
def fxt_num_masks(request: pytest.FixtureRequest) -> int:
"""Number of masks for black box methods."""
num_masks = int(request.config.getoption("--num-masks"))
msg = f"{num_masks = }"
def fxt_preset(request: pytest.FixtureRequest) -> str:
"""Efficiency preset for black box methods."""
preset = request.config.getoption("--preset")
msg = f"{preset = }"
log.info(msg)
print(msg)
return num_masks
return preset


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -148,6 +149,7 @@ def fxt_perf_summary(
"Method.RECIPROCAM": "RECIPROCAM",
"Method.VITRECIPROCAM": "RECIPROCAM",
"Method.RISE": "RISE",
"Method.AISE": "AISE",
}
)
raw_data.to_csv(fxt_output_root / "perf-raw-all.csv", index=False)
Expand Down
19 changes: 12 additions & 7 deletions tests/perf/test_performance.py → tests/perf/test_efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from openvino_xai.common.parameters import Method, Task
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.utils import get_postprocess_fn, get_preprocess_fn
from openvino_xai.methods.black_box.base import Preset
from openvino_xai.utils.model_export import export_to_onnx
from tests.perf.perf_tests_utils import (
clear_cache,
Expand All @@ -38,7 +39,7 @@
)


class TestPerfClassificationTimm:
class TestEfficiency:
clear_cache_converted_models = False
clear_cache_hf_models = False
supported_num_classes = {
Expand Down Expand Up @@ -122,12 +123,15 @@ def test_classification_white_box(self, model_id: str, fxt_num_repeat: int, fxt_
records.append(record)

df = pd.DataFrame(records)
df.to_csv(self.output_dir / f"perf-raw-wb-{model_id}.csv")
df.to_csv(self.output_dir / f"perf-raw-wb-{model_id}-{explain_method}.csv")

clear_cache(self.data_dir, self.cache_dir, self.clear_cache_converted_models, self.clear_cache_hf_models)

@pytest.mark.parametrize("model_id", TEST_MODELS)
def test_classification_black_box(self, model_id, fxt_num_repeat: int, fxt_num_masks: int, fxt_tags: dict):
@pytest.mark.parametrize("method", [Method.AISE, Method.RISE])
def test_classification_black_box(
self, model_id: str, method: Method, fxt_num_repeat: int, fxt_preset: str, fxt_tags: dict
):
timm_model, model_cfg = get_timm_model(model_id, self.supported_num_classes)

onnx_path = self.data_dir / "timm_models" / "converted_models" / model_id / "model_fp32.onnx"
Expand Down Expand Up @@ -163,9 +167,9 @@ def test_classification_black_box(self, model_id, fxt_num_repeat: int, fxt_num_m

record = fxt_tags.copy()
record["model"] = model_id
record["method"] = Method.RISE
record["method"] = method
record["seed"] = seed
record["num_masks"] = fxt_num_masks
record["preset"] = fxt_preset

start_time = time()

Expand All @@ -175,14 +179,15 @@ def test_classification_black_box(self, model_id, fxt_num_repeat: int, fxt_num_m
preprocess_fn=preprocess_fn,
postprocess_fn=postprocess_fn,
explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO
explain_method=method, # defaults to AISE
)
explanation = explainer(
image,
targets=[target_class],
resize=True,
colormap=True,
overlay=True,
num_masks=fxt_num_masks, # kwargs of the RISE algo
preset=Preset(fxt_preset), # kwargs of the black box algo
)

explain_time = time() - start_time
Expand All @@ -194,6 +199,6 @@ def test_classification_black_box(self, model_id, fxt_num_repeat: int, fxt_num_m
records.append(record)

df = pd.DataFrame(records)
df.to_csv(self.output_dir / f"perf-raw-bb-{model_id}.csv", index=False)
df.to_csv(self.output_dir / f"perf-raw-bb-{model_id}-{method}.csv", index=False)

clear_cache(self.data_dir, self.cache_dir, self.clear_cache_converted_models, self.clear_cache_hf_models)
2 changes: 1 addition & 1 deletion tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, num_classes: int = 2):
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.LazyConv2d(256, (1, 1)),
)
self.neck = torch.nn.AdaptiveAvgPool2d((1, 1))
self.output = torch.nn.LazyLinear(out_features=num_classes)
Expand Down Expand Up @@ -123,7 +124,6 @@ def _output_hook(
assert type(output) == dict
prediction = output["prediction"]
saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]
assert np.all(saliency_maps == prediction)


def test_prepare_model():
Expand Down

0 comments on commit 3a58970

Please sign in to comment.