From 04d5d526d4adf5744dbdb471bb22579abba59eb4 Mon Sep 17 00:00:00 2001 From: Songki Choi Date: Tue, 10 Sep 2024 13:38:09 +0900 Subject: [PATCH] Auto-detect feature layer for Pytorch models (#64) * Support basic sub-string match * Add basic 4D feature map layer detection * Add N-last LayerNorm module detection for ViTs * Add integration test for auto layer detection * Apply review comments --- openvino_xai/methods/white_box/torch.py | 86 +++++++++++++++++++--- tests/intg/test_classification_timm.py | 10 ++- tests/unit/methods/test_factory.py | 6 +- tests/unit/methods/white_box/test_torch.py | 71 ++++++++++++++++-- 4 files changed, 152 insertions(+), 21 deletions(-) diff --git a/openvino_xai/methods/white_box/torch.py b/openvino_xai/methods/white_box/torch.py index a1e8e80f..2163dca8 100644 --- a/openvino_xai/methods/white_box/torch.py +++ b/openvino_xai/methods/white_box/torch.py @@ -41,25 +41,37 @@ def __init__( target_layer: str | None = None, embed_scaling: bool = True, device_name: str = "CPU", + prepare_model: bool = True, **kwargs, ): super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name) self._target_layer = target_layer self._embed_scaling = embed_scaling + if prepare_model: + self.prepare_model() + def prepare_model(self, load_model: bool = True) -> torch.nn.Module: """Return XAI inserted model.""" if has_xai(self._model): if load_model: self._model_compiled = self._model return self._model + if self._model_compiled is not None: + return self._model_compiled model = copy.deepcopy(self._model) + # Feature - feature_layer = model.get_submodule(self._target_layer) - feature_layer.register_forward_hook(self._feature_hook) + if self._target_layer: + feature_module = self._find_feature_module_by_name(model, self._target_layer) + else: + feature_module = self._find_feature_module_auto(model) + feature_module.register_forward_hook(self._feature_hook) + # Output model.register_forward_hook(self._output_hook) + setattr(model, "has_xai", True) model.eval() @@ -86,11 +98,51 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping: output[name] = data.numpy(force=True) return output + def _find_feature_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module: + """Search the last layer by name sub string match.""" + target_module = None + for name, module in model.named_modules(): + if target_name in name: + target_module = module + if target_module is None: + raise ValueError(f"{target_name} is not found in the torch model") + return target_module + + def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model.""" + # Find the last layer that outputs 4D tensor during temp forward pass + self._feature_module = None + self._num_modules = 0 + + def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None: + if isinstance(output, torch.Tensor): + module.index = self._num_modules + self._num_modules += 1 + shape = output.shape + if len(shape) == 4 and shape[2] > 1 and shape[3] > 1: + self._feature_module = module + + global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook) + try: + module.forward(torch.zeros((1, 3, 128, 128))) + finally: + global_hook_handle.remove() + if self._feature_module is None: + raise RuntimeError("Feature module with 4D output is not found in the torch model") + if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures + raise RuntimeError( + f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%" + ) + + return self._feature_module + def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: + """Manipulate feature map for saliency map generation.""" self._feature_map = output return output def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" return { "prediction": output, SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output), @@ -137,8 +189,8 @@ class TorchReciproCAM(TorchWhiteBoxMethod): """ def __init__(self, *args, optimize_gap: bool = False, **kwargs): - super().__init__(*args, **kwargs) self._optimize_gap = optimize_gap + super().__init__(*args, **kwargs) def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" @@ -153,16 +205,17 @@ def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tens return torch.cat(feature_maps) def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]: - batch_size, _, h, w = self._feature_shape - num_classes = output.shape[1] - predictions = output[:batch_size] - saliency_maps = output[batch_size:] - saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) - saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW + """Split combined output B0xC into BxC precition and BxCxHxW saliency map.""" + batch_size, _, h, w = self._feature_shape # B0xDxHxW + num_classes = output.shape[1] # C + predictions = output[:batch_size] # BxC + saliency_maps = output[batch_size:] # BHWxC + saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC + saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW if self._embed_scaling: saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w)) saliency_maps = self._normalize_map(saliency_maps) - saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) + saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW return { "prediction": predictions, SALIENCY_MAP_OUTPUT_NAME: saliency_maps, @@ -209,9 +262,20 @@ def __init__( normalize: bool = True, **kwargs, ) -> None: - super().__init__(*args, **kwargs) self._use_gaussian = use_gaussian self._use_cls_token = use_cls_token + super().__init__(*args, **kwargs) + + def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module: + """Detect feature module in the model by finding the 3rd last LayerNorm module.""" + self._feature_module = None + norm_modules = [m for _, m in module.named_modules() if isinstance(m, torch.nn.LayerNorm)] + + if len(norm_modules) < 3: + raise RuntimeError("Feature modules with LayerNorm are less than 3 in the torch model") + + self._feature_module = norm_modules[-3] + return self._feature_module def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor: """feature_maps -> vertical stack of feature_maps + mosaic_feature_maps.""" diff --git a/tests/intg/test_classification_timm.py b/tests/intg/test_classification_timm.py index d7667d61..5ce553b6 100644 --- a/tests/intg/test_classification_timm.py +++ b/tests/intg/test_classification_timm.py @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format): "deit_tiny_patch16_224.fb_in1k", ], ) - def test_torch_insert_xai_with_layer(self, model_id: str): + @pytest.mark.parametrize("detect", ["auto", "name"]) + def test_torch_insert_xai_with_layer(self, model_id: str, detect: str): xai_cfg = { "resnet18.a1_in1k": ("layer4", Method.RECIPROCAM), "efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM), @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str): model_dir = self.data_dir / "timm_models" / "converted_models" model, model_cfg = self.get_timm_model(model_id, model_dir) + target_layer = xai_cfg[model_id][0] if detect == "name" else None + explain_method = xai_cfg[model_id][1] + image = cv2.imread("tests/assets/cheetah_person.jpg") image = cv2.resize(image, dsize=model_cfg["input_size"][1:]) image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB) @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str): xai_model: torch.nn.Module = insert_xai( model, task=Task.CLASSIFICATION, - target_layer=xai_cfg[model_id][0], - explain_method=xai_cfg[model_id][1], + target_layer=target_layer, + explain_method=explain_method, ) with torch.no_grad(): diff --git a/tests/unit/methods/test_factory.py b/tests/unit/methods/test_factory.py index dac63211..0a0b24d8 100644 --- a/tests/unit/methods/test_factory.py +++ b/tests/unit/methods/test_factory.py @@ -151,7 +151,7 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path): assert str(exc_info.value) == "Requested explanation method abc is not implemented." -def test_create_torch_method(): +def test_create_torch_method(mocker: MockerFixture): model = {} with pytest.raises(ValueError): explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn()) @@ -172,6 +172,10 @@ def test_create_torch_method(): Task.DETECTION, model, get_postprocess_fn(), target_layer="" ) + mocker.patch.object(torch_method.TorchActivationMap, "prepare_model") + mocker.patch.object(torch_method.TorchReciproCAM, "prepare_model") + mocker.patch.object(torch_method.TorchViTReciproCAM, "prepare_model") + model = torch.nn.Module() explain_method = WhiteBoxMethodFactory.create_method( Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP diff --git a/tests/unit/methods/white_box/test_torch.py b/tests/unit/methods/white_box/test_torch.py index fa374f6c..8bea7d4d 100644 --- a/tests/unit/methods/white_box/test_torch.py +++ b/tests/unit/methods/white_box/test_torch.py @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module): def __init__(self, num_classes: int = 2): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) self.neck = torch.nn.AdaptiveAvgPool2d((1, 1)) self.output = torch.nn.LazyLinear(out_features=num_classes) @@ -48,24 +53,46 @@ def forward(self, x: torch.Tensor): class DummyVIT(torch.nn.Module): - def __init__(self, num_classes: int = 2): + def __init__(self, num_classes: int = 2, dim: int = 3): super().__init__() self.num_classes = num_classes - self.feature = torch.nn.Identity() + self.dim = dim + self.pre = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + ) + self.feature = torch.nn.Sequential( + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + torch.nn.Identity(), + ) + self.norm1 = torch.nn.LayerNorm(dim) + self.norm2 = torch.nn.LayerNorm(dim) + self.norm3 = torch.nn.LayerNorm(dim) self.output = torch.nn.LazyLinear(out_features=num_classes) def forward(self, x: torch.Tensor): b, c, h, w = x.shape + x = self.pre(x) x = x.reshape(b, c, h * w) x = x.transpose(1, 2) x = torch.cat([torch.rand((b, 1, c)), x], dim=1) x = self.feature(x) + x = x + self.norm1(x) + x = x + self.norm2(x) + x = x + self.norm3(x) x = self.output(x[:, 0]) return torch.nn.functional.softmax(x, dim=1) def test_torch_method(): model = DummyCNN() + + with pytest.raises(ValueError): + method = TorchWhiteBoxMethod(model=model, target_layer="something_else") + model_xai = method.prepare_model() + method = TorchWhiteBoxMethod(model=model, target_layer="feature") model_xai = method.prepare_model() assert has_xai(model_xai) @@ -101,7 +128,9 @@ def _output_hook( def test_prepare_model(): model = DummyCNN() - method = TorchWhiteBoxMethod(model=model, target_layer="feature") + method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False) + model_xai = method.prepare_model(load_model=False) + assert method._model_compiled is None model_xai = method.prepare_model(load_model=False) assert method._model_compiled is None assert model is not model_xai @@ -116,6 +145,35 @@ def test_prepare_model(): assert model_xai == model +def test_detect_feature_layer(): + model = DummyCNN() + method = TorchWhiteBoxMethod(model=model, target_layer=None) + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + assert method._feature_module is model_xai.feature + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward + + model = DummyVIT() + with pytest.raises(RuntimeError): + # 4D feature map search should fail for ViTs + method = TorchWhiteBoxMethod(model=model, target_layer=None) + + model = DummyVIT() + method = TorchViTReciproCAM(model=model, target_layer=None) + model_xai = method.prepare_model() + assert has_xai(model_xai) + data = np.random.rand(1, 3, 5, 5) + output = method.model_forward(data) + assert type(output) == dict + assert method._feature_module is model_xai.norm1 + output = method.model_forward(data) + assert type(output) == dict # still good for 2nd forward + + def test_activationmap() -> None: batch_size = 2 num_classes = 3 @@ -156,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None: def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None: batch_size = 2 num_classes = 3 - model = DummyVIT(num_classes=num_classes) + dim = 3 + model = DummyVIT(num_classes=num_classes, dim=dim) method = TorchViTReciproCAM( model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token ) model_xai = method.prepare_model() assert has_xai(model_xai) - data = np.random.rand(batch_size, 4, 5, 5) + data = np.random.rand(batch_size, dim, 5, 5) output = method.model_forward(data) assert type(output) == dict saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]