-
Notifications
You must be signed in to change notification settings - Fork 286
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add AIM examples * Add AIM docs * Add AIM to README
- Loading branch information
Showing
12 changed files
with
406 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
.. _aim: | ||
|
||
AIM | ||
=== | ||
|
||
Example implementation of the Autoregressive Image Model (AIM) architecture. AIM is a | ||
transformer model based on the `Vision Transformer (ViT) <https://arxiv.org/abs/2010.11929>`_ | ||
architecture. It learns image representations by predicting pixel values for image | ||
patches based on previous patches in the image. This is similar to the next word prediction | ||
task in natural language processing. AIM demonstrates that it is possible to train | ||
large-scale vision models using an autoregressive objective. The model is split into | ||
and encoder and a decoder part. The encoder generates features for image patches and | ||
the decoder predicts pixel values based on the features. | ||
|
||
Reference: | ||
`Scalable Pre-training of Large Autoregressive Image Models, 2024 <https://arxiv.org/abs/2401.08541>`_ | ||
|
||
|
||
.. tabs:: | ||
.. tab:: PyTorch | ||
|
||
This example can be run from the command line with:: | ||
|
||
python lightly/examples/pytorch/aim.py | ||
|
||
.. literalinclude:: ../../../examples/pytorch/aim.py | ||
|
||
.. tab:: Lightning | ||
|
||
This example can be run from the command line with:: | ||
|
||
python lightly/examples/pytorch_lightning/aim.py | ||
|
||
.. literalinclude:: ../../../examples/pytorch_lightning/aim.py | ||
|
||
.. tab:: Lightning Distributed | ||
|
||
This example runs on multiple gpus using Distributed Data Parallel (DDP) | ||
training with Pytorch Lightning. At least one GPU must be available on | ||
the system. The example can be run from the command line with:: | ||
|
||
python lightly/examples/pytorch_lightning_distributed/aim.py | ||
|
||
The model differs in the following ways from the non-distributed | ||
implementation: | ||
|
||
- Distributed Data Parallel is enabled | ||
- Distributed Sampling is used in the dataloader | ||
|
||
Distributed Sampling makes sure that each distributed process sees only | ||
a subset of the data. | ||
|
||
.. literalinclude:: ../../../examples/pytorch_lightning_distributed/aim.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Note: The model and training settings do not follow the reference settings | ||
# from the paper. The settings are chosen such that the example can easily be | ||
# run on a small dataset with a single GPU. | ||
|
||
import torch | ||
import torchvision | ||
from torch import nn | ||
|
||
from lightly.models import utils | ||
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer | ||
from lightly.transforms import AIMTransform | ||
|
||
|
||
class AIM(nn.Module): | ||
def __init__(self, vit): | ||
super().__init__() | ||
utils.initialize_2d_sine_cosine_positional_embedding( | ||
pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token | ||
) | ||
self.patch_size = vit.patch_embed.patch_size[0] | ||
self.num_patches = vit.patch_embed.num_patches | ||
|
||
self.backbone = vit | ||
self.projection_head = AIMPredictionHead( | ||
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1 | ||
) | ||
|
||
def forward(self, images): | ||
batch_size = images.shape[0] | ||
|
||
mask = utils.random_prefix_mask( | ||
size=(batch_size, self.num_patches), | ||
max_prefix_length=self.num_patches - 1, | ||
device=images.device, | ||
) | ||
features = self.backbone.forward_features(images, mask=mask) | ||
# Add positional embedding before head. | ||
features = self.backbone._pos_embed(features) | ||
predictions = self.projection_head(features) | ||
|
||
# Convert images to patches and normalize them. | ||
patches = utils.patchify(images, self.patch_size) | ||
patches = utils.normalize_mean_var(patches, dim=-1) | ||
|
||
return predictions, patches | ||
|
||
|
||
vit = MaskedCausalVisionTransformer( | ||
img_size=224, | ||
patch_size=32, | ||
embed_dim=768, | ||
depth=12, | ||
num_heads=12, | ||
qk_norm=False, | ||
class_token=False, | ||
no_embed_class=True, | ||
) | ||
model = AIM(vit) | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
model.to(device) | ||
|
||
transform = AIMTransform() | ||
# we ignore object detection annotations by setting target_transform to return 0 | ||
dataset = torchvision.datasets.VOCDetection( | ||
"datasets/pascal_voc", | ||
download=True, | ||
transform=transform, | ||
target_transform=lambda t: 0, | ||
) | ||
# or create a dataset from a folder containing images or videos: | ||
# dataset = LightlyDataset("path/to/folder") | ||
|
||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=256, | ||
shuffle=True, | ||
drop_last=True, | ||
num_workers=8, | ||
) | ||
|
||
criterion = nn.MSELoss() | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=1.5e-4) | ||
|
||
print("Starting Training") | ||
for epoch in range(10): | ||
total_loss = 0 | ||
for batch in dataloader: | ||
views = batch[0] | ||
images = views[0].to(device) # views contains only a single view | ||
predictions, targets = model(images) | ||
loss = criterion(predictions, targets) | ||
total_loss += loss.detach() | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
avg_loss = total_loss / len(dataloader) | ||
print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Note: The model and training settings do not follow the reference settings | ||
# from the paper. The settings are chosen such that the example can easily be | ||
# run on a small dataset with a single GPU. | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
import torchvision | ||
from torch import nn | ||
|
||
from lightly.models import utils | ||
from lightly.models.modules import AIMPredictionHead, MaskedCausalVisionTransformer | ||
from lightly.transforms import AIMTransform | ||
|
||
|
||
class AIM(pl.LightningModule): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
vit = MaskedCausalVisionTransformer( | ||
img_size=224, | ||
patch_size=32, | ||
embed_dim=768, | ||
depth=12, | ||
num_heads=12, | ||
qk_norm=False, | ||
class_token=False, | ||
no_embed_class=True, | ||
) | ||
utils.initialize_2d_sine_cosine_positional_embedding( | ||
pos_embedding=vit.pos_embed, has_class_token=vit.has_class_token | ||
) | ||
self.patch_size = vit.patch_embed.patch_size[0] | ||
self.num_patches = vit.patch_embed.num_patches | ||
|
||
self.backbone = vit | ||
self.projection_head = AIMPredictionHead( | ||
input_dim=vit.embed_dim, output_dim=3 * self.patch_size**2, num_blocks=1 | ||
) | ||
|
||
self.criterion = nn.MSELoss() | ||
|
||
def training_step(self, batch, batch_idx): | ||
views, targets = batch[0], batch[1] | ||
images = views[0] # AIM has only a single view | ||
batch_size = images.shape[0] | ||
|
||
mask = utils.random_prefix_mask( | ||
size=(batch_size, self.num_patches), | ||
max_prefix_length=self.num_patches - 1, | ||
device=images.device, | ||
) | ||
features = self.backbone.forward_features(images, mask=mask) | ||
# Add positional embedding before head. | ||
features = self.backbone._pos_embed(features) | ||
predictions = self.projection_head(features) | ||
|
||
# Convert images to patches and normalize them. | ||
patches = utils.patchify(images, self.patch_size) | ||
patches = utils.normalize_mean_var(patches, dim=-1) | ||
|
||
loss = self.criterion(predictions, patches) | ||
return loss | ||
|
||
def configure_optimizers(self): | ||
optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4) | ||
return optim | ||
|
||
|
||
model = AIM() | ||
|
||
transform = AIMTransform() | ||
# we ignore object detection annotations by setting target_transform to return 0 | ||
dataset = torchvision.datasets.VOCDetection( | ||
"datasets/pascal_voc", | ||
download=True, | ||
transform=transform, | ||
target_transform=lambda t: 0, | ||
) | ||
# or create a dataset from a folder containing images or videos: | ||
# dataset = LightlyDataset("path/to/folder") | ||
|
||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=256, | ||
shuffle=True, | ||
drop_last=True, | ||
num_workers=8, | ||
) | ||
|
||
accelerator = "gpu" if torch.cuda.is_available() else "cpu" | ||
|
||
trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator) | ||
trainer.fit(model=model, train_dataloaders=dataloader) |
Oops, something went wrong.