Skip to content

Commit

Permalink
Add antialias option to masked vit (#1595)
Browse files Browse the repository at this point in the history
* Add antialias option for `MaskedVisionTransformer` classes
* Rename `initialize_weights` to `weight_initialization`
* Add MaskedVisionTransformer tests
  • Loading branch information
guarin authored Jul 24, 2024
1 parent 64ef30b commit e244646
Show file tree
Hide file tree
Showing 9 changed files with 546 additions and 114 deletions.
37 changes: 31 additions & 6 deletions lightly/models/modules/masked_vision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional, Tuple

from torch import Tensor
from torch.nn import Module


class MaskedVisionTransformer(ABC):
class MaskedVisionTransformer(ABC, Module):
"""
Abstract base class for Masked Vision Transformer models.
Expand All @@ -13,23 +14,47 @@ class MaskedVisionTransformer(ABC):
tokenization of images, and various operations needed for the transformer.
"""

@property
@abstractmethod
def sequence_length(self) -> int:
...

@abstractmethod
def forward(
self,
images: Tensor,
idx_mask: Optional[Tensor] = None,
idx_keep: Optional[Tensor] = None,
) -> Tensor:
pass
...

@abstractmethod
def forward_intermediates(
self,
images: Tensor,
idx_mask: Optional[Tensor] = None,
idx_keep: Optional[Tensor] = None,
norm: bool = False,
) -> Tuple[Tensor, List[Tensor]]:
...

@abstractmethod
def encode(
self,
images: Tensor,
idx_mask: Optional[Tensor] = None,
idx_keep: Optional[Tensor] = None,
) -> Tensor:
...

@abstractmethod
def images_to_tokens(self, images: Tensor) -> Tensor:
pass
...

@abstractmethod
def add_prefix_tokens(self, x: Tensor) -> Tensor:
pass
...

@abstractmethod
def add_pos_embed(self, x: Tensor) -> Tensor:
pass
...
25 changes: 21 additions & 4 deletions lightly/models/modules/masked_vision_transformer_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,29 @@
from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer


class MaskedVisionTransformerTIMM(MaskedVisionTransformer, Module):
class MaskedVisionTransformerTIMM(MaskedVisionTransformer):
"""Masked Vision Transformer class using TIMM.
Attributes:
vit:
The VisionTransformer object of TIMM.
mask_token:
The mask token.
weight_initialization:
The weight initialization method. Valid options are ['', 'skip']. '' uses
the default MAE weight initialization and 'skip' skips the weight
initialization.
antialias:
Whether to use antialiasing when resampling the positional embeddings.
"""

def __init__(
self,
vit: VisionTransformer,
initialize_weights: bool = True,
mask_token: Optional[Parameter] = None,
weight_initialization: str = "",
antialias: bool = True,
) -> None:
super().__init__()
self.vit = vit
Expand All @@ -36,9 +43,17 @@ def __init__(
if mask_token is not None
else Parameter(torch.zeros(1, 1, self.vit.embed_dim))
)
if initialize_weights:

if weight_initialization not in ("", "skip"):
raise ValueError(
f"Invalid weight initialization method: '{weight_initialization}'. "
"Valid options are: ['', 'skip']."
)
if weight_initialization != "skip":
self._initialize_weights()

self.antialias = antialias

@property
def sequence_length(self) -> int:
seq_len: int = self.vit.patch_embed.num_patches + self.vit.num_prefix_tokens
Expand Down Expand Up @@ -268,6 +283,7 @@ def add_pos_embed(self, x: Tensor) -> Tensor:
num_prefix_tokens=(
0 if self.vit.no_embed_class else self.vit.num_prefix_tokens
),
antialias=self.antialias,
)
x = x.view(B, -1, C)
else:
Expand All @@ -291,7 +307,8 @@ def _initialize_weights(self) -> None:
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

# Initialize the class token.
torch.nn.init.normal_(self.vit.cls_token, std=0.02)
if self.vit.has_class_token:
torch.nn.init.normal_(self.vit.cls_token, std=0.02)

# initialize nn.Linear and nn.LayerNorm
self.apply(_init_weights)
Expand Down
33 changes: 29 additions & 4 deletions lightly/models/modules/masked_vision_transformer_torchvision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -11,22 +11,28 @@
from lightly.models.modules.masked_vision_transformer import MaskedVisionTransformer


class MaskedVisionTransformerTorchvision(MaskedVisionTransformer, Module):
class MaskedVisionTransformerTorchvision(MaskedVisionTransformer):
"""Masked Vision Transformer class using Torchvision.
Attributes:
vit:
The VisionTransformer object of Torchvision.
mask_token:
The mask token.
weight_initialization:
The weight initialization method. Valid options are ['', 'skip']. '' uses
the default MAE weight initialization and 'skip' skips the weight
antialias:
Whether to use antialiasing when resampling the positional embeddings.
"""

def __init__(
self,
vit: VisionTransformer,
initialize_weights: bool = True,
mask_token: Optional[Parameter] = None,
weight_initialization: str = "",
antialias: bool = True,
) -> None:
super().__init__()
self.vit = vit
Expand All @@ -35,9 +41,16 @@ def __init__(
if mask_token is not None
else Parameter(torch.zeros(1, 1, self.vit.hidden_dim))
)
if initialize_weights:
if weight_initialization not in ("", "skip"):
raise ValueError(
f"Invalid weight initialization method: '{weight_initialization}'. "
"Valid options are: ['', 'skip']."
)
if weight_initialization != "skip":
self._initialize_weights()

self.antialias = antialias

@property
def sequence_length(self) -> int:
seq_len: int = self.vit.seq_length
Expand Down Expand Up @@ -70,6 +83,7 @@ def interpolate_pos_encoding(self, input: Tensor) -> Tensor:
),
scale_factor=math.sqrt(npatch / N),
mode="bicubic",
antialias=self.antialias,
)
pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1)
Expand Down Expand Up @@ -104,6 +118,17 @@ def forward(
class_token = out[:, 0]
return class_token

def forward_intermediates(
self,
images: Tensor,
idx_mask: Optional[Tensor] = None,
idx_keep: Optional[Tensor] = None,
norm: bool = False,
) -> Tuple[Tensor, List[Tensor]]:
raise NotImplementedError(
"forward_intermediates is not implemented for this model."
)

def encode(
self,
images: Tensor,
Expand Down
Empty file added tests/models/__init__.py
Empty file.
Empty file.
Loading

0 comments on commit e244646

Please sign in to comment.