Skip to content

Commit

Permalink
Add AIM examples and docs (#1509)
Browse files Browse the repository at this point in the history
* Add AIM examples
* Add AIM docs
* Add AIM to README
  • Loading branch information
guarin authored Mar 6, 2024
1 parent 053327d commit e7f13a5
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 27 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and PyTorch Lightning distributed examples for all models to kickstart your proj

**Models**:

- AIM, 2024 [paper](https://arxiv.org/abs/2401.08541) [docs](https://docs.lightly.ai/self-supervised-learning/examples/aim.html)
- Barlow Twins, 2021 [paper](https://arxiv.org/abs/2103.03230) [docs](https://docs.lightly.ai/self-supervised-learning/examples/barlowtwins.html)
- BYOL, 2020 [paper](https://arxiv.org/abs/2006.07733) [docs](https://docs.lightly.ai/self-supervised-learning/examples/byol.html)
- DCL & DCLW, 2021 [paper](https://arxiv.org/abs/2110.06848) [docs](https://docs.lightly.ai/self-supervised-learning/examples/dcl.html)
Expand Down
27 changes: 9 additions & 18 deletions benchmarks/imagenet/vitb16/aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device

img_size = 224
self.patch_size = 14
self.num_patches = (img_size // self.patch_size) ** 2

vit = MaskedCausalVisionTransformer(
img_size=img_size,
patch_size=self.patch_size,
img_size=224,
patch_size=14,
num_classes=num_classes,
embed_dim=1536,
depth=24,
Expand All @@ -36,14 +32,11 @@ def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
class_token=False,
no_embed_class=True,
)
# Use absolute positional embedding.
pos_embed = get_2d_sincos_pos_embed(
embed_dim=vit.embed_dim,
grid_size=int(self.num_patches**0.5),
cls_token=False,
utils.initialize_2d_sine_cosine_positional_embedding(
pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token
)
vit.pos_embed.requires_grad = False
vit.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
self.patch_size = vit.patch_embed.patch_size[0]
self.num_patches = vit.patch_embed.num_patches

self.backbone = vit
self.projection_head = AIMPredictionHead(
Expand All @@ -67,8 +60,8 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
images = images[0] # images is a list containing only one view
views, targets = batch[0], batch[1]
images = views[0] # AIM has only a single view
batch_size = images.shape[0]

mask = random_prefix_mask(
Expand All @@ -83,9 +76,7 @@ def training_step(

# Convert images to patches and normalize them.
patches = utils.patchify(images, self.patch_size)
mean = patches.mean(dim=-1, keepdim=True)
var = patches.var(dim=-1, keepdim=True)
patches = (patches - mean) / (var + 1.0e-6) ** 0.5
patches = utils.normalize_mean_var(patches, dim=-1)

loss = self.criterion(predictions, patches)

Expand Down
53 changes: 53 additions & 0 deletions docs/source/examples/aim.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
.. _aim:

AIM
===

Example implementation of the Autoregressive Image Model (AIM) architecture. AIM is a
transformer model based on the `Vision Transformer (ViT) <https://arxiv.org/abs/2010.11929>`_
architecture. It learns image representations by predicting pixel values for image
patches based on previous patches in the image. This is similar to the next word prediction
task in natural language processing. AIM demonstrates that it is possible to train
large-scale vision models using an autoregressive objective. The model is split into
and encoder and a decoder part. The encoder generates features for image patches and
the decoder predicts pixel values based on the features.

Reference:
`Scalable Pre-training of Large Autoregressive Image Models, 2024 <https://arxiv.org/abs/2401.08541>`_


.. tabs::
.. tab:: PyTorch

This example can be run from the command line with::

python lightly/examples/pytorch/aim.py

.. literalinclude:: ../../../examples/pytorch/aim.py

.. tab:: Lightning

This example can be run from the command line with::

python lightly/examples/pytorch_lightning/aim.py

.. literalinclude:: ../../../examples/pytorch_lightning/aim.py

.. tab:: Lightning Distributed

This example runs on multiple gpus using Distributed Data Parallel (DDP)
training with Pytorch Lightning. At least one GPU must be available on
the system. The example can be run from the command line with::

python lightly/examples/pytorch_lightning_distributed/aim.py

The model differs in the following ways from the non-distributed
implementation:

- Distributed Data Parallel is enabled
- Distributed Sampling is used in the dataloader

Distributed Sampling makes sure that each distributed process sees only
a subset of the data.

.. literalinclude:: ../../../examples/pytorch_lightning_distributed/aim.py
1 change: 1 addition & 0 deletions docs/source/examples/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ for PyTorch and PyTorch Lightning to give you a headstart when implementing your
.. toctree::
:maxdepth: 1

aim.rst
barlowtwins.rst
byol.rst
dcl.rst
Expand Down
98 changes: 98 additions & 0 deletions examples/pytorch/aim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
import torchvision
from torch import nn

from lightly.models import utils
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer
from lightly.transforms import AIMTransform


class AIM(nn.Module):
def __init__(self, vit):
super().__init__()
utils.initialize_2d_sine_cosine_positional_embedding(
pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token
)
self.patch_size = vit.patch_embed.patch_size[0]
self.num_patches = vit.patch_embed.num_patches

self.backbone = vit
self.projection_head = AIMPredictionHead(
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1
)

def forward(self, images):
batch_size = images.shape[0]

mask = utils.random_prefix_mask(
size=(batch_size, self.num_patches),
max_prefix_length=self.num_patches - 1,
device=images.device,
)
features = self.backbone.forward_features(images, mask=mask)
# Add positional embedding before head.
features = self.backbone._pos_embed(features)
predictions = self.projection_head(features)

# Convert images to patches and normalize them.
patches = utils.patchify(images, self.patch_size)
patches = utils.normalize_mean_var(patches, dim=-1)

return predictions, patches


vit = MaskedCausalVisionTransformer(
img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
qk_norm=False,
class_token=False,
no_embed_class=True,
)
model = AIM(vit)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

transform = AIMTransform()
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4)

print("Starting Training")
for epoch in range(10):
total_loss = 0
for batch in dataloader:
views = batch[0]
images = views[0].to(device) # views contains only a single view
predictions, targets = model(images)
loss = criterion(predictions, targets)
total_loss += loss.detach()
loss.backward()
optimizer.step()
optimizer.zero_grad()
avg_loss = total_loss / len(dataloader)
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
92 changes: 92 additions & 0 deletions examples/pytorch_lightning/aim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import pytorch_lightning as pl
import torch
import torchvision
from torch import nn

from lightly.models import utils
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer
from lightly.transforms import AIMTransform


class AIM(pl.LightningModule):
def __init__(self) -> None:
super().__init__()
vit = MaskedCausalVisionTransformer(
img_size=224,
patch_size=32,
embed_dim=768,
depth=12,
num_heads=12,
qk_norm=False,
class_token=False,
no_embed_class=True,
)
utils.initialize_2d_sine_cosine_positional_embedding(
pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token
)
self.patch_size = vit.patch_embed.patch_size[0]
self.num_patches = vit.patch_embed.num_patches

self.backbone = vit
self.projection_head = AIMPredictionHead(
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1
)

self.criterion = nn.MSELoss()

def training_step(self, batch, batch_idx):
views, targets = batch[0], batch[1]
images = views[0] # AIM has only a single view
batch_size = images.shape[0]

mask = utils.random_prefix_mask(
size=(batch_size, self.num_patches),
max_prefix_length=self.num_patches - 1,
device=images.device,
)
features = self.backbone.forward_features(images, mask=mask)
# Add positional embedding before head.
features = self.backbone._pos_embed(features)
predictions = self.projection_head(features)

# Convert images to patches and normalize them.
patches = utils.patchify(images, self.patch_size)
patches = utils.normalize_mean_var(patches, dim=-1)

loss = self.criterion(predictions, patches)
return loss

def configure_optimizers(self):
optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4)
return optim


model = AIM()

transform = AIMTransform()
# we ignore object detection annotations by setting target_transform to return 0
dataset = torchvision.datasets.VOCDetection(
"datasets/pascal_voc",
download=True,
transform=transform,
target_transform=lambda t: 0,
)
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=256,
shuffle=True,
drop_last=True,
num_workers=8,
)

accelerator = "gpu" if torch.cuda.is_available() else "cpu"

trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)
trainer.fit(model=model, train_dataloaders=dataloader)
Loading

0 comments on commit e7f13a5

Please sign in to comment.