From e7f13a53e68031ee0ae77124e76401b48381001f Mon Sep 17 00:00:00 2001 From: guarin <43336610+guarin@users.noreply.github.com> Date: Wed, 6 Mar 2024 11:59:19 +0100 Subject: [PATCH] Add AIM examples and docs (#1509) * Add AIM examples * Add AIM docs * Add AIM to README --- README.md | 1 + benchmarks/imagenet/vitb16/aim.py | 27 ++--- docs/source/examples/aim.rst | 53 ++++++++++ docs/source/examples/models.rst | 1 + examples/pytorch/aim.py | 98 +++++++++++++++++ examples/pytorch_lightning/aim.py | 92 ++++++++++++++++ examples/pytorch_lightning_distributed/aim.py | 100 ++++++++++++++++++ .../models/modules/masked_autoencoder_timm.py | 4 +- .../modules/masked_vision_transformer_timm.py | 10 +- .../masked_vision_transformer_torchvision.py | 3 +- lightly/models/utils.py | 29 ++++- tests/models/test_ModelUtils.py | 15 +++ 12 files changed, 406 insertions(+), 27 deletions(-) create mode 100644 docs/source/examples/aim.rst create mode 100644 examples/pytorch/aim.py create mode 100644 examples/pytorch_lightning/aim.py create mode 100644 examples/pytorch_lightning_distributed/aim.py diff --git a/README.md b/README.md index 9b1c4c213..8e1368db4 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ and PyTorch Lightning distributed examples for all models to kickstart your proj **Models**: +- AIM, 2024 [paper](https://arxiv.org/abs/2401.08541) [docs](https://docs.lightly.ai/self-supervised-learning/examples/aim.html) - Barlow Twins, 2021 [paper](https://arxiv.org/abs/2103.03230) [docs](https://docs.lightly.ai/self-supervised-learning/examples/barlowtwins.html) - BYOL, 2020 [paper](https://arxiv.org/abs/2006.07733) [docs](https://docs.lightly.ai/self-supervised-learning/examples/byol.html) - DCL & DCLW, 2021 [paper](https://arxiv.org/abs/2110.06848) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dcl.html) diff --git a/benchmarks/imagenet/vitb16/aim.py b/benchmarks/imagenet/vitb16/aim.py index 32d81f41a..0fc2dc17b 100644 --- a/benchmarks/imagenet/vitb16/aim.py +++ b/benchmarks/imagenet/vitb16/aim.py @@ -21,13 +21,9 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: self.save_hyperparameters() self.batch_size_per_device = batch_size_per_device - img_size = 224 - self.patch_size = 14 - self.num_patches = (img_size // self.patch_size) ** 2 - vit = MaskedCausalVisionTransformer( - img_size=img_size, - patch_size=self.patch_size, + img_size=224, + patch_size=14, num_classes=num_classes, embed_dim=1536, depth=24, @@ -36,14 +32,11 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None: class_token=False, no_embed_class=True, ) - # Use absolute positional embedding. - pos_embed = get_2d_sincos_pos_embed( - embed_dim=vit.embed_dim, - grid_size=int(self.num_patches**0.5), - cls_token=False, + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token ) - vit.pos_embed.requires_grad = False - vit.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + self.patch_size = vit.patch_embed.patch_size[0] + self.num_patches = vit.patch_embed.num_patches self.backbone = vit self.projection_head = AIMPredictionHead( @@ -67,8 +60,8 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: def training_step( self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int ) -> Tensor: - images, targets = batch[0], batch[1] - images = images[0] # images is a list containing only one view + views, targets = batch[0], batch[1] + images = views[0] # AIM has only a single view batch_size = images.shape[0] mask = random_prefix_mask( @@ -83,9 +76,7 @@ def training_step( # Convert images to patches and normalize them. patches = utils.patchify(images, self.patch_size) - mean = patches.mean(dim=-1, keepdim=True) - var = patches.var(dim=-1, keepdim=True) - patches = (patches - mean) / (var + 1.0e-6) ** 0.5 + patches = utils.normalize_mean_var(patches, dim=-1) loss = self.criterion(predictions, patches) diff --git a/docs/source/examples/aim.rst b/docs/source/examples/aim.rst new file mode 100644 index 000000000..a24d4e1a4 --- /dev/null +++ b/docs/source/examples/aim.rst @@ -0,0 +1,53 @@ +.. _aim: + +AIM +=== + +Example implementation of the Autoregressive Image Model (AIM) architecture. AIM is a +transformer model based on the `Vision Transformer (ViT) `_ +architecture. It learns image representations by predicting pixel values for image +patches based on previous patches in the image. This is similar to the next word prediction +task in natural language processing. AIM demonstrates that it is possible to train +large-scale vision models using an autoregressive objective. The model is split into +and encoder and a decoder part. The encoder generates features for image patches and +the decoder predicts pixel values based on the features. + +Reference: + `Scalable Pre-training of Large Autoregressive Image Models, 2024 `_ + + +.. tabs:: + .. tab:: PyTorch + + This example can be run from the command line with:: + + python lightly/examples/pytorch/aim.py + + .. literalinclude:: ../../../examples/pytorch/aim.py + + .. tab:: Lightning + + This example can be run from the command line with:: + + python lightly/examples/pytorch_lightning/aim.py + + .. literalinclude:: ../../../examples/pytorch_lightning/aim.py + + .. tab:: Lightning Distributed + + This example runs on multiple gpus using Distributed Data Parallel (DDP) + training with Pytorch Lightning. At least one GPU must be available on + the system. The example can be run from the command line with:: + + python lightly/examples/pytorch_lightning_distributed/aim.py + + The model differs in the following ways from the non-distributed + implementation: + + - Distributed Data Parallel is enabled + - Distributed Sampling is used in the dataloader + + Distributed Sampling makes sure that each distributed process sees only + a subset of the data. + + .. literalinclude:: ../../../examples/pytorch_lightning_distributed/aim.py diff --git a/docs/source/examples/models.rst b/docs/source/examples/models.rst index d0eb62b99..177f99c6e 100644 --- a/docs/source/examples/models.rst +++ b/docs/source/examples/models.rst @@ -10,6 +10,7 @@ for PyTorch and PyTorch Lightning to give you a headstart when implementing your .. toctree:: :maxdepth: 1 + aim.rst barlowtwins.rst byol.rst dcl.rst diff --git a/examples/pytorch/aim.py b/examples/pytorch/aim.py new file mode 100644 index 000000000..43d3108d0 --- /dev/null +++ b/examples/pytorch/aim.py @@ -0,0 +1,98 @@ +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import torch +import torchvision +from torch import nn + +from lightly.models import utils +from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer +from lightly.transforms import AIMTransform + + +class AIM(nn.Module): + def __init__(self, vit): + super().__init__() + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token + ) + self.patch_size = vit.patch_embed.patch_size[0] + self.num_patches = vit.patch_embed.num_patches + + self.backbone = vit + self.projection_head = AIMPredictionHead( + input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1 + ) + + def forward(self, images): + batch_size = images.shape[0] + + mask = utils.random_prefix_mask( + size=(batch_size, self.num_patches), + max_prefix_length=self.num_patches - 1, + device=images.device, + ) + features = self.backbone.forward_features(images, mask=mask) + # Add positional embedding before head. + features = self.backbone._pos_embed(features) + predictions = self.projection_head(features) + + # Convert images to patches and normalize them. + patches = utils.patchify(images, self.patch_size) + patches = utils.normalize_mean_var(patches, dim=-1) + + return predictions, patches + + +vit = MaskedCausalVisionTransformer( + img_size=224, + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + qk_norm=False, + class_token=False, + no_embed_class=True, +) +model = AIM(vit) + +device = "cuda" if torch.cuda.is_available() else "cpu" +model.to(device) + +transform = AIMTransform() +# we ignore object detection annotations by setting target_transform to return 0 +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=256, + shuffle=True, + drop_last=True, + num_workers=8, +) + +criterion = nn.MSELoss() +optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) + +print("Starting Training") +for epoch in range(10): + total_loss = 0 + for batch in dataloader: + views = batch[0] + images = views[0].to(device) # views contains only a single view + predictions, targets = model(images) + loss = criterion(predictions, targets) + total_loss += loss.detach() + loss.backward() + optimizer.step() + optimizer.zero_grad() + avg_loss = total_loss / len(dataloader) + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/examples/pytorch_lightning/aim.py b/examples/pytorch_lightning/aim.py new file mode 100644 index 000000000..c871357eb --- /dev/null +++ b/examples/pytorch_lightning/aim.py @@ -0,0 +1,92 @@ +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import pytorch_lightning as pl +import torch +import torchvision +from torch import nn + +from lightly.models import utils +from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer +from lightly.transforms import AIMTransform + + +class AIM(pl.LightningModule): + def __init__(self) -> None: + super().__init__() + vit = MaskedCausalVisionTransformer( + img_size=224, + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + qk_norm=False, + class_token=False, + no_embed_class=True, + ) + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token + ) + self.patch_size = vit.patch_embed.patch_size[0] + self.num_patches = vit.patch_embed.num_patches + + self.backbone = vit + self.projection_head = AIMPredictionHead( + input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1 + ) + + self.criterion = nn.MSELoss() + + def training_step(self, batch, batch_idx): + views, targets = batch[0], batch[1] + images = views[0] # AIM has only a single view + batch_size = images.shape[0] + + mask = utils.random_prefix_mask( + size=(batch_size, self.num_patches), + max_prefix_length=self.num_patches - 1, + device=images.device, + ) + features = self.backbone.forward_features(images, mask=mask) + # Add positional embedding before head. + features = self.backbone._pos_embed(features) + predictions = self.projection_head(features) + + # Convert images to patches and normalize them. + patches = utils.patchify(images, self.patch_size) + patches = utils.normalize_mean_var(patches, dim=-1) + + loss = self.criterion(predictions, patches) + return loss + + def configure_optimizers(self): + optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4) + return optim + + +model = AIM() + +transform = AIMTransform() +# we ignore object detection annotations by setting target_transform to return 0 +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=256, + shuffle=True, + drop_last=True, + num_workers=8, +) + +accelerator = "gpu" if torch.cuda.is_available() else "cpu" + +trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator) +trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/aim.py b/examples/pytorch_lightning_distributed/aim.py new file mode 100644 index 000000000..ef1dd74e8 --- /dev/null +++ b/examples/pytorch_lightning_distributed/aim.py @@ -0,0 +1,100 @@ +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. + +import pytorch_lightning as pl +import torch +import torchvision +from torch import nn + +from lightly.models import utils +from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer +from lightly.transforms import AIMTransform + + +class AIM(pl.LightningModule): + def __init__(self) -> None: + super().__init__() + vit = MaskedCausalVisionTransformer( + img_size=224, + patch_size=32, + embed_dim=768, + depth=12, + num_heads=12, + qk_norm=False, + class_token=False, + no_embed_class=True, + ) + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token + ) + self.patch_size = vit.patch_embed.patch_size[0] + self.num_patches = vit.patch_embed.num_patches + + self.backbone = vit + self.projection_head = AIMPredictionHead( + input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1 + ) + + self.criterion = nn.MSELoss() + + def training_step(self, batch, batch_idx): + views, targets = batch[0], batch[1] + images = views[0] # AIM has only a single view + batch_size = images.shape[0] + + mask = utils.random_prefix_mask( + size=(batch_size, self.num_patches), + max_prefix_length=self.num_patches - 1, + device=images.device, + ) + features = self.backbone.forward_features(images, mask=mask) + # Add positional embedding before head. + features = self.backbone._pos_embed(features) + predictions = self.projection_head(features) + + # Convert images to patches and normalize them. + patches = utils.patchify(images, self.patch_size) + patches = utils.normalize_mean_var(patches, dim=-1) + + loss = self.criterion(predictions, patches) + return loss + + def configure_optimizers(self): + optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4) + return optim + + +model = AIM() + +transform = AIMTransform() +# we ignore object detection annotations by setting target_transform to return 0 +dataset = torchvision.datasets.VOCDetection( + "datasets/pascal_voc", + download=True, + transform=transform, + target_transform=lambda t: 0, +) +# or create a dataset from a folder containing images or videos: +# dataset = LightlyDataset("path/to/folder") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=256, + shuffle=True, + drop_last=True, + num_workers=8, +) + +accelerator = "gpu" if torch.cuda.is_available() else "cpu" + +# Train with DDP on multiple gpus. Distributed sampling is also enabled with +# replace_sampler_ddp=True. +trainer = pl.Trainer( + max_epochs=10, + devices="auto", + accelerator="gpu", + strategy="ddp", + use_distributed_sampler=True, # or replace_sampler_ddp=True for PyTorch Lightning <2.0 +) +trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/lightly/models/modules/masked_autoencoder_timm.py b/lightly/models/modules/masked_autoencoder_timm.py index 435ed8cf6..5f6dd9ceb 100644 --- a/lightly/models/modules/masked_autoencoder_timm.py +++ b/lightly/models/modules/masked_autoencoder_timm.py @@ -168,7 +168,9 @@ def predict(self, input: Tensor) -> Tensor: def _initialize_weights(self) -> None: torch.nn.init.normal_(self.mask_token, std=0.02) - utils.initialize_2d_sine_cosine_positional_embedding(self.decoder_pos_embed) + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=self.decoder_pos_embed, has_class_token=True + ) self.apply(_init_weights) diff --git a/lightly/models/modules/masked_vision_transformer_timm.py b/lightly/models/modules/masked_vision_transformer_timm.py index 446f339c3..e41ece8e9 100644 --- a/lightly/models/modules/masked_vision_transformer_timm.py +++ b/lightly/models/modules/masked_vision_transformer_timm.py @@ -191,9 +191,9 @@ def add_pos_embed(self, x: Tensor) -> Tensor: pos_embed = resample_abs_pos_embed( self.vit.pos_embed, (H, W), - num_prefix_tokens=0 - if self.vit.no_embed_class - else self.vit.num_prefix_tokens, + num_prefix_tokens=( + 0 if self.vit.no_embed_class else self.vit.num_prefix_tokens + ), ) x = x.view(B, -1, C) else: @@ -222,7 +222,9 @@ def _initialize_weights(self) -> None: # initialize nn.Linear and nn.LayerNorm self.apply(_init_weights) - utils.initialize_2d_sine_cosine_positional_embedding(self.vit.pos_embed) + utils.initialize_2d_sine_cosine_positional_embedding( + pos_embedding=self.vit.pos_embed, has_class_token=self.vit.has_class_token + ) def _init_weights(module: Module) -> None: diff --git a/lightly/models/modules/masked_vision_transformer_torchvision.py b/lightly/models/modules/masked_vision_transformer_torchvision.py index dcb7f1a05..4e6fffd03 100644 --- a/lightly/models/modules/masked_vision_transformer_torchvision.py +++ b/lightly/models/modules/masked_vision_transformer_torchvision.py @@ -202,7 +202,8 @@ def _initialize_weights(self) -> None: # Initialize positional encoding. utils.initialize_2d_sine_cosine_positional_embedding( - self.vit.encoder.pos_embedding + pos_embedding=self.vit.encoder.pos_embedding, + has_class_token=True, ) # Initialize linear layers. diff --git a/lightly/models/utils.py b/lightly/models/utils.py index ce5febc10..1fcba9acf 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -12,6 +12,7 @@ import torch.distributed as dist import torch.nn as nn from numpy.typing import NDArray +from torch import Tensor from torch.nn import Module, Sequential from torch.nn.modules import CrossMapLRN2d, GroupNorm, LayerNorm, LocalResponseNorm from torch.nn.modules.batchnorm import _NormBase @@ -638,13 +639,15 @@ def add_stochastic_depth_to_blocks(vit: Module, prob: float = 0.0, mode="row") - mod.mlp = Sequential(mod.mlp, StochasticDepth(p=prob, mode=mode)) -def initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None: +def initialize_2d_sine_cosine_positional_embedding( + pos_embedding: Parameter, has_class_token: bool +) -> None: _, seq_length, hidden_dim = pos_embedding.shape - grid_size = int((seq_length - 1) ** 0.5) + grid_size = int((seq_length - int(has_class_token)) ** 0.5) sine_cosine_embedding = get_2d_sine_cosine_positional_embedding( embed_dim=hidden_dim, grid_size=grid_size, - cls_token=True, + cls_token=has_class_token, ) pos_embedding.data.copy_( torch.from_numpy(sine_cosine_embedding).float().unsqueeze(0) @@ -748,3 +751,23 @@ def get_1d_sine_cosine_positional_embedding_from_positions( emb = np.concatenate([emb_sin, emb_cos], axis=1) # (N*M, embed_dim) return emb + + +def normalize_mean_var(x: Tensor, dim: int = -1, eps: float = 1.0e-6) -> Tensor: + """Normalizes the input tensor to zero mean and unit variance. + + Args: + x: + Input tensor. + dim: + Dimension along which to compute mean and standard deviation. Takes last + dimension by default. + eps: + Epsilon value to avoid division by zero. + + Returns: + Normalized tensor. + """ + mean = x.mean(dim=dim, keepdim=True) + var = x.var(dim=dim, keepdim=True) + return (x - mean) / (var + eps).sqrt() diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index 16decca7c..2b7174993 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -1,6 +1,8 @@ import copy +import math import unittest +import pytest import torch import torch.nn as nn @@ -388,3 +390,16 @@ def test_get_named_leaf_modules() -> None: assert utils.get_named_leaf_modules(linear1) == {"": linear1} assert utils.get_named_leaf_modules(sequential1) == {"0": linear1, "1": linear2} assert utils.get_named_leaf_modules(sequential2) == {"0.0": linear1, "0.1": linear2} + + +def test_normalize_mean_var() -> None: + x = torch.tensor([1.0, 2.0, 3.0]) + norm = utils.normalize_mean_var(x).tolist() + assert norm[0] == pytest.approx(-1) + assert norm[1] == pytest.approx(0.0) + assert norm[2] == pytest.approx(1) + + x = torch.rand(2, 3, 4) + norm = utils.normalize_mean_var(x) + assert torch.allclose(norm.mean(dim=-1), torch.tensor(0.0), rtol=0.0001, atol=1e-5) + assert torch.allclose(norm.var(dim=-1), torch.tensor(1.0), rtol=0.0001, atol=1e-5)