Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attentive layer to Jepa #927

Merged
merged 40 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7fca3ca
add init function to the builders
Dec 19, 2024
7623e1b
add builder skeleton for the AttentivePooler
Dec 19, 2024
7b959fe
refactor init_module function
Dec 19, 2024
cef6687
refactor init_module function
Dec 19, 2024
c7a9021
update cross attention layer
Dec 19, 2024
ab93be9
update cross attn layer
Dec 19, 2024
f745bea
Cosmetic updates
cbalioglu Dec 19, 2024
60ec4c1
add forward() function
Dec 19, 2024
854b68e
Can's comments
Dec 19, 2024
328f8ca
fix git rebase
Dec 19, 2024
f57108d
fix git rebase
Dec 19, 2024
2d27bab
lint
Dec 19, 2024
8d7dfaf
lint
Dec 19, 2024
4b69946
rebase
Dec 19, 2024
e86afaa
flake8
Dec 19, 2024
2a76edd
remove commits remnant
Dec 19, 2024
250f1ee
black
Dec 19, 2024
f4aaf33
black
Dec 19, 2024
e4d9a0a
add builder func
Dec 20, 2024
b8cfd50
revert remnant codes
Dec 20, 2024
a6987ab
revert remnant codes
Dec 20, 2024
837cc6c
revert remnant codes
Dec 20, 2024
25eb3f0
lint
Dec 20, 2024
be3b6e2
lint
Dec 20, 2024
8f90974
rebase
Dec 22, 2024
2f685e8
nit import clean
Dec 22, 2024
919b305
nit rename layers
Dec 22, 2024
c771767
update factory
Dec 22, 2024
a43f683
lint
Dec 22, 2024
b76bbe5
fix typo
Dec 24, 2024
88b2d95
allow unstricted model loading
Dec 26, 2024
57987da
Feedback commit
cbalioglu Dec 23, 2024
4690e6b
update cross_attn build func
Dec 24, 2024
a345c4c
lint
Dec 26, 2024
f96344b
update AttentivePooler param names
Dec 26, 2024
75472ab
decouple #938
Dec 26, 2024
8993698
lint
Dec 26, 2024
4394a0a
lint
Dec 26, 2024
05c979f
lint
Dec 26, 2024
fc3b5c9
lint
Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 101 additions & 11 deletions src/fairseq2/models/jepa/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from functools import partial
from typing import Final, cast

from torch.nn import GELU
import torch
from torch.nn import GELU, Module

from fairseq2.config_registry import ConfigRegistry
from fairseq2.models.jepa.model import JepaModel
Expand All @@ -23,6 +26,7 @@
from fairseq2.nn import (
InterpolatedPositionEncoder,
LayerNorm,
Linear,
Sinusoidal2dPositionEncoder,
Sinusoidal3dPositionEncoder,
StandardLayerNorm,
Expand All @@ -39,6 +43,7 @@
TransformerNormOrder,
create_default_sdpa,
)
from fairseq2.nn.transformer.residual import DropPathResidualConnect
from fairseq2.typing import DataType, Device

JEPA_FAMILY: Final = "jepa"
Expand Down Expand Up @@ -97,9 +102,21 @@ class JepaEncoderConfig:
feed-forward networks to :attr:`model_dim`.
"""

init_std: float = 0.02
"""
The standard deviation to initialize weights and biases of projection and
normalization layers.
"""

dropout_p: float = 0.0
"""The dropout probability on outputs of Transformer layers."""

droppath_p: float = 0.0
"""
The probability of dropping sequences from outputs of multi-head attention
and feed-forward network layers before adding residuals.
"""

uniform_power: bool = False
"""
If ``True``, each patch dimension will have equal representation in the
Expand Down Expand Up @@ -182,6 +199,10 @@ def build_frontend(self) -> TransformerFrontend:
def build_feature_extractor(self) -> PatchFeatureExtractor:
config = self._config

init_std = config.init_std

init_conv = partial(init_truncated_uniforma_weights_and_bias, std=init_std)

num_patch_dims = len(config.patch_dims)

if num_patch_dims == 3:
Expand All @@ -191,6 +212,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor:
config.num_input_channels,
config.model_dim,
patch_3d_dims,
init_fn=init_conv,
device=self._device,
dtype=self._dtype,
)
Expand All @@ -201,6 +223,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor:
config.num_input_channels,
config.model_dim,
patch_2d_dims,
init_fn=init_conv,
device=self._device,
dtype=self._dtype,
)
Expand Down Expand Up @@ -255,7 +278,7 @@ def build_encoder(self) -> TransformerEncoder:

num_layers = config.num_encoder_layers

layers = [self.build_encoder_layer() for _ in range(num_layers)]
layers = [self.build_encoder_layer(i) for i in range(num_layers)]

return StandardTransformerEncoder(
layers,
Expand All @@ -265,60 +288,110 @@ def build_encoder(self) -> TransformerEncoder:
dtype=self._dtype,
)

def build_encoder_layer(self) -> TransformerEncoderLayer:
def build_encoder_layer(self, layer_idx: int) -> TransformerEncoderLayer:
config = self._config

self_attn = self.build_attention()
self_attn = self.build_attention(layer_idx)

ffn = self.build_ffn(layer_idx)

ffn = self.build_ffn()
drop_path = DropPathResidualConnect(drop_p=config.droppath_p)

return StandardTransformerEncoderLayer(
self_attn,
ffn,
dropout_p=config.dropout_p,
norm_order=TransformerNormOrder.PRE,
layer_norm_factory=self.build_layer_norm,
self_attn_residual=drop_path,
ffn_residual=drop_path,
device=self._device,
dtype=self._dtype,
)

def build_attention(self) -> MultiheadAttention:
def build_attention(self, layer_idx: int) -> MultiheadAttention:
config = self._config

sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p)

output_proj = self.build_mha_output_projection(layer_idx)

return StandardMultiheadAttention(
config.model_dim,
config.num_encoder_attn_heads,
sdpa=sdpa,
bias=config.qkv_bias,
output_proj_bias=True,
output_proj=output_proj,
device=self._device,
dtype=self._dtype,
)

def build_mha_output_projection(self, layer_idx: int) -> Linear:
config = self._config

init_std = config.init_std

def init_projection(proj: Linear) -> None:
init_truncated_uniforma_weights_and_bias(proj, std=init_std)

with torch.no_grad():
proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1)))

return Linear(
config.model_dim,
config.model_dim,
bias=True,
init_fn=init_projection,
device=self._device,
dtype=self._dtype,
)

def build_ffn(self) -> FeedForwardNetwork:
def build_ffn(self, layer_idx: int) -> FeedForwardNetwork:
config = self._config

init_std = config.init_std

def init_projection(proj: Linear) -> None:
init_truncated_uniforma_weights_and_bias(proj, std=init_std)

with torch.no_grad():
proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1)))

inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio)

return StandardFeedForwardNetwork(
config.model_dim,
int(config.model_dim * config.ffn_inner_dim_ratio),
inner_dim,
bias=True,
inner_activation=GELU(),
proj_init_fn=init_projection,
norm_order=TransformerNormOrder.PRE,
device=self._device,
dtype=self._dtype,
)

@staticmethod
def build_layer_norm(
self,
model_dim: int,
*,
device: Device | None = None,
dtype: DataType | None = None,
) -> LayerNorm:
config = self._config

init_std = config.init_std

init_layer_norm = partial(
init_truncated_uniforma_weights_and_bias, std=init_std
)

return StandardLayerNorm(
model_dim, bias=True, eps=1e-6, device=device, dtype=dtype
model_dim,
bias=True,
eps=1e-6,
init_fn=init_layer_norm,
device=device,
dtype=dtype,
)


Expand All @@ -329,3 +402,20 @@ def create_jepa_model(
dtype: DataType | None = None,
) -> JepaModel:
return JepaBuilder(config, device=device, dtype=dtype).build_model()


def init_truncated_uniforma_weights_and_bias(
m: Module,
*,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> None:
if not hasattr(m, "weight") or not hasattr(m, "bias"):
raise ValueError(f"Cannot initialize weights and bias of a {type(m)}")

with torch.no_grad():
torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
11 changes: 4 additions & 7 deletions src/fairseq2/models/jepa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import final

from torch.nn import Module
Expand Down Expand Up @@ -43,11 +42,9 @@ def __init__(
self.encoder_frontend = encoder_frontend
self.encoder = encoder

def forward(self, batch: SequenceBatch) -> JepaOutput:
raise NotImplementedError()
def forward(self, batch: SequenceBatch) -> SequenceBatch:
seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask)

seqs, padding_mask = self.encoder(seqs, padding_mask)

@final
@dataclass
class JepaOutput:
pass
return SequenceBatch(seqs, padding_mask)
27 changes: 27 additions & 0 deletions src/fairseq2/models/vit/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import final

from torch import Tensor
Expand Down Expand Up @@ -54,13 +55,15 @@ class Conv2dPatchFeatureExtractor(PatchFeatureExtractor):
"""Extracts patch features from 2-dimensional inputs using convolution."""

conv: Conv2d
init_fn: Callable[[Conv2d], None] | None

def __init__(
self,
num_channels: int,
feature_dim: int,
patch_dims: tuple[int, int],
*,
init_fn: Callable[[Conv2d], None] | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand All @@ -81,6 +84,17 @@ def __init__(
dtype=dtype,
)

self.init_fn = init_fn

self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
if self.init_fn is not None:
self.init_fn(self.conv)
else:
self.conv.reset_parameters()

@override
def forward(self, x: Tensor) -> Tensor:
# (N, C, H_inp, W_inp) -> (N, H_out, W_out, E)
Expand All @@ -92,13 +106,15 @@ class Conv3dPatchFeatureExtractor(PatchFeatureExtractor):
"""Extracts patch features from 3-dimensional inputs using convolution."""

conv: Conv3d
init_fn: Callable[[Conv3d], None] | None

def __init__(
self,
num_channels: int,
feature_dim: int,
patch_dims: tuple[int, int, int],
*,
init_fn: Callable[[Conv3d], None] | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand All @@ -119,6 +135,17 @@ def __init__(
dtype=dtype,
)

self.init_fn = init_fn

self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
if self.init_fn is not None:
self.init_fn(self.conv)
else:
self.conv.reset_parameters()

@override
def forward(self, x: Tensor) -> Tensor:
# (N, C, D_inp, H_inp, W_inp) -> (N, D_out, H_out, W_out, E)
Expand Down
17 changes: 12 additions & 5 deletions src/fairseq2/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any, Literal, final

import torch
Expand Down Expand Up @@ -39,6 +39,7 @@ class LayerNorm(Module, ABC):
elementwise_affine: bool
weight: Parameter | None
bias: Parameter | None
init_fn: Callable[[LayerNorm], None] | None

def __init__(
self,
Expand All @@ -47,6 +48,7 @@ def __init__(
*,
eps: float = 1e-5,
elementwise_affine: bool = True,
init_fn: Callable[[LayerNorm], None] | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand Down Expand Up @@ -88,15 +90,20 @@ def __init__(
else:
self.register_parameter("bias", None)

self.init_fn = init_fn

self.reset_parameters()

def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
if self.weight is not None:
nn.init.ones_(self.weight)
if self.init_fn is not None:
self.init_fn(self)
else:
if self.weight is not None:
nn.init.ones_(self.weight)

if self.bias is not None:
nn.init.zeros_(self.bias)
if self.bias is not None:
nn.init.zeros_(self.bias)

@abstractmethod
def forward(self, x: Tensor) -> Tensor:
Expand Down
Loading
Loading