From e2446467b16aa39d40fe475bf93d25d6f611bbfb Mon Sep 17 00:00:00 2001 From: guarin <43336610+guarin@users.noreply.github.com> Date: Wed, 24 Jul 2024 12:01:34 +0200 Subject: [PATCH] Add antialias option to masked vit (#1595) * Add antialias option for `MaskedVisionTransformer` classes * Rename `initialize_weights` to `weight_initialization` * Add MaskedVisionTransformer tests --- .../modules/masked_vision_transformer.py | 37 ++- .../modules/masked_vision_transformer_timm.py | 25 +- .../masked_vision_transformer_torchvision.py | 33 ++- tests/models/__init__.py | 0 tests/models/modules/__init__.py | 0 .../modules/masked_vision_transformer_test.py | 246 ++++++++++++++++++ .../modules/test_masked_autoencoder_timm.py | 101 +------ .../test_masked_vision_transformer_timm.py | 95 +++++++ ...t_masked_vision_transformer_torchvision.py | 123 +++++++++ 9 files changed, 546 insertions(+), 114 deletions(-) create mode 100644 tests/models/__init__.py create mode 100644 tests/models/modules/__init__.py create mode 100644 tests/models/modules/masked_vision_transformer_test.py create mode 100644 tests/models/modules/test_masked_vision_transformer_timm.py create mode 100644 tests/models/modules/test_masked_vision_transformer_torchvision.py diff --git a/lightly/models/modules/masked_vision_transformer.py b/lightly/models/modules/masked_vision_transformer.py index 566af7b9e..9df55a58c 100644 --- a/lightly/models/modules/masked_vision_transformer.py +++ b/lightly/models/modules/masked_vision_transformer.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import List, Optional, Tuple from torch import Tensor +from torch.nn import Module -class MaskedVisionTransformer(ABC): +class MaskedVisionTransformer(ABC, Module): """ Abstract base class for Masked Vision Transformer models. @@ -13,6 +14,11 @@ class MaskedVisionTransformer(ABC): tokenization of images, and various operations needed for the transformer. """ + @property + @abstractmethod + def sequence_length(self) -> int: + ... + @abstractmethod def forward( self, @@ -20,16 +26,35 @@ def forward( idx_mask: Optional[Tensor] = None, idx_keep: Optional[Tensor] = None, ) -> Tensor: - pass + ... + + @abstractmethod + def forward_intermediates( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + norm: bool = False, + ) -> Tuple[Tensor, List[Tensor]]: + ... + + @abstractmethod + def encode( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + ) -> Tensor: + ... @abstractmethod def images_to_tokens(self, images: Tensor) -> Tensor: - pass + ... @abstractmethod def add_prefix_tokens(self, x: Tensor) -> Tensor: - pass + ... @abstractmethod def add_pos_embed(self, x: Tensor) -> Tensor: - pass + ... diff --git a/lightly/models/modules/masked_vision_transformer_timm.py b/lightly/models/modules/masked_vision_transformer_timm.py index ff4c3fd95..4079023d4 100644 --- a/lightly/models/modules/masked_vision_transformer_timm.py +++ b/lightly/models/modules/masked_vision_transformer_timm.py @@ -12,7 +12,7 @@ from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer -class MaskedVisionTransformerTIMM(MaskedVisionTransformer, Module): +class MaskedVisionTransformerTIMM(MaskedVisionTransformer): """Masked Vision Transformer class using TIMM. Attributes: @@ -20,14 +20,21 @@ class MaskedVisionTransformerTIMM(MaskedVisionTransformer, Module): The VisionTransformer object of TIMM. mask_token: The mask token. + weight_initialization: + The weight initialization method. Valid options are ['', 'skip']. '' uses + the default MAE weight initialization and 'skip' skips the weight + initialization. + antialias: + Whether to use antialiasing when resampling the positional embeddings. """ def __init__( self, vit: VisionTransformer, - initialize_weights: bool = True, mask_token: Optional[Parameter] = None, + weight_initialization: str = "", + antialias: bool = True, ) -> None: super().__init__() self.vit = vit @@ -36,9 +43,17 @@ def __init__( if mask_token is not None else Parameter(torch.zeros(1, 1, self.vit.embed_dim)) ) - if initialize_weights: + + if weight_initialization not in ("", "skip"): + raise ValueError( + f"Invalid weight initialization method: '{weight_initialization}'. " + "Valid options are: ['', 'skip']." + ) + if weight_initialization != "skip": self._initialize_weights() + self.antialias = antialias + @property def sequence_length(self) -> int: seq_len: int = self.vit.patch_embed.num_patches + self.vit.num_prefix_tokens @@ -268,6 +283,7 @@ def add_pos_embed(self, x: Tensor) -> Tensor: num_prefix_tokens=( 0 if self.vit.no_embed_class else self.vit.num_prefix_tokens ), + antialias=self.antialias, ) x = x.view(B, -1, C) else: @@ -291,7 +307,8 @@ def _initialize_weights(self) -> None: torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # Initialize the class token. - torch.nn.init.normal_(self.vit.cls_token, std=0.02) + if self.vit.has_class_token: + torch.nn.init.normal_(self.vit.cls_token, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(_init_weights) diff --git a/lightly/models/modules/masked_vision_transformer_torchvision.py b/lightly/models/modules/masked_vision_transformer_torchvision.py index 81969774a..7a24c196c 100644 --- a/lightly/models/modules/masked_vision_transformer_torchvision.py +++ b/lightly/models/modules/masked_vision_transformer_torchvision.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -11,7 +11,7 @@ from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer -class MaskedVisionTransformerTorchvision(MaskedVisionTransformer, Module): +class MaskedVisionTransformerTorchvision(MaskedVisionTransformer): """Masked Vision Transformer class using Torchvision. Attributes: @@ -19,14 +19,20 @@ class MaskedVisionTransformerTorchvision(MaskedVisionTransformer, Module): The VisionTransformer object of Torchvision. mask_token: The mask token. + weight_initialization: + The weight initialization method. Valid options are ['', 'skip']. '' uses + the default MAE weight initialization and 'skip' skips the weight + antialias: + Whether to use antialiasing when resampling the positional embeddings. """ def __init__( self, vit: VisionTransformer, - initialize_weights: bool = True, mask_token: Optional[Parameter] = None, + weight_initialization: str = "", + antialias: bool = True, ) -> None: super().__init__() self.vit = vit @@ -35,9 +41,16 @@ def __init__( if mask_token is not None else Parameter(torch.zeros(1, 1, self.vit.hidden_dim)) ) - if initialize_weights: + if weight_initialization not in ("", "skip"): + raise ValueError( + f"Invalid weight initialization method: '{weight_initialization}'. " + "Valid options are: ['', 'skip']." + ) + if weight_initialization != "skip": self._initialize_weights() + self.antialias = antialias + @property def sequence_length(self) -> int: seq_len: int = self.vit.seq_length @@ -70,6 +83,7 @@ def interpolate_pos_encoding(self, input: Tensor) -> Tensor: ), scale_factor=math.sqrt(npatch / N), mode="bicubic", + antialias=self.antialias, ) pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) @@ -104,6 +118,17 @@ def forward( class_token = out[:, 0] return class_token + def forward_intermediates( + self, + images: Tensor, + idx_mask: Optional[Tensor] = None, + idx_keep: Optional[Tensor] = None, + norm: bool = False, + ) -> Tuple[Tensor, List[Tensor]]: + raise NotImplementedError( + "forward_intermediates is not implemented for this model." + ) + def encode( self, images: Tensor, diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/modules/__init__.py b/tests/models/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/modules/masked_vision_transformer_test.py b/tests/models/modules/masked_vision_transformer_test.py new file mode 100644 index 000000000..ad87352ef --- /dev/null +++ b/tests/models/modules/masked_vision_transformer_test.py @@ -0,0 +1,246 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import pytest +import torch +from pytest_mock import MockerFixture +from torch.nn import Parameter + +from lightly.models import utils +from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer + + +class MaskedVisionTransformerTest(ABC): + """Common tests for all classes implementing MaskedVisionTransformer.""" + + @abstractmethod + def get_masked_vit( + self, + patch_size: int, + depth: int, + num_heads: int, + embed_dim: int = 768, + class_token: bool = True, + reg_tokens: int = 0, + mask_token: Optional[Parameter] = None, + antialias: bool = True, + weight_initialization: str = "", + ) -> MaskedVisionTransformer: + ... + + @abstractmethod + def test__init__mask_token(self, mask_token: Optional[Parameter]) -> None: + ... + + @abstractmethod + def test__init__weight_initialization(self, mocker: MockerFixture) -> None: + ... + + @abstractmethod + def test__init__weight_initialization__skip(self, mocker: MockerFixture) -> None: + ... + + def test__init__weight_initialization__invalid(self) -> None: + with pytest.raises(ValueError): + self.get_masked_vit( + patch_size=32, depth=1, num_heads=1, weight_initialization="invalid" + ) + + def test_sequence_length(self) -> None: + # TODO(Guarin, 07/2024): Support reg_tokens > 0 and test the sequence length + # with reg_tokens > 0. + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + class_token=True, + reg_tokens=0, + ) + + # 49 for img_height//patch_size * img_width//patch_size + # 1 for the class token + # 0 for the reg tokens + expected_sequence_length = 49 + 1 + 0 + assert model.sequence_length == expected_sequence_length + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("mask_ratio", [None, 0.6]) + def test_forward(self, device: str, mask_ratio: Optional[float]) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + batch_size = 8 + embed_dim = 768 + model = self.get_masked_vit( + patch_size=32, depth=2, num_heads=2, embed_dim=embed_dim + ).to(device) + images = torch.rand(batch_size, 3, 224, 224).to(device) + + idx_keep = None + if mask_ratio is not None: + idx_keep, _ = utils.random_token_mask( + size=(batch_size, model.sequence_length), + device=device, + mask_ratio=mask_ratio, + ) + + class_tokens = model(images=images, idx_keep=idx_keep) + + # output shape must be correct + assert class_tokens.shape == (batch_size, embed_dim) + # output must have reasonable numbers + assert torch.all(torch.isfinite(class_tokens)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + "mask_ratio, expected_sequence_length", [(None, 50), (0.6, 20)] + ) + def test_forward_intermediates( + self, device: str, mask_ratio: Optional[float], expected_sequence_length: int + ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + batch_size = 8 + embed_dim = 768 + model = self.get_masked_vit( + patch_size=32, depth=2, num_heads=2, embed_dim=embed_dim + ).to(device) + images = torch.rand(batch_size, 3, 224, 224).to(device) + + idx_keep = None + if mask_ratio is not None: + idx_keep, _ = utils.random_token_mask( + size=(batch_size, model.sequence_length), + device=device, + mask_ratio=mask_ratio, + ) + + output, intermediates = model.forward_intermediates( + images=images, idx_keep=idx_keep + ) + expected_output = model.encode(images=images, idx_keep=idx_keep) + + # output shape must be correct + assert output.shape == (batch_size, expected_sequence_length, embed_dim) + # output should be same as from encode + assert torch.allclose(output, expected_output) + # intermediates must have reasonable numbers + for intermediate in intermediates: + assert intermediate.shape == ( + batch_size, + expected_sequence_length, + embed_dim, + ) + assert torch.all(torch.isfinite(intermediate)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + "mask_ratio,expected_sequence_length", [(None, 50), (0.6, 20)] + ) + def test_encode( + self, + device: str, + mask_ratio: Optional[float], + expected_sequence_length: int, + ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + batch_size = 8 + embed_dim = 768 + model = self.get_masked_vit( + patch_size=32, depth=2, num_heads=2, embed_dim=embed_dim + ).to(device) + images = torch.rand(batch_size, 3, 224, 224).to(device) + + idx_keep = None + if mask_ratio is not None: + idx_keep, _ = utils.random_token_mask( + size=(batch_size, model.sequence_length), + device=device, + mask_ratio=mask_ratio, + ) + + tokens = model.encode(images=images, idx_keep=idx_keep) + assert tokens.shape == (batch_size, expected_sequence_length, embed_dim) + assert torch.all(torch.isfinite(tokens)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_images_to_tokens(self, device: str) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + model = self.get_masked_vit( + patch_size=32, depth=2, num_heads=2, embed_dim=768 + ).to(device) + images = torch.rand(2, 3, 224, 224).to(device) + tokens = model.images_to_tokens(images) + assert tokens.shape == (2, 49, 768) + assert torch.all(torch.isfinite(tokens)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + "class_token,reg_tokens,expected_sequence_length", + [ + (False, 0, 49), + (True, 0, 50), + # (False, 2, 51), TODO(Guarin, 07/2024): Support reg_tokens > 0 + # (True, 2, 52), TODO(Guarin, 07/2024): Support reg_tokens > 0 + ], + ) + def test_add_prefix_tokens( + self, + device: str, + class_token: bool, + reg_tokens: int, + expected_sequence_length: int, + ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + class_token=class_token, + reg_tokens=reg_tokens, + ).to(device) + x = torch.rand(2, 49, 768).to(device) + assert model.add_prefix_tokens(x).shape == (2, expected_sequence_length, 768) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + "class_token,reg_tokens,expected_sequence_length", + [ + (False, 0, 49), + (True, 0, 50), + # (False, 2, 51), TODO(Guarin, 07/2024): Support reg_tokens > 0 + # (True, 2, 52), TODO(Guarin, 07/2024): Support reg_tokens > 0 + ], + ) + def test_add_pos_embed( + self, + device: str, + class_token: bool, + reg_tokens: int, + expected_sequence_length: int, + ) -> None: + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + + torch.manual_seed(0) + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + class_token=class_token, + reg_tokens=reg_tokens, + ).to(device) + x = torch.rand(2, model.sequence_length, 768).to(device) + assert model.add_pos_embed(x).shape == (2, expected_sequence_length, 768) diff --git a/tests/models/modules/test_masked_autoencoder_timm.py b/tests/models/modules/test_masked_autoencoder_timm.py index 495a83b4f..8d9180963 100644 --- a/tests/models/modules/test_masked_autoencoder_timm.py +++ b/tests/models/modules/test_masked_autoencoder_timm.py @@ -11,106 +11,7 @@ pytest.skip("TIMM vision transformer is not available", allow_module_level=True) -from timm.models.vision_transformer import VisionTransformer, vit_base_patch32_224 - -from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM - - -class TestMaskedVisionTransformerTIMM(unittest.TestCase): - def _vit(self) -> VisionTransformer: - return vit_base_patch32_224() - - def test_from_vit(self) -> None: - MaskedVisionTransformerTIMM(vit=self._vit()) - - def _test_forward( - self, device: torch.device, batch_size: int = 8, seed: int = 0 - ) -> None: - torch.manual_seed(seed) - vit = self._vit() - backbone = MaskedVisionTransformerTIMM(vit=vit).to(device) - images = torch.rand( - batch_size, 3, vit.patch_embed.img_size[0], vit.patch_embed.img_size[0] - ).to(device) - _idx_keep, _ = utils.random_token_mask( - size=(batch_size, backbone.sequence_length), - device=device, - ) - for idx_keep in [None, _idx_keep]: - with self.subTest(idx_keep=idx_keep): - class_tokens = backbone(images=images, idx_keep=idx_keep) - - # output shape must be correct - expected_shape = [batch_size, vit.embed_dim] - self.assertListEqual(list(class_tokens.shape), expected_shape) - - # output must have reasonable numbers - self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf))) - - def test_forward(self) -> None: - self._test_forward(torch.device("cpu")) - - @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") - def test_forward_cuda(self) -> None: - self._test_forward(torch.device("cuda")) - - def _test_forward_intermediates( - self, device: torch.device, batch_size: int = 8, seed: int = 0 - ) -> None: - torch.manual_seed(seed) - vit = self._vit() - backbone = MaskedVisionTransformerTIMM(vit=vit).to(device) - expected_sequence_length = backbone.sequence_length - images = torch.rand( - batch_size, 3, vit.patch_embed.img_size[0], vit.patch_embed.img_size[0] - ).to(device) - _idx_keep, _ = utils.random_token_mask( - size=(batch_size, backbone.sequence_length), - device=device, - ) - for idx_keep, expected_sequence_length in [(None, 50), (_idx_keep, 20)]: - with self.subTest(idx_keep=idx_keep): - class_tokens = backbone(images=images, idx_keep=idx_keep) - - # output shape must be correct - expected_shape = [batch_size, vit.embed_dim] - self.assertListEqual(list(class_tokens.shape), expected_shape) - - output, intermediates = backbone.forward_intermediates( - images=images, idx_keep=idx_keep - ) - - self.assertTrue(len(intermediates) == len(vit.blocks)) - - # output shape must be correct - expected_intermediates_shape = [ - batch_size, - expected_sequence_length, - vit.embed_dim, - ] - for intermediate in intermediates: - self.assertListEqual( - list(intermediate.shape), expected_intermediates_shape - ) - - # output must have reasonable numbers - self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf))) - - def test_forward_intermediates(self) -> None: - self._test_forward_intermediates(torch.device("cpu")) - - @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") - def test_forward_intermediates_cuda(self) -> None: - self._test_forward_intermediates(torch.device("cuda")) - - def test_images_to_tokens(self) -> None: - torch.manual_seed(0) - vit = self._vit() - backbone = MaskedVisionTransformerTIMM(vit=vit) - images = torch.rand(2, 3, 224, 224) - assert torch.all( - vit.patch_embed(images) == backbone.images_to_tokens(images=images) - ) +from lightly.models.modules import MAEDecoderTIMM class TestMAEDecoderTIMM(unittest.TestCase): diff --git a/tests/models/modules/test_masked_vision_transformer_timm.py b/tests/models/modules/test_masked_vision_transformer_timm.py new file mode 100644 index 000000000..93f0b9408 --- /dev/null +++ b/tests/models/modules/test_masked_vision_transformer_timm.py @@ -0,0 +1,95 @@ +from typing import Optional + +import pytest +import torch +from pytest_mock import MockerFixture +from torch.nn import Parameter + +from lightly.utils import dependency + +if not dependency.timm_vit_available(): + # We do not use pytest.importorskip on module level because it makes mypy unhappy. + pytest.skip("TIMM vision transformer is not available", allow_module_level=True) + + +from timm.models.vision_transformer import VisionTransformer + +from lightly.models.modules import masked_vision_transformer_timm +from lightly.models.modules.masked_vision_transformer_timm import ( + MaskedVisionTransformerTIMM, +) + +from .masked_vision_transformer_test import MaskedVisionTransformerTest + + +class TestMaskedVisionTransformerTIMM(MaskedVisionTransformerTest): + def get_masked_vit( + self, + patch_size: int, + depth: int, + num_heads: int, + embed_dim: int = 768, + class_token: bool = True, + reg_tokens: int = 0, + mask_token: Optional[Parameter] = None, + antialias: bool = True, + weight_initialization: str = "", + ) -> MaskedVisionTransformerTIMM: + vit = VisionTransformer( + patch_size=patch_size, + depth=depth, + num_heads=num_heads, + embed_dim=embed_dim, + class_token=class_token, + reg_tokens=reg_tokens, + global_pool="token" if class_token else "avg", + dynamic_img_size=True, + ) + return MaskedVisionTransformerTIMM( + vit=vit, + mask_token=mask_token, + weight_initialization=weight_initialization, + antialias=antialias, + ) + + @pytest.mark.parametrize("mask_token", [None, Parameter(torch.rand(1, 1, 768))]) + def test__init__mask_token(self, mask_token: Optional[Parameter]) -> None: + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + mask_token=mask_token, + ) + assert isinstance(model.mask_token, Parameter) + + def test__init__weight_initialization(self, mocker: MockerFixture) -> None: + mock_init_weights = mocker.spy(masked_vision_transformer_timm, "_init_weights") + self.get_masked_vit( + patch_size=32, depth=1, num_heads=1, weight_initialization="" + ) + mock_init_weights.assert_called() + + def test__init__weight_initialization__skip(self, mocker: MockerFixture) -> None: + mock_init_weights = mocker.spy(masked_vision_transformer_timm, "_init_weights") + self.get_masked_vit( + patch_size=32, depth=1, num_heads=1, weight_initialization="skip" + ) + mock_init_weights.assert_not_called() + + @pytest.mark.parametrize("antialias", [False, True]) + def test_add_pos_embed__antialias( + self, antialias: bool, mocker: MockerFixture + ) -> None: + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + antialias=antialias, + ) + mock_resample_abs_pos_embed = mocker.spy( + masked_vision_transformer_timm, "resample_abs_pos_embed" + ) + x = torch.rand(2, model.sequence_length, 768) + model.add_pos_embed(x) + _, call_kwargs = mock_resample_abs_pos_embed.call_args + assert call_kwargs["antialias"] == antialias diff --git a/tests/models/modules/test_masked_vision_transformer_torchvision.py b/tests/models/modules/test_masked_vision_transformer_torchvision.py new file mode 100644 index 000000000..5b06b8cb7 --- /dev/null +++ b/tests/models/modules/test_masked_vision_transformer_torchvision.py @@ -0,0 +1,123 @@ +from typing import Optional + +import pytest +import torch +from pytest_mock import MockerFixture +from torch.nn import Parameter + +from lightly.utils import dependency + +if not dependency.torchvision_vit_available(): + # We do not use pytest.importorskip on module level because it makes mypy unhappy. + pytest.skip( + "Torchvision vision transformer is not available", allow_module_level=True + ) + + +from torchvision.models import VisionTransformer + +from lightly.models.modules import masked_vision_transformer_torchvision +from lightly.models.modules.masked_vision_transformer_torchvision import ( + MaskedVisionTransformerTorchvision, +) + +from .masked_vision_transformer_test import MaskedVisionTransformerTest + + +class TestMaskedVisionTransformerTorchvision(MaskedVisionTransformerTest): + def get_masked_vit( + self, + patch_size: int, + depth: int, + num_heads: int, + embed_dim: int = 768, + class_token: bool = True, + reg_tokens: int = 0, + mask_token: Optional[Parameter] = None, + antialias: bool = True, + weight_initialization: str = "", + ) -> MaskedVisionTransformerTorchvision: + assert class_token, "Torchvision ViT has always a class token" + assert reg_tokens == 0, "Torchvision ViT does not support reg tokens" + vit = VisionTransformer( + image_size=224, + patch_size=patch_size, + num_layers=depth, + num_heads=num_heads, + hidden_dim=embed_dim, + mlp_dim=embed_dim * 4, + ) + return MaskedVisionTransformerTorchvision( + vit=vit, + mask_token=mask_token, + weight_initialization=weight_initialization, + antialias=antialias, + ) + + @pytest.mark.parametrize("mask_token", [None, Parameter(torch.rand(1, 1, 768))]) + def test__init__mask_token(self, mask_token: Optional[Parameter]) -> None: + model = self.get_masked_vit( + patch_size=32, + depth=1, + num_heads=1, + mask_token=mask_token, + ) + assert isinstance(model.mask_token, Parameter) + + def test__init__weight_initialization(self, mocker: MockerFixture) -> None: + mock_initialize_linear_layers = mocker.spy( + masked_vision_transformer_torchvision, "_initialize_linear_layers" + ) + self.get_masked_vit( + patch_size=32, depth=1, num_heads=1, weight_initialization="" + ) + mock_initialize_linear_layers.assert_called() + + def test__init__weight_initialization__skip(self, mocker: MockerFixture) -> None: + mock_initialize_linear_layers = mocker.spy( + masked_vision_transformer_torchvision, "_initialize_linear_layers" + ) + self.get_masked_vit( + patch_size=32, depth=1, num_heads=1, weight_initialization="skip" + ) + mock_initialize_linear_layers.assert_not_called() + + @pytest.mark.skip(reason="Torchvision ViT does not support forward intermediates") + def test_forward_intermediates( + self, device: str, mask_ratio: Optional[float], expected_sequence_length: int + ) -> None: + ... + + @pytest.mark.skip(reason="Torchvision ViT does not support reg tokens") + def test_add_prefix_tokens( + self, + device: str, + class_token: bool, + reg_tokens: int, + expected_sequence_length: int, + ) -> None: + ... + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize( + "class_token,reg_tokens,expected_sequence_length", + [ + # (False, 0, 49), # Torchvision ViT has always a class token + (True, 0, 50), + # (False, 2, 51), TODO(Guarin, 07/2024): Support reg_tokens > 0 + # (True, 2, 52), TODO(Guarin, 07/2024): Support reg_tokens > 0 + ], + ) + def test_add_pos_embed( + self, + device: str, + class_token: bool, + reg_tokens: int, + expected_sequence_length: int, + ) -> None: + super().test_add_pos_embed( + device=device, + class_token=class_token, + reg_tokens=reg_tokens, + expected_sequence_length=expected_sequence_length, + )