Skip to content

Commit

Permalink
Auto-detect feature layer for Pytorch models (#64)
Browse files Browse the repository at this point in the history
* Support basic sub-string match

* Add basic 4D feature map layer detection

* Add N-last LayerNorm module detection for ViTs

* Add integration test for auto layer detection

* Apply review comments
  • Loading branch information
goodsong81 authored Sep 10, 2024
1 parent 8b0ddf9 commit 04d5d52
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 21 deletions.
86 changes: 75 additions & 11 deletions openvino_xai/methods/white_box/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,37 @@ def __init__(
target_layer: str | None = None,
embed_scaling: bool = True,
device_name: str = "CPU",
prepare_model: bool = True,
**kwargs,
):
super().__init__(model=model, preprocess_fn=preprocess_fn, device_name=device_name)
self._target_layer = target_layer
self._embed_scaling = embed_scaling

if prepare_model:
self.prepare_model()

def prepare_model(self, load_model: bool = True) -> torch.nn.Module:
"""Return XAI inserted model."""
if has_xai(self._model):
if load_model:
self._model_compiled = self._model
return self._model
if self._model_compiled is not None:
return self._model_compiled

model = copy.deepcopy(self._model)

# Feature
feature_layer = model.get_submodule(self._target_layer)
feature_layer.register_forward_hook(self._feature_hook)
if self._target_layer:
feature_module = self._find_feature_module_by_name(model, self._target_layer)
else:
feature_module = self._find_feature_module_auto(model)
feature_module.register_forward_hook(self._feature_hook)

# Output
model.register_forward_hook(self._output_hook)

setattr(model, "has_xai", True)
model.eval()

Expand All @@ -86,11 +98,51 @@ def model_forward(self, x: np.ndarray, preprocess: bool = True) -> Mapping:
output[name] = data.numpy(force=True)
return output

def _find_feature_module_by_name(self, model: torch.nn.Module, target_name: str) -> torch.nn.Module:
"""Search the last layer by name sub string match."""
target_module = None
for name, module in model.named_modules():
if target_name in name:
target_module = module
if target_module is None:
raise ValueError(f"{target_name} is not found in the torch model")
return target_module

def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model."""
# Find the last layer that outputs 4D tensor during temp forward pass
self._feature_module = None
self._num_modules = 0

def _detect_hook(module: torch.nn.Module, inputs: Any, output: Any) -> None:
if isinstance(output, torch.Tensor):
module.index = self._num_modules
self._num_modules += 1
shape = output.shape
if len(shape) == 4 and shape[2] > 1 and shape[3] > 1:
self._feature_module = module

global_hook_handle = torch.nn.modules.module.register_module_forward_hook(_detect_hook)
try:
module.forward(torch.zeros((1, 3, 128, 128)))
finally:
global_hook_handle.remove()
if self._feature_module is None:
raise RuntimeError("Feature module with 4D output is not found in the torch model")
if self._feature_module.index / self._num_modules < 0.5: # Check if ViT-like architectures
raise RuntimeError(
f"Modules with 4D output end in early-half stages: {100 * self._feature_module.index / self._num_modules}%"
)

return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""Manipulate feature map for saliency map generation."""
self._feature_map = output
return output

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
return {
"prediction": output,
SALIENCY_MAP_OUTPUT_NAME: torch.empty_like(output),
Expand Down Expand Up @@ -137,8 +189,8 @@ class TorchReciproCAM(TorchWhiteBoxMethod):
"""

def __init__(self, *args, optimize_gap: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self._optimize_gap = optimize_gap
super().__init__(*args, **kwargs)

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand All @@ -153,16 +205,17 @@ def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tens
return torch.cat(feature_maps)

def _output_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> Dict[str, torch.Tensor]:
batch_size, _, h, w = self._feature_shape
num_classes = output.shape[1]
predictions = output[:batch_size]
saliency_maps = output[batch_size:]
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes])
saliency_maps = saliency_maps.transpose(1, 2) # BxHWxC -> BxCxHW
"""Split combined output B0xC into BxC precition and BxCxHxW saliency map."""
batch_size, _, h, w = self._feature_shape # B0xDxHxW
num_classes = output.shape[1] # C
predictions = output[:batch_size] # BxC
saliency_maps = output[batch_size:] # BHWxC
saliency_maps = saliency_maps.reshape([batch_size, h * w, num_classes]) # BxHWxC
saliency_maps = saliency_maps.transpose(1, 2) # BxCxHW
if self._embed_scaling:
saliency_maps = saliency_maps.reshape((batch_size * num_classes, h * w))
saliency_maps = self._normalize_map(saliency_maps)
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w])
saliency_maps = saliency_maps.reshape([batch_size, num_classes, h, w]) # BxCxHxW
return {
"prediction": predictions,
SALIENCY_MAP_OUTPUT_NAME: saliency_maps,
Expand Down Expand Up @@ -209,9 +262,20 @@ def __init__(
normalize: bool = True,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._use_gaussian = use_gaussian
self._use_cls_token = use_cls_token
super().__init__(*args, **kwargs)

def _find_feature_module_auto(self, module: torch.nn.Module) -> torch.nn.Module:
"""Detect feature module in the model by finding the 3rd last LayerNorm module."""
self._feature_module = None
norm_modules = [m for _, m in module.named_modules() if isinstance(m, torch.nn.LayerNorm)]

if len(norm_modules) < 3:
raise RuntimeError("Feature modules with LayerNorm are less than 3 in the torch model")

self._feature_module = norm_modules[-3]
return self._feature_module

def _feature_hook(self, module: torch.nn.Module, inputs: Any, output: torch.Tensor) -> torch.Tensor:
"""feature_maps -> vertical stack of feature_maps + mosaic_feature_maps."""
Expand Down
10 changes: 7 additions & 3 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def test_model_format(self, model_id, explain_mode, model_format):
"deit_tiny_patch16_224.fb_in1k",
],
)
def test_torch_insert_xai_with_layer(self, model_id: str):
@pytest.mark.parametrize("detect", ["auto", "name"])
def test_torch_insert_xai_with_layer(self, model_id: str, detect: str):
xai_cfg = {
"resnet18.a1_in1k": ("layer4", Method.RECIPROCAM),
"efficientnet_b0.ra_in1k": ("bn2", Method.RECIPROCAM),
Expand All @@ -465,6 +466,9 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
model_dir = self.data_dir / "timm_models" / "converted_models"
model, model_cfg = self.get_timm_model(model_id, model_dir)

target_layer = xai_cfg[model_id][0] if detect == "name" else None
explain_method = xai_cfg[model_id][1]

image = cv2.imread("tests/assets/cheetah_person.jpg")
image = cv2.resize(image, dsize=model_cfg["input_size"][1:])
image = cv2.cvtColor(image, code=cv2.COLOR_BGR2RGB)
Expand All @@ -478,8 +482,8 @@ def test_torch_insert_xai_with_layer(self, model_id: str):
xai_model: torch.nn.Module = insert_xai(
model,
task=Task.CLASSIFICATION,
target_layer=xai_cfg[model_id][0],
explain_method=xai_cfg[model_id][1],
target_layer=target_layer,
explain_method=explain_method,
)

with torch.no_grad():
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/methods/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_create_wb_det_cnn_method(fxt_data_root: Path):
assert str(exc_info.value) == "Requested explanation method abc is not implemented."


def test_create_torch_method():
def test_create_torch_method(mocker: MockerFixture):
model = {}
with pytest.raises(ValueError):
explain_method = BlackBoxMethodFactory.create_method(Task.CLASSIFICATION, model, get_postprocess_fn())
Expand All @@ -172,6 +172,10 @@ def test_create_torch_method():
Task.DETECTION, model, get_postprocess_fn(), target_layer=""
)

mocker.patch.object(torch_method.TorchActivationMap, "prepare_model")
mocker.patch.object(torch_method.TorchReciproCAM, "prepare_model")
mocker.patch.object(torch_method.TorchViTReciproCAM, "prepare_model")

model = torch.nn.Module()
explain_method = WhiteBoxMethodFactory.create_method(
Task.CLASSIFICATION, model, get_postprocess_fn(), explain_method=Method.ACTIVATIONMAP
Expand Down
71 changes: 65 additions & 6 deletions tests/unit/methods/white_box/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ class DummyCNN(torch.nn.Module):
def __init__(self, num_classes: int = 2):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.neck = torch.nn.AdaptiveAvgPool2d((1, 1))
self.output = torch.nn.LazyLinear(out_features=num_classes)

Expand All @@ -48,24 +53,46 @@ def forward(self, x: torch.Tensor):


class DummyVIT(torch.nn.Module):
def __init__(self, num_classes: int = 2):
def __init__(self, num_classes: int = 2, dim: int = 3):
super().__init__()
self.num_classes = num_classes
self.feature = torch.nn.Identity()
self.dim = dim
self.pre = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
)
self.feature = torch.nn.Sequential(
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
torch.nn.Identity(),
)
self.norm1 = torch.nn.LayerNorm(dim)
self.norm2 = torch.nn.LayerNorm(dim)
self.norm3 = torch.nn.LayerNorm(dim)
self.output = torch.nn.LazyLinear(out_features=num_classes)

def forward(self, x: torch.Tensor):
b, c, h, w = x.shape
x = self.pre(x)
x = x.reshape(b, c, h * w)
x = x.transpose(1, 2)
x = torch.cat([torch.rand((b, 1, c)), x], dim=1)
x = self.feature(x)
x = x + self.norm1(x)
x = x + self.norm2(x)
x = x + self.norm3(x)
x = self.output(x[:, 0])
return torch.nn.functional.softmax(x, dim=1)


def test_torch_method():
model = DummyCNN()

with pytest.raises(ValueError):
method = TorchWhiteBoxMethod(model=model, target_layer="something_else")
model_xai = method.prepare_model()

method = TorchWhiteBoxMethod(model=model, target_layer="feature")
model_xai = method.prepare_model()
assert has_xai(model_xai)
Expand Down Expand Up @@ -101,7 +128,9 @@ def _output_hook(

def test_prepare_model():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer="feature")
method = TorchWhiteBoxMethod(model=model, target_layer="feature", prepare_model=False)
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
model_xai = method.prepare_model(load_model=False)
assert method._model_compiled is None
assert model is not model_xai
Expand All @@ -116,6 +145,35 @@ def test_prepare_model():
assert model_xai == model


def test_detect_feature_layer():
model = DummyCNN()
method = TorchWhiteBoxMethod(model=model, target_layer=None)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
assert method._feature_module is model_xai.feature
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward

model = DummyVIT()
with pytest.raises(RuntimeError):
# 4D feature map search should fail for ViTs
method = TorchWhiteBoxMethod(model=model, target_layer=None)

model = DummyVIT()
method = TorchViTReciproCAM(model=model, target_layer=None)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(1, 3, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
assert method._feature_module is model_xai.norm1
output = method.model_forward(data)
assert type(output) == dict # still good for 2nd forward


def test_activationmap() -> None:
batch_size = 2
num_classes = 3
Expand Down Expand Up @@ -156,13 +214,14 @@ def test_reciprocam(optimize_gap: bool) -> None:
def test_vitreciprocam(use_gaussian: bool, use_cls_token: bool) -> None:
batch_size = 2
num_classes = 3
model = DummyVIT(num_classes=num_classes)
dim = 3
model = DummyVIT(num_classes=num_classes, dim=dim)
method = TorchViTReciproCAM(
model=model, target_layer="feature", use_gaussian=use_gaussian, use_cls_token=use_cls_token
)
model_xai = method.prepare_model()
assert has_xai(model_xai)
data = np.random.rand(batch_size, 4, 5, 5)
data = np.random.rand(batch_size, dim, 5, 5)
output = method.model_forward(data)
assert type(output) == dict
saliency_maps = output[SALIENCY_MAP_OUTPUT_NAME]
Expand Down

0 comments on commit 04d5d52

Please sign in to comment.