Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
negvet committed Aug 20, 2024
1 parent b7879b8 commit 738eeb9
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 10 deletions.
6 changes: 1 addition & 5 deletions openvino_xai/methods/black_box/aise.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def generate_saliency_map( # type: ignore
self.target_box = boxes[target]
self.target_label = labels[target]

if self.target_box[0] == self.target_box[2] or self.target_box[1] == self.target_box[3]:
if self.target_box[0] >= self.target_box[2] or self.target_box[1] >= self.target_box[3]:
continue

self.kernel_params_hist = collections.defaultdict(list)
Expand Down Expand Up @@ -430,10 +430,6 @@ def _process_box(self, padding_coef: float = 0.5) -> None:
x_to = min(target_box_scaled[2] + box_width * padding_coef, 1.0)
y_from = max(target_box_scaled[1] - box_height * padding_coef, 0.0)
y_to = min(target_box_scaled[3] + box_height * padding_coef, 1.0)

if x_from < x_to or y_from < y_to:
raise ValueError("Bounding box data is incorrect.")

self.bounds = Bounds([x_from, y_from], [x_to, y_to])

def _get_loss(self, data_perturbed: np.array) -> float:
Expand Down
114 changes: 109 additions & 5 deletions tests/intg/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from openvino_xai.common.utils import retrieve_otx_model
from openvino_xai.explainer.explainer import Explainer, ExplainMode
from openvino_xai.explainer.utils import get_preprocess_fn
from openvino_xai.methods.factory import WhiteBoxMethodFactory
from openvino_xai.methods.black_box.aise import AISEDetection
from openvino_xai.methods.factory import BlackBoxMethodFactory, WhiteBoxMethodFactory
from openvino_xai.methods.white_box.det_class_probability_map import (
DetClassProbabilityMap,
)
Expand Down Expand Up @@ -57,6 +58,7 @@
MODELS = list(MODEL_CONFIGS.keys())

DEFAULT_DET_MODEL = "det_mobilenetv2_atss_bccd"
FAST_DET_MODEL = "det_mobilenetv2_atss_bccd"

EXPLAIN_ALL_CLASSES = [
True,
Expand All @@ -66,11 +68,11 @@

class TestDetWB:
"""
Tests detection models in WB mode.
Tests detection models in white-box mode.
"""

image = cv2.imread("tests/assets/blood.jpg")
_ref_sal_maps_reciprocam = {
_ref_sal_maps = {
"det_mobilenetv2_atss_bccd": np.array([222, 243, 232, 229, 221, 217, 237, 246, 252, 255], dtype=np.uint8),
"det_mobilenetv2_ssd_bccd": np.array([83, 93, 61, 48, 110, 109, 78, 128, 158, 111], dtype=np.uint8),
"det_yolox_bccd": np.array([17, 13, 15, 60, 94, 52, 61, 47, 8, 40], dtype=np.uint8),
Expand Down Expand Up @@ -120,7 +122,7 @@ def test_detclassprobabilitymap(self, model_name, embed_scaling, explain_all_cla
assert explanation.saliency_map[0].shape == self._sal_map_size

actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16)
ref_sal_vals = self._ref_sal_maps_reciprocam[model_name].astype(np.uint8)
ref_sal_vals = self._ref_sal_maps[model_name].astype(np.uint8)
if embed_scaling:
# Reference values generated with embed_scaling=True
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)
Expand Down Expand Up @@ -198,7 +200,7 @@ def test_two_sequential_norms(self):
)

actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16)
ref_sal_vals = self._ref_sal_maps_reciprocam[DEFAULT_DET_MODEL].astype(np.uint8)
ref_sal_vals = self._ref_sal_maps[DEFAULT_DET_MODEL].astype(np.uint8)
# Reference values generated with embed_scaling=True
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

Expand Down Expand Up @@ -234,6 +236,108 @@ def get_default_model(self):
return model


class TestDetBB:
"""
Tests detection models in black-box mode.
"""

image = cv2.imread("tests/assets/blood.jpg")

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
self.data_dir = fxt_data_root

@pytest.mark.parametrize("model_name", MODELS)
def test_aisedetection(self, model_name):
retrieve_otx_model(self.data_dir, model_name)
model_path = self.data_dir / "otx_models" / (model_name + ".xml")
model = ov.Core().read_model(model_path)

preprocess_fn = get_preprocess_fn(
input_size=MODEL_CONFIGS[model_name].input_size,
hwc_to_chw=True,
)
explainer = Explainer(
model=model,
task=Task.DETECTION,
preprocess_fn=preprocess_fn,
postprocess_fn=self.postprocess_fn,
explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO
num_iterations_per_kernel=5,
divisors=[5],
)

target_list = [1]
explanation = explainer(
self.image,
targets=target_list,
resize=False,
colormap=False,
)
assert explanation is not None

target_class = target_list[0]
assert target_class in explanation.saliency_map
assert len(explanation.saliency_map) == len(target_list)
assert explanation.saliency_map[target_class].ndim == 2

def test_detection_visualizing(self):
model = self.get_default_model()

preprocess_fn = get_preprocess_fn(
input_size=MODEL_CONFIGS[FAST_DET_MODEL].input_size,
hwc_to_chw=True,
)
explainer = Explainer(
model=model,
task=Task.DETECTION,
preprocess_fn=preprocess_fn,
postprocess_fn=self.postprocess_fn,
explain_mode=ExplainMode.BLACKBOX, # defaults to AUTO
num_iterations_per_kernel=5,
divisors=[5],
)

target_list = [1]
explanation = explainer(
self.image,
targets=target_list,
overlay=True,
)
assert explanation is not None
assert explanation.shape == (480, 640, 3)

target_class = target_list[0]
assert len(explanation.saliency_map) == len(target_list)
assert target_class in explanation.saliency_map

def test_create_aise_detection_method(self):
"""Test create_white_box_detection_method."""
model = self.get_default_model()

preprocess_fn = get_preprocess_fn(
input_size=MODEL_CONFIGS[FAST_DET_MODEL].input_size,
hwc_to_chw=True,
)
det_xai_method = BlackBoxMethodFactory.create_method(
Task.DETECTION,
model,
preprocess_fn,
)
assert isinstance(det_xai_method, AISEDetection)

def get_default_model(self):
retrieve_otx_model(self.data_dir, FAST_DET_MODEL)
model_path = self.data_dir / "otx_models" / (FAST_DET_MODEL + ".xml")
model = ov.Core().read_model(model_path)
return model

@staticmethod
def postprocess_fn(x) -> np.ndarray:
"""Returns boxes, scores, labels."""
return x["boxes"][0][:, :4], x["boxes"][0][:, 4], x["labels"][0]


class TestExample:
"""Test sanity of examples/run_detection.py."""

Expand Down

0 comments on commit 738eeb9

Please sign in to comment.